Skip to content

Commit

Permalink
[Relay] Allow converting keras.layers.Sequential (apache#2842)
Browse files Browse the repository at this point in the history
* Allow converting keras.layers.Sequential

* Use existing new_var function

* Only update expr when missing

* Add test
  • Loading branch information
nhynes authored and MarisaKirisame committed Apr 9, 2019
1 parent 57b31d4 commit b8b5ad5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,8 @@ def get_expr(self, name):

def set_expr(self, name, expr):
assert isinstance(expr, _expr.Expr)
self.exprs[name] = expr
if name not in self.exprs:
self.exprs[name] = expr

def set_padding(self, paddings):
self.paddings = paddings
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .. import expr as _expr
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable
from .common import ExprTable, new_var

__all__ = ['from_keras']

Expand Down Expand Up @@ -661,12 +661,15 @@ def from_keras(model, shape=None):
raise ValueError("Keras frontend currently supports data_format = channels_last only.")
_check_unsupported_layers(model)

def _convert_input_layer(keras_layer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, new_var(input_name, shape=input_shape))

etab = ExprTable()
for keras_layer in model.layers:
if isinstance(keras_layer, keras.engine.InputLayer):
input_name = keras_layer.name
input_shape = shape[input_name] if shape is not None and input_name in shape else None
etab.set_expr(input_name, _expr.var(input_name, shape=input_shape))
_convert_input_layer(keras_layer)
else:
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
Expand All @@ -690,6 +693,7 @@ def from_keras(model, shape=None):
for n_idx, t_idx, inbound_layer in zip_node:
if isinstance(inbound_layer, keras.engine.InputLayer):
expr_name = inbound_layer.name
_convert_input_layer(inbound_layer)
else:
expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
expr = etab.get_expr(expr_name)
Expand Down
12 changes: 12 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ def test_forward_dense():
verify_keras_frontend(keras_model)


def test_forward_sequential():
keras_model = keras.models.Sequential([
keras.layers.Dense(16, input_dim=32, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(8, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(1, activation='sigmoid')
])
verify_keras_frontend(keras_model)


def test_forward_pool():
data = keras.layers.Input(shape=(32,32,1))
# maxpool
Expand Down Expand Up @@ -244,6 +255,7 @@ def test_forward_mobilenet():
test_forward_merge()
test_forward_activations()
test_forward_dense()
test_forward_sequential()
test_forward_pool()
test_forward_conv()
test_forward_upsample(interpolation='nearest')
Expand Down

0 comments on commit b8b5ad5

Please sign in to comment.