-
Notifications
You must be signed in to change notification settings - Fork 307
/
Copy pathtranslate.py
132 lines (108 loc) · 4.62 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
''' Translate input text with trained model. '''
import contextlib
import pathlib
import os
import sys
import torch
import argparse
import dill as pickle
from tqdm import tqdm
import transformer.Constants as Constants
from transformer.Models import Transformer
from transformer.Translator import Translator
@contextlib.contextmanager
def _with_sys_path(path):
"""Temporarily add the given path to `sys.path`"""
path = os.fspath(path)
try:
sys.path.insert(0, path)
yield
finally:
sys.path.remove(path)
package_root = pathlib.Path(os.path.dirname(os.path.realpath(__file__))).parent.parent.parent
with _with_sys_path(package_root):
from util.torchtext_legacy.data import Dataset
def load_model(opt, device):
checkpoint = torch.load(opt.model, map_location=device)
model_opt = checkpoint['settings']
model = Transformer(
model_opt.src_vocab_size,
model_opt.trg_vocab_size,
model_opt.src_pad_idx,
model_opt.trg_pad_idx,
trg_emb_prj_weight_sharing=model_opt.proj_share_weight,
emb_src_trg_weight_sharing=model_opt.embs_share_weight,
d_k=model_opt.d_k,
d_v=model_opt.d_v,
d_model=model_opt.d_model,
d_word_vec=model_opt.d_word_vec,
d_inner=model_opt.d_inner_hid,
n_layers=model_opt.n_layers,
n_head=model_opt.n_head,
dropout=model_opt.dropout).to(device)
model.load_state_dict(checkpoint['model'])
print('[Info] Trained model state loaded.')
return model
def main():
'''Main Function'''
parser = argparse.ArgumentParser(description='translate.py')
parser.add_argument('-model', required=True,
help='Path to model weight file')
parser.add_argument('-data_pkl', required=True,
help='Pickle file with both instances and vocabulary.')
parser.add_argument('-output', default='pred.txt',
help="""Path to output the predictions (each line will
be the decoded sequence""")
parser.add_argument('-beam_size', type=int, default=5)
parser.add_argument('-max_seq_len', type=int, default=100)
parser.add_argument('-no_cuda', action='store_true')
parser.add_argument('-use_dml', action='store_true')
# TODO: Translate bpe encoded files
#parser.add_argument('-src', required=True,
# help='Source sequence to decode (one line per sequence)')
#parser.add_argument('-vocab', required=True,
# help='Source sequence to decode (one line per sequence)')
# TODO: Batch translation
#parser.add_argument('-batch_size', type=int, default=30,
# help='Batch size')
#parser.add_argument('-n_best', type=int, default=1,
# help="""If verbose is set, will output the n_best
# decoded sentences""")
opt = parser.parse_args()
opt.cuda = not opt.no_cuda
data = pickle.load(open(opt.data_pkl, 'rb'))
SRC, TRG = data['vocab']['src'], data['vocab']['trg']
opt.src_pad_idx = SRC.vocab.stoi[Constants.PAD_WORD]
opt.trg_pad_idx = TRG.vocab.stoi[Constants.PAD_WORD]
opt.trg_bos_idx = TRG.vocab.stoi[Constants.BOS_WORD]
opt.trg_eos_idx = TRG.vocab.stoi[Constants.EOS_WORD]
test_loader = Dataset(examples=data['test'], fields={'src': SRC, 'trg': TRG})
if opt.use_dml:
import torch_directml
device = torch_directml.device(torch_directml.default_device())
else:
device = torch.device('cuda' if opt.cuda else 'cpu')
translator = Translator(
model=load_model(opt, device),
beam_size=opt.beam_size,
max_seq_len=opt.max_seq_len,
src_pad_idx=opt.src_pad_idx,
trg_pad_idx=opt.trg_pad_idx,
trg_bos_idx=opt.trg_bos_idx,
trg_eos_idx=opt.trg_eos_idx).to(device)
unk_idx = SRC.vocab.stoi[SRC.unk_token]
with open(opt.output, 'w') as f:
for example in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False):
#print(' '.join(example.src))
src_seq = [SRC.vocab.stoi.get(word, unk_idx) for word in example.src]
pred_seq = translator.translate_sentence(torch.LongTensor([src_seq]).to(device))
pred_line = ' '.join(TRG.vocab.itos[idx] for idx in pred_seq)
pred_line = pred_line.replace(Constants.BOS_WORD, '').replace(Constants.EOS_WORD, '')
#print(pred_line)
f.write(pred_line.strip() + '\n')
print('[Info] Finished.')
if __name__ == "__main__":
'''
Usage: python translate.py -data_pkl .data\m30k_deen_shr.pkl -model trained.chkpt -output dml_prediction.txt
'''
main()