Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[java] Allows the creation and extraction of zero length tensors #15116

Merged
merged 3 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions java/src/main/java/ai/onnxruntime/OnnxTensor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -76,7 +76,10 @@ public Object getValue() throws OrtException {
}
} else {
Object carrier = info.makeCarrier();
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
if (info.getNumElements() > 0) {
// If the tensor has values copy them out
getArray(OnnxRuntime.ortApiHandle, nativeHandle, carrier);
}
if ((info.type == OnnxJavaType.STRING) && (info.shape.length != 1)) {
// We read the strings out from native code in a flat array and then reshape
// to the desired output shape.
Expand Down
15 changes: 9 additions & 6 deletions java/src/main/java/ai/onnxruntime/OrtUtil.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -41,9 +41,9 @@ public static int[] transformShape(long[] shape) {
int[] newShape = new int[shape.length];
for (int i = 0; i < shape.length; i++) {
long curDim = shape[i];
if (curDim < 1 || curDim > Integer.MAX_VALUE) {
if (curDim < 0 || curDim > Integer.MAX_VALUE) {
yuslepukhin marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException(
"Invalid shape for a Java array, expected positive entries smaller than Integer.MAX_VALUE. Found "
"Invalid shape for a Java array, expected non-negative entries smaller than Integer.MAX_VALUE. Found "
+ Arrays.toString(shape));
} else {
newShape[i] = (int) curDim;
Expand Down Expand Up @@ -345,20 +345,23 @@ private static int reshape(Object input, Object output, int position) {
/**
* Counts the number of elements stored in a Tensor of this shape.
*
* <p>Multiplies all the elements together if they are positive, throws an {@link
* <p>Multiplies all the elements together if they are non-negative, throws an {@link
* IllegalArgumentException} otherwise.
*
* @param shape The shape to use.
* @return The number of elements.
*/
public static long elementCount(long[] shape) {
// Java side tensors must be less than Integer.MAX_VALUE,
// tensors created in native code can be larger, but are not usable in Java.
// Tensors should not be able to be created which will overflow a 64-bit long.
long count = 1;
for (int i = 0; i < shape.length; i++) {
if (shape[i] > 0) {
if (shape[i] >= 0) {
count *= shape[i];
} else {
throw new IllegalArgumentException(
"Received non-positive value in shape " + Arrays.toString(shape) + " .");
"Received negative value in shape " + Arrays.toString(shape) + " .");
}
}
return count;
Expand Down
45 changes: 42 additions & 3 deletions java/src/main/java/ai/onnxruntime/TensorInfo.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -107,6 +107,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
/** The native type of this tensor. */
public final OnnxTensorType onnxType;

/** The number of elements in this tensor. */
final long numElements;

/**
* Constructs a TensorInfo with the specified shape, Java type and native type.
*
Expand All @@ -118,6 +121,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.shape = shape;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
}

/**
Expand All @@ -132,6 +136,7 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
this.shape = shape;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
}

/**
Expand Down Expand Up @@ -173,6 +178,39 @@ private boolean validateShape() {
return OrtUtil.validateShape(shape);
}

/**
* Computes the number of elements in this tensor.
*
* <p>This replicates {@link OrtUtil#elementCount}, but does not throw on negative values which
* are used for symbolic dimensions in input and output info objects.
*
* @param shape The tensor shape.
* @return The number of elements.
*/
private static long elementCount(long[] shape) {
// Java side tensors must be less than Integer.MAX_VALUE,
// tensors created in native code can be larger, but are not usable in Java.
// Tensors should not be able to be created which will overflow a 64-bit long.
long output = 1;
for (int i = 0; i < shape.length; i++) {
output *= shape[i];
}
return output;
}

/**
* Returns the number of elements in this tensor.
*
* <p>If the returned value is negative, then this tensor info refers to an input or output
* placeholder which has symbolic dimensions, and the element count cannot be computed without
* specifying the symbolic dimensions.
*
* @return The number of elements.
*/
public long getNumElements() {
return numElements;
}

/**
* Constructs an array the right shape and type to hold this tensor.
*
Expand All @@ -181,11 +219,12 @@ private boolean validateShape() {
* correct shape using {@link OrtUtil#reshape(String[],long[])}.
*
* @return A multidimensional array of the appropriate primitive type (or String).
* @throws OrtException If the shape isn't representable in Java (i.e. if one of it's indices is
* @throws OrtException If the shape isn't representable in Java (i.e. if one of its indices is
* greater than an int).
*/
public Object makeCarrier() throws OrtException {
if (!validateShape()) {
// Zero length tensors are allowed to be returned.
if (!validateShape() && numElements != 0) {
throw new OrtException(
"This tensor is not representable in Java, it's too big - shape = "
+ Arrays.toString(shape));
Expand Down
24 changes: 23 additions & 1 deletion java/src/test/java/ai/onnxruntime/TensorCreationTest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
/*
* Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import java.nio.ByteBuffer;
import java.nio.FloatBuffer;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -122,4 +123,25 @@ public void testUint8Creation() throws OrtException {
Assertions.assertArrayEquals(buf, (byte[]) t.getValue());
}
}

@Test
public void testEmptyTensor() throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();
FloatBuffer buf = FloatBuffer.allocate(0);
long[] shape = new long[] {4, 0};
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
float[][] output = (float[][]) t.getValue();
Assertions.assertEquals(4, output.length);
Assertions.assertEquals(0, output[0].length);
FloatBuffer fb = t.getFloatBuffer();
Assertions.assertEquals(0, fb.remaining());
}
shape = new long[] {0, 4};
try (OnnxTensor t = OnnxTensor.createTensor(env, buf, shape)) {
Assertions.assertArrayEquals(shape, t.getInfo().getShape());
float[][] output = (float[][]) t.getValue();
Assertions.assertEquals(0, output.length);
}
}
}