diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index c5e60a3dbaf51..5703fb9c48495 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -76,7 +76,10 @@ public Object getValue() throws OrtException { } } else { Object carrier = info.makeCarrier(); - getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); + if (info.getNumElements() > 0) { + // If the tensor has values copy them out + getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); + } if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { // We read the strings out from native code in a flat array and then reshape // to the desired output shape. diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index ca340676e247d..eb27d1dafd5f2 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -41,9 +41,9 @@ public static int[] transformShape(long[] shape) { int[] newShape = new int[shape.length]; for (int i = 0; i < shape.length; i++) { long curDim = shape[i]; - if (curDim < 1 || curDim > Integer.MAX_VALUE) { + if (curDim < 0 || curDim > Integer.MAX_VALUE) { throw new IllegalArgumentException( - "Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found " + "Invalid shape for a Java array, expected non-negative entries smaller than Integer.MAX_VALUE. Found " + Arrays.toString(shape)); } else { newShape[i] = (int) curDim; @@ -345,20 +345,23 @@ private static int reshape(Object input, Object output, int position) { /** * Counts the number of elements stored in a Tensor of this shape. * - *

Multiplies all the elements together if they are positive, throws an {@link + *

Multiplies all the elements together if they are non-negative, throws an {@link * IllegalArgumentException} otherwise. * * @param shape The shape to use. * @return The number of elements. */ public static long elementCount(long[] shape) { + // Java side tensors must be less than Integer.MAX_VALUE, + // tensors created in native code can be larger, but are not usable in Java. + // Tensors should not be able to be created which will overflow a 64-bit long. long count = 1; for (int i = 0; i < shape.length; i++) { - if (shape[i] > 0) { + if (shape[i] >= 0) { count *= shape[i]; } else { throw new IllegalArgumentException( - "Received non-positive value in shape " + Arrays.toString(shape) + " ."); + "Received negative value in shape " + Arrays.toString(shape) + " ."); } } return count; diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index b9b7835da2ee5..613fcd61ea476 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -107,6 +107,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { /** The native type of this tensor. */ public final OnnxTensorType onnxType; + /** The number of elements in this tensor. */ + final long numElements; + /** * Constructs a TensorInfo with the specified shape, Java type and native type. * @@ -118,6 +121,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { this.shape = shape; this.type = type; this.onnxType = onnxType; + this.numElements = elementCount(shape); } /** @@ -132,6 +136,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { this.shape = shape; this.onnxType = OnnxTensorType.mapFromInt(typeInt); this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); + this.numElements = elementCount(shape); } /** @@ -173,6 +178,39 @@ private boolean validateShape() { return OrtUtil.validateShape(shape); } + /** + * Computes the number of elements in this tensor. + * + *

This replicates {@link OrtUtil#elementCount}, but does not throw on negative values which + * are used for symbolic dimensions in input and output info objects. + * + * @param shape The tensor shape. + * @return The number of elements. + */ + private static long elementCount(long[] shape) { + // Java side tensors must be less than Integer.MAX_VALUE, + // tensors created in native code can be larger, but are not usable in Java. + // Tensors should not be able to be created which will overflow a 64-bit long. + long output = 1; + for (int i = 0; i < shape.length; i++) { + output *= shape[i]; + } + return output; + } + + /** + * Returns the number of elements in this tensor. + * + *

If the returned value is negative, then this tensor info refers to an input or output + * placeholder which has symbolic dimensions, and the element count cannot be computed without + * specifying the symbolic dimensions. + * + * @return The number of elements. + */ + public long getNumElements() { + return numElements; + } + /** * Constructs an array the right shape and type to hold this tensor. * @@ -181,11 +219,12 @@ private boolean validateShape() { * correct shape using {@link OrtUtil#reshape(String[],long[])}. * * @return A multidimensional array of the appropriate primitive type (or String). - * @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is + * @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is * greater than an int). */ public Object makeCarrier() throws OrtException { - if (!validateShape()) { + // Zero length tensors are allowed to be returned. + if (!validateShape() && numElements != 0) { throw new OrtException( "This tensor is not representable in Java, it's too big - shape = " + Arrays.toString(shape)); diff --git a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java index 681179beff9c9..bd3209279f11a 100644 --- a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java +++ b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java @@ -1,10 +1,11 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; import java.nio.ByteBuffer; +import java.nio.FloatBuffer; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -122,4 +123,25 @@ public void testUint8Creation() throws OrtException { Assertions.assertArrayEquals(buf, (byte[]) t.getValue()); } } + + @Test + public void testEmptyTensor() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + FloatBuffer buf = FloatBuffer.allocate(0); + long[] shape = new long[] {4, 0}; + try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) { + Assertions.assertArrayEquals(shape, t.getInfo().getShape()); + float[][] output = (float[][]) t.getValue(); + Assertions.assertEquals(4, output.length); + Assertions.assertEquals(0, output[0].length); + FloatBuffer fb = t.getFloatBuffer(); + Assertions.assertEquals(0, fb.remaining()); + } + shape = new long[] {0, 4}; + try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) { + Assertions.assertArrayEquals(shape, t.getInfo().getShape()); + float[][] output = (float[][]) t.getValue(); + Assertions.assertEquals(0, output.length); + } + } }