Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtEnvironment.java
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,30 @@ public List<OrtEpDevice> getEpDevices() throws OrtException {
return Collections.unmodifiableList(devicesList);
}

/**
* Checks the supplied model info string against the list of {@link OrtEpDevice}s to see if the
* model is compatible.
*
* @param epDevices The EP-Device tuples to use.
* @param modelInfo The model info string to check.
* @return The model compatibility.
* @throws OrtException If the call failed.
*/
public OrtCompiledModelCompatibility getModelCompatibilityForEpDevices(
List<OrtEpDevice> epDevices, String modelInfo) throws OrtException {
if (epDevices == null || epDevices.isEmpty()) {
throw new IllegalArgumentException("Must supply at least one OrtEpDevice");
}
long[] deviceHandles = new long[epDevices.size()];
for (int i = 0; i < epDevices.size(); i++) {
deviceHandles[i] = epDevices.get(i).getNativeHandle();
}

int output =
getModelCompatibilityForEpDevices(OnnxRuntime.ortApiHandle, deviceHandles, modelInfo);
return OrtCompiledModelCompatibility.mapFromInt(output);
}

/**
* Creates the native object.
*
Expand Down Expand Up @@ -556,6 +580,18 @@ private static native void unregisterExecutionProviderLibrary(
*/
private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException;

/**
* Checks if a model is compatible with the supplied list of EP device handles.
*
* @param apiHandle The API handle to use.
* @param epHandles An array of OrtEpDevice handles.
* @param modelInfo The model info string.
* @return An int representing the {@link OrtCompiledModelCompatibility} value.
* @throws OrtException If the call failed.
*/
private static native int getModelCompatibilityForEpDevices(
long apiHandle, long[] epHandles, String modelInfo) throws OrtException;

/**
* Closes the OrtEnvironment, frees the handle.
*
Expand All @@ -580,6 +616,59 @@ private static native void setTelemetry(long apiHandle, long nativeHandle, boole
@Override
public void close() {}

/** Enum representing a compiled model's compatibility with a set of {@link OrtEpDevice}s. */
public enum OrtCompiledModelCompatibility {
/** The EP is not applicable for the model. */
EP_NOT_APPLICABLE(0),
/** The EP supports the model optimally. */
EP_SUPPORTED_OPTIMAL(1),
/** The EP supports the model, but the model would perform better if recompiled. */
EP_SUPPORTED_PREFER_RECOMPILATION(2),
/** The EP does not support the model. */
EP_UNSUPPORTED(3);

private final int value;

private static final Logger logger =
Logger.getLogger(OrtCompiledModelCompatibility.class.getName());
private static final OrtCompiledModelCompatibility[] values =
new OrtCompiledModelCompatibility[4];

static {
for (OrtCompiledModelCompatibility ot : OrtCompiledModelCompatibility.values()) {
values[ot.value] = ot;
}
}

OrtCompiledModelCompatibility(int value) {
this.value = value;
}

/**
* Gets the native value associated with this model compatibility value.
*
* @return The native value.
*/
public int getValue() {
return value;
}

/**
* Maps from the C API's int enum to the Java enum.
*
* @param logLevel The index of the Java enum.
* @return The Java enum.
*/
public static OrtCompiledModelCompatibility mapFromInt(int logLevel) {
if ((logLevel >= 0) && (logLevel < values.length)) {
return values[logLevel];
} else {
logger.warning("Unknown model compatibility " + logLevel + " setting to EP_UNSUPPORTED");
return EP_UNSUPPORTED;
}
}
}

/**
* Controls the global thread pools in the environment. Only used if the session is constructed
* using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set.
Expand Down
34 changes: 17 additions & 17 deletions java/src/main/java/ai/onnxruntime/OrtEpDevice.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ public final class OrtEpDevice {
*/
OrtEpDevice(long nativeHandle) {
this.nativeHandle = nativeHandle;
this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle);
this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle);
String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle);
this.epName = getEpName(OnnxRuntime.ortApiHandle, nativeHandle);
this.epVendor = getEpVendor(OnnxRuntime.ortApiHandle, nativeHandle);
String[][] metadata = getEpMetadata(OnnxRuntime.ortApiHandle, nativeHandle);
this.epMetadata = OrtUtil.convertToMap(metadata);
String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle);
String[][] options = getEpOptions(OnnxRuntime.ortApiHandle, nativeHandle);
this.epOptions = OrtUtil.convertToMap(options);
this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle));
}
Expand All @@ -43,38 +43,38 @@ long getNativeHandle() {
}

/**
* Gets the EP name.
* Gets the Execution Provider name.
*
* @return The EP name.
*/
public String getName() {
public String getEpName() {
return epName;
}

/**
* Gets the vendor name.
* Gets the Execution Provider vendor name.
*
* @return The vendor name.
* @return The EP vendor name.
*/
public String getVendor() {
public String getEpVendor() {
return epVendor;
}

/**
* Gets an unmodifiable view on the EP metadata.
* Gets an unmodifiable view on the Execution Provider metadata.
*
* @return The EP metadata.
*/
public Map<String, String> getMetadata() {
public Map<String, String> getEpMetadata() {
return epMetadata;
}

/**
* Gets an unmodifiable view on the EP options.
* Gets an unmodifiable view on the Execution Provider options.
*
* @return The EP options.
*/
public Map<String, String> getOptions() {
public Map<String, String> getEpOptions() {
return epOptions;
}

Expand Down Expand Up @@ -105,13 +105,13 @@ public String toString() {
+ '}';
}

private static native String getName(long apiHandle, long nativeHandle);
private static native String getEpName(long apiHandle, long nativeHandle);

private static native String getVendor(long apiHandle, long nativeHandle);
private static native String getEpVendor(long apiHandle, long nativeHandle);

private static native String[][] getMetadata(long apiHandle, long nativeHandle);
private static native String[][] getEpMetadata(long apiHandle, long nativeHandle);

private static native String[][] getOptions(long apiHandle, long nativeHandle);
private static native String[][] getEpOptions(long apiHandle, long nativeHandle);

private static native long getDeviceHandle(long apiHandle, long nativeHandle);
}
38 changes: 37 additions & 1 deletion java/src/main/native/OrtJniUtil.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand Down Expand Up @@ -109,6 +109,42 @@ jint convertFromOrtSparseFormat(OrtSparseFormat format) {
}
}

/**
* Must be kept in sync with convertToCompiledModelCompatibility.
*/
jint convertFromCompiledModelCompatibility(OrtCompiledModelCompatibility compat) {
switch (compat) {
case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE:
return 0;
case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL:
return 1;
case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION:
return 2;
case OrtCompiledModelCompatibility_EP_UNSUPPORTED:
return 3;
default:
// if this value is observed the enum has changed and the code should be updated.
return -1;
}
}

/**
* Must be kept in sync with convertFromCompiledModelCompatibility.
*/
OrtCompiledModelCompatibility convertToCompiledModelCompatibility(jint compat) {
switch (compat) {
case 0:
return OrtCompiledModelCompatibility_EP_NOT_APPLICABLE;
case 1:
return OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL;
case 2:
return OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION;
case 3:
default:
return OrtCompiledModelCompatibility_EP_UNSUPPORTED;
}
}

/**
* Must be kept in sync with convertToONNXDataFormat
*/
Expand Down
6 changes: 5 additions & 1 deletion java/src/main/native/OrtJniUtil.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand Down Expand Up @@ -34,6 +34,10 @@ OrtSparseFormat convertToOrtSparseFormat(jint format);

jint convertFromOrtSparseFormat(OrtSparseFormat format);

jint convertFromCompiledModelCompatibility(OrtCompiledModelCompatibility compat);

OrtCompiledModelCompatibility convertToCompiledModelCompatibility(jint compat);

jint convertFromONNXDataFormat(ONNXTensorElementDataType type);

ONNXTensorElementDataType convertToONNXDataFormat(jint type);
Expand Down
38 changes: 37 additions & 1 deletion java/src/main/native/ai_onnxruntime_OrtEnvironment.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand Down Expand Up @@ -130,6 +130,42 @@ JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OrtEnvironment_getEpDevices
}
}

/*
* Class: ai_onnxruntime_OrtEnvironment
* Method: getModelCompatibilityForEpDevices
* Signature: (J[JLjava/lang/String;)I
*/
JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtEnvironment_getModelCompatibilityForEpDevices
(JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlongArray epHandles, jstring modelInfo) {
(void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;

// convert pointers for EpDevice handles
jsize deviceCount = (*jniEnv)->GetArrayLength(jniEnv, epHandles);
const OrtEpDevice** devicePtrs = allocarray(deviceCount, sizeof(OrtEpDevice *));
jlong* deviceHandleElements = (*jniEnv)->GetLongArrayElements(jniEnv, epHandles, NULL);
for (jsize i = 0; i < deviceCount; i++) {
devicePtrs[i] = (OrtEpDevice*) deviceHandleElements[i];
}
(*jniEnv)->ReleaseLongArrayElements(jniEnv, epHandles, deviceHandleElements, JNI_ABORT);

// get utf-8 string
const char* modelStr = (*jniEnv)->GetStringUTFChars(jniEnv, modelInfo, NULL);

OrtCompiledModelCompatibility compatibility;
OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetModelCompatibilityForEpDevices(devicePtrs, deviceCount, modelStr, &compatibility));

// cleanup
(*jniEnv)->ReleaseStringUTFChars(jniEnv, modelInfo, modelStr);
free((void*)devicePtrs);
if (code != ORT_OK) {
return -1;
} else {
jint returnVal = convertFromCompiledModelCompatibility(compatibility);
return returnVal;
}
}

/*
* Class: ai_onnxruntime_OrtEnvironment
* Method: close
Expand Down
8 changes: 4 additions & 4 deletions java/src/main/native/ai_onnxruntime_OrtEpDevice.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
* Method: getName
* Signature: (JJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpName
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
Expand All @@ -27,7 +27,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName
* Method: getVendor
* Signature: (JJ)Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor
JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpVendor
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
Expand All @@ -42,7 +42,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor
* Method: getMetadata
* Signature: (JJ)[[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpMetadata
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
Expand All @@ -57,7 +57,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata
* Method: getOptions
* Signature: (JJ)[[Ljava/lang/String;
*/
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getOptions
JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpOptions
(JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) {
(void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*) apiHandle;
Expand Down
Loading
Loading