diff --git a/example/rnn/lstm.py b/example/rnn/lstm.py index 996861a80894..25245aad18ee 100644 --- a/example/rnn/lstm.py +++ b/example/rnn/lstm.py @@ -17,7 +17,7 @@ def lstm(num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): """LSTM Cell symbol""" if dropout > 0.: - in_data = mx.sym.Dropout(data=in_data, p=dropout) + indata = mx.sym.Dropout(data=indata, p=dropout) i2h = mx.sym.FullyConnected(data=indata, weight=param.i2h_weight, bias=param.i2h_bias,