Skip to content

Commit 57808cd

Browse files
committed
convert vanilla rnn to lstm
1 parent 7a8b56f commit 57808cd

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch/rnnlm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class RNNLM(nn.Module):
5858
def __init__(self):
5959
super(RNNLM, self).__init__()
6060
self.embeddings = nn.Embedding(vocab_size, args.EMBED_SIZE)
61-
self.rnn = nn.RNN(args.EMBED_SIZE, args.HIDDEN_SIZE)
61+
self.rnn = nn.LSTM(args.EMBED_SIZE, args.HIDDEN_SIZE)
6262
self.proj = nn.Linear(args.HIDDEN_SIZE, vocab_size)
6363
def forward(self, sequences):
6464
rnn_output, _ = self.rnn(self.embeddings(sequences))

0 commit comments

Comments
 (0)