diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index c6a678250..a4d0e906f 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -27,6 +27,7 @@ from __future__ import division from __future__ import print_function from six.moves import range # pylint: disable=redefined-builtin +import re from tensor2tensor.data_generators import librispeech from tensor2tensor.layers import common_attention @@ -786,6 +787,23 @@ def preprocess_targets(targets, i): decoder_self_attention_bias += common_attention.attention_bias_proximal( decode_length) + # Create tensors for encoder-decoder attention history + att_cache = {"attention_history": {}} + num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers + att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2] + for layer in range(num_layers): + att_cache["attention_history"]["layer_%d" % layer] = tf.zeros( + [att_batch_size, hparams.num_heads, 0, enc_seq_length]) + + def update_decoder_attention_history(cache): + for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x, + self.attention_weights.keys()): + m = re.search(r"(layer_\d+)", k) + if m is None: + continue + cache["attention_history"][m[0]] = tf.concat( + [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2) + def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] @@ -804,6 +822,8 @@ def symbols_to_logits_fn(ids, i, cache): cache, nonpadding=features_to_nonpadding(features, "targets")) + update_decoder_attention_history(cache) + modality_name = hparams.name.get( "targets", modalities.get_name(target_modality))(hparams, target_vocab_size) @@ -846,7 +866,8 @@ def forced_logits(): batch_size=batch_size, force_decode_length=self._decode_hparams.force_decode_length, sos_id=sos_id, - eos_id=eos_id) + eos_id=eos_id, + cache=att_cache) if partial_targets is not None: if beam_size <= 1 or top_beams <= 1: ret["outputs"] = ret["outputs"][:, partial_targets_length:]