From ef11032c8956006e36c2ad26270ff28a41e9db89 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 5 Apr 2023 18:49:59 +0100 Subject: [PATCH] [java] Allows the creation and extraction of zero length tensors (#15116) ### Description Allows the creation of zero length tensors via the buffer path (the array path with zero length arrays still throws as the validation logic to check it's not ragged would require more intrusive revision), and allows the `tensor.getValue()` method to return a Java multidimensional array with a zero dimension. Also added a test for the creation and extraction behaviour. ### Motivation and Context The Python interface can return zero length tensors (e.g. if object detection doesn't find any objects), and before this PR in Java calling `tensor.getValue()` throws an exception with a confusing error message. Fixes #7270 & #15107. --- .../main/java/ai/onnxruntime/OnnxTensor.java | 7 ++- .../src/main/java/ai/onnxruntime/OrtUtil.java | 15 ++++--- .../main/java/ai/onnxruntime/TensorInfo.java | 45 +++++++++++++++++-- .../ai/onnxruntime/TensorCreationTest.java | 24 +++++++++- 4 files changed, 79 insertions(+), 12 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java index c5e60a3dbaf51..5703fb9c48495 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java +++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -76,7 +76,10 @@ public Object getValue() throws OrtException { } } else { Object carrier = info.makeCarrier(); - getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); + if (info.getNumElements() > 0) { + // If the tensor has values copy them out + getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier); + } if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) { // We read the strings out from native code in a flat array and then reshape // to the desired output shape. diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index ca340676e247d..eb27d1dafd5f2 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -41,9 +41,9 @@ public static int[] transformShape(long[] shape) { int[] newShape = new int[shape.length]; for (int i = 0; i < shape.length; i++) { long curDim = shape[i]; - if (curDim < 1 || curDim > Integer.MAX_VALUE) { + if (curDim < 0 || curDim > Integer.MAX_VALUE) { throw new IllegalArgumentException( - "Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found " + "Invalid shape for a Java array, expected non-negative entries smaller than Integer.MAX_VALUE. Found " + Arrays.toString(shape)); } else { newShape[i] = (int) curDim; @@ -345,20 +345,23 @@ private static int reshape(Object input, Object output, int position) { /** * Counts the number of elements stored in a Tensor of this shape. * - *

Multiplies all the elements together if they are positive, throws an {@link + *

Multiplies all the elements together if they are non-negative, throws an {@link * IllegalArgumentException} otherwise. * * @param shape The shape to use. * @return The number of elements. */ public static long elementCount(long[] shape) { + // Java side tensors must be less than Integer.MAX_VALUE, + // tensors created in native code can be larger, but are not usable in Java. + // Tensors should not be able to be created which will overflow a 64-bit long. long count = 1; for (int i = 0; i < shape.length; i++) { - if (shape[i] > 0) { + if (shape[i] >= 0) { count *= shape[i]; } else { throw new IllegalArgumentException( - "Received non-positive value in shape " + Arrays.toString(shape) + " ."); + "Received negative value in shape " + Arrays.toString(shape) + " ."); } } return count; diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index b9b7835da2ee5..613fcd61ea476 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -107,6 +107,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { /** The native type of this tensor. */ public final OnnxTensorType onnxType; + /** The number of elements in this tensor. */ + final long numElements; + /** * Constructs a TensorInfo with the specified shape, Java type and native type. * @@ -118,6 +121,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { this.shape = shape; this.type = type; this.onnxType = onnxType; + this.numElements = elementCount(shape); } /** @@ -132,6 +136,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) { this.shape = shape; this.onnxType = OnnxTensorType.mapFromInt(typeInt); this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType); + this.numElements = elementCount(shape); } /** @@ -173,6 +178,39 @@ private boolean validateShape() { return OrtUtil.validateShape(shape); } + /** + * Computes the number of elements in this tensor. + * + *

This replicates {@link OrtUtil#elementCount}, but does not throw on negative values which + * are used for symbolic dimensions in input and output info objects. + * + * @param shape The tensor shape. + * @return The number of elements. + */ + private static long elementCount(long[] shape) { + // Java side tensors must be less than Integer.MAX_VALUE, + // tensors created in native code can be larger, but are not usable in Java. + // Tensors should not be able to be created which will overflow a 64-bit long. + long output = 1; + for (int i = 0; i < shape.length; i++) { + output *= shape[i]; + } + return output; + } + + /** + * Returns the number of elements in this tensor. + * + *

If the returned value is negative, then this tensor info refers to an input or output + * placeholder which has symbolic dimensions, and the element count cannot be computed without + * specifying the symbolic dimensions. + * + * @return The number of elements. + */ + public long getNumElements() { + return numElements; + } + /** * Constructs an array the right shape and type to hold this tensor. * @@ -181,11 +219,12 @@ private boolean validateShape() { * correct shape using {@link OrtUtil#reshape(String[],long[])}. * * @return A multidimensional array of the appropriate primitive type (or String). - * @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is + * @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is * greater than an int). */ public Object makeCarrier() throws OrtException { - if (!validateShape()) { + // Zero length tensors are allowed to be returned. + if (!validateShape() && numElements != 0) { throw new OrtException( "This tensor is not representable in Java, it's too big - shape = " + Arrays.toString(shape)); diff --git a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java index 681179beff9c9..bd3209279f11a 100644 --- a/java/src/test/java/ai/onnxruntime/TensorCreationTest.java +++ b/java/src/test/java/ai/onnxruntime/TensorCreationTest.java @@ -1,10 +1,11 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; import java.nio.ByteBuffer; +import java.nio.FloatBuffer; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -122,4 +123,25 @@ public void testUint8Creation() throws OrtException { Assertions.assertArrayEquals(buf, (byte[]) t.getValue()); } } + + @Test + public void testEmptyTensor() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + FloatBuffer buf = FloatBuffer.allocate(0); + long[] shape = new long[] {4, 0}; + try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) { + Assertions.assertArrayEquals(shape, t.getInfo().getShape()); + float[][] output = (float[][]) t.getValue(); + Assertions.assertEquals(4, output.length); + Assertions.assertEquals(0, output[0].length); + FloatBuffer fb = t.getFloatBuffer(); + Assertions.assertEquals(0, fb.remaining()); + } + shape = new long[] {0, 4}; + try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) { + Assertions.assertArrayEquals(shape, t.getInfo().getShape()); + float[][] output = (float[][]) t.getValue(); + Assertions.assertEquals(0, output.length); + } + } }