From 7d10e8cdaec7fcaa8f4f1e3220455a8db43affb6 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 12 Sep 2025 11:26:17 -0400 Subject: [PATCH 1/2] Adding support for OrtCompiledModelCompatibility. --- .../java/ai/onnxruntime/OrtEnvironment.java | 85 +++++++++++++++++++ java/src/main/native/OrtJniUtil.c | 38 ++++++++- java/src/main/native/OrtJniUtil.h | 6 +- .../native/ai_onnxruntime_OrtEnvironment.c | 38 ++++++++- .../java/ai/onnxruntime/EpDeviceTest.java | 30 +++++++ 5 files changed, 194 insertions(+), 3 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 497772baf5357..c55d50b7d08b0 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -488,6 +488,30 @@ public List getEpDevices() throws OrtException { return Collections.unmodifiableList(devicesList); } + /** + * Checks the supplied model info string against the list of {@link OrtEpDevice}s to see if the + * model is compatible. + * + * @param epDevices The EP-Device tuples to use. + * @param modelInfo The model info string to check. + * @return The model compatibility. + * @throws OrtException If the call failed. + */ + public OrtCompiledModelCompatibility getModelCompatibilityForEpDevices( + List epDevices, String modelInfo) throws OrtException { + if (epDevices == null || epDevices.isEmpty()) { + throw new IllegalArgumentException("Must supply at least one OrtEpDevice"); + } + long[] deviceHandles = new long[epDevices.size()]; + for (int i = 0; i < epDevices.size(); i++) { + deviceHandles[i] = epDevices.get(i).getNativeHandle(); + } + + int output = + getModelCompatibilityForEpDevices(OnnxRuntime.ortApiHandle, deviceHandles, modelInfo); + return OrtCompiledModelCompatibility.mapFromInt(output); + } + /** * Creates the native object. * @@ -556,6 +580,18 @@ private static native void unregisterExecutionProviderLibrary( */ private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException; + /** + * Checks if a model is compatible with the supplied list of EP device handles. + * + * @param apiHandle The API handle to use. + * @param epHandles An array of OrtEpDevice handles. + * @param modelInfo The model info string. + * @return An int representing the {@link OrtCompiledModelCompatibility} value. + * @throws OrtException If the call failed. + */ + private static native int getModelCompatibilityForEpDevices( + long apiHandle, long[] epHandles, String modelInfo) throws OrtException; + /** * Closes the OrtEnvironment, frees the handle. * @@ -580,6 +616,55 @@ private static native void setTelemetry(long apiHandle, long nativeHandle, boole @Override public void close() {} + /** Enum representing a compiled model's compatibility with a set of {@link OrtEpDevice}s. */ + public enum OrtCompiledModelCompatibility { + EP_NOT_APPLICABLE(0), + EP_SUPPORTED_OPTIMAL(1), + EP_SUPPORTED_PREFER_RECOMPILATION(2), + EP_UNSUPPORTED(3); + + private final int value; + + private static final Logger logger = + Logger.getLogger(OrtCompiledModelCompatibility.class.getName()); + private static final OrtCompiledModelCompatibility[] values = + new OrtCompiledModelCompatibility[4]; + + static { + for (OrtCompiledModelCompatibility ot : OrtCompiledModelCompatibility.values()) { + values[ot.value] = ot; + } + } + + OrtCompiledModelCompatibility(int value) { + this.value = value; + } + + /** + * Gets the native value associated with this model compatibility value. + * + * @return The native value. + */ + public int getValue() { + return value; + } + + /** + * Maps from the C API's int enum to the Java enum. + * + * @param logLevel The index of the Java enum. + * @return The Java enum. + */ + public static OrtCompiledModelCompatibility mapFromInt(int logLevel) { + if ((logLevel >= 0) && (logLevel < values.length)) { + return values[logLevel]; + } else { + logger.warning("Unknown model compatibility " + logLevel + " setting to EP_UNSUPPORTED"); + return EP_UNSUPPORTED; + } + } + } + /** * Controls the global thread pools in the environment. Only used if the session is constructed * using an options with {@link OrtSession.SessionOptions#disablePerSessionThreads()} set. diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 96ea8e79bc978..eef4b731d106e 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -109,6 +109,42 @@ jint convertFromOrtSparseFormat(OrtSparseFormat format) { } } +/** + * Must be kept in sync with convertToCompiledModelCompatibility. + */ +jint convertFromCompiledModelCompatibility(OrtCompiledModelCompatibility compat) { + switch (compat) { + case OrtCompiledModelCompatibility_EP_NOT_APPLICABLE: + return 0; + case OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL: + return 1; + case OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION: + return 2; + case OrtCompiledModelCompatibility_EP_UNSUPPORTED: + return 3; + default: + // if this value is observed the enum has changed and the code should be updated. + return -1; + } +} + +/** + * Must be kept in sync with convertFromCompiledModelCompatibility. + */ +OrtCompiledModelCompatibility convertToCompiledModelCompatibility(jint compat) { + switch (compat) { + case 0: + return OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + case 1: + return OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + case 2: + return OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + case 3: + default: + return OrtCompiledModelCompatibility_EP_UNSUPPORTED; + } +} + /** * Must be kept in sync with convertToONNXDataFormat */ diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 040fd41264c10..f9f4717597831 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -34,6 +34,10 @@ OrtSparseFormat convertToOrtSparseFormat(jint format); jint convertFromOrtSparseFormat(OrtSparseFormat format); +jint convertFromCompiledModelCompatibility(OrtCompiledModelCompatibility compat); + +OrtCompiledModelCompatibility convertToCompiledModelCompatibility(jint compat); + jint convertFromONNXDataFormat(ONNXTensorElementDataType type); ONNXTensorElementDataType convertToONNXDataFormat(jint type); diff --git a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c index 77b096d62ec76..8061c6454ef47 100644 --- a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c +++ b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -130,6 +130,42 @@ JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OrtEnvironment_getEpDevices } } +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: getModelCompatibilityForEpDevices + * Signature: (J[JLjava/lang/String;)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtEnvironment_getModelCompatibilityForEpDevices + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlongArray epHandles, jstring modelInfo) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + + // convert pointers for EpDevice handles + jsize deviceCount = (*jniEnv)->GetArrayLength(jniEnv, epHandles); + const OrtEpDevice** devicePtrs = allocarray(deviceCount, sizeof(OrtEpDevice *)); + jlong* deviceHandleElements = (*jniEnv)->GetLongArrayElements(jniEnv, epHandles, NULL); + for (jsize i = 0; i < deviceCount; i++) { + devicePtrs[i] = (OrtEpDevice*) deviceHandleElements[i]; + } + (*jniEnv)->ReleaseLongArrayElements(jniEnv, epHandles, deviceHandleElements, JNI_ABORT); + + // get utf-8 string + const char* modelStr = (*jniEnv)->GetStringUTFChars(jniEnv, modelInfo, NULL); + + OrtCompiledModelCompatibility compatibility; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetModelCompatibilityForEpDevices(devicePtrs, deviceCount, modelStr, &compatibility)); + + // cleanup + (*jniEnv)->ReleaseStringUTFChars(jniEnv, modelInfo, modelStr); + free((void*)devicePtrs); + if (code != ORT_OK) { + return -1; + } else { + jint returnVal = convertFromCompiledModelCompatibility(compatibility); + return returnVal; + } +} + /* * Class: ai_onnxruntime_OrtEnvironment * Method: close diff --git a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java index ec4c977508c8c..e4672d4a211c4 100644 --- a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java +++ b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java @@ -120,4 +120,34 @@ public void appendToSessionOptionsV2() { // dummy options runTest.accept(() -> Collections.singletonMap("random_key", "value")); } + + @Test + public void GetEpCompatibilityInvalidArgs() { + Assertions.assertThrows( + IllegalArgumentException.class, + () -> ortEnv.getModelCompatibilityForEpDevices(null, "info")); + Assertions.assertThrows( + IllegalArgumentException.class, + () -> ortEnv.getModelCompatibilityForEpDevices(Collections.emptyList(), "info")); + } + + @Test + public void GetEpCompatibilitySingleDeviceCpuProvider() throws OrtException { + List epDevices = ortEnv.getEpDevices(); + String someInfo = "arbitrary-compat-string"; + + // Use CPU device + OrtEpDevice cpu = + epDevices.stream() + .filter(d -> d.getName().equals("CPUExecutionProvider")) + .findFirst() + .get(); + Assertions.assertNotNull(cpu); + List selected = Collections.singletonList(cpu); + OrtEnvironment.OrtCompiledModelCompatibility status = + ortEnv.getModelCompatibilityForEpDevices(selected, someInfo); + + // CPU defaults to not applicable in this scenario + Assertions.assertEquals(OrtEnvironment.OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status); + } } From 92b131c0443603938f3b4a87e51dc11d06f22f23 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 12 Sep 2025 11:40:20 -0400 Subject: [PATCH 2/2] Renaming some accessors in OrtEpDevice, adding javadoc to the enum constants in OrtCompiledModelCompatibility. --- .../java/ai/onnxruntime/OrtEnvironment.java | 4 +++ .../main/java/ai/onnxruntime/OrtEpDevice.java | 34 +++++++++---------- .../main/native/ai_onnxruntime_OrtEpDevice.c | 8 ++--- .../java/ai/onnxruntime/EpDeviceTest.java | 14 ++++---- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index c55d50b7d08b0..2a9d4876c4c1a 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -618,9 +618,13 @@ public void close() {} /** Enum representing a compiled model's compatibility with a set of {@link OrtEpDevice}s. */ public enum OrtCompiledModelCompatibility { + /** The EP is not applicable for the model. */ EP_NOT_APPLICABLE(0), + /** The EP supports the model optimally. */ EP_SUPPORTED_OPTIMAL(1), + /** The EP supports the model, but the model would perform better if recompiled. */ EP_SUPPORTED_PREFER_RECOMPILATION(2), + /** The EP does not support the model. */ EP_UNSUPPORTED(3); private final int value; diff --git a/java/src/main/java/ai/onnxruntime/OrtEpDevice.java b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java index f63dec1dbaf83..338c907fc81da 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEpDevice.java +++ b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java @@ -24,11 +24,11 @@ public final class OrtEpDevice { */ OrtEpDevice(long nativeHandle) { this.nativeHandle = nativeHandle; - this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle); - this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); - String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); + this.epName = getEpName(OnnxRuntime.ortApiHandle, nativeHandle); + this.epVendor = getEpVendor(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] metadata = getEpMetadata(OnnxRuntime.ortApiHandle, nativeHandle); this.epMetadata = OrtUtil.convertToMap(metadata); - String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] options = getEpOptions(OnnxRuntime.ortApiHandle, nativeHandle); this.epOptions = OrtUtil.convertToMap(options); this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle)); } @@ -43,38 +43,38 @@ long getNativeHandle() { } /** - * Gets the EP name. + * Gets the Execution Provider name. * * @return The EP name. */ - public String getName() { + public String getEpName() { return epName; } /** - * Gets the vendor name. + * Gets the Execution Provider vendor name. * - * @return The vendor name. + * @return The EP vendor name. */ - public String getVendor() { + public String getEpVendor() { return epVendor; } /** - * Gets an unmodifiable view on the EP metadata. + * Gets an unmodifiable view on the Execution Provider metadata. * * @return The EP metadata. */ - public Map getMetadata() { + public Map getEpMetadata() { return epMetadata; } /** - * Gets an unmodifiable view on the EP options. + * Gets an unmodifiable view on the Execution Provider options. * * @return The EP options. */ - public Map getOptions() { + public Map getEpOptions() { return epOptions; } @@ -105,13 +105,13 @@ public String toString() { + '}'; } - private static native String getName(long apiHandle, long nativeHandle); + private static native String getEpName(long apiHandle, long nativeHandle); - private static native String getVendor(long apiHandle, long nativeHandle); + private static native String getEpVendor(long apiHandle, long nativeHandle); - private static native String[][] getMetadata(long apiHandle, long nativeHandle); + private static native String[][] getEpMetadata(long apiHandle, long nativeHandle); - private static native String[][] getOptions(long apiHandle, long nativeHandle); + private static native String[][] getEpOptions(long apiHandle, long nativeHandle); private static native long getDeviceHandle(long apiHandle, long nativeHandle); } diff --git a/java/src/main/native/ai_onnxruntime_OrtEpDevice.c b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c index 5a1e3092b0fb9..168c0534aaffd 100644 --- a/java/src/main/native/ai_onnxruntime_OrtEpDevice.c +++ b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c @@ -12,7 +12,7 @@ * Method: getName * Signature: (JJ)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpName (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; @@ -27,7 +27,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName * Method: getVendor * Signature: (JJ)Ljava/lang/String; */ -JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpVendor (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; @@ -42,7 +42,7 @@ JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor * Method: getMetadata * Signature: (JJ)[[Ljava/lang/String; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpMetadata (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; @@ -57,7 +57,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata * Method: getOptions * Signature: (JJ)[[Ljava/lang/String; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getOptions +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getEpOptions (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*) apiHandle; diff --git a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java index e4672d4a211c4..25a21eae35be6 100644 --- a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java +++ b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java @@ -52,11 +52,11 @@ public void getEpDevices() throws OrtException { Assertions.assertNotNull(epDevices); Assertions.assertFalse(epDevices.isEmpty()); for (OrtEpDevice epDevice : epDevices) { - Assertions.assertFalse(epDevice.getName().isEmpty()); - Assertions.assertFalse(epDevice.getVendor().isEmpty()); - Map metadata = epDevice.getMetadata(); + Assertions.assertFalse(epDevice.getEpName().isEmpty()); + Assertions.assertFalse(epDevice.getEpVendor().isEmpty()); + Map metadata = epDevice.getEpMetadata(); Assertions.assertNotNull(metadata); - Map options = epDevice.getOptions(); + Map options = epDevice.getEpOptions(); Assertions.assertNotNull(options); readHardwareDeviceValues(epDevice.getDevice()); } @@ -76,7 +76,7 @@ public void registerUnregisterLibrary() throws OrtException { // check OrtEpDevice was found List epDevices = ortEnv.getEpDevices(); - boolean found = epDevices.stream().anyMatch(a -> a.getName().equals(epName)); + boolean found = epDevices.stream().anyMatch(a -> a.getEpName().equals(epName)); Assertions.assertTrue(found); // unregister @@ -96,7 +96,7 @@ public void appendToSessionOptionsV2() { // break. List selectedEpDevices = epDevices.stream() - .filter(a -> a.getName().equals("CPUExecutionProvider")) + .filter(a -> a.getEpName().equals("CPUExecutionProvider")) .collect(Collectors.toList()); Map epOptions = options.get(); @@ -139,7 +139,7 @@ public void GetEpCompatibilitySingleDeviceCpuProvider() throws OrtException { // Use CPU device OrtEpDevice cpu = epDevices.stream() - .filter(d -> d.getName().equals("CPUExecutionProvider")) + .filter(d -> d.getEpName().equals("CPUExecutionProvider")) .findFirst() .get(); Assertions.assertNotNull(cpu);