Skip to content

Commit 53361fa

Browse files
thefiddlertqchen
authored andcommitted
Update frontend for keras 2.1.3 compatibility (apache#314)
* Keras keeps renaming properties. Update frontend for keras 2.1.3 compatibility * Add error message when inbound_nodes is not found
1 parent d626eab commit 53361fa

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

nnvm/python/nnvm/frontend/keras.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,14 @@ def _convert_concat(insym, keras_layer, _):
349349

350350

351351
def _convert_reshape(insym, keras_layer, _):
352-
return _sym.reshape(insym, keras_layer.shape)
352+
shape = keras_layer.shape if hasattr(keras_layer, 'shape') else \
353+
keras_layer.target_shape if hasattr(keras_layer, 'target_shape') else\
354+
None
355+
356+
if shape is None:
357+
raise TypeError("No shape attribute in reshape layer: {}".format(keras_layer))
358+
359+
return _sym.reshape(insym, shape=shape)
353360

354361

355362
def _default_skip(insym, keras_layer, _): # pylint: disable=unused-argument
@@ -477,7 +484,15 @@ def from_keras(model):
477484
symtab.get_var(keras_layer.name, must_contain=False)
478485
else:
479486
predecessors = []
480-
for node in keras_layer.inbound_nodes:
487+
inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
488+
else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
489+
else None
490+
491+
if inbound_nodes is None:
492+
raise TypeError("Unknown layer type or unsupported Keras version : {}"
493+
.format(keras_layer))
494+
495+
for node in inbound_nodes:
481496
for pred in node.inbound_layers:
482497
predecessors.append(pred.name)
483498
if len(predecessors) == 1:

0 commit comments

Comments
 (0)