Skip to content

Commit

Permalink
[java] Allows the creation and extraction of zero length tensors (#15116
Browse files Browse the repository at this point in the history
)

### Description
Allows the creation of zero length tensors via the buffer path (the
array path with zero length arrays still throws as the validation logic
to check it's not ragged would require more intrusive revision), and
allows the `tensor.getValue()` method to return a Java multidimensional
array with a zero dimension. Also added a test for the creation and
extraction behaviour.

### Motivation and Context
The Python interface can return zero length tensors (e.g. if object
detection doesn't find any objects), and before this PR in Java calling
`tensor.getValue()` throws an exception with a confusing error message.
Fixes #7270 & #15107.
  • Loading branch information
Craigacp authored Apr 5, 2023
1 parent 9191e04 commit ef11032
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 12 deletions.
7 changes: 5 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -76,7 +76,10 @@ public Object getValue() throws OrtException {
}
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if (info.getNumElements() > 0) {
// If the tensor has values copy them out
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
}
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
Expand Down
15 changes: 9 additions & 6 deletions java/src/main/java/ai/onnxruntime/OrtUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -41,9 +41,9 @@ public static int[] transformShape(long[] shape) {
int[] newShape = new int[shape.length];
for (int i = 0; i < shape.length; i++) {
long curDim = shape[i];
if (curDim < 1 || curDim > Integer.MAX_VALUE) {
if (curDim < 0 || curDim > Integer.MAX_VALUE) {
throw new IllegalArgumentException(
"Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found "
"Invalid shape for a Java array, expected non-negative entries smaller than Integer.MAX_VALUE. Found "
+ Arrays.toString(shape));
} else {
newShape[i] = (int) curDim;
Expand Down Expand Up @@ -345,20 +345,23 @@ private static int reshape(Object input, Object output, int position) {
/**
* Counts the number of elements stored in a Tensor of this shape.
*
* <p>Multiplies all the elements together if they are positive, throws an {@link
* <p>Multiplies all the elements together if they are non-negative, throws an {@link
* IllegalArgumentException} otherwise.
*
* @param shape The shape to use.
* @return The number of elements.
*/
public static long elementCount(long[] shape) {
// Java side tensors must be less than Integer.MAX_VALUE,
// tensors created in native code can be larger, but are not usable in Java.
// Tensors should not be able to be created which will overflow a 64-bit long.
long count = 1;
for (int i = 0; i < shape.length; i++) {
if (shape[i] > 0) {
if (shape[i] >= 0) {
count *= shape[i];
} else {
throw new IllegalArgumentException(
"Received non-positive value in shape " + Arrays.toString(shape) + " .");
"Received negative value in shape " + Arrays.toString(shape) + " .");
}
}
return count;
Expand Down
45 changes: 42 additions & 3 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -107,6 +107,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
/** The native type of this tensor. */
public final OnnxTensorType onnxType;

/** The number of elements in this tensor. */
final long numElements;

/**
* Constructs a TensorInfo with the specified shape, Java type and native type.
*
Expand All @@ -118,6 +121,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.shape = shape;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
}

/**
Expand All @@ -132,6 +136,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.shape = shape;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
}

/**
Expand Down Expand Up @@ -173,6 +178,39 @@ private boolean validateShape() {
return OrtUtil.validateShape(shape);
}

/**
* Computes the number of elements in this tensor.
*
* <p>This replicates {@link OrtUtil#elementCount}, but does not throw on negative values which
* are used for symbolic dimensions in input and output info objects.
*
* @param shape The tensor shape.
* @return The number of elements.
*/
private static long elementCount(long[] shape) {
// Java side tensors must be less than Integer.MAX_VALUE,
// tensors created in native code can be larger, but are not usable in Java.
// Tensors should not be able to be created which will overflow a 64-bit long.
long output = 1;
for (int i = 0; i < shape.length; i++) {
output *= shape[i];
}
return output;
}

/**
* Returns the number of elements in this tensor.
*
* <p>If the returned value is negative, then this tensor info refers to an input or output
* placeholder which has symbolic dimensions, and the element count cannot be computed without
* specifying the symbolic dimensions.
*
* @return The number of elements.
*/
public long getNumElements() {
return numElements;
}

/**
* Constructs an array the right shape and type to hold this tensor.
*
Expand All @@ -181,11 +219,12 @@ private boolean validateShape() {
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
*
* @return A multidimensional array of the appropriate primitive type (or String).
* @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is
* @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is
* greater than an int).
*/
public Object makeCarrier() throws OrtException {
if (!validateShape()) {
// Zero length tensors are allowed to be returned.
if (!validateShape() && numElements != 0) {
throw new OrtException(
"This tensor is not representable in Java, it's too big - shape = "
+ Arrays.toString(shape));
Expand Down
24 changes: 23 additions & 1 deletion java/src/test/java/ai/onnxruntime/TensorCreationTest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -122,4 +123,25 @@ public void testUint8Creation() throws OrtException {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
}

@Test
public void testEmptyTensor() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();
FloatBuffer buf = FloatBuffer.allocate(0);
long[] shape = new long[] {4, 0};
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
float[][] output = (float[][]) t.getValue();
Assertions.assertEquals(4, output.length);
Assertions.assertEquals(0, output[0].length);
FloatBuffer fb = t.getFloatBuffer();
Assertions.assertEquals(0, fb.remaining());
}
shape = new long[] {0, 4};
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
float[][] output = (float[][]) t.getValue();
Assertions.assertEquals(0, output.length);
}
}
}

0 comments on commit ef11032

Please sign in to comment.