diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index a65bd9373d1b7..08e6b6de663c1 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -157,7 +157,7 @@ if (WIN32) if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) + if (TARGET onnxruntime_providers_shared) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) @@ -205,7 +205,7 @@ if (WIN32) else() add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) + if (TARGET onnxruntime_providers_shared) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index cf948c7e4ed21..ffd18ee78b7bb 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1657,6 +1657,10 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR}) endif() + if (WIN32) + set(EXAMPLE_PLUGIN_EP_DST_FILE_NAME $,$,$>) + add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_NATIVE_TEST_DIR}/${EXAMPLE_PLUGIN_EP_DST_FILE_NAME}) + endif() # delegate to gradle's test runner diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index fd813eff2f575..038692729356d 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -42,6 +42,8 @@ final class OnnxRuntime { private static final int ORT_API_VERSION_13 = 13; // Post 1.13 builds of the ORT API private static final int ORT_API_VERSION_14 = 14; + // Post 1.22 builds of the ORT API + private static final int ORT_API_VERSION_23 = 23; // The initial release of the ORT training API. private static final int ORT_TRAINING_API_VERSION_1 = 1; @@ -103,6 +105,9 @@ final class OnnxRuntime { /** The Training API handle. */ static long ortTrainingApiHandle; + /** The Compile API handle. */ + static long ortCompileApiHandle; + /** Is training enabled in the native library */ static boolean trainingEnabled; @@ -174,12 +179,13 @@ static synchronized void init() throws IOException { } load(ONNXRUNTIME_JNI_LIBRARY_NAME); - ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14); + ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23); if (ortApiHandle == 0L) { throw new IllegalStateException( "There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded"); } - ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_14); + ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23); + ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle); trainingEnabled = ortTrainingApiHandle != 0L; providers = initialiseProviders(ortApiHandle); version = initialiseVersion(); @@ -497,6 +503,14 @@ private static EnumSet initialiseProviders(long ortApiHandle) { */ private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber); + /** + * Get a reference to the compile API struct. + * + * @param apiHandle The ORT API struct pointer. + * @return A pointer to the compile API struct. + */ + private static native long initialiseCompileAPIBase(long apiHandle); + /** * Gets the array of available providers. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 8382ef06e26e5..497772baf5357 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -8,7 +8,11 @@ import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; import java.util.EnumSet; +import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.logging.Logger; @@ -442,6 +446,48 @@ public static EnumSet getAvailableProviders() { return OnnxRuntime.providers.clone(); } + /** + * Registers an execution provider library with this OrtEnvironment. + * + * @param registrationName The name to register the library with (used to remove it later with + * {@link #unregisterExecutionProviderLibrary(String)}). + * @param libraryPath The path to the library binary on disk. + * @throws OrtException If the library could not be registered. + */ + public void registerExecutionProviderLibrary(String registrationName, String libraryPath) + throws OrtException { + registerExecutionProviderLibrary( + OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath); + } + + /** + * Unregisters an execution provider library from this OrtEnvironment. + * + * @param registrationName The name the library was registered under. + * @throws OrtException If the library could not be removed. + */ + public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException { + unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName); + } + + /** + * Get the list of all execution provider and device combinations that are available. + * + * @see OrtSession.SessionOptions#addExecutionProvider(List, Map) + * @return The list of execution provider and device combinations. + * @throws OrtException If the devices could not be listed. + */ + public List getEpDevices() throws OrtException { + long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle); + + List devicesList = new ArrayList<>(); + for (long deviceHandle : deviceHandles) { + devicesList.add(new OrtEpDevice(deviceHandle)); + } + + return Collections.unmodifiableList(devicesList); + } + /** * Creates the native object. * @@ -476,6 +522,40 @@ private static native long createHandle( */ private static native long getDefaultAllocator(long apiHandle) throws OrtException; + /** + * Registers the specified execution provider with this OrtEnvironment. + * + * @param apiHandle The API handle. + * @param nativeHandle The OrtEnvironment handle. + * @param registrationName The name of the execution provider. + * @param libraryPath The path to the execution provider binary. + * @throws OrtException If the registration failed. + */ + private static native void registerExecutionProviderLibrary( + long apiHandle, long nativeHandle, String registrationName, String libraryPath) + throws OrtException; + + /** + * Removes the specified execution provider from this OrtEnvironment. + * + * @param apiHandle The API handle. + * @param nativeHandle The OrtEnvironment handle. + * @param registrationName The name of the execution provider. + * @throws OrtException If the removal failed. + */ + private static native void unregisterExecutionProviderLibrary( + long apiHandle, long nativeHandle, String registrationName) throws OrtException; + + /** + * Gets handles for the EP device tuples available in this OrtEnvironment. + * + * @param apiHandle The API handle to use. + * @param nativeHandle The OrtEnvironment handle. + * @return An array of OrtEpDevice handles. + * @throws OrtException If the call failed. + */ + private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException; + /** * Closes the OrtEnvironment, frees the handle. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEpDevice.java b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java new file mode 100644 index 0000000000000..f63dec1dbaf83 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.Map; + +/** A tuple of Execution Provider information and the hardware device. */ +public final class OrtEpDevice { + + private final long nativeHandle; + + private final String epName; + private final String epVendor; + private final Map epMetadata; + private final Map epOptions; + private final OrtHardwareDevice device; + + /** + * Construct an OrtEpDevice tuple from the native pointer. + * + * @param nativeHandle The native pointer. + */ + OrtEpDevice(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle); + this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); + this.epMetadata = OrtUtil.convertToMap(metadata); + String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle); + this.epOptions = OrtUtil.convertToMap(options); + this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle)); + } + + /** + * Return the native pointer. + * + * @return The native pointer. + */ + long getNativeHandle() { + return nativeHandle; + } + + /** + * Gets the EP name. + * + * @return The EP name. + */ + public String getName() { + return epName; + } + + /** + * Gets the vendor name. + * + * @return The vendor name. + */ + public String getVendor() { + return epVendor; + } + + /** + * Gets an unmodifiable view on the EP metadata. + * + * @return The EP metadata. + */ + public Map getMetadata() { + return epMetadata; + } + + /** + * Gets an unmodifiable view on the EP options. + * + * @return The EP options. + */ + public Map getOptions() { + return epOptions; + } + + /** + * Gets the device information. + * + * @return The device information. + */ + public OrtHardwareDevice getDevice() { + return device; + } + + @Override + public String toString() { + return "OrtEpDevice{" + + "epName='" + + epName + + '\'' + + ", epVendor='" + + epVendor + + '\'' + + ", epMetadata=" + + epMetadata + + ", epOptions=" + + epOptions + + ", device=" + + device + + '}'; + } + + private static native String getName(long apiHandle, long nativeHandle); + + private static native String getVendor(long apiHandle, long nativeHandle); + + private static native String[][] getMetadata(long apiHandle, long nativeHandle); + + private static native String[][] getOptions(long apiHandle, long nativeHandle); + + private static native long getDeviceHandle(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java b/java/src/main/java/ai/onnxruntime/OrtFlags.java similarity index 88% rename from java/src/main/java/ai/onnxruntime/providers/OrtFlags.java rename to java/src/main/java/ai/onnxruntime/OrtFlags.java index 73d3eeae6499c..f57fd945dbeec 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java +++ b/java/src/main/java/ai/onnxruntime/OrtFlags.java @@ -1,8 +1,8 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ -package ai.onnxruntime.providers; +package ai.onnxruntime; import java.util.EnumSet; diff --git a/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java b/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java new file mode 100644 index 0000000000000..bd99f5599fd14 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.Map; +import java.util.logging.Logger; + +/** Hardware information for a specific device. */ +public final class OrtHardwareDevice { + + /** The hardware device types. */ + // Must be updated in concert with the native OrtHardwareDeviceType enum in the C API + public enum OrtHardwareDeviceType { + /** A CPU device. */ + CPU(0), + /** A GPU device. */ + GPU(1), + /** A NPU (Neural Processing Unit) device. */ + NPU(2); + private final int value; + + private static final Logger logger = Logger.getLogger(OrtHardwareDeviceType.class.getName()); + private static final OrtHardwareDeviceType[] values = new OrtHardwareDeviceType[3]; + + static { + for (OrtHardwareDeviceType ot : OrtHardwareDeviceType.values()) { + values[ot.value] = ot; + } + } + + OrtHardwareDeviceType(int value) { + this.value = value; + } + + /** + * Gets the native value associated with this device type. + * + * @return The native value. + */ + public int getValue() { + return value; + } + + /** + * Maps from the C API's int enum to the Java enum. + * + * @param deviceType The index of the Java enum. + * @return The Java enum. + */ + public static OrtHardwareDeviceType mapFromInt(int deviceType) { + if ((deviceType >= 0) && (deviceType < values.length)) { + return values[deviceType]; + } else { + logger.warning("Unknown device type '" + deviceType + "' setting to CPU"); + return CPU; + } + } + } + + private final long nativeHandle; + + private final OrtHardwareDeviceType type; + private final int vendorId; + private final String vendor; + private final int deviceId; + private final Map metadata; + + OrtHardwareDevice(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.type = + OrtHardwareDeviceType.mapFromInt(getDeviceType(OnnxRuntime.ortApiHandle, nativeHandle)); + this.vendorId = getVendorId(OnnxRuntime.ortApiHandle, nativeHandle); + this.vendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); + this.deviceId = getDeviceId(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); + this.metadata = OrtUtil.convertToMap(metadata); + } + + long getNativeHandle() { + return nativeHandle; + } + + /** + * Gets the device type. + * + * @return The device type. + */ + public OrtHardwareDeviceType getType() { + return type; + } + + /** + * Gets the vendor ID number. + * + * @return The vendor ID number. + */ + public int getVendorId() { + return vendorId; + } + + /** + * Gets the device ID number. + * + * @return The device ID number. + */ + public int getDeviceId() { + return deviceId; + } + + /** + * Gets an unmodifiable view on the device metadata. + * + * @return The device metadata. + */ + public Map getMetadata() { + return metadata; + } + + /** + * Gets the vendor name. + * + * @return The vendor name. + */ + public String getVendor() { + return vendor; + } + + @Override + public String toString() { + return "OrtHardwareDevice{" + + "type=" + + type + + ", vendorId=" + + vendorId + + ", vendor='" + + vendor + + '\'' + + ", deviceId=" + + deviceId + + ", metadata=" + + metadata + + '}'; + } + + private static native String getVendor(long apiHandle, long nativeHandle); + + private static native String[][] getMetadata(long apiHandle, long nativeHandle); + + private static native int getDeviceType(long apiHandle, long nativeHandle); + + private static native int getDeviceId(long apiHandle, long nativeHandle); + + private static native int getVendorId(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java b/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java new file mode 100644 index 0000000000000..09b3064b72b93 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.nio.ByteBuffer; +import java.util.EnumSet; + +/** Configuration options for compiling ONNX models. */ +public final class OrtModelCompilationOptions implements AutoCloseable { + /** Flags representing options when compiling a model. */ + public enum OrtCompileApiFlags implements OrtFlags { + /** Default. Do not enable any additional compilation options. */ + NONE(0), + + /** + * Force compilation to return an error (ORT_FAIL) if no nodes were compiled. Otherwise, a model + * with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. + */ + ERROR_IF_NO_NODES_COMPILED(1), + + /** + * Force compilation to return an error (ORT_FAIL) if a file with the same filename as the + * output model exists. Otherwise, compilation will automatically overwrite the output file if + * it exists. + */ + ERROR_IF_OUTPUT_FILE_EXISTS(1 << 1); + + /** The native value of the enum. */ + public final int value; + + OrtCompileApiFlags(int value) { + this.value = value; + } + + @Override + public int getValue() { + return value; + } + } + + private final long nativeHandle; + private boolean closed = false; + + // Used to ensure the byte buffer doesn't get GC'd before the model is compiled. + private ByteBuffer buffer; + + OrtModelCompilationOptions(long nativeHandle) { + this.nativeHandle = nativeHandle; + } + + /** + * Creates a model compilation options from an existing SessionOptions. + * + *

An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX + * model. The OrtSessionOptions object has the execution providers with which the model will be + * compiled. + * + * @param env The OrtEnvironment. + * @param sessionOptions The session options to use. + * @return A constructed model compilation options instance. + * @throws OrtException If the construction failed. + */ + public static OrtModelCompilationOptions createFromSessionOptions( + OrtEnvironment env, OrtSession.SessionOptions sessionOptions) throws OrtException { + long handle = + createFromSessionOptions( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + env.getNativeHandle(), + sessionOptions.getNativeHandle()); + return new OrtModelCompilationOptions(handle); + } + + /** + * Checks if the OrtModelCompilationOptions is closed, if so throws {@link IllegalStateException}. + */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtModelCompilationOptions."); + } + } + + @Override + public void close() { + if (!closed) { + close(OnnxRuntime.ortCompileApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close a closed OrtModelCompilationOptions."); + } + } + + /** + * Sets the file path to the input ONNX model. + * + *

The input model's location must be set either to a path on disk with this method, or by + * supplying an in-memory reference with {@link #setInputModelFromBuffer}. + * + * @param inputModelPath The path to the model on disk. + * @throws OrtException If the set failed. + */ + public void setInputModelPath(String inputModelPath) throws OrtException { + checkClosed(); + setInputModelPath( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, inputModelPath); + } + + /** + * Uses the supplied buffer as the input ONNX model. + * + *

The input model's location must be set either to an in-memory reference with this method, or + * by supplying a path on disk with {@link #setInputModelPath(String)}. + * + *

If the {@link ByteBuffer} is not direct it is copied into a direct buffer. In either case + * this object holds a reference to the buffer to prevent it from being GC'd. + * + * @param inputModelBuffer The buffer. + * @throws OrtException If the buffer could not be set. + */ + public void setInputModelFromBuffer(ByteBuffer inputModelBuffer) throws OrtException { + checkClosed(); + if (!inputModelBuffer.isDirect()) { + // if it's not a direct buffer, copy it. + buffer = ByteBuffer.allocateDirect(inputModelBuffer.remaining()); + int tmpPos = inputModelBuffer.position(); + buffer.put(inputModelBuffer); + buffer.rewind(); + inputModelBuffer.position(tmpPos); + } else { + buffer = inputModelBuffer; + } + int bufferPos = buffer.position(); + int bufferRemaining = buffer.remaining(); + setInputModelFromBuffer( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + buffer, + bufferPos, + bufferRemaining); + } + + /** + * Sets the file path for the output compiled ONNX model. + * + *

If this is unset it will append `_ctx` to the file name, e.g., my_model.onnx becomes + * my_model_ctx.onnx. + * + * @param outputModelPath The output model path. + * @throws OrtException If the path could not be set. + */ + public void setOutputModelPath(String outputModelPath) throws OrtException { + checkClosed(); + setOutputModelPath( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, outputModelPath); + } + + /** + * Optionally sets the file that stores initializers for the compiled ONNX model. If unset then + * initializers are stored inside the model. + * + *

Only initializers for nodes that were not compiled are stored in the external initializers + * file. Compiled nodes contain their initializer data within the `ep_cache_context` attribute of + * EPContext nodes. + * + * @see OrtModelCompilationOptions#setEpContextEmbedMode + * @param outputExternalInitializersPath Path to the file. + * @param sizeThreshold Initializers larger than this threshold are stored in the file. + * @throws OrtException If the path could not be set. + */ + public void setOutputExternalInitializersPath( + String outputExternalInitializersPath, long sizeThreshold) throws OrtException { + checkClosed(); + // check positive + setOutputExternalInitializersPath( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + outputExternalInitializersPath, + sizeThreshold); + } + + /** + * Enables or disables the embedding of EPContext binary data into the ep_cache_context attribute + * of EPContext nodes. + * + *

Defaults to false. When enabled, the `ep_cache_context` attribute of EPContext nodes will + * store the context binary data, which may include weights for compiled subgraphs. When disabled, + * the `ep_cache_context` attribute of EPContext nodes will contain the path to the file + * containing the context binary data. The path is set by the execution provider creating the + * EPContext node. + * + *

For more details see the EPContext design + * document. + * + * @param embedEpContext True to embed EPContext binary data into the EPContext node's + * ep_cache_context attribute. + * @throws OrtException If the set operation failed. + */ + public void setEpContextEmbedMode(boolean embedEpContext) throws OrtException { + checkClosed(); + setEpContextEmbedMode( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, embedEpContext); + } + + /** + * Sets the specified compilation flags. + * + * @param flags The compilation flags. + * @throws OrtException If the set operation failed. + */ + public void setCompilationFlags(EnumSet flags) throws OrtException { + checkClosed(); + setCompilationFlags( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + OrtFlags.aggregateToInt(flags)); + } + + /** + * Compiles the ONNX model with the configuration described by this instance of + * OrtModelCompilationOptions. + * + * @throws OrtException If the compilation failed. + */ + public void compileModel() throws OrtException { + checkClosed(); + // Safe as the environment must exist to create one of these objects. + OrtEnvironment env = OrtEnvironment.getEnvironment(); + compileModel( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + env.getNativeHandle(), + nativeHandle); + } + + private static native long createFromSessionOptions( + long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; + + private static native void close(long compileApiHandle, long nativeHandle); + + private static native void setInputModelPath( + long apiHandle, long compileApiHandle, long nativeHandle, String inputModelPath) + throws OrtException; + + private static native void setInputModelFromBuffer( + long apiHandle, + long compileApiHandle, + long nativeHandle, + ByteBuffer inputBuffer, + long bufferPos, + long bufferRemaining) + throws OrtException; + + private static native void setOutputModelPath( + long apiHandle, long compileApiHandle, long nativeHandle, String outputModelPath) + throws OrtException; + + private static native void setOutputExternalInitializersPath( + long apiHandle, + long compileApiHandle, + long nativeHandle, + String externalInitializersPath, + long sizeThreshold) + throws OrtException; + + private static native void setEpContextEmbedMode( + long apiHandle, long compileApiHandle, long nativeHandle, boolean embedEpContext) + throws OrtException; + + private static native void setCompilationFlags( + long apiHandle, long compileApiHandle, long nativeHandle, int flags) throws OrtException; + + private static native void compileModel( + long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; +} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index a399d5080ca16..42dc90b71cb80 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ @@ -8,7 +8,6 @@ import ai.onnxruntime.providers.CoreMLFlags; import ai.onnxruntime.providers.NNAPIFlags; import ai.onnxruntime.providers.OrtCUDAProviderOptions; -import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; import java.nio.ByteBuffer; @@ -624,6 +623,10 @@ private native OnnxModelMetadata constructMetadata( *

Used to set the number of threads, optimisation level, computation backend and other * options. * + *

The order execution providers are added to an options instance is the order they will be + * considered for op node assignment, with the EP added first having priority. The CPU EP is a + * fallback and added by default. + * *

Modifying this after the session has been constructed will have no effect. * *

The SessionOptions object must not be closed until all sessions which use it are closed, as @@ -730,7 +733,7 @@ public SessionOptions() { @Override public void close() { if (!closed) { - if (customLibraryHandles.size() > 0) { + if (!customLibraryHandles.isEmpty()) { long[] longArray = new long[customLibraryHandles.size()]; for (int i = 0; i < customLibraryHandles.size(); i++) { longArray[i] = customLibraryHandles.get(i); @@ -917,10 +920,10 @@ public void registerCustomOpLibrary(String path) throws OrtException { * *

 OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); * - *

See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more - * information on custom ops. See - * https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 - * for an example of a custom op library registration function. + *

See Add + * Custom Op for more information on custom ops. See an example of a custom op library + * registration function here. * * @param registrationFuncName The name of the registration function to call. * @throws OrtException If there was an error finding or calling the registration function. @@ -1273,10 +1276,47 @@ public void addCoreML(EnumSet flags) throws OrtException { addCoreML(OnnxRuntime.ortApiHandle, nativeHandle, OrtFlags.aggregateToInt(flags)); } + /** + * Adds the specified execution provider and device tuples as an execution backend. + * + *

Execution provider priority is in the order added, i.e., the first provider added to a + * session options will be used first for op node assignment. + * + * @param devices The EP and device tuples. Each element must use the same EP, though they can + * use different devices. + * @param providerOptions Configuration options for the execution provider. Refer to the + * specific execution provider's documentation. + * @throws OrtException If there was an error in native code. + */ + public void addExecutionProvider(List devices, Map providerOptions) + throws OrtException { + checkClosed(); + if (devices.isEmpty()) { + throw new IllegalArgumentException("Must supply at least one OrtEpDevice"); + } + long[] deviceHandles = new long[devices.size()]; + for (int i = 0; i < devices.size(); i++) { + deviceHandles[i] = devices.get(i).getNativeHandle(); + } + String[][] optsArray = OrtUtil.unpackMap(providerOptions); + // This is valid as the environment must have been created to create the OrtEpDevice list. + long envHandle = OrtEnvironment.getEnvironment().getNativeHandle(); + addExecutionProvider( + OnnxRuntime.ortApiHandle, + envHandle, + nativeHandle, + deviceHandles, + optsArray[0], + optsArray[1]); + } + /** * Adds the named execution provider (backend) as an execution backend. This generic function * only allows a subset of execution providers. * + *

Execution provider priority is in the order added, i.e., the first provider added to a + * session options will be used first for op node assignment. + * * @param providerName The name of the execution provider. * @param providerOptions Configuration options for the execution provider. Refer to the * specific execution provider's documentation. @@ -1285,20 +1325,9 @@ public void addCoreML(EnumSet flags) throws OrtException { private void addExecutionProvider(String providerName, Map providerOptions) throws OrtException { checkClosed(); - String[] providerOptionKey = new String[providerOptions.size()]; - String[] providerOptionVal = new String[providerOptions.size()]; - int i = 0; - for (Map.Entry entry : providerOptions.entrySet()) { - providerOptionKey[i] = entry.getKey(); - providerOptionVal[i] = entry.getValue(); - i++; - } + String[][] optsArray = OrtUtil.unpackMap(providerOptions); addExecutionProvider( - OnnxRuntime.ortApiHandle, - nativeHandle, - providerName, - providerOptionKey, - providerOptionVal); + OnnxRuntime.ortApiHandle, nativeHandle, providerName, optsArray[0], optsArray[1]); } /** @@ -1484,6 +1513,15 @@ private native void addExecutionProvider( String[] providerOptionKey, String[] providerOptionVal) throws OrtException; + + private native void addExecutionProvider( + long apiHandle, + long envHandle, + long nativeHandle, + long[] deviceHandles, + String[] providerOptionKey, + String[] providerOptionVal) + throws OrtException; } /** Used to control logging and termination of a call to {@link OrtSession#run}. */ diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 2f44236e4ef67..ee91fdb292baa 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -16,6 +16,9 @@ import java.nio.ShortBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.logging.Logger; /** Util code for interacting with Java arrays. */ @@ -370,6 +373,52 @@ public static boolean validateShape(long[] shape) { return valid && shape.length <= TensorInfo.MAX_DIMENSIONS; } + /** + * Converts the output of a OrtKeyValuePairs into a Java unmodifiable HashMap. + * + * @param zippedString The zipped keys and values. + * @return An unmodifiable Map. + */ + static Map convertToMap(String[][] zippedString) { + if (zippedString.length != 2) { + throw new IllegalArgumentException("Invalid zipped string, must have two arrays."); + } else if (zippedString[0].length != zippedString[1].length) { + throw new IllegalArgumentException( + "Invalid zipped string, must have two arrays of the same length."); + } + Map map = new HashMap<>(capacityFromSize(zippedString[0].length)); + for (int i = 0; i < zippedString[0].length; i++) { + map.put(zippedString[0][i], zippedString[1][i]); + } + return Collections.unmodifiableMap(map); + } + + /** + * Converts a Java string map into a pair of arrays suitable for constructing a native + * OrtKeyValuePairs object. + * + * @param map A map from string to string, with no null keys or values. + * @return A pair of String arrays. + */ + static String[][] unpackMap(Map map) { + String[] keys = new String[map.size()]; + String[] values = new String[map.size()]; + int i = 0; + for (Map.Entry entry : map.entrySet()) { + if (entry.getKey() == null || entry.getValue() == null) { + throw new IllegalArgumentException( + "Invalid map, keys and values must not be null, found key = " + + entry.getKey() + + ", value = " + + entry.getValue()); + } + keys[i] = entry.getKey(); + values[i] = entry.getValue(); + i++; + } + return new String[][] {keys, values}; + } + /** * Flatten a multidimensional String array into a single dimensional String array, reading it in a * multidimensional row-major order. diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index 22bf940844774..15fe459dad7c8 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; +import ai.onnxruntime.OrtFlags; + /** Flags for the CoreML provider. */ public enum CoreMLFlags implements OrtFlags { /** diff --git a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java index eeaf6cc8d53bc..dd30684078717 100644 --- a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; +import ai.onnxruntime.OrtFlags; + /** Flags for the NNAPI provider. */ public enum NNAPIFlags implements OrtFlags { /** Enables fp16 support. */ diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index fe19015d642f0..752b99d6cd7dc 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1014,6 +1014,36 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca } } +jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp) { + // extract pair arrays + const char* const* keys = NULL; + const char* const* values = NULL; + size_t numKeys = 0; + api->GetKeyValuePairs(kvp, &keys, &values, &numKeys); + jsize jNumKeys = safecast_size_t_to_jsize(numKeys); + + // create Java String[] + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray keyArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); + jobjectArray valueArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); + + // populate Java arrays + for (jsize i = 0; i < jNumKeys; i++) { + jstring key = (*jniEnv)->NewStringUTF(jniEnv, keys[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, keyArray, i, key); + jstring value = (*jniEnv)->NewStringUTF(jniEnv, values[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, valueArray, i, value); + } + + // create Java String[][] + jclass stringArrClazz = (*jniEnv)->GetObjectClass(jniEnv, keyArray); + jobjectArray pair = (*jniEnv)->NewObjectArray(jniEnv, 2, stringArrClazz, 0); + (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 0, keyArray); + (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 1, valueArray); + + return pair; +} + jint throwOrtException(JNIEnv *jniEnv, int messageId, const char *message) { jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 7f41e06371f2a..040fd41264c10 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -78,6 +78,8 @@ jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue); +jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp); + jint throwOrtException(JNIEnv *env, int messageId, const char *message); jint convertErrorCode(OrtErrorCode code); diff --git a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c index 659f34e1fb66f..d8f5f1a3cb2db 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c +++ b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c @@ -32,6 +32,19 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseTrainingAPIBas return (jlong) trainingApi; } +/* + * Class: ai_onnxruntime_OnnxRuntime + * Method: initialiseCompileAPIBase + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseCompileAPIBase + (JNIEnv * jniEnv, jclass clazz, jlong apiHandle) { + (void)jniEnv; (void)clazz; // required JNI parameters not needed by functions which don't call back into Java. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = api->GetCompileApi(); + return (jlong) compileApi; +} + /* * Class: ai_onnxruntime_OnnxRuntime * Method: getAvailableProviders diff --git a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c index e1b1ff1c05fe1..77b096d62ec76 100644 --- a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c +++ b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c @@ -60,6 +60,76 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEnvironment_getDefaultAllocator return (jlong)allocator; } +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: registerExecutionProviderLibrary + * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_registerExecutionProviderLibrary + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name, jstring libraryPath) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, libraryPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, libraryPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, libraryPath, NULL); + checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, libraryPath, cPath); +#endif + (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); +} + +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: unregisterExecutionProviderLibrary + * Signature: (JJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_unregisterExecutionProviderLibrary + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); + checkOrtStatus(jniEnv, api, api->UnregisterExecutionProviderLibrary(env, cName)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); +} + +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: getEpDevices + * Signature: (JJ)[J + */ +JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OrtEnvironment_getEpDevices + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + size_t numDevices = 0; + const OrtEpDevice* const* devicesArr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetEpDevices(env, &devicesArr, &numDevices)); + if (code != ORT_OK) { + return NULL; + } else { + jsize numDevicesInt = safecast_size_t_to_jsize(numDevices); + jlongArray outputArr = (*jniEnv)->NewLongArray(jniEnv, numDevicesInt); + (*jniEnv)->SetLongArrayRegion(jniEnv, outputArr, 0, numDevicesInt, (jlong*)devicesArr); + return outputArr; + } +} + /* * Class: ai_onnxruntime_OrtEnvironment * Method: close diff --git a/java/src/main/native/ai_onnxruntime_OrtEpDevice.c b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c new file mode 100644 index 0000000000000..5a1e3092b0fb9 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtEpDevice.h" + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getName + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const char* name = api->EpDevice_EpName(epDevice); + jstring nameStr = (*jniEnv)->NewStringUTF(jniEnv, name); + return nameStr; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getVendor + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const char* vendor = api->EpDevice_EpVendor(epDevice); + jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); + return vendorStr; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getMetadata + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->EpDevice_EpMetadata(epDevice); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getOptions + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getOptions + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->EpDevice_EpOptions(epDevice); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getDeviceHandle + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEpDevice_getDeviceHandle + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtHardwareDevice* device = api->EpDevice_Device(epDevice); + return (jlong) device; +} diff --git a/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c b/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c new file mode 100644 index 0000000000000..3191a89c26ba1 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtHardwareDevice.h" + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getVendor + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendor + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + const char* vendor = api->HardwareDevice_Vendor(device); + jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); + return vendorStr; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getMetadata + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getMetadata + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->HardwareDevice_Metadata(device); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getDeviceType + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceType + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + OrtHardwareDeviceType type = api->HardwareDevice_Type(device); + jint output = 0; + // Must be kept aligned with the Java OrtHardwareDeviceType enum. + switch (type) { + case OrtHardwareDeviceType_CPU: + output = 0; + break; + case OrtHardwareDeviceType_GPU: + output = 1; + break; + case OrtHardwareDeviceType_NPU: + output = 2; + break; + default: + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Unexpected device type found. Only CPU, GPU and NPU are supported."); + break; + } + return output; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getDeviceId + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceId + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + uint32_t id = api->HardwareDevice_DeviceId(device); + return (jint) id; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getVendorId + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendorId + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + uint32_t id = api->HardwareDevice_VendorId(device); + return (jint) id; +} diff --git a/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c b/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c new file mode 100644 index 0000000000000..4f79383d09766 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtModelCompilationOptions.h" + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: createFromSessionOptions + * Signature: (JJJJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_createFromSessionOptions + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong sessionOptionsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + const OrtEnv* env = (const OrtEnv*)envHandle; + const OrtSessionOptions* sessionOptions = (const OrtSessionOptions*) sessionOptionsHandle; + OrtModelCompilationOptions* output = NULL; + checkOrtStatus(jniEnv, api, compileApi->CreateModelCompilationOptionsFromSessionOptions(env, sessionOptions, &output)); + return (jlong) output; +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: close + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_close + (JNIEnv * jniEnv, jclass jclazz, jlong compileApiHandle, jlong nativeHandle) { + (void)jniEnv; (void)jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + compileApi->ReleaseModelCompilationOptions((OrtModelCompilationOptions *)nativeHandle); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setInputModelPath + * Signature: (JJJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring modelPath) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, modelPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setInputModelFromBuffer + * Signature: (JJJLjava/nio/ByteBuffer;JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelFromBuffer + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jobject buffer, jlong bufferPos, jlong bufferRemaining) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + // Cast to pointers + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelFromBuffer(compOpts, bufferArr, bufferRemaining)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setOutputModelPath + * Signature: (JJJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputModelPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring outputPath) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, outputPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, outputPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, outputPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, outputPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setOutputExternalInitializersPath + * Signature: (JJJLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputExternalInitializersPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring initializersPath, jlong threshold) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, initializersPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, initializersPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, newString, threshold)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, initializersPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, cPath, threshold)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, initializersPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setEpContextEmbedMode + * Signature: (JJJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setEpContextEmbedMode + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jboolean embedMode) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetEpContextEmbedMode(compOpts, (bool) embedMode)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setCompilationFlags + * Signature: (JJJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setCompilationFlags + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jint flags) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetFlags(compOpts, flags)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: compileModel + * Signature: (JJJJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_compileModel + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong nativeHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + const OrtEnv* env = (const OrtEnv*)envHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->CompileModel(env, compOpts)); +} diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index ff6b7fa703e6e..95bcdf7af9746 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -718,11 +718,11 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addROC } /* - * Class:: ai_onnxruntime_OrtSession_SessionOptions + * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: addExecutionProvider - * Signature: (JILjava/lang/String)V + * Signature: (JJLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider( +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJLjava_lang_String_2_3Ljava_lang_String_2_3Ljava_lang_String_2( JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring jepName, jobjectArray configKeyArr, jobjectArray configValueArr) { (void)jobj; @@ -756,3 +756,50 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExe free((void*)jkeyArray); free((void*)jvalueArray); } + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: addExecutionProvider + * Signature: (JJJ[J[Ljava/lang/String;[Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJJ_3J_3Ljava_lang_String_2_3Ljava_lang_String_2 + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jlong optionsHandle, jlongArray deviceHandleArr, jobjectArray configKeyArr, jobjectArray configValueArr) { + (void)jobj; + + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*) envHandle; + OrtSessionOptions* options = (OrtSessionOptions*)optionsHandle; + jsize deviceCount = (*jniEnv)->GetArrayLength(jniEnv, deviceHandleArr); + jsize keyCount = (*jniEnv)->GetArrayLength(jniEnv, configKeyArr); + + const char** keyArray = (const char**)allocarray(keyCount, sizeof(const char*)); + const char** valueArray = (const char**)allocarray(keyCount, sizeof(const char*)); + jstring* jkeyArray = (jstring*)allocarray(keyCount, sizeof(jstring)); + jstring* jvalueArray = (jstring*)allocarray(keyCount, sizeof(jstring)); + const OrtEpDevice** devicePtrs = allocarray(deviceCount, sizeof(OrtEpDevice *)); + + jlong* deviceHandleElements = (*jniEnv)->GetLongArrayElements(jniEnv, deviceHandleArr, NULL); + for (jsize i = 0; i < deviceCount; i++) { + devicePtrs[i] = (OrtEpDevice*) deviceHandleElements[i]; + } + (*jniEnv)->ReleaseLongArrayElements(jniEnv, deviceHandleArr, deviceHandleElements, JNI_ABORT); + + for (jsize i = 0; i < keyCount; i++) { + jkeyArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configKeyArr, i)); + jvalueArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configValueArr, i)); + keyArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jkeyArray[i], NULL); + valueArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jvalueArray[i], NULL); + } + + checkOrtStatus(jniEnv, api, api->SessionOptionsAppendExecutionProvider_V2(options, env, devicePtrs, deviceCount, keyArray, valueArray, keyCount)); + + for (jsize i = 0; i < keyCount; i++) { + (*jniEnv)->ReleaseStringUTFChars(jniEnv, jkeyArray[i], keyArray[i]); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, jvalueArray[i], valueArray[i]); + } + free((void*)devicePtrs); + free((void*)keyArray); + free((void*)valueArray); + free((void*)jkeyArray); + free((void*)jvalueArray); +} diff --git a/java/src/test/java/ai/onnxruntime/CompileApiTest.java b/java/src/test/java/ai/onnxruntime/CompileApiTest.java new file mode 100644 index 0000000000000..b70f4dca5cbd0 --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/CompileApiTest.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import ai.onnxruntime.OrtSession.SessionOptions; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** Test for the compilation API. */ +public class CompileApiTest { + private final OrtEnvironment env = OrtEnvironment.getEnvironment(); + + @Test + public void basicUsage() throws OrtException, IOException { + SessionOptions so = new SessionOptions(); + try (OrtModelCompilationOptions compileOptions = + OrtModelCompilationOptions.createFromSessionOptions(env, so)) { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.setInputModelPath("model.onnx"); + compileOptions.setOutputModelPath("compiled_model.onnx"); + + compileOptions.setOutputExternalInitializersPath("external_data.bin", 512); + compileOptions.setEpContextEmbedMode(true); + } + + try (OrtModelCompilationOptions compileOptions = + OrtModelCompilationOptions.createFromSessionOptions(env, so)) { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + byte[] modelBytes = Files.readAllBytes(modelPath); + ByteBuffer modelBuffer = ByteBuffer.wrap(modelBytes); + compileOptions.setInputModelFromBuffer(modelBuffer); + compileOptions.setOutputModelPath("compiled_model.onnx"); + + File f = new File("compiled_model.onnx"); + + compileOptions.compileModel(); + + // Check the compiled model is valid + try (OrtSession session = env.createSession(f.toString(), so)) { + Assertions.assertNotNull(session); + } + + f.delete(); + } + } +} diff --git a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java new file mode 100644 index 0000000000000..ec4c977508c8c --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import ai.onnxruntime.OrtHardwareDevice.OrtHardwareDeviceType; +import ai.onnxruntime.OrtSession.SessionOptions; +import java.io.File; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; + +/** Tests for {@link OrtEpDevice} and {@link OrtHardwareDevice}. */ +@EnabledOnOs(value = OS.WINDOWS) +public class EpDeviceTest { + private final OrtEnvironment ortEnv = OrtEnvironment.getEnvironment(); + + private void readHardwareDeviceValues(OrtHardwareDevice device) { + OrtHardwareDeviceType type = device.getType(); + + Assertions.assertTrue( + type == OrtHardwareDeviceType.CPU + || type == OrtHardwareDeviceType.GPU + || type == OrtHardwareDeviceType.NPU); + + if (type == OrtHardwareDeviceType.CPU) { + Assertions.assertFalse(device.getVendor().isEmpty()); + } else { + Assertions.assertTrue(device.getVendorId() != 0); + Assertions.assertTrue(device.getDeviceId() != 0); + } + + Map metadata = device.getMetadata(); + Assertions.assertNotNull(metadata); + for (Map.Entry kvp : metadata.entrySet()) { + Assertions.assertFalse(kvp.getKey().isEmpty()); + } + } + + @Test + public void getEpDevices() throws OrtException { + List epDevices = ortEnv.getEpDevices(); + Assertions.assertNotNull(epDevices); + Assertions.assertFalse(epDevices.isEmpty()); + for (OrtEpDevice epDevice : epDevices) { + Assertions.assertFalse(epDevice.getName().isEmpty()); + Assertions.assertFalse(epDevice.getVendor().isEmpty()); + Map metadata = epDevice.getMetadata(); + Assertions.assertNotNull(metadata); + Map options = epDevice.getOptions(); + Assertions.assertNotNull(options); + readHardwareDeviceValues(epDevice.getDevice()); + } + } + + @Test + public void registerUnregisterLibrary() throws OrtException { + String libFullPath = TestHelpers.getResourcePath("/example_plugin_ep.dll").toString(); + Assertions.assertTrue( + new File(libFullPath).exists(), "Expected lib " + libFullPath + " does not exist."); + + // example plugin ep uses the registration name as the ep name + String epName = "java_ep"; + + // register. shouldn't throw + ortEnv.registerExecutionProviderLibrary(epName, libFullPath); + + // check OrtEpDevice was found + List epDevices = ortEnv.getEpDevices(); + boolean found = epDevices.stream().anyMatch(a -> a.getName().equals(epName)); + Assertions.assertTrue(found); + + // unregister + ortEnv.unregisterExecutionProviderLibrary(epName); + } + + @Test + public void appendToSessionOptionsV2() { + Consumer>> runTest = + (Supplier> options) -> { + try (SessionOptions sessionOptions = new SessionOptions()) { + sessionOptions.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE); + + List epDevices = ortEnv.getEpDevices(); + + // cpu ep ignores the provider options so we can use any value in epOptions and it won't + // break. + List selectedEpDevices = + epDevices.stream() + .filter(a -> a.getName().equals("CPUExecutionProvider")) + .collect(Collectors.toList()); + + Map epOptions = options.get(); + sessionOptions.addExecutionProvider(selectedEpDevices, epOptions); + + Path model = TestHelpers.getResourcePath("/squeezenet.onnx"); + String modelPath = model.toString(); + + // session should load successfully + try (OrtSession session = ortEnv.createSession(modelPath, sessionOptions)) { + Assertions.assertNotNull(session); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + }; + + // empty options + runTest.accept(Collections::emptyMap); + + // dummy options + runTest.accept(() -> Collections.singletonMap("random_key", "value")); + } +}