Skip to content

Commit be4257e

Browse files
author
yanwii
committed
修复pytorch 更新后的bug
1 parent e14c3f9 commit be4257e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

seq2seq.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding: UTF-8 -*-
12
import math
23
import os
34
import random
@@ -118,8 +119,8 @@ def __init__(self):
118119
self.batch_index = 0
119120
self.GO_token = 2
120121
self.EOS_token = 1
121-
self.input_size = 7
122-
self.output_size = 9
122+
self.input_size = 14
123+
self.output_size = 15
123124
self.hidden_size = 100
124125
self.max_length = 15
125126
self.show_epoch = 100
@@ -247,13 +248,13 @@ def step(self, input_variable, target_variable, max_length):
247248
if use_teacher_forcing:
248249
for di in range(target_length):
249250
decoder_output, decoder_context, decoder_hidden, decoder_attention = self.decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
250-
loss += self.criterion(decoder_output[0], target_variable[di])
251+
loss += self.criterion(decoder_output, target_variable[di])
251252
decoder_input = target_variable[di]
252253
decoder_outputs.append(decoder_output.unsqueeze(0))
253254
else:
254255
for di in range(target_length):
255256
decoder_output, decoder_context, decoder_hidden, decoder_attention = self.decoder(decoder_input, decoder_context, decoder_hidden, encoder_outputs)
256-
loss += self.criterion(decoder_output[0], target_variable[di])
257+
loss += self.criterion(decoder_output, target_variable[di])
257258
decoder_outputs.append(decoder_output.unsqueeze(0))
258259
topv, topi = decoder_output.data.topk(1)
259260
ni = topi[0][0]

0 commit comments

Comments
 (0)