Skip to content

Commit 7116a3d

Browse files
fix issue 207
1 parent 6a71750 commit 7116a3d

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

SeqLSTM.lua

+2-2
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,8 @@ function SeqLSTM:accGradParameters(input, gradOutput, scale)
359359
end
360360

361361
function SeqLSTM:forget()
362-
self.c0:zero()
363-
self.h0:zero()
362+
self.c0:resize(0)
363+
self.h0:resize(0)
364364
end
365365

366366
-- Toggle to feed long sequences using multiple forwards.

test/test.lua

+11
Original file line numberDiff line numberDiff line change
@@ -4898,6 +4898,17 @@ function rnntest.FastLSTM_issue203()
48984898
mytester:assert(err < 0.000001, "error "..err)
48994899
end
49004900

4901+
function rnntest.SeqLSTM_issue207()
4902+
local lstm = nn.SeqLSTM(10, 10)
4903+
lstm.batchfirst = true
4904+
lstm:remember('both')
4905+
lstm:training()
4906+
lstm:forward(torch.Tensor(32, 20, 10))
4907+
lstm:evaluate()
4908+
lstm:forget()
4909+
lstm:forward(torch.Tensor(1, 20, 10))
4910+
end
4911+
49014912
function rnn.test(tests, benchmark_)
49024913
mytester = torch.Tester()
49034914
benchmark = benchmark_

0 commit comments

Comments
 (0)