@@ -159,6 +159,24 @@ def test_forward_merge(self, keras_mod):
159159 keras_model = keras_mod .models .Model (data , out )
160160 verify_keras_frontend (keras_model )
161161
162+ def test_forward_concatenate (self , keras_mod ):
163+ """test_forward_concatenate"""
164+ data1 = keras_mod .layers .Input (shape = (1 , 2 , 2 ))
165+ data2 = keras_mod .layers .Input (shape = (1 , 1 , 2 ))
166+ merge_func = keras_mod .layers .Concatenate (axis = 2 )
167+ out = merge_func ([data1 , data2 ])
168+ keras_model = keras_mod .models .Model ([data1 , data2 ], out )
169+ verify_keras_frontend (keras_model , layout = "NHWC" )
170+ verify_keras_frontend (keras_model , layout = "NCHW" )
171+ # test default axis (e.g., -1)
172+ data1 = keras_mod .layers .Input (shape = (1 , 2 , 2 ))
173+ data2 = keras_mod .layers .Input (shape = (1 , 2 , 3 ))
174+ merge_func = keras_mod .layers .Concatenate ()
175+ out = merge_func ([data1 , data2 ])
176+ keras_model = keras_mod .models .Model ([data1 , data2 ], out )
177+ verify_keras_frontend (keras_model , layout = "NHWC" )
178+ verify_keras_frontend (keras_model , layout = "NCHW" )
179+
162180 def test_forward_merge_dot (self , keras_mod ):
163181 """test_forward_merge_dot"""
164182 data1 = keras_mod .layers .Input (shape = (2 , 2 ))
@@ -793,6 +811,7 @@ def test_forward_time_distributed(self, keras_mod):
793811if __name__ == "__main__" :
794812 for k in [keras , tf_keras ]:
795813 sut = TestKeras ()
814+ sut .test_forward_concatenate (keras_mod = k )
796815 sut .test_forward_merge_dot (keras_mod = k )
797816 sut .test_forward_merge (keras_mod = k )
798817 sut .test_forward_activations (keras_mod = k )
0 commit comments