Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Masking Layer doesn't work after adding a NaiveRunGraph feature in MXNet #228

Open
3 tasks done
karan6181 opened this issue Mar 29, 2019 · 2 comments
Open
3 tasks done

Comments

@karan6181
Copy link

  • Masking Layer fails with MXNet backend after this PR #14192 got merged in MXNet master.
  • Couple of RNN test fails such as test_masking_correctness(), test_masking_layer() in ./keras-apache-mxnet/tests/keras/layers/recurrent_test.py

Thank you!

  • Check that you are up-to-date with the master branch of Keras. You can update with:
    pip install git+git://github.com/awslabs/keras-apache-mxnet.git --upgrade --no-deps

  • If running on MXNet, check that you are up-to-date with the latest version. The installation
    instructions can be found here

  • Provide a link to a GitHub Gist of a Python script that can reproduce your issue (or just copy the script here if it is short).

Below is the minimum reproducible code:

import numpy as np
from keras.layers import LSTM
from keras.layers import Embedding
from keras.models import Sequential

num_samples = 2
timesteps = 5
embedding_dim = 4
units = 3
embedding_num = 12

model = Sequential()
model.add(Embedding(embedding_num, embedding_dim,
                               mask_zero=True,
                               input_length=timesteps
                               ))

# layer = recurrent.SimpleRNN(units)
layer = LSTM(units)
model.add(layer)
model.compile(optimizer='sgd', loss='mse')

left_padded_input = np.ones((num_samples, timesteps))
left_padded_input[0, :1] = 0
left_padded_input[1, :2] = 0
out6 = model.predict(left_padded_input)
@roywei
Copy link

roywei commented Mar 29, 2019

I think it triggers navie run graph only if masking enabled + sym.foreach operator used.
Which means RNN layer with unroll=False does not work with masking layer.
Current workaround to enable masking: use unroll=True in RNN layer

@karan6181
Copy link
Author

Yes absolutely correct. If we add unroll=True in RNN/LSTM/GRU layer, it uses the static forward and works without any issue.

Below is the running code where I have added unroll=True:

import numpy as np
from keras.layers import LSTM
from keras.layers import Embedding
from keras.models import Sequential

num_samples = 2
timesteps = 5
embedding_dim = 4
units = 3
embedding_num = 12

model = Sequential()
model.add(Embedding(embedding_num, embedding_dim,
                               mask_zero=True,
                               input_length=timesteps
                               ))

# layer = recurrent.SimpleRNN(units)
layer = LSTM(units, unroll=True)
model.add(layer)
model.compile(optimizer='sgd', loss='mse')

left_padded_input = np.ones((num_samples, timesteps))
left_padded_input[0, :1] = 0
left_padded_input[1, :2] = 0
out6 = model.predict(left_padded_input)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants