Skip to content

Commit da66bc6

Browse files
jsmeredithtensorflower-gardener
authored andcommitted
Check for unexpected scalars in the shape argument to ParallelConcat.
PiperOrigin-RevId: 504901518
1 parent 789ed75 commit da66bc6

File tree

4 files changed

+26
-3
lines changed

4 files changed

+26
-3
lines changed

tensorflow/core/kernels/inplace_ops.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class ParallelConcatUpdate : public OpKernel {
7878
OP_REQUIRES(
7979
ctx, value.dim_size(0) > loc_,
8080
errors::InvalidArgument("0th dimension of value = ", value.dim_size(0),
81-
" is less than loc_=", loc_));
81+
" must be greater than loc_ = ", loc_));
8282

8383
auto update = ctx->input(1);
8484

tensorflow/core/ops/array_ops.cc

+7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#include <algorithm>
1717
#include <ostream>
18+
#include <vector>
1819

1920
#include "tensorflow/core/framework/common_shape_fns.h"
2021
#include "tensorflow/core/framework/full_type.pb.h"
@@ -309,6 +310,12 @@ REGISTER_OP("ParallelConcat")
309310
return errors::InvalidArgument(
310311
"All input shapes must be fully defined.");
311312
}
313+
if (c->Rank(c->input(i)) < 1) {
314+
return errors::InvalidArgument(
315+
"The rank of all input shapes must be greater than 0, "
316+
"but input ",
317+
i, " had rank ", c->Rank(c->input(i)), ".");
318+
}
312319
DimensionHandle unused;
313320
if (!c->WithValue(c->Dim(c->input(i), 0), 1, &unused).ok()) {
314321
return errors::InvalidArgument("Size of first dimension must be 1.");

tensorflow/python/kernel_tests/array_ops/stack_op_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,9 @@ def f():
8383
y = gen_array_ops.parallel_concat(values=[["tf"]], shape=0)
8484
return y
8585

86-
with self.assertRaisesRegex(errors.InvalidArgumentError,
87-
r"0th dimension of value .* is less than"):
86+
with self.assertRaisesRegex(
87+
errors.InvalidArgumentError, r"0th dimension .* must be greater than"
88+
):
8889
f()
8990

9091
def testSimpleParallelGPU(self):

tensorflow/python/ops/array_ops_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tensorflow.python.eager import def_function
1919
from tensorflow.python.framework import dtypes
2020
from tensorflow.python.framework import tensor_spec
21+
from tensorflow.python.framework import test_util
2122
from tensorflow.python.ops import array_ops
2223
from tensorflow.python.ops import math_ops
2324
from tensorflow.python.ops import random_ops
@@ -91,6 +92,20 @@ def g(x):
9192
conc = g.get_concrete_function(tensor_spec.TensorSpec([10, None]))
9293
self.assertAllEqual(conc.output_shapes.as_list(), [10])
9394

95+
@test_util.run_in_graph_and_eager_modes
96+
def testParallelConcatFailsWithRankZeroShape(self):
97+
op = array_ops.ParallelConcat
98+
para = {"shape": 0, "values": [1]}
99+
100+
def func():
101+
y = op(**para)
102+
return y
103+
104+
with self.assertRaisesRegex(
105+
Exception, "(rank|dimension) of .* must be greater than .* 0"
106+
):
107+
func()
108+
94109

95110
if __name__ == "__main__":
96111
test.main()

0 commit comments

Comments
 (0)