2
2
from chainer import Chain , links as L , Variable , functions as F
3
3
from chainer import optimizers , Variable , serializers
4
4
import chainer
5
- from typing import List , Tuple
5
+ from typing import List , Tuple , Dict
6
6
import pickle
7
7
from logging import getLogger , StreamHandler , DEBUG
8
8
import multiprocessing
@@ -292,7 +292,16 @@ def save_model(filename: str, model: EncoderDecoder) -> None:
292
292
293
293
def load_model (filename : str ) -> EncoderDecoder :
294
294
with np .load (filename ) as f :
295
- deserializer = serializers .NpzDeserializer (f )
295
+ d = dict (f .iteritems ())
296
+
297
+ # 後方互換性のために _W を _extract_output に改名する
298
+ for k , v in list (d .items ()):
299
+ old_prefix = "model/_W/"
300
+ new_prefix = "model/_extract_output/"
301
+ if k .startswith (old_prefix ):
302
+ d [new_prefix + k [len (old_prefix ):]] = v
303
+
304
+ deserializer = serializers .NpzDeserializer (d )
296
305
297
306
pickled_params = deserializer ("hyper_parameters" , None )
298
307
params = pickle .loads (pickled_params .tobytes ())
@@ -343,6 +352,22 @@ def translate_it2(model: EncoderDecoder, dataset: DataSet, index: int):
343
352
print ("Output:" , " " .join (ev .to_word (i ) for i in model .translate (dataset .ja_sentences [index ])))
344
353
345
354
355
+ def translate_it3 (models : Dict [str , EncoderDecoder ], dataset : DataSet , index : int ):
356
+ max_name_len = max (len (name ) for name in models )
357
+
358
+ jv = dataset .ja_vocabulary
359
+ ev = dataset .en_vocabulary
360
+
361
+ print ("{}:" .format ("JA" .rjust (max_name_len )),
362
+ " " .join (jv .to_word (i ) for i in dataset .ja_sentences [index ]))
363
+ print ("{}:" .format ("EN" .rjust (max_name_len )),
364
+ " " .join (ev .to_word (i ) for i in dataset .en_sentences [index ]))
365
+
366
+ for name , model in models :
367
+ print ("{}:" .format (name .rjust (max_name_len )),
368
+ " " .join (ev .to_word (i ) for i in model .translate (dataset .ja_sentences [index ])))
369
+
370
+
346
371
def translate_all (model : EncoderDecoder , dataset : DataSet , use_beam_search : bool ):
347
372
result = []
348
373
for i in range (dataset .n_sentences ):
@@ -356,6 +381,24 @@ def translate_all(model: EncoderDecoder, dataset: DataSet, use_beam_search: bool
356
381
return result
357
382
358
383
384
+ def show_translations (translations : Dict [str , List [List [int ]]], dataset : DataSet , index : int ):
385
+ max_name_len = max (len (name ) for name in translations )
386
+
387
+ jv = dataset .ja_vocabulary
388
+ ev = dataset .en_vocabulary
389
+
390
+ print ("{}: ----- :" .format ("JA" .ljust (max_name_len )),
391
+ " " .join (jv .to_word (i ) for i in dataset .ja_sentences [index ]))
392
+ print ("{}: ----- :" .format ("EN" .ljust (max_name_len )),
393
+ " " .join (ev .to_word (i ) for i in dataset .en_sentences [index ]))
394
+
395
+ for name , outputs in translations .items ():
396
+ bleu = calculate_bleu (dataset .en_sentences [index ], outputs [index ])
397
+ print ("{}:" .format (name .ljust (max_name_len )),
398
+ "{:.3f} :" .format (bleu ),
399
+ " " .join (ev .to_word (i ) for i in outputs [index ]))
400
+
401
+
359
402
def run_single (model_prefix , test_set , epoch , use_beam_search ):
360
403
logger .info ("epoch {}, use_beam_search {}" .format (epoch , use_beam_search ))
361
404
0 commit comments