From 764d8dbb209f4fd88a6ab3ebf842186c0909cefb Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 12 Aug 2019 16:55:32 -0700 Subject: [PATCH 1/5] Use softmax with length --- src/gluonnlp/model/attention_cell.py | 32 ++++++--------------- src/gluonnlp/model/transformer.py | 43 +++++++++++----------------- 2 files changed, 25 insertions(+), 50 deletions(-) diff --git a/src/gluonnlp/model/attention_cell.py b/src/gluonnlp/model/attention_cell.py index 32d4500992..76998d7ae2 100644 --- a/src/gluonnlp/model/attention_cell.py +++ b/src/gluonnlp/model/attention_cell.py @@ -29,7 +29,6 @@ from mxnet.gluon import nn from .block import L2Normalization -# TODO(sxjscience) Add mask flag to softmax operator. Think about how to accelerate the kernel def _masked_softmax(F, att_score, mask, dtype): """Ignore the masked elements when calculating the softmax @@ -39,27 +38,14 @@ def _masked_softmax(F, att_score, mask, dtype): att_score : Symborl or NDArray Shape (batch_size, query_length, memory_length) mask : Symbol or NDArray or None - Shape (batch_size, query_length, memory_length) + Shape (batch_size, query_length) Returns ------- att_weights : Symborl or NDArray Shape (batch_size, query_length, memory_length) """ if mask is not None: - # Fill in the masked scores with a very small value - neg = -1e18 - if np.dtype(dtype) == np.float16: - neg = -1e4 - else: - try: - # if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN. - from mxnet.contrib import amp - if amp.amp._amp_initialized: - neg = -1e4 - except ImportError: - pass - att_score = F.where(mask, att_score, neg * F.ones_like(att_score)) - att_weights = F.softmax(att_score, axis=-1) * mask + att_weights = F.softmax(att_score, length=mask, use_length=True, axis=-1) else: att_weights = F.softmax(att_score, axis=-1) return att_weights @@ -94,8 +80,8 @@ def _compute_weight(self, F, query, key, mask=None): key : Symbol or NDArray Key of the memory. Shape (batch_size, memory_length, key_dim) mask : Symbol or NDArray or None - Mask the memory slots. Shape (batch_size, query_length, memory_length) - Only contains 0 or 1 where 0 means that the memory slot will not be used. + Mask the memory slots. Shape (batch_size, query_length) + Contains length of the valid portion of the input. If set to None. No mask will be used. Returns @@ -142,8 +128,8 @@ def __call__(self, query, key, value=None, mask=None): # pylint: disable=argume Value of the memory. If set to None, the value will be set as the key. Shape (batch_size, memory_length, value_dim) mask : Symbol or NDArray or None, default None - Mask of the memory slots. Shape (batch_size, query_length, memory_length) - Only contains 0 or 1 where 0 means that the memory slot will not be used. + Mask of the memory slots. Shape (batch_size, query_length) + Contains length of the valid portion of the input. If set to None. No mask will be used. Returns @@ -237,8 +223,8 @@ def __call__(self, query, key, value=None, mask=None): Value of the memory. If set to None, the value will be set as the key. Shape (batch_size, memory_length, value_dim) mask : Symbol or NDArray or None, default None - Mask of the memory slots. Shape (batch_size, query_length, memory_length) - Only contains 0 or 1 where 0 means that the memory slot will not be used. + Mask of the memory slots. Shape (batch_size, query_length) + Contains length of the valid portion of the input. If set to None. No mask will be used. Returns @@ -266,7 +252,7 @@ def _compute_weight(self, F, query, key, mask=None): if mask is not None: mask = F.broadcast_axis(F.expand_dims(mask, axis=1), axis=1, size=self._num_heads)\ - .reshape(shape=(-1, 0, 0), reverse=True) + .reshape(shape=(-1, 0), reverse=True) att_weights = self._base_cell._compute_weight(F, query, key, mask) return att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True) diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index d640a93c86..42ed2f89cc 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -265,7 +265,7 @@ def hybrid_forward(self, F, inputs, mask=None): # pylint: disable=arguments-dif inputs : Symbol or NDArray Input sequence. Shape (batch_size, length, C_in) mask : Symbol or NDArray or None - Mask for inputs. Shape (batch_size, length, length) + Mask for inputs. Shape (batch_size, length) Returns ------- @@ -502,11 +502,10 @@ def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_wei steps = self._arange_like(F, inputs, axis=1) if valid_length is not None: - ones = F.ones_like(steps) - mask = F.broadcast_lesser(F.reshape(steps, shape=(1, -1)), - F.reshape(valid_length, shape=(-1, 1))) - mask = F.broadcast_mul(F.expand_dims(mask, axis=1), - F.broadcast_mul(ones, F.reshape(ones, shape=(-1, 1)))) + zeros = F.zeros_like(steps) + mask = F.broadcast_add(F.reshape(valid_length, shape=(-1, 1)), + F.reshape(zeros, shape=(1, -1))) + mask = F.cast(mask, dtype='int32') if states is None: states = [mask] else: @@ -665,7 +664,7 @@ class TransformerEncoderCell(BaseTransformerEncoderCell): Inputs: - **inputs** : input sequence. Shape (batch_size, length, C_in) - - **mask** : mask for inputs. Shape (batch_size, length, length) + - **mask** : mask for inputs. Shape (batch_size, length) Outputs: - **outputs**: output tensor of the transformer encoder cell. @@ -856,9 +855,9 @@ def hybrid_forward(self, F, inputs, mem_value, mask=None, mem_mask=None): #pyli mem_value : Symbol or NDArrays Memory value, i.e. output of the encoder. Shape (batch_size, mem_length, C_in) mask : Symbol or NDArray or None - Mask for inputs. Shape (batch_size, length, length) + Mask for inputs. Shape (batch_size, length) mem_mask : Symbol or NDArray or None - Mask for mem_value. Shape (batch_size, length, mem_length) + Mask for mem_value. Shape (batch_size, length) Returns ------- @@ -994,9 +993,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): if encoder_valid_length is not None: dtype = encoder_valid_length.dtype ctx = encoder_valid_length.context - mem_masks = mx.nd.broadcast_lesser( - mx.nd.arange(mem_length, ctx=ctx, dtype=dtype).reshape((1, -1)), - encoder_valid_length.reshape((-1, 1))) + mem_masks = mx.nd.cast(encoder_valid_length, dtype='int32') decoder_states.append(mem_masks) self._encoder_valid_length = encoder_valid_length return decoder_states @@ -1028,17 +1025,12 @@ def decode_seq(self, inputs, states, valid_length=None): """ batch_size = inputs.shape[0] length = inputs.shape[1] - length_array = mx.nd.arange(length, ctx=inputs.context, dtype=inputs.dtype) - mask = mx.nd.broadcast_lesser_equal( - length_array.reshape((1, -1)), - length_array.reshape((-1, 1))) + mask = mx.nd.arange(1, length+1, ctx=inputs.context, dtype=inputs.dtype) if valid_length is not None: - arange = mx.nd.arange(length, ctx=valid_length.context, dtype=valid_length.dtype) - batch_mask = mx.nd.broadcast_lesser( - arange.reshape((1, -1)), - valid_length.reshape((-1, 1))) - mask = mx.nd.broadcast_mul(mx.nd.expand_dims(batch_mask, -1), - mx.nd.expand_dims(mask, 0)) + mask = mx.nd.broadcast_lesser_equal(mask.reshape(1, -1), + valid_length.reshape(-1, 1)) + mask = mask * mx.nd.broadcast_minimum(mask.reshape(1, -1), + valid_length.reshape(-1, 1)) else: mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size) states = [None] + states @@ -1106,11 +1098,8 @@ def forward(self, step_input, states, mask=None): #pylint: disable=arguments-di .broadcast_axes(axis=1, size=step_input.shape[1]) states[-1] = augmented_mem_mask if mask is None: - length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context, - dtype=step_input.dtype) - mask = mx.nd.broadcast_lesser_equal( - length_array.reshape((1, -1)), - length_array.reshape((-1, 1))) + mask = mx.nd.arange(1, step_input.shape[1]+1, ctx=step_input.context, + dtype=step_input.dtype) mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=step_input.shape[0]) steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context) From c7ae5b43ec17dbd13e73612450e9d7c641a23c60 Mon Sep 17 00:00:00 2001 From: Brenton Chu Date: Tue, 27 Aug 2019 17:54:25 -0700 Subject: [PATCH 2/5] Fix transformer decode for softmax length mask --- src/gluonnlp/model/transformer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index 42ed2f89cc..efea8f19e9 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -1025,12 +1025,12 @@ def decode_seq(self, inputs, states, valid_length=None): """ batch_size = inputs.shape[0] length = inputs.shape[1] - mask = mx.nd.arange(1, length+1, ctx=inputs.context, dtype=inputs.dtype) + valid_length_cast = mx.nd.cast(valid_length, dtype='int32') + arange = mx.nd.arange(1, length+1, ctx=inputs.context, dtype='int32') if valid_length is not None: - mask = mx.nd.broadcast_lesser_equal(mask.reshape(1, -1), - valid_length.reshape(-1, 1)) - mask = mask * mx.nd.broadcast_minimum(mask.reshape(1, -1), - valid_length.reshape(-1, 1)) + mask = mx.nd.broadcast_lesser_equal(arange.reshape(1, -1), + valid_length_cast.reshape(-1, 1)) + mask = mask * arange.reshape(1, -1) else: mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size) states = [None] + states @@ -1099,7 +1099,7 @@ def forward(self, step_input, states, mask=None): #pylint: disable=arguments-di states[-1] = augmented_mem_mask if mask is None: mask = mx.nd.arange(1, step_input.shape[1]+1, ctx=step_input.context, - dtype=step_input.dtype) + dtype='int32') mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=step_input.shape[0]) steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context) From 00a97d220c7fa8c2dbfce94e802b5f37c22148ff Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 29 Aug 2019 13:36:48 -0700 Subject: [PATCH 3/5] Fix lint --- src/gluonnlp/model/attention_cell.py | 11 +++-------- src/gluonnlp/model/transformer.py | 3 --- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/src/gluonnlp/model/attention_cell.py b/src/gluonnlp/model/attention_cell.py index 76998d7ae2..5a2b7ca204 100644 --- a/src/gluonnlp/model/attention_cell.py +++ b/src/gluonnlp/model/attention_cell.py @@ -29,7 +29,7 @@ from mxnet.gluon import nn from .block import L2Normalization -def _masked_softmax(F, att_score, mask, dtype): +def _masked_softmax(F, att_score, mask): """Ignore the masked elements when calculating the softmax Parameters @@ -62,13 +62,8 @@ class AttentionCell(HybridBlock): """ def __init__(self, prefix=None, params=None): - self._dtype = np.float32 super(AttentionCell, self).__init__(prefix=prefix, params=params) - def cast(self, dtype): - self._dtype = dtype - super(AttentionCell, self).cast(dtype) - def _compute_weight(self, F, query, key, mask=None): """Compute attention weights based on the query and the keys @@ -356,7 +351,7 @@ def _compute_weight(self, F, query, key, mask=None): F.expand_dims(mapped_key, axis=1)) mid_feat = self._act(mid_feat) att_score = self._attention_score(mid_feat).reshape(shape=(0, 0, 0)) - att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype)) + att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask)) return att_weights @@ -463,5 +458,5 @@ def _compute_weight(self, F, query, key, mask=None): att_score = F.batch_dot(query, key, transpose_b=True) - att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype)) + att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask)) return att_weights diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index efea8f19e9..eb0486f076 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -989,10 +989,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): """ mem_value = encoder_outputs decoder_states = [mem_value] - mem_length = mem_value.shape[1] if encoder_valid_length is not None: - dtype = encoder_valid_length.dtype - ctx = encoder_valid_length.context mem_masks = mx.nd.cast(encoder_valid_length, dtype='int32') decoder_states.append(mem_masks) self._encoder_valid_length = encoder_valid_length From 8675077e5ea90c290f1a2d93caa66350a8c477f4 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 29 Aug 2019 13:41:05 -0700 Subject: [PATCH 4/5] Fix more lint --- src/gluonnlp/model/attention_cell.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gluonnlp/model/attention_cell.py b/src/gluonnlp/model/attention_cell.py index 5a2b7ca204..4984a13b6f 100644 --- a/src/gluonnlp/model/attention_cell.py +++ b/src/gluonnlp/model/attention_cell.py @@ -23,7 +23,6 @@ __all__ = ['AttentionCell', 'MultiHeadAttentionCell', 'MLPAttentionCell', 'DotProductAttentionCell'] import math -import numpy as np import mxnet as mx from mxnet.gluon.block import HybridBlock from mxnet.gluon import nn From d473c9878d747d8fdbd61eb2eb5874fdbcead954 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 27 Sep 2019 10:16:56 -0700 Subject: [PATCH 5/5] Fix TransformerXL --- .../transformer/attention_cell.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/scripts/language_model/transformer/attention_cell.py b/scripts/language_model/transformer/attention_cell.py index 532f3f8eea..0b75f06082 100644 --- a/scripts/language_model/transformer/attention_cell.py +++ b/scripts/language_model/transformer/attention_cell.py @@ -21,11 +21,43 @@ __all__ = ['PositionalEmbeddingMultiHeadAttentionCell'] import math +import numpy as np import mxnet as mx -from gluonnlp.model.attention_cell import _masked_softmax +def _masked_softmax(F, att_score, mask, dtype): + """Ignore the masked elements when calculating the softmax + Parameters + ---------- + F : symbol or ndarray + att_score : Symborl or NDArray + Shape (batch_size, query_length, memory_length) + mask : Symbol or NDArray or None + Shape (batch_size, query_length, memory_length) + Returns + ------- + att_weights : Symborl or NDArray + Shape (batch_size, query_length, memory_length) + """ + if mask is not None: + # Fill in the masked scores with a very small value + neg = -1e18 + if np.dtype(dtype) == np.float16: + neg = -1e4 + else: + try: + # if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN. + from mxnet.contrib import amp + if amp.amp._amp_initialized: + neg = -1e4 + except ImportError: + pass + att_score = F.where(mask, att_score, neg * F.ones_like(att_score)) + att_weights = F.softmax(att_score, axis=-1) * mask + else: + att_weights = F.softmax(att_score, axis=-1) + return att_weights class PositionalEmbeddingMultiHeadAttentionCell(mx.gluon.HybridBlock): """Multi-head Attention Cell with positional embeddings.