Skip to content

Commit 8ee8fd5

Browse files
authored
Lattice-to-sequence (neulab#547)
* add __getitem__ and get_unpadded_sent to Sentence * started integrating / updating lattices * series of bug fixes to lattice encoder * fixed config file and serializable interface * added documentation * removed last remaining from_spec from preproc code * WIP: added LatticeFromPlfExtractor * extracting lattices from PLF works * implement lattice reader * config file with lattice reader working * removed some specialized code * remove broken arc dropout * move Lattice class to sent module * moved lattice reader to input_readers * simplified lattice embedder by delegating to base embedder * simplify config * moved lattice embedder to embedders module * move lattice lstm out of specialized encoders package * minor cleanup * add link * fix inconsistency in preproc code * remove unused config file * add Lattice.__len__ * simplified code by passing on expr seqs instead of lattices * lattice plotting * remove legacy comment * made lattice plotting more flexible * fix lattice.reversed() * remove unused fields * lattice padding * fix access to bwd prob * prepared LatticeBiasedMlpAttender * finished LatticeBiasedMlpAttender * text_input feature for LatticeReader * fix reading in of bwd probs * unpadded sent handling for lattice * 'flatten' option for lattice reader * remove duplicated classes
1 parent c17b2de commit 8ee8fd5

16 files changed

+857
-106
lines changed

docs/getting_started.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Prerequisites
1111
Before running *xnmt* you must install the required packages, including Python bindings for
1212
`DyNet <https://github.com/clab/dynet>`_.
1313
This can be done by running ``pip install -r requirements.txt``.
14-
(There is also ``requirements-extra.txt`` that has some requirements for utility scripts that are not part of *xnmt* itself.)
14+
(There are also optional package requirements under ``requirements-extra/`` for features that are non-central to *xnmt*.)
1515

1616
Next, install *xnmt* by running ``python setup.py install`` for normal usage or ``python setup.py develop`` for
1717
development.

examples/05_preproc.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@
6060
- '{DATA_OUT}/train.tok.norm.filter.ja'
6161
- '{DATA_OUT}/train.tok.norm.filter.en'
6262
specs:
63-
- type: length
64-
min: 1
65-
max: 60
63+
- !SentenceFiltererLength
64+
min_all: 1
65+
max_all: 60
6666
- !PreprocVocab
6767
in_files:
6868
- '{DATA_OUT}/train.tok.norm.ja'

examples/data/fisher_dev.en

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
afternoon .
2+
good afternoon
3+
my name is carmen , in chicago . you ?
4+
oh , my name is ricardo .
5+
of

examples/data/fisher_dev.en.vocab

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
afternoon
2+
carmen
3+
chicago
4+
good
5+
in
6+
is
7+
my
8+
name
9+
of
10+
oh
11+
ricardo
12+
you
13+
,
14+
?
15+
.

examples/data/fisher_dev.es.plf

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
((('tal', -0.727828979, 1),('tardes', -2.55085754, 2),('tarde', -0.823196411, 2),),(('ves', -2.08010864, 1),('vez', -0.731903076, 1),('de', -0.931167603, 1),),)
2+
((('buenas', 0, 1),),(('tardes', 0, 1),),)
3+
((('mi', 0, 1),),(('nombre', 0, 1),),(('es', 0, 1),),(('carmen', 0, 1),),(('de', 0, 1),),(('chicago', 0, 1),),(('y', 0, 1),),(('tu', 0, 1),),)
4+
((('no', -0.760681152, 1),('o', -1.75738525, 5),('oh', -1.02124023, 7),),(('me', 0, 1),),(('no', 0, 1),),(('me', 0, 1),),(('ricardo', 0, 11),),(('me', 0, 1),),(('no', 0, 3),),(('me', -0.817199707, 1),('mi', -0.582763672, 4),),(('no', 0, 1),),(('me', 0, 1),),(('ricardo', 0, 5),),(('nombre', -0.5859375, 1),('no', -0.813262939, 2),),(('ricardo', 0, 3),),(('me', 0, 1),),(('ricardo', 0, 1),),)
5+
((('yea', -2.66070557, 1),('sí', -1.18453979, 1),('ya', -1.33209229, 1),('yeah', -1.02085876, 1),),)

examples/data/fisher_dev.es.vocab

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
buenas
2+
carmen
3+
chicago
4+
de
5+
es
6+
me
7+
mi
8+
no
9+
nombre
10+
o
11+
oh
12+
ricardo
13+
14+
tal
15+
tarde
16+
tardes
17+
tu
18+
ves
19+
vez
20+
y
21+
ya
22+
yea
23+
yeah

requirements-extra/lattice.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
graphviz

test/config/lattice.yaml

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
lattice: !Experiment
2+
exp_global: !ExpGlobal
3+
default_layer_dim: 32
4+
dropout: 0.3
5+
preproc: !PreprocRunner
6+
overwrite: False
7+
tasks:
8+
- !PreprocExtract
9+
in_files:
10+
- examples/data/fisher_dev.es.plf
11+
out_files:
12+
- examples/output/fisher_dev.es.xlat
13+
specs: !LatticeFromPlfExtractor {}
14+
model: !DefaultTranslator
15+
src_embedder: !SimpleWordEmbedder {}
16+
encoder: !BiLatticeLSTMTransducer
17+
layers: 2
18+
attender: !LatticeBiasedMlpAttender {}
19+
trg_embedder: !SimpleWordEmbedder {}
20+
decoder: !AutoRegressiveDecoder
21+
rnn: !UniLSTMSeqTransducer
22+
layers: 1
23+
transform: !AuxNonLinear
24+
output_dim: 512
25+
activation: 'tanh'
26+
bridge: !CopyBridge {}
27+
scorer: !Softmax {}
28+
src_reader: !LatticeReader
29+
vocab: !Vocab
30+
vocab_file: examples/data/fisher_dev.es.vocab
31+
trg_reader: !PlainTextReader
32+
vocab: !Vocab
33+
_xnmt_id: trg_vocab
34+
vocab_file: examples/data/fisher_dev.en.vocab
35+
train: !SimpleTrainingRegimen
36+
trainer: !AdamTrainer
37+
alpha: 0.0003
38+
run_for_epochs: 10
39+
batcher: !SrcBatcher
40+
batch_size: 1
41+
restart_trainer: True
42+
lr_decay: 0.8
43+
patience: 5
44+
src_file: examples/output/fisher_dev.es.xlat
45+
trg_file: examples/data/fisher_dev.en
46+
dev_tasks:
47+
- !AccuracyEvalTask
48+
eval_metrics: bleu
49+
src_file: examples/output/fisher_dev.es.xlat
50+
ref_file: examples/data/fisher_dev.en
51+
hyp_file: examples/output/{EXP}.dev_hyp
52+
- !LossEvalTask
53+
src_file: examples/output/fisher_dev.es.xlat
54+
ref_file: examples/data/fisher_dev.en
55+
evaluate:
56+
- !AccuracyEvalTask
57+
eval_metrics: bleu
58+
src_file: examples/output/fisher_dev.es.xlat
59+
ref_file: examples/data/fisher_dev.en
60+
hyp_file: examples/output/{EXP}.test_hyp

test/config/preproc.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ standard-preproc: !Experiment
3636
- test/tmp/head.tok.norm.filter.ja
3737
- test/tmp/head.tok.norm.filter.en
3838
specs:
39-
- type: length
40-
min: 1
41-
max: 50
39+
- !SentenceFiltererLength
40+
min_all: 1
41+
max_all: 50
4242
- !PreprocVocab
4343
in_files:
4444
- test/tmp/head.tok.norm.ja

test/test_run.py

+3
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def test_ensembling(self):
3131
def test_forced(self):
3232
run.main(["test/config/forced.yaml"])
3333

34+
def test_lattice(self):
35+
run.main(["test/config/lattice.yaml"])
36+
3437
def test_lm(self):
3538
run.main(["test/config/lm.yaml"])
3639

xnmt/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import xnmt.train.regimens
5252
import xnmt.train.tasks
5353
import xnmt.transducers.convolution
54+
import xnmt.transducers.lattice
5455
import xnmt.transducers.network_in_network
5556
import xnmt.transducers.positional
5657
import xnmt.transducers.pyramidal

xnmt/input_readers.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import ast
12
from itertools import zip_longest
23
from functools import lru_cache
3-
import ast
44
from typing import Iterator, Optional, Sequence, Union
55
import numbers
66

@@ -468,6 +468,75 @@ def read_sent(self, line, idx):
468468
def read_sents(self, filename, filter_ids=None):
469469
return [l for l in self.iterate_filtered(filename, filter_ids)]
470470

471+
472+
class LatticeReader(BaseTextReader, Serializable):
473+
"""
474+
Reads lattices from a text file.
475+
476+
The expected lattice file format is as follows:
477+
* 1 line per lattice
478+
* lines are serialized python lists / tuples
479+
* 2 lists per lattice:
480+
    - list of nodes, with every node a 4-tuple: (lexicon_entry, fwd_log_prob, marginal_log_prob, bwd_log_prob)
481+
    - list of arcs, each arc a tuple: (node_id_start, node_id_end)
482+
            - node_id references the nodes and is 0-indexed
483+
            - node_id_start < node_id_end
484+
* All paths must share a common start and end node, i.e. <s> and </s> need to be contained in the lattice
485+
486+
A simple example lattice:
487+
[('<s>', 0.0, 0.0, 0.0), ('buenas', 0, 0.0, 0.0), ('tardes', 0, 0.0, 0.0), ('</s>', 0.0, 0.0, 0.0)],[(0, 1), (1, 2), (2, 3)]
488+
489+
Args:
490+
vocab: Vocabulary to convert string tokens to integer ids. If not given, plain text will be assumed to contain
491+
space-separated integer ids.
492+
text_input: If ``True``, assume a standard text file as input and convert it to a flat lattice.
493+
flatten: If ``True``, convert to a flat lattice, with all probabilities set to 1.
494+
"""
495+
yaml_tag = '!LatticeReader'
496+
497+
@serializable_init
498+
def __init__(self, vocab: Vocab, text_input: bool = False, flatten = False):
499+
self.vocab = vocab
500+
self.text_input = text_input
501+
self.flatten = flatten
502+
503+
def read_sent(self, line, idx):
504+
if self.text_input:
505+
nodes = [sent.LatticeNode(nodes_prev=[], nodes_next=[1], value=Vocab.SS,
506+
fwd_log_prob=0.0, marginal_log_prob=0.0, bwd_log_prob=0.0)]
507+
for word in line.strip().split():
508+
nodes.append(
509+
sent.LatticeNode(nodes_prev=[len(nodes)-1], nodes_next=[len(nodes)+1], value=self.vocab.convert(word),
510+
fwd_log_prob=0.0, marginal_log_prob=0.0, bwd_log_prob=0.0))
511+
nodes.append(
512+
sent.LatticeNode(nodes_prev=[len(nodes) - 1], nodes_next=[], value=Vocab.ES,
513+
fwd_log_prob=0.0, marginal_log_prob=0.0, bwd_log_prob=0.0))
514+
515+
else:
516+
node_list, arc_list = ast.literal_eval(line)
517+
nodes = [sent.LatticeNode(nodes_prev=[], nodes_next=[],
518+
value=self.vocab.convert(item[0]),
519+
fwd_log_prob=item[1], marginal_log_prob=item[2], bwd_log_prob=item[3])
520+
for item in node_list]
521+
if self.flatten:
522+
for node_i in range(len(nodes)):
523+
if node_i < len(nodes)-1: nodes[node_i].nodes_next.append(node_i+1)
524+
if node_i > 0: nodes[node_i].nodes_prev.append(node_i-1)
525+
nodes[node_i].fwd_log_prob = nodes[node_i].bwd_log_prob = nodes[node_i].marginal_log_prob = 0.0
526+
else:
527+
for from_index, to_index in arc_list:
528+
nodes[from_index].nodes_next.append(to_index)
529+
nodes[to_index].nodes_prev.append(from_index)
530+
531+
assert nodes[0].value == self.vocab.SS
532+
assert nodes[-1].value == self.vocab.ES
533+
534+
return sent.Lattice(idx=idx, nodes=nodes, vocab=self.vocab)
535+
536+
def vocab_size(self):
537+
return len(self.vocab)
538+
539+
471540
###### A utility function to read a parallel corpus
472541
def read_parallel_corpus(src_reader: InputReader,
473542
trg_reader: InputReader,

xnmt/modelparts/attenders.py

+54-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import math
22
import numbers
33

4+
import numpy as np
45
import dynet as dy
56

67
from xnmt import logger
7-
from xnmt import batchers, expression_seqs, param_collections, param_initializers
8+
from xnmt import batchers, expression_seqs, events, param_collections, param_initializers
89
from xnmt.persistence import serializable_init, Serializable, Ref, bare
910

1011
class Attender(object):
@@ -203,4 +204,56 @@ def calc_context(self, state: dy.Expression) -> dy.Expression:
203204
attention = self.calc_attention(state)
204205
return self.I * attention
205206

207+
class LatticeBiasedMlpAttender(MlpAttender, Serializable):
208+
"""
209+
Modified MLP attention, where lattices are assumed as input and the attention is biased toward confident nodes.
210+
211+
Args:
212+
input_dim: input dimension
213+
state_dim: dimension of state inputs
214+
hidden_dim: hidden MLP dimension
215+
param_init: how to initialize weight matrices
216+
bias_init: how to initialize bias vectors
217+
truncate_dec_batches: whether the decoder drops batch elements as soon as these are masked at some time step.
218+
"""
219+
220+
yaml_tag = '!LatticeBiasedMlpAttender'
221+
222+
@events.register_xnmt_handler
223+
@serializable_init
224+
def __init__(self,
225+
input_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
226+
state_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
227+
hidden_dim: numbers.Integral = Ref("exp_global.default_layer_dim"),
228+
param_init: param_initializers.ParamInitializer = Ref("exp_global.param_init", default=bare(param_initializers.GlorotInitializer)),
229+
bias_init: param_initializers.ParamInitializer = Ref("exp_global.bias_init", default=bare(param_initializers.ZeroInitializer)),
230+
truncate_dec_batches: bool = Ref("exp_global.truncate_dec_batches", default=False)) -> None:
231+
super().__init__(input_dim=input_dim, state_dim=state_dim, hidden_dim=hidden_dim, param_init=param_init,
232+
bias_init=bias_init, truncate_dec_batches=truncate_dec_batches)
233+
234+
@events.handle_xnmt_event
235+
def on_start_sent(self, src):
236+
self.cur_sent_bias = np.full((src.sent_len(), 1, src.batch_size()), -1e10)
237+
for batch_i, lattice_batch_elem in enumerate(src):
238+
for node_i, node in enumerate(lattice_batch_elem.nodes):
239+
self.cur_sent_bias[node_i, 0, batch_i] = node.marginal_log_prob
240+
self.cur_sent_bias_expr = None
241+
242+
def calc_attention(self, state):
243+
V = dy.parameter(self.pV)
244+
U = dy.parameter(self.pU)
245+
246+
WI = self.WI
247+
curr_sent_mask = self.curr_sent.mask
248+
if self.truncate_dec_batches:
249+
if curr_sent_mask: state, WI, curr_sent_mask = batchers.truncate_batches(state, WI, curr_sent_mask)
250+
else: state, WI = batchers.truncate_batches(state, WI)
251+
h = dy.tanh(dy.colwise_add(WI, V * state))
252+
scores = dy.transpose(U * h)
253+
if curr_sent_mask is not None:
254+
scores = curr_sent_mask.add_to_tensor_expr(scores, multiplicator = -1e10)
255+
if self.cur_sent_bias_expr is None: self.cur_sent_bias_expr = dy.inputTensor(self.cur_sent_bias, batched=True)
256+
normalized = dy.softmax(scores + self.cur_sent_bias_expr)
257+
self.attention_vecs.append(normalized)
258+
return normalized
206259

0 commit comments

Comments
 (0)