Skip to content

Commit

Permalink
Make keras reshape less restrictive
Browse files Browse the repository at this point in the history
  • Loading branch information
Trevor Morris committed Feb 11, 2021
1 parent 12c6b70 commit d365069
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 23 deletions.
31 changes: 8 additions & 23 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,29 +864,14 @@ def _convert_reshape(inexpr, keras_layer, etab):
_check_data_format(keras_layer)
inshape = keras_layer.input_shape # includes batch
tshape = keras_layer.target_shape # no batch
if len(inshape) == 3 and len(tshape) == 1:
# (?, a, b) -> (-1, ab)
shape = (-1, tshape[0])
elif len(inshape) in [2, 3] and len(tshape) == 2:
# (?, cc) -> (-1, c, c)
# (?, a, b) -> (-1, c, c)
assert tshape[0] == tshape[1], "Only supports square target shapes, but got {}".format(
tshape
)
shape = (-1,) + tshape
else:
# (?, h, w, c) -> (-1, c, H, W)
# (?, h, w, c) -> (-1, c, hw)
# (?, hw, c) -> (-1, c, h, w)
ch = inshape[-1]
assert ch == tshape[-1], (
"Only supports last dimension in target shape being equal to "
"the channel number of input tensor."
)
if etab.data_layout == "NCHW":
shape = (-1, ch) + tshape[:-1]
else:
shape = (-1,) + tshape[:-1] + (ch,)
shape = (-1,) + tshape

if etab.data_layout == "NCHW" and (len(inshape) > 3 or len(tshape) > 2):
# Perform reshape in original NHWC format.
inexpr = _op.transpose(inexpr, [0] + list(range(2, len(inshape))) + [1])
inexpr = _op.reshape(inexpr, newshape=shape)
return _op.transpose(inexpr, axes=[0, -1] + list(range(1, len(shape) - 1)))

return _op.reshape(inexpr, newshape=shape)


Expand Down
10 changes: 10 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,16 @@ def test_forward_reshape(self, keras):
x = keras.layers.Reshape(target_shape=(4, 4))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
# "non-square" target shape
data = keras.layers.Input(shape=(15,))
x = keras.layers.Reshape(target_shape=(5, 3))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)
# modify channel dim
data = keras.layers.Input(shape=(3, 2, 4))
x = keras.layers.Reshape(target_shape=(3, 8))(data)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model, need_transpose=False)

def test_forward_crop(self, keras):
data = keras.layers.Input(shape=(32, 32, 3))
Expand Down

0 comments on commit d365069

Please sign in to comment.