Skip to content

Commit 03fecba

Browse files
authored
[Bugfix][Frontend][Keras] Add a check to reject the invalid input shape (#15335)
* reject invalid input_shape * Update test_forward.py * Update keras.py * Update keras.py * Update test_forward.py * Update test_forward.py * Update test_forward.py * Update test_forward.py * Update keras.py * Update test_forward.py * Update test_forward.py * Update keras.py * Update test_forward.py
1 parent 8e33401 commit 03fecba

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,12 @@ def _check_model_is_tf_keras():
14361436
def _convert_input_layer(keras_layer):
14371437
input_name = keras_layer.name
14381438
input_shape = shape[input_name] if shape is not None and input_name in shape else None
1439+
if input_shape and len(input_shape) > 1 and any(dim <= 0 for dim in input_shape[1:]):
1440+
msg = (
1441+
"Expected input's non-batch dimensions to have positive length, "
1442+
f"but the input has a shape of {input_shape}"
1443+
)
1444+
raise ValueError(msg)
14391445
etab.set_expr(input_name, new_var(input_name, shape=input_shape))
14401446

14411447
def _convert_layer(keras_layer, etab, scope=""):

tests/python/frontend/keras/test_forward.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tvm import relay
3333
from tvm.contrib import graph_executor
3434
import tvm.testing
35+
import pytest
3536

3637
if tf.executing_eagerly():
3738
GPUS = tf.config.experimental.list_physical_devices("GPU")
@@ -295,6 +296,7 @@ def test_forward_sequential(self, keras_mod):
295296
verify_keras_frontend(keras_model)
296297

297298
def test_forward_pool(self, keras_mod):
299+
"""test_forward_pool"""
298300
data = keras_mod.layers.Input(shape=(32, 32, 1))
299301
# maxpool
300302
x = keras_mod.layers.MaxPooling2D((3, 3), strides=(1, 1), padding="same")(data)
@@ -304,6 +306,12 @@ def test_forward_pool(self, keras_mod):
304306
y = keras_mod.layers.AveragePooling2D((3, 3), strides=(1, 1), padding="same")(data)
305307
keras_model = keras_mod.models.Model(data, y)
306308
verify_keras_frontend(keras_model)
309+
# reject the invalid input shape
310+
data = keras_mod.layers.Input(shape=(0, 3, 6, 4))
311+
x = keras_mod.layers.GlobalAveragePooling3D()(data)
312+
keras_model = keras_mod.models.Model(data, x)
313+
with pytest.raises(ValueError):
314+
verify_keras_frontend(keras_model)
307315

308316
def test_forward_conv1d(self, keras_mod):
309317
"""test_forward_conv1d"""

0 commit comments

Comments
 (0)