From 9adeb1bab83079ba9e3c98db45f011bfc852e3a4 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 26 Dec 2025 15:41:16 -0500 Subject: [PATCH 1/8] Porting across MemorySegment support to something that will compile with Java 8. --- java/build.gradle | 2 +- .../ai/onnxruntime/MemorySegmentShim.java | 65 ++++++ .../main/java/ai/onnxruntime/OnnxTensor.java | 194 ++++++++++++++++-- .../main/java/ai/onnxruntime/TensorInfo.java | 40 +++- .../jvm/ai/onnxruntime/MemorySegmentShim.java | 194 ++++++++++++++++++ .../main/native/ai_onnxruntime_OnnxTensor.c | 69 ++++++- .../ai/onnxruntime/MemorySegmentTest.java | 146 +++++++++++++ .../java/ai/onnxruntime/ModelGenerators.java | 71 ++++++- .../resources/java-external-embedding.onnx | 20 ++ 9 files changed, 779 insertions(+), 22 deletions(-) create mode 100644 java/src/main/android/ai/onnxruntime/MemorySegmentShim.java create mode 100644 java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java create mode 100644 java/src/test/java/ai/onnxruntime/MemorySegmentTest.java create mode 100644 java/src/test/resources/java-external-embedding.onnx diff --git a/java/build.gradle b/java/build.gradle index 64a31c89ad322..12e5f2ec70469 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -167,7 +167,7 @@ dependencies { } processTestResources { - duplicatesStrategy(DuplicatesStrategy.INCLUDE) // allows duplicates in the test resources + duplicatesStrategy = DuplicatesStrategy.INCLUDE // allows duplicates in the test resources } test { diff --git a/java/src/main/android/ai/onnxruntime/MemorySegmentShim.java b/java/src/main/android/ai/onnxruntime/MemorySegmentShim.java new file mode 100644 index 0000000000000..0662b6b68eb41 --- /dev/null +++ b/java/src/main/android/ai/onnxruntime/MemorySegmentShim.java @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Wrapper for java.lang.foreign.MemorySegment instances which throws {@link java.lang.UnsupportedOperationException} + * as FFM is not supported on Android. + */ +final class MemorySegmentShim { + private static final Logger logger = Logger.getLogger(MemorySegmentShim.class.getName()); + + /** + * Constructor which wraps a MemorySegment. Always throws on Android. + * + * @param segment The memory segment. + * @throws UnsupportedOperationException If java.lang.foreign.MemorySegment is not available in the running JDK. + */ + MemorySegmentShim(Object segment) { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + + /** + * Constructor which builds a MemorySegment using the supplied arguments. Always throws on Android. + * + * @param address The address of the memory. + * @param byteSize The size of the memory. + * @throws UnsupportedOperationException If java.lang.foreign.MemorySegment is not available in the running JDK. + */ + MemorySegmentShim(long address, long byteSize) { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + + /** + * Always throws {@link UnsupportedOperationException} on Android. + * @return The MemorySegment. + */ + Object get() { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + + /** + * Always throws {@link UnsupportedOperationException} on Android. + * @return The address of the MemorySegment. + */ + long address() { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + + /** + * Always throws {@link UnsupportedOperationException} on Android. + * @return The size of the MemorySegment in bytes. + */ + long byteSize() { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + +} diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 3f276a3670156..46945bdd1f6d6 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -23,6 +23,12 @@ public class OnnxTensor extends OnnxTensorLike { private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName()); + /** + * This reference is held for OnnxTensors backed by a java.lang.foreign.MemorySegment to ensure + * the segment does not go out of scope while the OnnxTensor exists. + */ + private final Object segment; + /** * This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does * not go out of scope while the OnnxTensor exists. @@ -36,13 +42,26 @@ public class OnnxTensor extends OnnxTensorLike { private final boolean ownsBuffer; OnnxTensor(long nativeHandle, long allocatorHandle, TensorInfo info) { - this(nativeHandle, allocatorHandle, info, null, false); + this(nativeHandle, allocatorHandle, info, (Buffer) null, false); } OnnxTensor( long nativeHandle, long allocatorHandle, TensorInfo info, Buffer buffer, boolean ownsBuffer) { super(nativeHandle, allocatorHandle, info); this.buffer = buffer; + this.segment = null; + this.ownsBuffer = ownsBuffer; + } + + OnnxTensor( + long nativeHandle, + long allocatorHandle, + TensorInfo info, + MemorySegmentShim segment, + boolean ownsBuffer) { + super(nativeHandle, allocatorHandle, info); + this.buffer = null; + this.segment = segment.get(); this.ownsBuffer = ownsBuffer; } @@ -77,7 +96,26 @@ public boolean ownsBuffer() { * @return A reference to the buffer. */ public Optional getBufferRef() { - return Optional.ofNullable(duplicate(buffer)); + if (buffer == null) { + return Optional.empty(); + } else { + return Optional.of(duplicate(buffer)); + } + } + + /** + * Returns a reference to the segment which backs this {@code OnnxTensor}. If the tensor is not + * backed by a segment (i.e., it is backed by a buffer or memory allocated by ORT) this method + * returns an empty {@link Optional}. + * + *

Changes to the segment elements will be reflected in the native {@code OrtValue}, this can + * be used to repeatedly update a single tensor for multiple different inferences without + * allocating new tensors, though the inputs must remain the same size and shape. + * + * @return A reference to the segment. + */ + public Optional getSegmentRef() { + return Optional.ofNullable(segment); } /** @@ -291,8 +329,10 @@ public synchronized void close() { * the OnnxTensor. * * @return A ByteBuffer copy of the OnnxTensor. + * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the + * native code encountered an error. */ - public ByteBuffer getByteBuffer() { + public ByteBuffer getByteBuffer() throws OrtException { checkClosed(); if (info.type != OnnxJavaType.STRING) { ByteBuffer buffer = getBuffer(); @@ -310,8 +350,10 @@ public ByteBuffer getByteBuffer() { * into a float (i.e. it's a float, fp16 or bf16), otherwise it returns null. * * @return A FloatBuffer copy of the OnnxTensor. + * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the + * native code encountered an error. */ - public FloatBuffer getFloatBuffer() { + public FloatBuffer getFloatBuffer() throws OrtException { checkClosed(); if (info.type == OnnxJavaType.FLOAT) { // if it's fp32 use the efficient copy. @@ -340,8 +382,10 @@ public FloatBuffer getFloatBuffer() { * double, otherwise it returns null. * * @return A DoubleBuffer copy of the OnnxTensor. + * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the + * native code encountered an error. */ - public DoubleBuffer getDoubleBuffer() { + public DoubleBuffer getDoubleBuffer() throws OrtException { checkClosed(); if (info.type == OnnxJavaType.DOUBLE) { DoubleBuffer buffer = getBuffer().asDoubleBuffer(); @@ -359,8 +403,10 @@ public DoubleBuffer getDoubleBuffer() { * uint16, fp16 or bf16, otherwise it returns null. * * @return A ShortBuffer copy of the OnnxTensor. + * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the + * native code encountered an error. */ - public ShortBuffer getShortBuffer() { + public ShortBuffer getShortBuffer() throws OrtException { checkClosed(); if ((info.type == OnnxJavaType.INT16) || (info.type == OnnxJavaType.FLOAT16) @@ -380,8 +426,10 @@ public ShortBuffer getShortBuffer() { * uint32, otherwise it returns null. * * @return An IntBuffer copy of the OnnxTensor. + * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the + * native code encountered an error. */ - public IntBuffer getIntBuffer() { + public IntBuffer getIntBuffer() throws OrtException { checkClosed(); if (info.type == OnnxJavaType.INT32) { IntBuffer buffer = getBuffer().asIntBuffer(); @@ -399,8 +447,10 @@ public IntBuffer getIntBuffer() { * uint64, otherwise it returns null. * * @return A LongBuffer copy of the OnnxTensor. + * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the + * native code encountered an error. */ - public LongBuffer getLongBuffer() { + public LongBuffer getLongBuffer() throws OrtException { checkClosed(); if (info.type == OnnxJavaType.INT64) { LongBuffer buffer = getBuffer().asLongBuffer(); @@ -419,9 +469,32 @@ public LongBuffer getLongBuffer() { * OnnxTensor#getBuffer(long,long)}. * * @return A ByteBuffer wrapping the data. + * @throws OrtException If the value could not be extracted as the Tensor is invalid, or if the + * native code encountered an error. */ - private ByteBuffer getBuffer() { - return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); + private ByteBuffer getBuffer() throws OrtException { + try { + return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); + } catch (IllegalArgumentException e) { + // thrown by the byte buffer constructor if the tensor is bigger than Integer.MAX_VALUE. + throw new OrtException( + "Cannot construct a java.nio.Buffer of this size. Message: " + e.getMessage()); + } + } + + /** + * Wraps the OrtTensor pointer in a MemorySegment. + * + *

MemorySegments are only supported on Java 22 or newer, if called in an earlier version of + * Java this method throws {@link UnsupportedOperationException}. + * + * @return A MemorySegment wrapping the data. + * @throws OrtException If the native code encountered an error. + */ + public Object getSegment() throws OrtException { + long[] info = getSegmentPointer(OnnxRuntime.ortApiHandle, nativeHandle); + MemorySegmentShim shim = new MemorySegmentShim(info[0], info[1]); + return shim.get(); } /** @@ -431,7 +504,16 @@ private ByteBuffer getBuffer() { * @param nativeHandle The OrtTensor pointer. * @return A ByteBuffer wrapping the data. */ - private native ByteBuffer getBuffer(long apiHandle, long nativeHandle); + private native ByteBuffer getBuffer(long apiHandle, long nativeHandle) throws OrtException; + + /** + * Gets the pointer and size in bytes for use in a MemorySegment. + * + * @param apiHandle The OrtApi pointer. + * @param nativeHandle The OrtTensor pointer. + * @return A two element array containing the address and size in bytes. + */ + private native long[] getSegmentPointer(long apiHandle, long nativeHandle) throws OrtException; private native float getFloat(long apiHandle, long nativeHandle, int onnxType) throws OrtException; @@ -702,6 +784,25 @@ static OnnxTensor createTensor( return createTensor(env, allocator, data, shape, OnnxJavaType.INT8); } + /** + * Create an OnnxTensor backed by a Java 22 native MemorySegment. + * + *

If called on Java 21 or older this method throws {@link UnsupportedOperationException}. + * + * @param env The current OrtEnvironment. + * @param data The tensor data in a {@code java.lang.foreign.MemorySegment}. + * @param shape The shape of tensor. + * @param type The type to use for the byte buffer elements. + * @return An OnnxTensor of the required shape. + * @throws IllegalArgumentException If the MemorySegment is not on the native heap. + * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. + */ + public static OnnxTensor createTensorFromMemorySegment( + OrtEnvironment env, Object data, long[] shape, OnnxJavaType type) throws OrtException { + MemorySegmentShim shim = new MemorySegmentShim(data); + return createTensor(env, env.defaultAllocator, shim, shape, type); + } + /** * Create an OnnxTensor backed by a direct ByteBuffer. The buffer should be in nativeOrder. * @@ -720,6 +821,34 @@ public static OnnxTensor createTensor( return createTensor(env, env.defaultAllocator, data, shape, type); } + /** + * Create an OnnxTensor backed by a MemorySegment. + * + *

If called in Java 21 or earlier it throws {@link UnsupportedOperationException}. + * + * @param env The current OrtEnvironment. + * @param allocator The allocator to use. + * @param data The tensor data. + * @param shape The shape of tensor. + * @param type The type to use for the byte buffer elements. + * @return An OnnxTensor of the required shape. + * @throws IllegalArgumentException If the MemorySegment is not on the native heap. + * @throws OrtException Thrown if there is an onnx error or if the data and shape don't match. + */ + static OnnxTensor createTensor( + OrtEnvironment env, + OrtAllocator allocator, + MemorySegmentShim data, + long[] shape, + OnnxJavaType type) + throws OrtException { + if (!allocator.isClosed()) { + return createTensor(type, allocator, data, shape); + } else { + throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); + } + } + /** * Create an OnnxTensor backed by a direct ByteBuffer. The buffer should be in nativeOrder. * @@ -926,6 +1055,38 @@ private static OnnxTensor createTensor( tuple.isCopy); } + /** + * Creates a tensor wrapped around a MemorySegment. + * + * @param type The buffer type. + * @param allocator The OrtAllocator. + * @param data The data. + * @param shape The tensor shape. + * @return An OnnxTensor instance. + * @throws IllegalArgumentException If the MemorySegment is not on the native heap. + * @throws OrtException If the create call failed. + */ + private static OnnxTensor createTensor( + OnnxJavaType type, OrtAllocator allocator, MemorySegmentShim data, long[] shape) + throws OrtException { + if (!data.isNative()) { + throw new IllegalArgumentException("MemorySegment must be native to create a tensor."); + } + TensorInfo info = TensorInfo.constructFromSegment(data, shape, type); + return new OnnxTensor( + createTensorFromSegment( + OnnxRuntime.ortApiHandle, + allocator.handle, + data.address(), + data.byteSize(), + shape, + info.onnxType.value), + allocator.handle, + info, + data, + false); + } + private static native long createTensorFromBuffer( long apiHandle, long allocatorHandle, @@ -936,6 +1097,15 @@ private static native long createTensorFromBuffer( int onnxType) throws OrtException; + private static native long createTensorFromSegment( + long apiHandle, + long allocatorHandle, + long dataPtr, + long bufferSize, + long[] shape, + int onnxType) + throws OrtException; + private static native long createString(long apiHandle, long allocatorHandle, String data) throws OrtException; diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index f3e9f21ef408d..da5382df8d05a 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -332,7 +332,7 @@ public long getNumElements() { */ public Object makeCarrier() throws OrtException { // Zero length tensors are allowed to be returned. - if (!validateShape() && numElements != 0) { + if ((!validateShape() && numElements != 0) || (numElements * type.size >= Integer.MAX_VALUE)) { throw new OrtException( "This tensor is not representable in Java, it's too big - shape = " + Arrays.toString(shape)); @@ -423,18 +423,44 @@ public static TensorInfo constructFromJavaArray(Object obj) throws OrtException */ public static TensorInfo constructFromBuffer(Buffer buffer, long[] shape, OnnxJavaType type) throws OrtException { + return constructFromMemory(buffer.remaining(), shape, type); + } + + /** + * Constructs a TensorInfo from the supplied MemorySegment. + * + * @param buffer The memory segment to inspect. + * @param shape The shape of the tensor. + * @param type The Java type. + * @return A TensorInfo for a tensor. + * @throws OrtException If the supplied buffer doesn't match the shape. + */ + static TensorInfo constructFromSegment(MemorySegmentShim buffer, long[] shape, OnnxJavaType type) + throws OrtException { + return constructFromMemory(buffer.byteSize(), shape, type); + } + + /** + * Constructs a TensorInfo from the supplied information. + * + * @param memoryLength The length of the memory used for this TensorInfo. + * @param shape The shape of the tensor. + * @param type The Java type. + * @return A TensorInfo for a tensor. + * @throws OrtException If the supplied memory information doesn't match the shape. + */ + private static TensorInfo constructFromMemory(long memoryLength, long[] shape, OnnxJavaType type) + throws OrtException { if ((type == OnnxJavaType.STRING) || (type == OnnxJavaType.UNKNOWN)) { throw new OrtException("Cannot create a tensor from a string or unknown buffer."); } long elementCount = OrtUtil.elementCount(shape); - long bufferRemaining = buffer.remaining(); - // Check if size matches - if (elementCount != bufferRemaining) { + if (elementCount != memoryLength) { // if not it could be a ByteBuffer passed in, so check how many bytes there are - long elemRemaining = bufferRemaining / type.size; + long elemRemaining = memoryLength / type.size; if (elementCount != elemRemaining) { throw new OrtException( "Shape " @@ -442,7 +468,7 @@ public static TensorInfo constructFromBuffer(Buffer buffer, long[] shape, OnnxJa + ", requires " + elementCount + " elements but the buffer has " - + bufferRemaining + + memoryLength + " elements."); } } diff --git a/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java new file mode 100644 index 0000000000000..5e801beabf262 --- /dev/null +++ b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.logging.Logger; + +/** + * Wrapper for java.lang.foreign.MemorySegment instances which uses reflection to access the methods + * so it can be compiled on Java 21 and earlier. Requires Java 22 or newer to use MemorySegments, + * when run on earlier versions all methods throw {@link UnsupportedOperationException}. + */ +final class MemorySegmentShim { + private static final Logger logger = Logger.getLogger(MemorySegmentShim.class.getName()); + + // Class is null if java.lang.foreign.MemorySegment is not available. + private static final Class memorySegmentClass; + + /* + * Method handles that bind to methods on java.lang.foreign.MemorySegment. + */ + private static final MethodHandle ofAddress; + private static final MethodHandle reinterpret; + private static final MethodHandle address; + private static final MethodHandle byteSize; + private static final MethodHandle isNative; + private static final MethodHandle set; // only used in tests + + static { + Class segmentClass = null; + MethodHandle tmpOfAddress = null; + MethodHandle tmpReinterpret = null; + MethodHandle tmpAddress = null; + MethodHandle tmpByteSize = null; + MethodHandle tmpIsNative = null; + MethodHandle tmpSet = null; + MethodHandles.Lookup lookup = MethodHandles.lookup(); + try { + segmentClass = Class.forName("java.lang.foreign.MemorySegment"); + Class valueLayoutClass = Class.forName("java.lang.foreign.ValueLayout$OfFloat"); + // Attempt to lookup the Java 22 memory segment methods. + tmpOfAddress = + lookup.findStatic( + segmentClass, "ofAddress", MethodType.methodType(segmentClass, long.class)); + tmpReinterpret = + lookup.findVirtual( + segmentClass, "reinterpret", MethodType.methodType(segmentClass, long.class)); + tmpAddress = lookup.findVirtual(segmentClass, "address", MethodType.methodType(long.class)); + tmpByteSize = lookup.findVirtual(segmentClass, "byteSize", MethodType.methodType(long.class)); + tmpIsNative = + lookup.findVirtual(segmentClass, "isNative", MethodType.methodType(boolean.class)); + tmpSet = + lookup.findVirtual( + segmentClass, + "set", + MethodType.methodType(valueLayoutClass, long.class, float.class)); + } catch (IllegalAccessException | NoSuchMethodException | ClassNotFoundException e) { + logger.fine("Running on Java 21 or earlier, MemorySegment not available"); + } + memorySegmentClass = segmentClass; + ofAddress = tmpOfAddress; + reinterpret = tmpReinterpret; + address = tmpAddress; + byteSize = tmpByteSize; + isNative = tmpIsNative; + set = tmpSet; + } + + // Only holds java.lang.foreign.MemorySegment instances + private final Object segment; + + /** + * Constructor which wraps a MemorySegment. + * + * @param segment The memory segment. + * @throws IllegalArgumentException If the supplied argument was not a + * java.lang.foreign.MemorySegment. + * @throws UnsupportedOperationException If java.lang.foreign.MemorySegment is not available in + * the running JDK. + */ + MemorySegmentShim(Object segment) { + if (memorySegmentClass != null) { + if (memorySegmentClass.isInstance(segment)) { + this.segment = segment; + } else { + throw new IllegalArgumentException( + "Segment argument was not a java.lang.foreign.MemorySegment, found " + + segment.getClass()); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + } + + /** + * Constructor which builds a MemorySegment using the supplied arguments. + * + * @param address The address of the memory. + * @param byteSize The size of the memory. + * @throws IllegalArgumentException If the supplied argument was not a valid memory region (i.e., + * positive address and non-negative size). + * @throws UnsupportedOperationException If java.lang.foreign.MemorySegment is not available in + * the running JDK. + */ + MemorySegmentShim(long address, long byteSize) { + if (memorySegmentClass != null) { + if (address > 0 && byteSize >= 0) { + try { + Object segment = ofAddress.invoke(address); + segment = reinterpret.invoke(segment, byteSize); + this.segment = segment; + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new IllegalArgumentException( + "Invalid segment, found a non-positive address or a negative size, address = " + + address + + ", byteSize = " + + byteSize); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + } + + /** + * Returns the MemorySegment instance. + * + * @return The MemorySegment. + */ + Object get() { + return segment; + } + + /** + * Returns the address of the MemorySegment. + * + * @return The address of the MemorySegment. + */ + long address() { + if (memorySegmentClass != null) { + try { + long ret = (long) address.invoke(segment); + return ret; + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + } + + /** + * Returns the size of the MemorySegment in bytes. + * + * @return The size of the MemorySegment in bytes. + */ + long byteSize() { + if (memorySegmentClass != null) { + try { + long ret = (long) byteSize.invoke(segment); + return ret; + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + } + + /** + * Returns true if this segment is backed by native memory, and false if it's backed by memory on + * the Java heap. + * + * @return True if the segment is native. + */ + boolean isNative() { + if (memorySegmentClass != null) { + try { + boolean ret = (boolean) isNative.invoke(segment); + return ret; + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + } +} diff --git a/java/src/main/native/ai_onnxruntime_OnnxTensor.c b/java/src/main/native/ai_onnxruntime_OnnxTensor.c index d757bd6281499..7688060919e74 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxTensor.c +++ b/java/src/main/native/ai_onnxruntime_OnnxTensor.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -47,6 +47,40 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromBuffer return (jlong) ortValue; } +/* + * Class: ai_onnxruntime_OnnxTensor + * Method: createTensorFromSegment + * Signature: (JJJJ[JI)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxTensor_createTensorFromSegment + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong allocatorHandle, jlong bufferPtr, jlong bufferSize, + jlongArray shape, jint onnxTypeJava) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtAllocator* allocator = (OrtAllocator*) allocatorHandle; + const OrtMemoryInfo* allocatorInfo; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->AllocatorGetInfo(allocator, &allocatorInfo)); + if (code != ORT_OK) { + return (jlong) NULL; + } + + // Convert type to ONNX C enum + ONNXTensorElementDataType onnxType = convertToONNXDataFormat(onnxTypeJava); + + // Extract the shape information + jlong* shapeArr = (*jniEnv)->GetLongArrayElements(jniEnv, shape, NULL); + jsize shapeLen = (*jniEnv)->GetArrayLength(jniEnv, shape); + + // Create the OrtValue + OrtValue* ortValue = NULL; + checkOrtStatus(jniEnv, api, api->CreateTensorWithDataAsOrtValue(allocatorInfo, (void*)bufferPtr, bufferSize, + (int64_t*)shapeArr, shapeLen, onnxType, &ortValue)); + (*jniEnv)->ReleaseLongArrayElements(jniEnv, shape, shapeArr, JNI_ABORT); + + // Return the pointer to the OrtValue + return (jlong) ortValue; +} + /* * Class: ai_onnxruntime_OnnxTensor * Method: createString @@ -170,6 +204,39 @@ JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getBuffer return NULL; } +/* + * Class: ai_onnxruntime_OnnxTensor + * Method: getSegmentPointer + * Signature: (JJ)[J; + */ +JNIEXPORT jobject JNICALL Java_ai_onnxruntime_OnnxTensor_getSegmentPointer + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtValue* ortValue = (OrtValue *) handle; + JavaTensorTypeShape typeShape; + OrtErrorCode code = getTensorTypeShape(jniEnv, &typeShape, api, ortValue); + + if (code == ORT_OK) { + size_t typeSize = onnxTypeSize(typeShape.onnxTypeEnum); + size_t sizeBytes = typeShape.elementCount * typeSize; + + uint8_t* arr = NULL; + code = checkOrtStatus(jniEnv, api, api->GetTensorMutableData(ortValue, (void**)&arr)); + + if (code == ORT_OK) { + jlongArray outputArray = (*jniEnv)->NewLongArray(jniEnv, 2); + jlong* cOutputArr = (*jniEnv)->GetLongArrayElements(jniEnv, outputArray, NULL); + cOutputArr[0] = (jlong) arr; + cOutputArr[1] = (jlong) sizeBytes; + // mode is 0 to copy back and release arr + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputArray, cOutputArr, 0); + return outputArray; + } + } + return NULL; +} + /* * Class: ai_onnxruntime_OnnxTensor * Method: getFloat diff --git a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java new file mode 100644 index 0000000000000..0eee30b9cb3d3 --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ + +package ai.onnxruntime; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.FloatBuffer; +import java.util.Map; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** Tests for interop with Java 22's MemorySegments. */ +public class MemorySegmentTest { + + @Test + public void testSegments() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + // Construct a big segment. + try (Arena arena = Arena.ofConfined()) { + // e.g., a 256k vocab with 4096 embedding dimensions + long[] shape = new long[] {256 * 1024, 4 * 1024}; + MemorySegment segment = arena.allocate(4L * OrtUtil.elementCount(shape)); + // Fill segment with appropriate values + for (int i = 0; i < 256 * 1024; i++) { + float floati = (float) i; + for (int j = 0; j < 4096; j++) { + segment.set(ValueLayout.JAVA_FLOAT, 4L * i * 4096L + 4L * j, floati); + } + } + OnnxTensor bigTensor = OnnxTensor.createTensor(env, segment, shape, OnnxJavaType.FLOAT); + + try { + FloatBuffer fb = bigTensor.getFloatBuffer(); + Assertions.fail("Should have thrown an exception"); + } catch (OrtException e) { + Assertions.assertTrue( + e.getMessage().contains("Cannot construct a java.nio.Buffer of this size.")); + } + + try { + float[][] arr = (float[][]) bigTensor.getValue(); + Assertions.fail("Should have thrown an exception"); + } catch (OrtException e) { + Assertions.assertTrue( + e.getMessage().contains("This tensor is not representable in Java, it's too big")); + } + + MemorySegment ref = (MemorySegment) bigTensor.getSegmentRef().get(); + Assertions.assertSame(segment, ref); + + MemorySegment other = (MemorySegment) bigTensor.getSegment(); + Assertions.assertEquals(segment, other); + } + } + + @Test + public void testSmallSegment() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + // Construct a small segment. + try (Arena arena = Arena.ofConfined()) { + long[] shape = new long[] {5, 4}; + MemorySegment segment = arena.allocate(4L * OrtUtil.elementCount(shape)); + // Fill segment with appropriate values + for (int i = 0; i < 5; i++) { + float floati = (float) i; + for (int j = 0; j < 4; j++) { + segment.set(ValueLayout.JAVA_FLOAT, 4L * i * 4L + 4L * j, floati); + } + } + OnnxTensor smallTensor = OnnxTensor.createTensor(env, segment, shape, OnnxJavaType.FLOAT); + + FloatBuffer fb = smallTensor.getFloatBuffer(); + + float[][] arr = (float[][]) smallTensor.getValue(); + + float[] fbArr = new float[fb.remaining()]; + fb.get(fbArr); + float[][] reshaped = (float[][]) OrtUtil.reshape(fbArr, shape); + Assertions.assertArrayEquals(arr, reshaped); + + MemorySegment ref = (MemorySegment) smallTensor.getSegmentRef().get(); + Assertions.assertSame(segment, ref); + + MemorySegment other = (MemorySegment) smallTensor.getSegment(); + Assertions.assertEquals(segment, other); + Assertions.assertNotSame(segment, other); + } + } + + @Test + public void testModel() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + String modelPath = TestHelpers.getResourcePath("/java-external-embedding.onnx").toString(); + + // Construct segment for use as embedding parameters. + try (Arena arena = Arena.ofConfined()) { + // i.e. a 256k vocab with 4096 embedding dimensions + long[] shape = new long[] {256 * 1024, 4 * 1024}; + MemorySegment segment = arena.allocate(4L * OrtUtil.elementCount(shape)); + // Fill segment with appropriate values + for (int i = 0; i < 256 * 1024; i++) { + float floati = (float) i; + for (int j = 0; j < 4096; j++) { + segment.set(ValueLayout.JAVA_FLOAT, 4L * i * 4096L + 4L * j, floati); + } + } + OnnxTensor embedding = OnnxTensor.createTensor(env, segment, shape, OnnxJavaType.FLOAT); + + // Construct input tensor + long[][] inputArr = + new long[][] {{64, 128, 256, 512, 0, 0, 0}, {1, 2, 3, 4, 5, 6, 128 * 1024}}; + OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputArr); + + // Construct model using external initializer. + try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions()) { + opts.addExternalInitializers(Map.of("embedding", embedding)); + + // Run model + try (OrtSession session = env.createSession(modelPath, opts); + OrtSession.Result output = session.run(Map.of("input", inputTensor))) { + // Validate output, which is [batch_size, seq_length, embedding_dimension] + // The embedding values should be filled with the index of that embedding + float[][][] outputArr = (float[][][]) output.get("output").get().getValue(); + Assertions.assertEquals(2, outputArr.length); + Assertions.assertEquals(7, outputArr[0].length); + Assertions.assertEquals(4096, outputArr[0][0].length); + for (int i = 0; i < inputArr.length; i++) { + for (int j = 0; j < inputArr[0].length; j++) { + float testVal = inputArr[i][j]; + for (int k = 0; k < 4096; k++) { + Assertions.assertEquals( + testVal, + outputArr[i][j][k], + "At position [" + i + "," + j + "," + k + "] values differ."); + } + } + } + } + } + } + } +} diff --git a/java/src/test/java/ai/onnxruntime/ModelGenerators.java b/java/src/test/java/ai/onnxruntime/ModelGenerators.java index 7bf7cef43208a..c1fb6f07c7ff9 100644 --- a/java/src/test/java/ai/onnxruntime/ModelGenerators.java +++ b/java/src/test/java/ai/onnxruntime/ModelGenerators.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2023, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -48,6 +48,75 @@ public static OnnxMl.TypeProto buildTensorTypeNode( return builder.build(); } + public void generateExternalEmbedding() throws IOException { + OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder(); + graph.setName("ort-test-embedding"); + + // Add placeholders + OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder(); + input.setName("input"); + OnnxMl.TypeProto inputType = + buildTensorTypeNode( + new long[] {-1, -1}, + new String[] {"batch_size", "sequence_length"}, + OnnxMl.TensorProto.DataType.INT64); + input.setType(inputType); + graph.addInput(input); + OnnxMl.ValueInfoProto.Builder output = OnnxMl.ValueInfoProto.newBuilder(); + output.setName("output"); + OnnxMl.TypeProto outputType = + buildTensorTypeNode( + new long[] {-1, -1, 4096}, + new String[] {"batch_size", "sequence_length", null}, + OnnxMl.TensorProto.DataType.FLOAT); + output.setType(outputType); + graph.addOutput(output); + + // Add initializer + OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder(); + tensor.addDims(256 * 1024); + tensor.addDims(4096); + tensor.setDataLocation(DataLocation.EXTERNAL); + tensor.addExternalData( + StringStringEntryProto.newBuilder() + .setKey("location") + .setValue("external-embedding.out") + .build()); + tensor.addExternalData( + StringStringEntryProto.newBuilder().setKey("offset").setValue("0").build()); + tensor.addExternalData( + StringStringEntryProto.newBuilder() + .setKey("length") + .setValue("" + (4L * 1024L * 1024L * 1024L)) + .build()); + tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + tensor.setName("embedding"); + graph.addInitializer(tensor); + + // Add operations + OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder(); + matmul.setName("gather-0"); + matmul.setOpType("Gather"); + matmul.addInput("embedding"); + matmul.addInput("input"); + matmul.addOutput("output"); + graph.addNode(matmul); + + // Build model + OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder(); + model.setGraph(graph); + model.setDocString("ORT embedding test"); + model.setModelVersion(0); + model.setIrVersion(8); + model.setDomain("ai.onnxruntime.test"); + model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build()); + try (OutputStream os = + Files.newOutputStream( + Paths.get("src", "test", "resources", "java-external-embedding.onnx"))) { + model.build().writeTo(os); + } + } + public void generateExternalMatMul() throws IOException { OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder(); graph.setName("ort-test-matmul"); diff --git a/java/src/test/resources/java-external-embedding.onnx b/java/src/test/resources/java-external-embedding.onnx new file mode 100644 index 0000000000000..3d29e1e722ef4 --- /dev/null +++ b/java/src/test/resources/java-external-embedding.onnx @@ -0,0 +1,20 @@ +"ai.onnxruntime.test2ORT embedding test:‹ +, + embedding +inputoutputgather-0"Gatherort-test-embedding*] +€€€ B embeddingj" +locationexternal-embedding.outj +offset0j +length +4294967296pZ0 +input' +%! +  +batch_size +sequence_lengthb6 +output, +*& +  +batch_size +sequence_length +€ B \ No newline at end of file From 1a96394c03fca8f15f0f71b8dff84e871d82e855 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 26 Dec 2025 22:27:39 -0500 Subject: [PATCH 2/8] Making the tests compile on Java 8, but run on Java 22 or newer. --- java/build.gradle | 5 +- .../jvm/ai/onnxruntime/MemorySegmentShim.java | 69 ++++++-- .../ai/onnxruntime/MemorySegmentTest.java | 147 ++++++++++++++---- 3 files changed, 182 insertions(+), 39 deletions(-) diff --git a/java/build.gradle b/java/build.gradle index 12e5f2ec70469..3dab290428783 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -161,8 +161,9 @@ if (cmakeBuildDir != null) { } dependencies { - testImplementation 'org.junit.jupiter:junit-jupiter-api:5.9.2' - testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.9.2' + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.14.1' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.14.1' + testRuntimeOnly 'org.junit.platform:junit-platform-launcher:1.12.2' testImplementation 'com.google.protobuf:protobuf-java:3.25.5' } diff --git a/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java index 5e801beabf262..eedfc43845ab4 100644 --- a/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java +++ b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java @@ -29,19 +29,22 @@ final class MemorySegmentShim { private static final MethodHandle byteSize; private static final MethodHandle isNative; private static final MethodHandle set; // only used in tests + private static final Object floatLayout; static { - Class segmentClass = null; - MethodHandle tmpOfAddress = null; - MethodHandle tmpReinterpret = null; - MethodHandle tmpAddress = null; - MethodHandle tmpByteSize = null; - MethodHandle tmpIsNative = null; - MethodHandle tmpSet = null; + Class segmentClass; + MethodHandle tmpOfAddress; + MethodHandle tmpReinterpret; + MethodHandle tmpAddress; + MethodHandle tmpByteSize; + MethodHandle tmpIsNative; + MethodHandle tmpSet; + Object tmpLayout; MethodHandles.Lookup lookup = MethodHandles.lookup(); try { segmentClass = Class.forName("java.lang.foreign.MemorySegment"); - Class valueLayoutClass = Class.forName("java.lang.foreign.ValueLayout$OfFloat"); + Class valueLayoutClass = Class.forName("java.lang.foreign.ValueLayout"); + Class floatValueLayoutClass = Class.forName("java.lang.foreign.ValueLayout$OfFloat"); // Attempt to lookup the Java 22 memory segment methods. tmpOfAddress = lookup.findStatic( @@ -57,9 +60,33 @@ final class MemorySegmentShim { lookup.findVirtual( segmentClass, "set", - MethodType.methodType(valueLayoutClass, long.class, float.class)); - } catch (IllegalAccessException | NoSuchMethodException | ClassNotFoundException e) { + MethodType.methodType(void.class, floatValueLayoutClass, long.class, float.class)); + tmpLayout = + lookup.findStaticGetter(valueLayoutClass, "JAVA_FLOAT", floatValueLayoutClass).invoke(); + } catch (IllegalAccessException + | NoSuchMethodException + | ClassNotFoundException + | NoSuchFieldException e) { logger.fine("Running on Java 21 or earlier, MemorySegment not available"); + segmentClass = null; + tmpOfAddress = null; + tmpReinterpret = null; + tmpAddress = null; + tmpByteSize = null; + tmpIsNative = null; + tmpSet = null; + tmpLayout = null; + } catch (Throwable e) { + logger.severe( + "Failed to load float value layout, while other Java 22 features were available."); + segmentClass = null; + tmpOfAddress = null; + tmpReinterpret = null; + tmpAddress = null; + tmpByteSize = null; + tmpIsNative = null; + tmpSet = null; + tmpLayout = null; } memorySegmentClass = segmentClass; ofAddress = tmpOfAddress; @@ -68,6 +95,7 @@ final class MemorySegmentShim { byteSize = tmpByteSize; isNative = tmpIsNative; set = tmpSet; + floatLayout = tmpLayout; } // Only holds java.lang.foreign.MemorySegment instances @@ -191,4 +219,25 @@ boolean isNative() { throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); } } + + /** + * Sets a float value on this memory segment at the specified index. + * + *

Only used in the tests, should not be used in user code as invoke is slower than + * invokeExact. + * + * @param idx The index to write to. + * @param value The value. + */ + void set(long idx, float value) { + if (memorySegmentClass != null) { + try { + set.invoke(segment, floatLayout, idx, value); + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.MemorySegment is not available."); + } + } } diff --git a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java index 0eee30b9cb3d3..e374b4a5a6c02 100644 --- a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java +++ b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java @@ -5,33 +5,121 @@ package ai.onnxruntime; -import java.lang.foreign.Arena; -import java.lang.foreign.MemorySegment; -import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; import java.nio.FloatBuffer; -import java.util.Map; +import java.util.Collections; +import java.util.logging.Logger; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; /** Tests for interop with Java 22's MemorySegments. */ +@EnabledForJreRange(min = JRE.JAVA_22) public class MemorySegmentTest { + /** Shim so the tests can create memory segments from an arena. */ + static final class ArenaShim implements AutoCloseable { + private static final Logger logger = Logger.getLogger(ArenaShim.class.getName()); + private static final Class arenaClass; + + /* + * Method handles that bind to methods on java.lang.foreign.MemorySegment. + */ + private static final MethodHandle ofConfined; + private static final MethodHandle allocate; + private static final MethodHandle close; + + private final Object arena; + + static { + Class tmpArenaClass; + MethodHandle tmpOfConfined; + MethodHandle tmpAllocate; + MethodHandle tmpClose; + MethodHandles.Lookup lookup = MethodHandles.lookup(); + try { + tmpArenaClass = Class.forName("java.lang.foreign.Arena"); + Class segmentClass = Class.forName("java.lang.foreign.MemorySegment"); + tmpOfConfined = + lookup.findStatic(tmpArenaClass, "ofConfined", MethodType.methodType(tmpArenaClass)); + tmpAllocate = + lookup.findVirtual( + tmpArenaClass, "allocate", MethodType.methodType(segmentClass, long.class)); + tmpClose = lookup.findVirtual(tmpArenaClass, "close", MethodType.methodType(void.class)); + } catch (IllegalAccessException | NoSuchMethodException | ClassNotFoundException e) { + logger.info("Running on Java 21 or earlier, Arena not available"); + tmpArenaClass = null; + tmpOfConfined = null; + tmpAllocate = null; + tmpClose = null; + } + arenaClass = tmpArenaClass; + ofConfined = tmpOfConfined; + allocate = tmpAllocate; + close = tmpClose; + } + + private ArenaShim(Object arena) { + this.arena = arena; + } + + static ArenaShim ofConfined() { + if (arenaClass != null) { + try { + return new ArenaShim(ofConfined.invoke()); + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.Arena is not available."); + } + } + + Object allocate(long size) { + if (arenaClass != null) { + try { + return allocate.invoke(arena, size); + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.Arena is not available."); + } + } + + public void close() { + if (arenaClass != null) { + try { + close.invoke(arena); + } catch (Throwable e) { + throw new AssertionError("Should not reach here", e); + } + } else { + throw new UnsupportedOperationException("java.lang.foreign.Arena is not available."); + } + } + } @Test public void testSegments() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); // Construct a big segment. - try (Arena arena = Arena.ofConfined()) { + try (ArenaShim arena = ArenaShim.ofConfined()) { // e.g., a 256k vocab with 4096 embedding dimensions long[] shape = new long[] {256 * 1024, 4 * 1024}; - MemorySegment segment = arena.allocate(4L * OrtUtil.elementCount(shape)); + MemorySegmentShim segment = + new MemorySegmentShim(arena.allocate(4L * OrtUtil.elementCount(shape))); // Fill segment with appropriate values for (int i = 0; i < 256 * 1024; i++) { float floati = (float) i; for (int j = 0; j < 4096; j++) { - segment.set(ValueLayout.JAVA_FLOAT, 4L * i * 4096L + 4L * j, floati); + segment.set(4L * i * 4096L + 4L * j, floati); } } - OnnxTensor bigTensor = OnnxTensor.createTensor(env, segment, shape, OnnxJavaType.FLOAT); + OnnxTensor bigTensor = + OnnxTensor.createTensorFromMemorySegment(env, segment.get(), shape, OnnxJavaType.FLOAT); try { FloatBuffer fb = bigTensor.getFloatBuffer(); @@ -49,11 +137,11 @@ public void testSegments() throws OrtException { e.getMessage().contains("This tensor is not representable in Java, it's too big")); } - MemorySegment ref = (MemorySegment) bigTensor.getSegmentRef().get(); - Assertions.assertSame(segment, ref); + Object refMemorySegment = bigTensor.getSegmentRef().get(); + Assertions.assertSame(segment.get(), refMemorySegment); - MemorySegment other = (MemorySegment) bigTensor.getSegment(); - Assertions.assertEquals(segment, other); + Object otherMemorySegment = bigTensor.getSegment(); + Assertions.assertEquals(segment.get(), otherMemorySegment); } } @@ -61,17 +149,19 @@ public void testSegments() throws OrtException { public void testSmallSegment() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); // Construct a small segment. - try (Arena arena = Arena.ofConfined()) { + try (ArenaShim arena = ArenaShim.ofConfined()) { long[] shape = new long[] {5, 4}; - MemorySegment segment = arena.allocate(4L * OrtUtil.elementCount(shape)); + MemorySegmentShim segment = + new MemorySegmentShim(arena.allocate(4L * OrtUtil.elementCount(shape))); // Fill segment with appropriate values for (int i = 0; i < 5; i++) { float floati = (float) i; for (int j = 0; j < 4; j++) { - segment.set(ValueLayout.JAVA_FLOAT, 4L * i * 4L + 4L * j, floati); + segment.set(4L * i * 4L + 4L * j, floati); } } - OnnxTensor smallTensor = OnnxTensor.createTensor(env, segment, shape, OnnxJavaType.FLOAT); + OnnxTensor smallTensor = + OnnxTensor.createTensorFromMemorySegment(env, segment.get(), shape, OnnxJavaType.FLOAT); FloatBuffer fb = smallTensor.getFloatBuffer(); @@ -82,12 +172,12 @@ public void testSmallSegment() throws OrtException { float[][] reshaped = (float[][]) OrtUtil.reshape(fbArr, shape); Assertions.assertArrayEquals(arr, reshaped); - MemorySegment ref = (MemorySegment) smallTensor.getSegmentRef().get(); - Assertions.assertSame(segment, ref); + Object refMemorySegment = smallTensor.getSegmentRef().get(); + Assertions.assertSame(segment.get(), refMemorySegment); - MemorySegment other = (MemorySegment) smallTensor.getSegment(); - Assertions.assertEquals(segment, other); - Assertions.assertNotSame(segment, other); + Object otherMemorySegment = smallTensor.getSegment(); + Assertions.assertEquals(segment.get(), otherMemorySegment); + Assertions.assertNotSame(segment.get(), otherMemorySegment); } } @@ -97,18 +187,20 @@ public void testModel() throws OrtException { String modelPath = TestHelpers.getResourcePath("/java-external-embedding.onnx").toString(); // Construct segment for use as embedding parameters. - try (Arena arena = Arena.ofConfined()) { + try (ArenaShim arena = ArenaShim.ofConfined()) { // i.e. a 256k vocab with 4096 embedding dimensions long[] shape = new long[] {256 * 1024, 4 * 1024}; - MemorySegment segment = arena.allocate(4L * OrtUtil.elementCount(shape)); + MemorySegmentShim segment = + new MemorySegmentShim(arena.allocate(4L * OrtUtil.elementCount(shape))); // Fill segment with appropriate values for (int i = 0; i < 256 * 1024; i++) { float floati = (float) i; for (int j = 0; j < 4096; j++) { - segment.set(ValueLayout.JAVA_FLOAT, 4L * i * 4096L + 4L * j, floati); + segment.set(4L * i * 4096L + 4L * j, floati); } } - OnnxTensor embedding = OnnxTensor.createTensor(env, segment, shape, OnnxJavaType.FLOAT); + OnnxTensor embedding = + OnnxTensor.createTensorFromMemorySegment(env, segment.get(), shape, OnnxJavaType.FLOAT); // Construct input tensor long[][] inputArr = @@ -117,11 +209,12 @@ public void testModel() throws OrtException { // Construct model using external initializer. try (OrtSession.SessionOptions opts = new OrtSession.SessionOptions()) { - opts.addExternalInitializers(Map.of("embedding", embedding)); + opts.addExternalInitializers(Collections.singletonMap("embedding", embedding)); // Run model try (OrtSession session = env.createSession(modelPath, opts); - OrtSession.Result output = session.run(Map.of("input", inputTensor))) { + OrtSession.Result output = + session.run(Collections.singletonMap("input", inputTensor))) { // Validate output, which is [batch_size, seq_length, embedding_dimension] // The embedding values should be filled with the index of that embedding float[][][] outputArr = (float[][][]) output.get("output").get().getValue(); From 873dbeae90dc1578d184763fc3da33099595aa32 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sat, 27 Dec 2025 20:51:01 -0500 Subject: [PATCH 3/8] Renaming the accessors and improving the docs. --- java/src/main/java/ai/onnxruntime/OnnxTensor.java | 14 ++++++++------ .../java/ai/onnxruntime/MemorySegmentTest.java | 8 ++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 46945bdd1f6d6..f75869ed3c494 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -104,17 +104,19 @@ public Optional getBufferRef() { } /** - * Returns a reference to the segment which backs this {@code OnnxTensor}. If the tensor is not - * backed by a segment (i.e., it is backed by a buffer or memory allocated by ORT) this method + * Returns a reference to the {@code MemorySegment} which backs this {@code OnnxTensor}. If the tensor is not + * backed by a {@code MemorySegment} (i.e., it is backed by a buffer or memory allocated by ORT) this method * returns an empty {@link Optional}. * *

Changes to the segment elements will be reflected in the native {@code OrtValue}, this can * be used to repeatedly update a single tensor for multiple different inferences without * allocating new tensors, though the inputs must remain the same size and shape. * + *

{@code java.lang.foreign.MemorySegment}s are only supported on Java 22 or newer. + * * @return A reference to the segment. */ - public Optional getSegmentRef() { + public Optional getMemorySegmentRef() { return Optional.ofNullable(segment); } @@ -491,7 +493,7 @@ private ByteBuffer getBuffer() throws OrtException { * @return A MemorySegment wrapping the data. * @throws OrtException If the native code encountered an error. */ - public Object getSegment() throws OrtException { + public Object getMemorySegment() throws OrtException { long[] info = getSegmentPointer(OnnxRuntime.ortApiHandle, nativeHandle); MemorySegmentShim shim = new MemorySegmentShim(info[0], info[1]); return shim.get(); @@ -785,12 +787,12 @@ static OnnxTensor createTensor( } /** - * Create an OnnxTensor backed by a Java 22 native MemorySegment. + * Create an OnnxTensor backed by a Java 22 native {@code MemorySegment}. * *

If called on Java 21 or older this method throws {@link UnsupportedOperationException}. * * @param env The current OrtEnvironment. - * @param data The tensor data in a {@code java.lang.foreign.MemorySegment}. + * @param data The tensor data in a native {@code java.lang.foreign.MemorySegment}. * @param shape The shape of tensor. * @param type The type to use for the byte buffer elements. * @return An OnnxTensor of the required shape. diff --git a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java index e374b4a5a6c02..d5a9f7ed1a599 100644 --- a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java +++ b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java @@ -137,10 +137,10 @@ public void testSegments() throws OrtException { e.getMessage().contains("This tensor is not representable in Java, it's too big")); } - Object refMemorySegment = bigTensor.getSegmentRef().get(); + Object refMemorySegment = bigTensor.getMemorySegmentRef().get(); Assertions.assertSame(segment.get(), refMemorySegment); - Object otherMemorySegment = bigTensor.getSegment(); + Object otherMemorySegment = bigTensor.getMemorySegment(); Assertions.assertEquals(segment.get(), otherMemorySegment); } } @@ -172,10 +172,10 @@ public void testSmallSegment() throws OrtException { float[][] reshaped = (float[][]) OrtUtil.reshape(fbArr, shape); Assertions.assertArrayEquals(arr, reshaped); - Object refMemorySegment = smallTensor.getSegmentRef().get(); + Object refMemorySegment = smallTensor.getMemorySegmentRef().get(); Assertions.assertSame(segment.get(), refMemorySegment); - Object otherMemorySegment = smallTensor.getSegment(); + Object otherMemorySegment = smallTensor.getMemorySegment(); Assertions.assertEquals(segment.get(), otherMemorySegment); Assertions.assertNotSame(segment.get(), otherMemorySegment); } From 1119dd5d3233e1582003e11f56e936e7ddb453aa Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 28 Dec 2025 22:14:41 -0500 Subject: [PATCH 4/8] MethodHandle combinators allow the use of invokeExact which is much faster. --- .../main/java/ai/onnxruntime/OnnxTensor.java | 6 +- .../jvm/ai/onnxruntime/MemorySegmentShim.java | 58 ++++++++++++------- .../ai/onnxruntime/MemorySegmentTest.java | 21 ++++--- 3 files changed, 55 insertions(+), 30 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index f75869ed3c494..377075a20161a 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -104,9 +104,9 @@ public Optional getBufferRef() { } /** - * Returns a reference to the {@code MemorySegment} which backs this {@code OnnxTensor}. If the tensor is not - * backed by a {@code MemorySegment} (i.e., it is backed by a buffer or memory allocated by ORT) this method - * returns an empty {@link Optional}. + * Returns a reference to the {@code MemorySegment} which backs this {@code OnnxTensor}. If the + * tensor is not backed by a {@code MemorySegment} (i.e., it is backed by a buffer or memory + * allocated by ORT) this method returns an empty {@link Optional}. * *

Changes to the segment elements will be reflected in the native {@code OrtValue}, this can * be used to repeatedly update a single tensor for multiple different inferences without diff --git a/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java index eedfc43845ab4..820e52d1ec34e 100644 --- a/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java +++ b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java @@ -46,21 +46,42 @@ final class MemorySegmentShim { Class valueLayoutClass = Class.forName("java.lang.foreign.ValueLayout"); Class floatValueLayoutClass = Class.forName("java.lang.foreign.ValueLayout$OfFloat"); // Attempt to lookup the Java 22 memory segment methods. + // The trailing .asType calls are to adapt the method handles to operate on values of type + // Object as the compiled code in this class does not have access to the concrete + // MemorySegment class, so the invokeExact calls would throw without the casts. + // The argument lists are (return type, argument types...) for static methods and + // (return type, receiver type, argument types...) for instance methods. tmpOfAddress = - lookup.findStatic( - segmentClass, "ofAddress", MethodType.methodType(segmentClass, long.class)); + lookup + .findStatic( + segmentClass, "ofAddress", MethodType.methodType(segmentClass, long.class)) + .asType(MethodType.methodType(Object.class, long.class)); tmpReinterpret = - lookup.findVirtual( - segmentClass, "reinterpret", MethodType.methodType(segmentClass, long.class)); - tmpAddress = lookup.findVirtual(segmentClass, "address", MethodType.methodType(long.class)); - tmpByteSize = lookup.findVirtual(segmentClass, "byteSize", MethodType.methodType(long.class)); + lookup + .findVirtual( + segmentClass, "reinterpret", MethodType.methodType(segmentClass, long.class)) + .asType(MethodType.methodType(Object.class, Object.class, long.class)); + tmpAddress = + lookup + .findVirtual(segmentClass, "address", MethodType.methodType(long.class)) + .asType(MethodType.methodType(long.class, Object.class)); + tmpByteSize = + lookup + .findVirtual(segmentClass, "byteSize", MethodType.methodType(long.class)) + .asType(MethodType.methodType(long.class, Object.class)); tmpIsNative = - lookup.findVirtual(segmentClass, "isNative", MethodType.methodType(boolean.class)); + lookup + .findVirtual(segmentClass, "isNative", MethodType.methodType(boolean.class)) + .asType(MethodType.methodType(boolean.class, Object.class)); tmpSet = - lookup.findVirtual( - segmentClass, - "set", - MethodType.methodType(void.class, floatValueLayoutClass, long.class, float.class)); + lookup + .findVirtual( + segmentClass, + "set", + MethodType.methodType(void.class, floatValueLayoutClass, long.class, float.class)) + .asType( + MethodType.methodType( + void.class, Object.class, Object.class, long.class, float.class)); tmpLayout = lookup.findStaticGetter(valueLayoutClass, "JAVA_FLOAT", floatValueLayoutClass).invoke(); } catch (IllegalAccessException @@ -138,8 +159,8 @@ final class MemorySegmentShim { if (memorySegmentClass != null) { if (address > 0 && byteSize >= 0) { try { - Object segment = ofAddress.invoke(address); - segment = reinterpret.invoke(segment, byteSize); + Object segment = ofAddress.invokeExact(address); + segment = reinterpret.invokeExact(segment, byteSize); this.segment = segment; } catch (Throwable e) { throw new AssertionError("Should not reach here", e); @@ -173,7 +194,7 @@ Object get() { long address() { if (memorySegmentClass != null) { try { - long ret = (long) address.invoke(segment); + long ret = (long) address.invokeExact(segment); return ret; } catch (Throwable e) { throw new AssertionError("Should not reach here", e); @@ -191,7 +212,7 @@ long address() { long byteSize() { if (memorySegmentClass != null) { try { - long ret = (long) byteSize.invoke(segment); + long ret = (long) byteSize.invokeExact(segment); return ret; } catch (Throwable e) { throw new AssertionError("Should not reach here", e); @@ -210,7 +231,7 @@ long byteSize() { boolean isNative() { if (memorySegmentClass != null) { try { - boolean ret = (boolean) isNative.invoke(segment); + boolean ret = (boolean) isNative.invokeExact(segment); return ret; } catch (Throwable e) { throw new AssertionError("Should not reach here", e); @@ -223,16 +244,13 @@ boolean isNative() { /** * Sets a float value on this memory segment at the specified index. * - *

Only used in the tests, should not be used in user code as invoke is slower than - * invokeExact. - * * @param idx The index to write to. * @param value The value. */ void set(long idx, float value) { if (memorySegmentClass != null) { try { - set.invoke(segment, floatLayout, idx, value); + set.invokeExact(segment, floatLayout, idx, value); } catch (Throwable e) { throw new AssertionError("Should not reach here", e); } diff --git a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java index d5a9f7ed1a599..e146bb24daaf3 100644 --- a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java +++ b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java @@ -43,11 +43,18 @@ static final class ArenaShim implements AutoCloseable { tmpArenaClass = Class.forName("java.lang.foreign.Arena"); Class segmentClass = Class.forName("java.lang.foreign.MemorySegment"); tmpOfConfined = - lookup.findStatic(tmpArenaClass, "ofConfined", MethodType.methodType(tmpArenaClass)); + lookup + .findStatic(tmpArenaClass, "ofConfined", MethodType.methodType(tmpArenaClass)) + .asType(MethodType.methodType(Object.class)); tmpAllocate = - lookup.findVirtual( - tmpArenaClass, "allocate", MethodType.methodType(segmentClass, long.class)); - tmpClose = lookup.findVirtual(tmpArenaClass, "close", MethodType.methodType(void.class)); + lookup + .findVirtual( + tmpArenaClass, "allocate", MethodType.methodType(segmentClass, long.class)) + .asType(MethodType.methodType(Object.class, Object.class, long.class)); + tmpClose = + lookup + .findVirtual(tmpArenaClass, "close", MethodType.methodType(void.class)) + .asType(MethodType.methodType(void.class, Object.class)); } catch (IllegalAccessException | NoSuchMethodException | ClassNotFoundException e) { logger.info("Running on Java 21 or earlier, Arena not available"); tmpArenaClass = null; @@ -68,7 +75,7 @@ private ArenaShim(Object arena) { static ArenaShim ofConfined() { if (arenaClass != null) { try { - return new ArenaShim(ofConfined.invoke()); + return new ArenaShim(ofConfined.invokeExact()); } catch (Throwable e) { throw new AssertionError("Should not reach here", e); } @@ -80,7 +87,7 @@ static ArenaShim ofConfined() { Object allocate(long size) { if (arenaClass != null) { try { - return allocate.invoke(arena, size); + return allocate.invokeExact(arena, size); } catch (Throwable e) { throw new AssertionError("Should not reach here", e); } @@ -92,7 +99,7 @@ Object allocate(long size) { public void close() { if (arenaClass != null) { try { - close.invoke(arena); + close.invokeExact(arena); } catch (Throwable e) { throw new AssertionError("Should not reach here", e); } From 5f5e552fdda591d639abd3549b6a6705e75a11f8 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 5 Jan 2026 12:22:42 -0500 Subject: [PATCH 5/8] Adding specific test to check it rejects non-MemorySegment objects. --- .../java/ai/onnxruntime/MemorySegmentTest.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java index e146bb24daaf3..42655df11282a 100644 --- a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java +++ b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java @@ -109,6 +109,20 @@ public void close() { } } + @Test + public void testNotASegment() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + long[] shape = new long[] {256 * 1024, 4 * 1024}; + Object notASegment = new Object(); + try { + OnnxTensor tensor = + OnnxTensor.createTensorFromMemorySegment(env, notASegment, shape, OnnxJavaType.FLOAT); + Assertions.fail("Should have thrown."); + } catch (IllegalArgumentException e) { + Assertions.assertTrue(e.getMessage().contains("Segment argument was not a java.lang.foreign.MemorySegment")); + } + } + @Test public void testSegments() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); From c1ef1de49f80425d832df4d020d707d12a4e0587 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 12 Jan 2026 21:19:21 -0500 Subject: [PATCH 6/8] Fixes after review. --- .../main/java/ai/onnxruntime/TensorInfo.java | 7 +- .../jvm/ai/onnxruntime/MemorySegmentShim.java | 133 ++++++++---------- .../ai/onnxruntime/MemorySegmentTest.java | 9 +- 3 files changed, 71 insertions(+), 78 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index da5382df8d05a..94b815c97c748 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -293,8 +293,6 @@ private boolean validateShape() { * @return The number of elements. */ private static long elementCount(long[] shape) { - // Java side tensors must be less than Integer.MAX_VALUE, - // tensors created in native code can be larger, but are not usable in Java. // Tensors should not be able to be created which will overflow a 64-bit long. long output = 1; for (int i = 0; i < shape.length; i++) { @@ -334,8 +332,9 @@ public Object makeCarrier() throws OrtException { // Zero length tensors are allowed to be returned. if ((!validateShape() && numElements != 0) || (numElements * type.size >= Integer.MAX_VALUE)) { throw new OrtException( - "This tensor is not representable in Java, it's too big - shape = " - + Arrays.toString(shape)); + "This tensor is not representable in Java as an array or a java.nio.Buffer, it's too big - shape = " + + Arrays.toString(shape) + + ", using tensors this large requires Java 22's java.lang.foreign.MemorySegment."); } switch (type) { case BFLOAT16: diff --git a/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java index 820e52d1ec34e..c110a2f3d7b1f 100644 --- a/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java +++ b/java/src/main/jvm/ai/onnxruntime/MemorySegmentShim.java @@ -7,6 +7,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -33,81 +34,71 @@ final class MemorySegmentShim { static { Class segmentClass; - MethodHandle tmpOfAddress; - MethodHandle tmpReinterpret; - MethodHandle tmpAddress; - MethodHandle tmpByteSize; - MethodHandle tmpIsNative; - MethodHandle tmpSet; - Object tmpLayout; - MethodHandles.Lookup lookup = MethodHandles.lookup(); try { segmentClass = Class.forName("java.lang.foreign.MemorySegment"); - Class valueLayoutClass = Class.forName("java.lang.foreign.ValueLayout"); - Class floatValueLayoutClass = Class.forName("java.lang.foreign.ValueLayout$OfFloat"); - // Attempt to lookup the Java 22 memory segment methods. - // The trailing .asType calls are to adapt the method handles to operate on values of type - // Object as the compiled code in this class does not have access to the concrete - // MemorySegment class, so the invokeExact calls would throw without the casts. - // The argument lists are (return type, argument types...) for static methods and - // (return type, receiver type, argument types...) for instance methods. - tmpOfAddress = - lookup - .findStatic( - segmentClass, "ofAddress", MethodType.methodType(segmentClass, long.class)) - .asType(MethodType.methodType(Object.class, long.class)); - tmpReinterpret = - lookup - .findVirtual( - segmentClass, "reinterpret", MethodType.methodType(segmentClass, long.class)) - .asType(MethodType.methodType(Object.class, Object.class, long.class)); - tmpAddress = - lookup - .findVirtual(segmentClass, "address", MethodType.methodType(long.class)) - .asType(MethodType.methodType(long.class, Object.class)); - tmpByteSize = - lookup - .findVirtual(segmentClass, "byteSize", MethodType.methodType(long.class)) - .asType(MethodType.methodType(long.class, Object.class)); - tmpIsNative = - lookup - .findVirtual(segmentClass, "isNative", MethodType.methodType(boolean.class)) - .asType(MethodType.methodType(boolean.class, Object.class)); - tmpSet = - lookup - .findVirtual( - segmentClass, - "set", - MethodType.methodType(void.class, floatValueLayoutClass, long.class, float.class)) - .asType( - MethodType.methodType( - void.class, Object.class, Object.class, long.class, float.class)); - tmpLayout = - lookup.findStaticGetter(valueLayoutClass, "JAVA_FLOAT", floatValueLayoutClass).invoke(); - } catch (IllegalAccessException - | NoSuchMethodException - | ClassNotFoundException - | NoSuchFieldException e) { + } catch (ClassNotFoundException e) { logger.fine("Running on Java 21 or earlier, MemorySegment not available"); segmentClass = null; - tmpOfAddress = null; - tmpReinterpret = null; - tmpAddress = null; - tmpByteSize = null; - tmpIsNative = null; - tmpSet = null; - tmpLayout = null; - } catch (Throwable e) { - logger.severe( - "Failed to load float value layout, while other Java 22 features were available."); - segmentClass = null; - tmpOfAddress = null; - tmpReinterpret = null; - tmpAddress = null; - tmpByteSize = null; - tmpIsNative = null; - tmpSet = null; - tmpLayout = null; + } + + MethodHandles.Lookup lookup = MethodHandles.lookup(); + MethodHandle tmpOfAddress = null; + MethodHandle tmpReinterpret = null; + MethodHandle tmpAddress = null; + MethodHandle tmpByteSize = null; + MethodHandle tmpIsNative = null; + MethodHandle tmpSet = null; + Object tmpLayout = null; + if (segmentClass != null) { + try { + Class valueLayoutClass = Class.forName("java.lang.foreign.ValueLayout"); + Class floatValueLayoutClass = Class.forName("java.lang.foreign.ValueLayout$OfFloat"); + // Attempt to lookup the Java 22 memory segment methods. + // The trailing .asType calls are to adapt the method handles to operate on values of type + // Object as the compiled code in this class does not have access to the concrete + // MemorySegment class, so the invokeExact calls would throw without the casts. + // The argument lists are (return type, argument types...) for static methods and + // (return type, receiver type, argument types...) for instance methods. + tmpOfAddress = + lookup + .findStatic( + segmentClass, "ofAddress", MethodType.methodType(segmentClass, long.class)) + .asType(MethodType.methodType(Object.class, long.class)); + tmpReinterpret = + lookup + .findVirtual( + segmentClass, "reinterpret", MethodType.methodType(segmentClass, long.class)) + .asType(MethodType.methodType(Object.class, Object.class, long.class)); + tmpAddress = + lookup + .findVirtual(segmentClass, "address", MethodType.methodType(long.class)) + .asType(MethodType.methodType(long.class, Object.class)); + tmpByteSize = + lookup + .findVirtual(segmentClass, "byteSize", MethodType.methodType(long.class)) + .asType(MethodType.methodType(long.class, Object.class)); + tmpIsNative = + lookup + .findVirtual(segmentClass, "isNative", MethodType.methodType(boolean.class)) + .asType(MethodType.methodType(boolean.class, Object.class)); + tmpSet = + lookup + .findVirtual( + segmentClass, + "set", + MethodType.methodType( + void.class, floatValueLayoutClass, long.class, float.class)) + .asType( + MethodType.methodType( + void.class, Object.class, Object.class, long.class, float.class)); + tmpLayout = + lookup.findStaticGetter(valueLayoutClass, "JAVA_FLOAT", floatValueLayoutClass).invoke(); + } catch (Throwable e) { + logger.log( + Level.SEVERE, + "Exception thrown while inspecting java.lang.foreign.MemorySegment methods", + e); + } } memorySegmentClass = segmentClass; ofAddress = tmpOfAddress; diff --git a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java index 42655df11282a..0293460d83e1a 100644 --- a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java +++ b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java @@ -116,10 +116,11 @@ public void testNotASegment() throws OrtException { Object notASegment = new Object(); try { OnnxTensor tensor = - OnnxTensor.createTensorFromMemorySegment(env, notASegment, shape, OnnxJavaType.FLOAT); + OnnxTensor.createTensorFromMemorySegment(env, notASegment, shape, OnnxJavaType.FLOAT); Assertions.fail("Should have thrown."); } catch (IllegalArgumentException e) { - Assertions.assertTrue(e.getMessage().contains("Segment argument was not a java.lang.foreign.MemorySegment")); + Assertions.assertTrue( + e.getMessage().contains("Segment argument was not a java.lang.foreign.MemorySegment")); } } @@ -155,7 +156,9 @@ public void testSegments() throws OrtException { Assertions.fail("Should have thrown an exception"); } catch (OrtException e) { Assertions.assertTrue( - e.getMessage().contains("This tensor is not representable in Java, it's too big")); + e.getMessage() + .contains( + "This tensor is not representable in Java as an array or a java.nio.Buffer, it's too big")); } Object refMemorySegment = bigTensor.getMemorySegmentRef().get(); From a6ebdf31a36ec90ab25debaf5e3c335f80193089 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 13 Jan 2026 15:12:53 -0500 Subject: [PATCH 7/8] Fixing missing override on test class. --- java/src/test/java/ai/onnxruntime/MemorySegmentTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java index 0293460d83e1a..0a03205366e9c 100644 --- a/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java +++ b/java/src/test/java/ai/onnxruntime/MemorySegmentTest.java @@ -96,6 +96,7 @@ Object allocate(long size) { } } + @Override public void close() { if (arenaClass != null) { try { From a163a68d4202b167582cbd5b6e83d218aabdef22 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 20 Jan 2026 21:47:03 -0500 Subject: [PATCH 8/8] Fixing the exception handling in getBuffer, and adding a null check to OnnxTensor.createFromMemorySegment. --- .../main/java/ai/onnxruntime/OnnxTensor.java | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index 377075a20161a..7cc066efd08f3 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -475,12 +475,18 @@ public LongBuffer getLongBuffer() throws OrtException { * native code encountered an error. */ private ByteBuffer getBuffer() throws OrtException { - try { + // Definitely can't allocate a byte buffer greater than Integer.MAX_VALUE, and + // it's typically recommended to make it a little smaller than that as the actual + // upper limit is somewhat JVM dependent. + int maxSize = (Integer.MAX_VALUE / info.type.size) - 4; + if (info.getNumElements() < maxSize) { return getBuffer(OnnxRuntime.ortApiHandle, nativeHandle).order(ByteOrder.nativeOrder()); - } catch (IllegalArgumentException e) { - // thrown by the byte buffer constructor if the tensor is bigger than Integer.MAX_VALUE. + } else { throw new OrtException( - "Cannot construct a java.nio.Buffer of this size. Message: " + e.getMessage()); + "Cannot construct a java.nio.Buffer of this size. This tensor has " + + info.getNumElements() + + ", and the maximum supported is " + + maxSize); } } @@ -844,8 +850,11 @@ static OnnxTensor createTensor( long[] shape, OnnxJavaType type) throws OrtException { - if (!allocator.isClosed()) { + if (!allocator.isClosed() && env != null) { return createTensor(type, allocator, data, shape); + } else if (env == null) { + throw new IllegalStateException( + "Trying to create an OnnxTensor with an invalid OrtEnvironment."); } else { throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator."); }