Skip to content

Commit 8e33401

Browse files
authored
[Bugfix][Relay][Frontend][Keras] Add a assertion to reject a invalid value for attribute units in RNN layers (#15334)
1 parent ab75b58 commit 8e33401

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def _convert_dense(
256256
weightList = keras_layer.get_weights()
257257
weight = etab.new_const(weightList[0].transpose([1, 0]))
258258
params = {"weight": weight, "units": weightList[0].shape[1]}
259+
units = list(weightList[0].shape)[1]
260+
assert units > 0, "The value of units must be a positive integer"
259261
if input_shape is None:
260262
input_shape = keras_layer.input_shape
261263
input_dim = len(input_shape)
@@ -1010,6 +1012,7 @@ def _convert_lstm(
10101012
if keras_layer.go_backwards:
10111013
in_data = _op.reverse(in_data, axis=1)
10121014
units = list(weightList[0].shape)[1]
1015+
assert units > 0, "The value of units must be a positive integer"
10131016
time_steps = in_shape[1]
10141017
in_data = _op.squeeze(in_data, axis=[0])
10151018
in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0)
@@ -1053,6 +1056,7 @@ def _convert_simple_rnn(
10531056
if keras_layer.use_bias:
10541057
in_bias = etab.new_const(weightList[2])
10551058
units = list(weightList[0].shape)[1]
1059+
assert units > 0, "The value of units must be a positive integer"
10561060
in_data = _op.nn.batch_flatten(in_data)
10571061
ixh = _op.nn.dense(in_data, kernel_weight, units=units)
10581062
if keras_layer.use_bias:
@@ -1082,6 +1086,7 @@ def _convert_gru(
10821086
if keras_layer.use_bias:
10831087
in_bias = etab.new_const(weightList[2])
10841088
units = list(weightList[0].shape)[1]
1089+
assert units > 0, "The value of units must be a positive integer"
10851090
in_data = _op.nn.batch_flatten(in_data)
10861091
matrix_x = _op.nn.dense(in_data, kernel_weight, units=units)
10871092
if keras_layer.use_bias:

tests/python/frontend/keras/test_forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def test_forward_activations_except(self, keras_mod):
251251
):
252252
act_funcs = [
253253
keras_mod.layers.LeakyReLU(alpha=None),
254-
keras_mod.layers.LEU(2, 3, 4),
254+
keras_mod.layers.ELU(2, 3, 4),
255255
keras_mod.layers.ReLU(threshold=None),
256256
]
257257
data = keras_mod.layers.Input(shape=(2, 3, 4))

0 commit comments

Comments
 (0)