We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7a8b56f commit 57808cdCopy full SHA for 57808cd
pytorch/rnnlm.py
@@ -58,7 +58,7 @@ class RNNLM(nn.Module):
58
def __init__(self):
59
super(RNNLM, self).__init__()
60
self.embeddings = nn.Embedding(vocab_size, args.EMBED_SIZE)
61
- self.rnn = nn.RNN(args.EMBED_SIZE, args.HIDDEN_SIZE)
+ self.rnn = nn.LSTM(args.EMBED_SIZE, args.HIDDEN_SIZE)
62
self.proj = nn.Linear(args.HIDDEN_SIZE, vocab_size)
63
def forward(self, sequences):
64
rnn_output, _ = self.rnn(self.embeddings(sequences))
0 commit comments