From 32440abfee42f84734b2e87e040e801dcd4bfb7d Mon Sep 17 00:00:00 2001 From: Eugene Karaulov Date: Wed, 5 Jun 2019 12:04:49 +0300 Subject: [PATCH 1/2] Using partial targets at inference time. --- tensor2tensor/models/transformer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index d98c82cce..9cd11bfe8 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -699,7 +699,10 @@ def _fast_decode(self, features=features) encoder_output = encoder_output[0] encoder_decoder_attention_bias = encoder_decoder_attention_bias[0] - partial_targets = None + if 'partial_targets' in features: + partial_targets = features['partial_targets'] + else: + partial_targets = None else: # The problem has no inputs. encoder_output = None @@ -712,6 +715,8 @@ def _fast_decode(self, if partial_targets is None: partial_targets = features["targets"] assert partial_targets is not None + + if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int64(partial_targets) partial_targets_shape = common_layers.shape_list(partial_targets) From bce60b9bb818987577bfcc00050ce1698d013124 Mon Sep 17 00:00:00 2001 From: Eugene Karaulov Date: Wed, 12 Jun 2019 19:43:53 +0300 Subject: [PATCH 2/2] Saving attention history to Transformer's cache during fast decoding. --- tensor2tensor/models/transformer.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index c6a678250..a4d0e906f 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -27,6 +27,7 @@ from __future__ import division from __future__ import print_function from six.moves import range # pylint: disable=redefined-builtin +import re from tensor2tensor.data_generators import librispeech from tensor2tensor.layers import common_attention @@ -786,6 +787,23 @@ def preprocess_targets(targets, i): decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) + # Create tensors for encoder-decoder attention history + att_cache = {"attention_history": {}} + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2] + for layer in range(num_layers): + att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( + [att_batch_size, hparams.num_heads, 0, enc_seq_length]) + + def update_decoder_attention_history(cache): + for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x, + self.attention_weights.keys()): + m = re.search(r"(layer_\d+)", k) + if m is None: + continue + cache["attention_history"][m[0]] = tf.concat( + [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2) + def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] @@ -804,6 +822,8 @@ def symbols_to_logits_fn(ids, i, cache): cache, nonpadding=features_to_nonpadding(features, "targets")) + update_decoder_attention_history(cache) + modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) @@ -846,7 +866,8 @@ def forced_logits(): batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, sos_id=sos_id, - eos_id=eos_id) + eos_id=eos_id, + cache=att_cache) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:]