Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,22 +1053,30 @@ def _convert_simple_rnn(
in_data = inexpr[0]
prev_op = inexpr[1]
weightList = keras_layer.get_weights()
kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
weightList0 = weightList[0].transpose([1, 0])
assert len(in_data.type_annotation.shape) == 3
for i in range(in_data.type_annotation.shape[1].value - 1):
weightList0 = np.hstack((weightList0, weightList[0].transpose([1, 0])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware of this line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you plz elaborate? I still cannot see any data type issue here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check your code for bfloat16 weights? numpy.hstack has dtype arg and I guess it possibly checks it if so numpy fails when dtype is bfloat16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. I will check it later. Thx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's true that numpy.hstack does not support bfloat16. But weightList[0] and weightList0 can never be bfloat16 to my best understanding. These two vars are from weightList = keras_layer.get_weights(), and they are NumPy arrays. If numpy does not support bfloat16, I think the dtype of weightList[0] should never be bfloat16. So this worry seems unnecessary here.

kernel_weight = etab.new_const(weightList0)
recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
if keras_layer.use_bias:
in_bias = etab.new_const(weightList[2])
units = list(weightList[0].shape)[1]
assert units > 0, "The value of units must be a positive integer"
dim = weightList0.shape[0]
in_data = _op.nn.batch_flatten(in_data)
ixh = _op.nn.dense(in_data, kernel_weight, units=units)
if keras_layer.use_bias:
ixh = _op.nn.bias_add(ixh, bias=in_bias)
split_list = []
for i in range(1, dim):
split_list.append(i)
ixh_tuple = _op.split(ixh, split_list, 1)
prev_op = _op.nn.batch_flatten(prev_op)
ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
output = ixh + ixh2
output = _convert_activation(output, keras_layer, etab, data_layout)
out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
output = _op.reshape(output, newshape=out_shape)
for i in range(dim):
ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
prev_op = ixh_tuple[0] + ixh2
output = prev_op
return [output, output]


Expand Down
11 changes: 11 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,16 @@ def test_forward_time_distributed(self, keras_mod):
)
verify_keras_frontend(dense_model, need_transpose=False)

def test_simplernn_with_infertype(self, keras_mod):
"""This test case is from https://github.com/apache/tvm/issues/14868"""
input_shape = (2, 2, 2)
x = keras_mod.layers.Input(shape=input_shape[1:], dtype="float32")
layer = keras_mod.layers.SimpleRNN(units=4)
y = layer(x)
model = keras_mod.models.Model(x, y)
mod, _ = relay.frontend.from_keras(model, {"input_1": input_shape})
relay.transform.InferType()(mod)


if __name__ == "__main__":
for k in [keras, tf_keras]:
Expand Down Expand Up @@ -867,3 +877,4 @@ def test_forward_time_distributed(self, keras_mod):
sut.test_forward_repeat_vector(keras_mod=k)
sut.test_forward_l2_normalize(keras_mod=k)
sut.test_forward_time_distributed(keras_mod=k)
sut.test_simplernn_with_infertype(keras_mod=k)