Skip to content

Commit

Permalink
Implementation of the WeCNLP abstract "Cross+Self-Attention for Trans…
Browse files Browse the repository at this point in the history
…former Models" (#1097)

Summary:
This PR implements a new attention module which combines cross-attention (encoder-decoder attention) and the decoder self-attention. This work was accepted as an abstract at WeCNLP 2019 (https://www.wecnlp.ai/wecnlp-2019).

Cross+Self-Attention reduces the amount of parameter and increases the inference speed without any degradation in translation quality.
More details can be found in the attached [abstract](https://github.com/pytorch/fairseq/files/3561282/paper.pdf)
Pull Request resolved: #1097

Differential Revision: D17653168

Pulled By: myleott

fbshipit-source-id: deb834c2c78a229d7418ffbfea20ba3ce252991c
  • Loading branch information
stephanpeitz authored and facebook-github-bot committed Sep 29, 2019
1 parent ea1a410 commit 4ac2c5f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 12 deletions.
66 changes: 62 additions & 4 deletions fairseq/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ def add_args(parser):
'Must be used with adaptive_loss criterion'),
parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D',
help='sets adaptive softmax dropout for the tail projections')
# args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019)
parser.add_argument('--no-cross-attention', default=False, action='store_true',
help='do not perform cross-attention')
parser.add_argument('--cross-self-attention', default=False, action='store_true',
help='perform cross+self-attention')
parser.add_argument('--layer-wise-attention', default=False, action='store_true',
help='perform layer-wise attention (cross-attention or cross+self-attention)')
# fmt: on

@classmethod
Expand Down Expand Up @@ -180,7 +187,12 @@ def build_encoder(cls, args, src_dict, embed_tokens):

@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return TransformerDecoder(args, tgt_dict, embed_tokens)
return TransformerDecoder(
args,
tgt_dict,
embed_tokens,
no_encoder_attn=getattr(args, 'no_cross_attention', False),
)


class TransformerEncoder(FairseqEncoder):
Expand Down Expand Up @@ -211,6 +223,8 @@ def __init__(self, args, dictionary, embed_tokens):
learned=args.encoder_learned_pos,
) if not args.no_token_positional_embeddings else None

self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)

self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
Expand All @@ -230,21 +244,29 @@ def forward_embedding(self, src_tokens):
x = F.dropout(x, p=self.dropout, training=self.training)
return x, embed

def forward(self, src_tokens, src_lengths, cls_input=None):
def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
if self.layer_wise_attention:
return_all_hiddens = True

x, encoder_embedding = self.forward_embedding(src_tokens)

# B x T x C -> T x B x C
Expand All @@ -255,17 +277,24 @@ def forward(self, src_tokens, src_lengths, cls_input=None):
if not encoder_padding_mask.any():
encoder_padding_mask = None

encoder_states = [] if return_all_hiddens else None

# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
if return_all_hiddens:
encoder_states.append(x)

if self.layer_norm:
x = self.layer_norm(x)
if return_all_hiddens:
encoder_states[-1] = x

return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
'encoder_embedding': encoder_embedding, # B x T x C
'encoder_states': encoder_states, # List[T x B x C]
}

def reorder_encoder_out(self, encoder_out, new_order):
Expand All @@ -285,6 +314,9 @@ def reorder_encoder_out(self, encoder_out, new_order):
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(0, new_order)
if encoder_out.get('encoder_states', None) is not None:
for idx, state in enumerate(encoder_out['encoder_states']):
encoder_out['encoder_states'][idx] = state.index_select(1, new_order)
return encoder_out

def max_positions(self):
Expand All @@ -293,6 +325,14 @@ def max_positions(self):
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())

def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]

def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
Expand Down Expand Up @@ -350,6 +390,9 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
learned=args.decoder_learned_pos,
) if not args.no_token_positional_embeddings else None

self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.layer_wise_attention = getattr(args, 'layer_wise_attention', False)

self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args, no_encoder_attn)
Expand Down Expand Up @@ -435,14 +478,26 @@ def extract_features(self, prev_output_tokens, encoder_out=None, incremental_sta

inner_states = [x]

self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
if not self_attn_padding_mask.any() and not self.cross_self_attention:
self_attn_padding_mask = None

# decoder layers
for layer in self.layers:
for idx, layer in enumerate(self.layers):
encoder_state = None
if encoder_out is not None:
if self.layer_wise_attention:
encoder_state = encoder_out['encoder_states'][idx]
else:
encoder_state = encoder_out['encoder_out']

x, attn = layer(
x,
encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_state,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
self_attn_padding_mask=self_attn_padding_mask,
)
inner_states.append(x)

Expand Down Expand Up @@ -553,6 +608,9 @@ def base_architecture(args):
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.adaptive_input = getattr(args, 'adaptive_input', False)
args.no_cross_attention = getattr(args, 'no_cross_attention', False)
args.cross_self_attention = getattr(args, 'cross_self_attention', False)
args.layer_wise_attention = getattr(args, 'layer_wise_attention', False)

args.decoder_output_dim = getattr(args, 'decoder_output_dim', args.decoder_embed_dim)
args.decoder_input_dim = getattr(args, 'decoder_input_dim', args.decoder_embed_dim)
Expand Down
10 changes: 9 additions & 1 deletion fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,15 @@ def forward(self, query, key, value, key_padding_mask=None, incremental_state=No
v = prev_value
else:
v = torch.cat((prev_value, v), dim=1)
if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
prev_key_padding_mask = saved_state['prev_key_padding_mask']
if static_kv:
key_padding_mask = prev_key_padding_mask
else:
key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_key_padding_mask'] = key_padding_mask

self._set_input_buffer(incremental_state, saved_state)

Expand Down Expand Up @@ -311,7 +318,8 @@ def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
for k in input_buffer.keys():
input_buffer[k] = input_buffer[k].index_select(0, new_order)
if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)

def _get_input_buffer(self, incremental_state):
Expand Down
34 changes: 28 additions & 6 deletions fairseq/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
Expand Down Expand Up @@ -134,13 +135,14 @@ class TransformerDecoderLayer(nn.Module):
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.cross_self_attention = getattr(args, 'cross_self_attention', False)
self.self_attn = MultiheadAttention(
embed_dim=self.embed_dim,
num_heads=args.decoder_attention_heads,
dropout=args.attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
self_attention=True
self_attention=not self.cross_self_attention,
)
self.dropout = args.dropout
self.activation_fn = utils.get_activation_fn(
Expand Down Expand Up @@ -208,13 +210,27 @@ def forward(
if prev_self_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_self_attn_state
prev_key, prev_value = prev_self_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_self_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
self.self_attn._set_input_buffer(incremental_state, saved_state)

if self.cross_self_attention and not (incremental_state is not None and "prev_key" in self.self_attn._get_input_buffer(incremental_state)):
if self_attn_mask is not None:
self_attn_mask = torch.cat((x.new(x.size(0), encoder_out.size(0)).zero_(), self_attn_mask), dim=1)
if self_attn_padding_mask is not None:
if encoder_padding_mask is None:
encoder_padding_mask = self_attn_padding_mask.new(encoder_out.size(1), encoder_out.size(0)).zero_()
self_attn_padding_mask = torch.cat((encoder_padding_mask, self_attn_padding_mask), dim=1)
y = torch.cat((encoder_out, x), dim=0)
else:
y = x

x, attn = self.self_attn(
query=x,
key=x,
value=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
Expand All @@ -230,9 +246,12 @@ def forward(
if prev_attn_state is not None:
if incremental_state is None:
incremental_state = {}
prev_key, prev_value = prev_attn_state
prev_key, prev_value = prev_attn_state[:2]
saved_state = {"prev_key": prev_key, "prev_value": prev_value}
if len(prev_attn_state) >= 3:
saved_state["prev_key_padding_mask"] = prev_attn_state[2]
self.encoder_attn._set_input_buffer(incremental_state, saved_state)

x, attn = self.encoder_attn(
query=x,
key=encoder_out,
Expand All @@ -256,7 +275,10 @@ def forward(
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
if self.onnx_trace and incremental_state is not None:
saved_state = self.self_attn._get_input_buffer(incremental_state)
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
if self_attn_padding_mask is not None:
self_attn_state = saved_state["prev_key"], saved_state["prev_value"], saved_state["prev_key_padding_mask"]
else:
self_attn_state = saved_state["prev_key"], saved_state["prev_value"]
return x, attn, self_attn_state
return x, attn

Expand Down
22 changes: 21 additions & 1 deletion tests/test_binaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def test_transformer(self):
], run_validation=True)
generate_main(data_dir)

def test_transformer_cross_self_attention(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_transformer_cross_self_attention') as data_dir:
create_dummy_data(data_dir)
preprocess_translation_data(data_dir)
train_translation_model(data_dir, 'transformer_iwslt_de_en', [
'--encoder-layers', '2',
'--decoder-layers', '2',
'--encoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--decoder-embed-dim', '8',
'--no-cross-attention',
'--cross-self-attention',
'--layer-wise-attention',
], run_validation=True)
generate_main(data_dir, extra_flags=[])

def test_lightconv(self):
with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_lightconv') as data_dir:
Expand Down Expand Up @@ -543,6 +560,10 @@ def train_translation_model(data_dir, arch, extra_flags=None, task='translation'


def generate_main(data_dir, extra_flags=None):
if extra_flags is None:
extra_flags = [
'--print-alignment',
]
generate_parser = options.get_generation_parser()
generate_args = options.parse_args_and_arch(
generate_parser,
Expand All @@ -554,7 +575,6 @@ def generate_main(data_dir, extra_flags=None):
'--max-len-b', '5',
'--gen-subset', 'valid',
'--no-progress-bar',
'--print-alignment',
] + (extra_flags or []),
)

Expand Down

0 comments on commit 4ac2c5f

Please sign in to comment.