Skip to content

Commit d5a0eed

Browse files
committed
feat(encoder-decoder): show_translations
1 parent c68999b commit d5a0eed

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

learning-chainer/encoder-decoder/encoder_decoder.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from chainer import Chain, links as L, Variable, functions as F
33
from chainer import optimizers, Variable, serializers
44
import chainer
5-
from typing import List, Tuple
5+
from typing import List, Tuple, Dict
66
import pickle
77
from logging import getLogger, StreamHandler, DEBUG
88
import multiprocessing
@@ -292,7 +292,16 @@ def save_model(filename: str, model: EncoderDecoder) -> None:
292292

293293
def load_model(filename: str) -> EncoderDecoder:
294294
with np.load(filename) as f:
295-
deserializer = serializers.NpzDeserializer(f)
295+
d = dict(f.iteritems())
296+
297+
# 後方互換性のために _W を _extract_output に改名する
298+
for k, v in list(d.items()):
299+
old_prefix = "model/_W/"
300+
new_prefix = "model/_extract_output/"
301+
if k.startswith(old_prefix):
302+
d[new_prefix + k[len(old_prefix):]] = v
303+
304+
deserializer = serializers.NpzDeserializer(d)
296305

297306
pickled_params = deserializer("hyper_parameters", None)
298307
params = pickle.loads(pickled_params.tobytes())
@@ -343,6 +352,22 @@ def translate_it2(model: EncoderDecoder, dataset: DataSet, index: int):
343352
print("Output:", " ".join(ev.to_word(i) for i in model.translate(dataset.ja_sentences[index])))
344353

345354

355+
def translate_it3(models: Dict[str, EncoderDecoder], dataset: DataSet, index: int):
356+
max_name_len = max(len(name) for name in models)
357+
358+
jv = dataset.ja_vocabulary
359+
ev = dataset.en_vocabulary
360+
361+
print("{}:".format("JA".rjust(max_name_len)),
362+
" ".join(jv.to_word(i) for i in dataset.ja_sentences[index]))
363+
print("{}:".format("EN".rjust(max_name_len)),
364+
" ".join(ev.to_word(i) for i in dataset.en_sentences[index]))
365+
366+
for name, model in models:
367+
print("{}:".format(name.rjust(max_name_len)),
368+
" ".join(ev.to_word(i) for i in model.translate(dataset.ja_sentences[index])))
369+
370+
346371
def translate_all(model: EncoderDecoder, dataset: DataSet, use_beam_search: bool):
347372
result = []
348373
for i in range(dataset.n_sentences):
@@ -356,6 +381,24 @@ def translate_all(model: EncoderDecoder, dataset: DataSet, use_beam_search: bool
356381
return result
357382

358383

384+
def show_translations(translations: Dict[str, List[List[int]]], dataset: DataSet, index: int):
385+
max_name_len = max(len(name) for name in translations)
386+
387+
jv = dataset.ja_vocabulary
388+
ev = dataset.en_vocabulary
389+
390+
print("{}: ----- :".format("JA".ljust(max_name_len)),
391+
" ".join(jv.to_word(i) for i in dataset.ja_sentences[index]))
392+
print("{}: ----- :".format("EN".ljust(max_name_len)),
393+
" ".join(ev.to_word(i) for i in dataset.en_sentences[index]))
394+
395+
for name, outputs in translations.items():
396+
bleu = calculate_bleu(dataset.en_sentences[index], outputs[index])
397+
print("{}:".format(name.ljust(max_name_len)),
398+
"{:.3f} :".format(bleu),
399+
" ".join(ev.to_word(i) for i in outputs[index]))
400+
401+
359402
def run_single(model_prefix, test_set, epoch, use_beam_search):
360403
logger.info("epoch {}, use_beam_search {}".format(epoch, use_beam_search))
361404

0 commit comments

Comments
 (0)