Skip to content

Commit e178375

Browse files
authored
[Frontend][Relay][Keras] Fix concatenate convert function in axis parsing (#15175)
1 parent 34637d7 commit e178375

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

python/tvm/relay/frontend/keras.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -955,10 +955,13 @@ def _convert_concat(
955955
if input_shape is None:
956956
input_shape = keras_layer.input_shape
957957

958-
if data_layout == "NHWC" or len(input_shape[0]) < 4:
959-
axis = -1
960-
else:
961-
axis = 1
958+
axis = keras_layer.axis
959+
dims = len(input_shape[0])
960+
if data_layout == "NCHW": # need_transpose
961+
if axis == -1:
962+
axis = 1
963+
else:
964+
axis = axis + 1 if axis < dims else 1
962965
return _op.concatenate(_as_list(inexpr), axis=axis)
963966

964967

tests/python/frontend/keras/test_forward.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,24 @@ def test_forward_merge(self, keras_mod):
159159
keras_model = keras_mod.models.Model(data, out)
160160
verify_keras_frontend(keras_model)
161161

162+
def test_forward_concatenate(self, keras_mod):
163+
"""test_forward_concatenate"""
164+
data1 = keras_mod.layers.Input(shape=(1, 2, 2))
165+
data2 = keras_mod.layers.Input(shape=(1, 1, 2))
166+
merge_func = keras_mod.layers.Concatenate(axis=2)
167+
out = merge_func([data1, data2])
168+
keras_model = keras_mod.models.Model([data1, data2], out)
169+
verify_keras_frontend(keras_model, layout="NHWC")
170+
verify_keras_frontend(keras_model, layout="NCHW")
171+
# test default axis (e.g., -1)
172+
data1 = keras_mod.layers.Input(shape=(1, 2, 2))
173+
data2 = keras_mod.layers.Input(shape=(1, 2, 3))
174+
merge_func = keras_mod.layers.Concatenate()
175+
out = merge_func([data1, data2])
176+
keras_model = keras_mod.models.Model([data1, data2], out)
177+
verify_keras_frontend(keras_model, layout="NHWC")
178+
verify_keras_frontend(keras_model, layout="NCHW")
179+
162180
def test_forward_merge_dot(self, keras_mod):
163181
"""test_forward_merge_dot"""
164182
data1 = keras_mod.layers.Input(shape=(2, 2))
@@ -793,6 +811,7 @@ def test_forward_time_distributed(self, keras_mod):
793811
if __name__ == "__main__":
794812
for k in [keras, tf_keras]:
795813
sut = TestKeras()
814+
sut.test_forward_concatenate(keras_mod=k)
796815
sut.test_forward_merge_dot(keras_mod=k)
797816
sut.test_forward_merge(keras_mod=k)
798817
sut.test_forward_activations(keras_mod=k)

0 commit comments

Comments
 (0)