Skip to content

Commit 62640d8

Browse files
msperberneubig
authored andcommitted
Decode-only (neulab#394)
* implement DecodingEvalTask + small clean up * implemented xnmt_decode.py * fix type hint
1 parent 32a7b86 commit 62640d8

File tree

6 files changed

+134
-47
lines changed

6 files changed

+134
-47
lines changed

examples/10_programmatic_load.py

-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
# if we were to continue training, we would need to set a save model file like this:
3131
# ParamManager.param_col.model_file = model_file
3232
ParamManager.populate()
33-
exp_global = loaded_experiment.exp_global
3433

3534
# run experiment
3635
loaded_experiment(save_fct=lambda: save_to_file(model_file, loaded_experiment))

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def get_git_revision():
5757
'console_scripts': [
5858
'xnmt = xnmt.xnmt_run_experiments:main',
5959
'xnmt_evaluate = xnmt.xnmt_evaluate:main',
60+
'xnmt_decode = xnmt.xnmt_decode:main',
6061
],
6162
}
6263
)

xnmt/eval_task.py

+41-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class EvalTask(object):
2222
def eval(self):
2323
raise NotImplementedError("EvalTask.eval() needs to be implemented in child classes")
2424

25-
class LossEvalTask(Serializable):
25+
class LossEvalTask(EvalTask, Serializable):
2626
"""
2727
A task that does evaluation of the loss function.
2828
@@ -52,7 +52,13 @@ def __init__(self, src_file: str, ref_file: str, model: GeneratorModel = Ref("mo
5252
self.max_trg_len = max_trg_len
5353
self.desc=desc
5454

55-
def eval(self):
55+
def eval(self) -> tuple:
56+
"""
57+
Perform evaluation task.
58+
59+
Returns:
60+
tuple of score and reference length
61+
"""
5662
self.model.set_train(False)
5763
if self.src_data is None:
5864
self.src_data, self.ref_data, self.src_batches, self.ref_batches = \
@@ -80,7 +86,7 @@ def eval(self):
8086
except KeyError:
8187
raise RuntimeError("Did you wrap your loss calculation with LossBuilder({'primary_loss': loss_value}) ?")
8288

83-
class AccuracyEvalTask(Serializable):
89+
class AccuracyEvalTask(EvalTask, Serializable):
8490
"""
8591
A task that does evaluation of some measure of accuracy.
8692
@@ -133,3 +139,35 @@ def eval(self):
133139
ref_words_cnt += self.model.trg_reader.count_words(ref_sent)
134140
ref_words_cnt += 0
135141
return eval_scores, ref_words_cnt
142+
143+
class DecodingEvalTask(EvalTask, Serializable):
144+
"""
145+
A task that does performs decoding without comparing against a reference.
146+
147+
Args:
148+
src_file: path(s) to read source file(s) from
149+
hyp_file: path to write hypothesis file to
150+
model: generator model to generate hypothesis with
151+
inference: inference object
152+
candidate_id_file:
153+
"""
154+
155+
yaml_tag = '!DecodingEvalTask'
156+
157+
@serializable_init
158+
def __init__(self, src_file: Union[str,Sequence[str]], hyp_file: str, model: GeneratorModel = Ref("model"),
159+
inference: Optional[SimpleInference] = None, candidate_id_file: Optional[str] = None):
160+
161+
self.model = model
162+
self.src_file = src_file
163+
self.hyp_file = hyp_file
164+
self.candidate_id_file = candidate_id_file
165+
self.inference = inference or self.model.inference
166+
167+
def eval(self):
168+
self.model.set_train(False)
169+
self.inference(generator=self.model,
170+
src_file=self.src_file,
171+
trg_file=self.hyp_file,
172+
candidate_id_file=self.candidate_id_file)
173+
return None, None

xnmt/inference.py

+59-42
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
# coding: utf-8
2-
31
from collections.abc import Iterable
2+
from typing import Optional
43

54
from xnmt.settings import settings
65

76
import dynet as dy
87

8+
from xnmt.batcher import Batcher
9+
from xnmt.generator import GeneratorModel
910
from xnmt import logger
1011
from xnmt.loss_calculator import MLELoss
1112
import xnmt.output
1213
from xnmt.reports import Reportable
1314
from xnmt.persistence import serializable_init, Serializable, Ref, bare
15+
from xnmt.search_strategy import SearchStrategy, BeamSearch
1416
from xnmt.util import make_parent_dir
15-
from xnmt.search_strategy import BeamSearch
1617

1718
NO_DECODING_ATTEMPTED = "@@NO_DECODING_ATTEMPTED@@"
1819

@@ -21,24 +22,29 @@ class SimpleInference(Serializable):
2122
Main class to perform decoding.
2223
2324
Args:
24-
src_file (str): path of input src file to be translated
25-
trg_file (str): path of file where trg translatons will be written
26-
ref_file (str): path of file with reference translations, e.g. for forced decoding
27-
max_src_len (int): Remove sentences from data to decode that are longer than this on the source side
28-
post_process (str): post-processing of translation outputs: ``none/join-char/join-bpe/join-piece``
29-
report_path (str): a path to which decoding reports will be written
30-
report_type (str): report to generate ``file/html``. Can be multiple, separate with comma.
31-
search_strategy (SearchStrategy): a search strategy used during decoding.
32-
mode (str): type of decoding to perform. ``onebest``: generate one best. ``forced``: perform forced decoding. ``forceddebug``: perform forced decoding, calculate training loss, and make suer the scores are identical for debugging purposes.
33-
batcher (Batcher):
25+
src_file: path of input src file to be translated
26+
trg_file: path of file where trg translatons will be written
27+
ref_file: path of file with reference translations, e.g. for forced decoding
28+
max_src_len: Remove sentences from data to decode that are longer than this on the source side
29+
post_process: post-processing of translation outputs: ``none/join-char/join-bpe/join-piece``
30+
report_path: a path to which decoding reports will be written
31+
report_type: report to generate ``file/html``. Can be multiple, separate with comma.
32+
search_strategy: a search strategy used during decoding.
33+
mode: type of decoding to perform.
34+
``onebest``: generate one best.
35+
``forced``: perform forced decoding.
36+
``forceddebug``: perform forced decoding, calculate training loss, and make suer the scores are identical
37+
for debugging purposes.
38+
batcher: inference batcher, needed e.g. in connection with ``pad_src_token_to_multiple``
3439
"""
3540

3641
yaml_tag = '!SimpleInference'
3742

3843
@serializable_init
39-
def __init__(self, src_file=None, trg_file=None, ref_file=None, max_src_len=None,
40-
post_process="none", report_path=None, report_type="html",
41-
search_strategy=bare(BeamSearch), mode="onebest", max_len=None, batcher=Ref("train.batcher", default=None)):
44+
def __init__(self, src_file: Optional[str] = None, trg_file: Optional[str] = None, ref_file: Optional[str] = None,
45+
max_src_len: Optional[int] = None, post_process: str = "none", report_path: Optional[str] = None,
46+
report_type: str = "html", search_strategy: SearchStrategy = bare(BeamSearch), mode: str = "onebest",
47+
max_len: Optional[int] = None, batcher: Optional[Batcher] = Ref("train.batcher", default=None)):
4248
self.src_file = src_file
4349
self.trg_file = trg_file
4450
self.ref_file = ref_file
@@ -51,52 +57,63 @@ def __init__(self, src_file=None, trg_file=None, ref_file=None, max_src_len=None
5157
self.search_strategy = search_strategy
5258
self.max_len = max_len
5359

54-
55-
def __call__(self, generator, src_file=None, trg_file=None, candidate_id_file=None):
60+
def __call__(self, generator: GeneratorModel, src_file: str = None, trg_file: str = None,
61+
candidate_id_file: str = None):
5662
"""
63+
Perform inference.
64+
5765
Args:
58-
generator (GeneratorModel): the model to be used
59-
src_file (str): path of input src file to be translated
60-
trg_file (str): path of file where trg translatons will be written
61-
candidate_id_file (str): if we are doing something like retrieval where we select from fixed candidates, sometimes we want to limit our candidates to a certain subset of the full set. this setting allows us to do this.
66+
generator: the model to be used
67+
src_file: path of input src file to be translated
68+
trg_file: path of file where trg translatons will be written
69+
candidate_id_file: if we are doing something like retrieval where we select from fixed candidates, sometimes we
70+
want to limit our candidates to a certain subset of the full set. this setting allows us to do
71+
this.
6272
"""
63-
args = dict(src_file=src_file or self.src_file, trg_file=trg_file or self.trg_file, ref_file=self.ref_file, max_src_len=self.max_src_len,
64-
post_process=self.post_process, candidate_id_file=candidate_id_file, report_path=self.report_path, report_type=self.report_type, mode=self.mode)
73+
# TODO: should be broken into smaller methods
74+
75+
src_file = src_file or self.src_file
76+
trg_file = trg_file or self.trg_file
6577

66-
is_reporting = issubclass(generator.__class__, Reportable) and args["report_path"] is not None
78+
is_reporting = issubclass(generator.__class__, Reportable) and self.report_path is not None
6779
# Corpus
68-
src_corpus = list(generator.src_reader.read_sents(args["src_file"]))
80+
src_corpus = list(generator.src_reader.read_sents(src_file))
6981
# Get reference if it exists and is necessary
70-
if args["mode"] == "forced" or args["mode"] == "forceddebug" or args["mode"] == "score":
71-
if args["ref_file"] is None:
72-
raise RuntimeError("When performing {} decoding, must specify reference file".format(args["mode"]))
82+
if self.mode == "forced" or self.mode == "forceddebug" or self.mode == "score":
83+
if self.ref_file is None:
84+
raise RuntimeError("When performing {} decoding, must specify reference file".format(self.mode))
7385
score_src_corpus = []
7486
ref_corpus = []
75-
with open(args["ref_file"], "r", encoding="utf-8") as fp:
87+
with open(self.ref_file, "r", encoding="utf-8") as fp:
7688
for line in fp:
77-
if args["mode"] == "score":
89+
if self.mode == "score":
7890
nbest = line.split("|||")
7991
assert len(nbest) > 1, "When performing scoring, ref_file must have nbest format 'index ||| hypothesis'"
8092
src_index = int(nbest[0].strip())
81-
assert src_index < len(src_corpus), "The src_file has only {} instances, nbest file has invalid src_index {}".format(len(src_corpus), src_index)
93+
assert src_index < len(src_corpus),\
94+
f"The src_file has only {len(src_corpus)} instances, nbest file has invalid src_index {src_index}"
8295
score_src_corpus.append(src_corpus[src_index])
8396
trg_input = generator.trg_reader.read_sent(nbest[1].strip())
8497
else:
8598
trg_input = generator.trg_reader.read_sent(line)
8699
ref_corpus.append(trg_input)
87-
if args["mode"] == "score":
100+
if self.mode == "score":
88101
src_corpus = score_src_corpus
89102
else:
90103
if self.max_len and any(len(s) > self.max_len for s in ref_corpus):
91-
logger.warning("Forced decoding with some targets being longer than max_len. Increase max_len to avoid unexpected behavior.")
104+
logger.warning("Forced decoding with some targets being longer than max_len. "
105+
"Increase max_len to avoid unexpected behavior.")
92106
else:
93107
ref_corpus = None
94108
# Vocab
95109
src_vocab = generator.src_reader.vocab if hasattr(generator.src_reader, "vocab") else None
96110
trg_vocab = generator.trg_reader.vocab if hasattr(generator.trg_reader, "vocab") else None
97111
# Perform initialization
98112
generator.set_train(False)
99-
generator.initialize_generator(**args)
113+
generator.initialize_generator(src_file=src_file, trg_file=trg_file, ref_file=self.ref_file,
114+
max_src_len=self.max_src_len, post_process=self.post_process,
115+
candidate_id_file=candidate_id_file, report_path=self.report_path,
116+
report_type=self.report_type, mode=self.mode)
100117

101118
if hasattr(generator, "set_post_processor"):
102119
generator.set_post_processor(self.get_output_processor())
@@ -111,7 +128,7 @@ def __call__(self, generator, src_file=None, trg_file=None, candidate_id_file=No
111128

112129
# If we're debugging, calculate the loss for each target sentence
113130
ref_scores = None
114-
if args["mode"] == 'forceddebug' or args["mode"] == 'score':
131+
if self.mode == 'forceddebug' or self.mode == 'score':
115132
some_batcher = xnmt.batcher.InOrderBatcher(32) # Arbitrary
116133
if not isinstance(some_batcher, xnmt.batcher.InOrderBatcher):
117134
raise ValueError(f"forceddebug requires InOrderBatcher, got: {some_batcher}")
@@ -127,11 +144,11 @@ def __call__(self, generator, src_file=None, trg_file=None, candidate_id_file=No
127144
ref_scores = [-x for x in ref_scores]
128145

129146
# Make the parent directory if necessary
130-
make_parent_dir(args["trg_file"])
147+
make_parent_dir(trg_file)
131148

132149
# Perform generation of output
133-
if args["mode"] != 'score':
134-
with open(args["trg_file"], 'wt', encoding='utf-8') as fp: # Saving the translated output to a trg file
150+
if self.mode != 'score':
151+
with open(trg_file, 'wt', encoding='utf-8') as fp: # Saving the translated output to a trg file
135152
src_ret=[]
136153
for i, src in enumerate(src_corpus):
137154
# This is necessary when the batcher does some sort of pre-processing, e.g.
@@ -140,7 +157,7 @@ def __call__(self, generator, src_file=None, trg_file=None, candidate_id_file=No
140157
self.batcher.add_single_batch(src_curr=[src], trg_curr=None, src_ret=src_ret, trg_ret=None)
141158
src = src_ret.pop()[0]
142159
# Do the decoding
143-
if args["max_src_len"] is not None and len(src) > args["max_src_len"]:
160+
if self.max_src_len is not None and len(src) > self.max_src_len:
144161
output_txt = NO_DECODING_ATTEMPTED
145162
else:
146163
dy.renew_cg(immediate_compute=settings.IMMEDIATE_COMPUTE, check_validity=settings.CHECK_VALIDITY)
@@ -153,8 +170,8 @@ def __call__(self, generator, src_file=None, trg_file=None, candidate_id_file=No
153170
# Printing to trg file
154171
fp.write(f"{output_txt}\n")
155172
else:
156-
with open(args["trg_file"], 'wt', encoding='utf-8') as fp:
157-
with open(args["ref_file"], "r", encoding="utf-8") as nbest_fp:
173+
with open(trg_file, 'wt', encoding='utf-8') as fp:
174+
with open(self.ref_file, "r", encoding="utf-8") as nbest_fp:
158175
for nbest, score in zip(nbest_fp, ref_scores):
159176
fp.write("{} ||| score={}\n".format(nbest.strip(), score))
160177

xnmt/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
YamlSerializable=Union[None,bool,int,float,'Serializable',List['YamlSerializable'],Dict[str,'YamlSerializable']]
88

99
def make_parent_dir(filename):
10-
if not os.path.exists(os.path.dirname(filename)):
10+
if not os.path.exists(os.path.dirname(filename) or "."):
1111
try:
1212
os.makedirs(os.path.dirname(filename))
1313
except OSError as exc: # Guard against race condition

xnmt/xnmt_decode.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import argparse, os, sys
2+
3+
from xnmt import eval_task
4+
from xnmt import param_collection
5+
from xnmt import persistence
6+
7+
def main():
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument("--src", help=f"Path of source file to read from.", required=True)
10+
parser.add_argument("--hyp", help="Path of file to write hypothesis to.", required=True)
11+
parser.add_argument("--mod", help="Path of model file to read.", required=True)
12+
args = parser.parse_args()
13+
14+
exp_dir = os.path.dirname(__file__)
15+
exp = "{EXP}"
16+
17+
param_collection.ParamManager.init_param_col()
18+
19+
# TODO: can we avoid the LoadSerialized proxy and load stuff directly?
20+
load_experiment = persistence.LoadSerialized(filename=args.mod)
21+
22+
uninitialized_experiment = persistence.YamlPreloader.preload_obj(load_experiment, exp_dir=exp_dir, exp_name=exp)
23+
loaded_experiment = persistence.initialize_if_needed(uninitialized_experiment)
24+
model = loaded_experiment.model
25+
inference = model.inference
26+
param_collection.ParamManager.populate()
27+
28+
decoding_task = eval_task.DecodingEvalTask(args.src, args.hyp, model, inference)
29+
decoding_task.eval()
30+
31+
if __name__ == "__main__":
32+
sys.exit(main())

0 commit comments

Comments
 (0)