Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE] [WIP] Use softmax with length in attention cells #910

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion scripts/language_model/transformer/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,43 @@
__all__ = ['PositionalEmbeddingMultiHeadAttentionCell']

import math
import numpy as np

import mxnet as mx

from gluonnlp.model.attention_cell import _masked_softmax
def _masked_softmax(F, att_score, mask, dtype):
"""Ignore the masked elements when calculating the softmax

Parameters
----------
F : symbol or ndarray
att_score : Symborl or NDArray
Shape (batch_size, query_length, memory_length)
mask : Symbol or NDArray or None
Shape (batch_size, query_length, memory_length)
Returns
-------
att_weights : Symborl or NDArray
Shape (batch_size, query_length, memory_length)
"""
if mask is not None:
# Fill in the masked scores with a very small value
neg = -1e18
if np.dtype(dtype) == np.float16:
neg = -1e4
else:
try:
# if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN.
from mxnet.contrib import amp
if amp.amp._amp_initialized:
neg = -1e4
except ImportError:
pass
att_score = F.where(mask, att_score, neg * F.ones_like(att_score))
att_weights = F.softmax(att_score, axis=-1) * mask
else:
att_weights = F.softmax(att_score, axis=-1)
return att_weights

class PositionalEmbeddingMultiHeadAttentionCell(mx.gluon.HybridBlock):
"""Multi-head Attention Cell with positional embeddings.
Expand Down
44 changes: 12 additions & 32 deletions src/gluonnlp/model/attention_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@
__all__ = ['AttentionCell', 'MultiHeadAttentionCell', 'MLPAttentionCell', 'DotProductAttentionCell']

import math
import numpy as np
import mxnet as mx
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from .block import L2Normalization

# TODO(sxjscience) Add mask flag to softmax operator. Think about how to accelerate the kernel
def _masked_softmax(F, att_score, mask, dtype):
def _masked_softmax(F, att_score, mask):
"""Ignore the masked elements when calculating the softmax

Parameters
Expand All @@ -39,27 +37,14 @@ def _masked_softmax(F, att_score, mask, dtype):
att_score : Symborl or NDArray
Shape (batch_size, query_length, memory_length)
mask : Symbol or NDArray or None
Shape (batch_size, query_length, memory_length)
Shape (batch_size, query_length)
Returns
-------
att_weights : Symborl or NDArray
Shape (batch_size, query_length, memory_length)
"""
if mask is not None:
# Fill in the masked scores with a very small value
neg = -1e18
if np.dtype(dtype) == np.float16:
neg = -1e4
else:
try:
# if AMP (automatic mixed precision) is enabled, -1e18 will cause NaN.
from mxnet.contrib import amp
if amp.amp._amp_initialized:
neg = -1e4
except ImportError:
pass
att_score = F.where(mask, att_score, neg * F.ones_like(att_score))
att_weights = F.softmax(att_score, axis=-1) * mask
att_weights = F.softmax(att_score, length=mask, use_length=True, axis=-1)
else:
att_weights = F.softmax(att_score, axis=-1)
return att_weights
Expand All @@ -76,13 +61,8 @@ class AttentionCell(HybridBlock):

"""
def __init__(self, prefix=None, params=None):
self._dtype = np.float32
super(AttentionCell, self).__init__(prefix=prefix, params=params)

def cast(self, dtype):
self._dtype = dtype
super(AttentionCell, self).cast(dtype)

def _compute_weight(self, F, query, key, mask=None):
"""Compute attention weights based on the query and the keys

Expand All @@ -94,8 +74,8 @@ def _compute_weight(self, F, query, key, mask=None):
key : Symbol or NDArray
Key of the memory. Shape (batch_size, memory_length, key_dim)
mask : Symbol or NDArray or None
Mask the memory slots. Shape (batch_size, query_length, memory_length)
Only contains 0 or 1 where 0 means that the memory slot will not be used.
Mask the memory slots. Shape (batch_size, query_length)
Contains length of the valid portion of the input.
If set to None. No mask will be used.

Returns
Expand Down Expand Up @@ -142,8 +122,8 @@ def __call__(self, query, key, value=None, mask=None): # pylint: disable=argume
Value of the memory. If set to None, the value will be set as the key.
Shape (batch_size, memory_length, value_dim)
mask : Symbol or NDArray or None, default None
Mask of the memory slots. Shape (batch_size, query_length, memory_length)
Only contains 0 or 1 where 0 means that the memory slot will not be used.
Mask of the memory slots. Shape (batch_size, query_length)
Contains length of the valid portion of the input.
If set to None. No mask will be used.

Returns
Expand Down Expand Up @@ -237,8 +217,8 @@ def __call__(self, query, key, value=None, mask=None):
Value of the memory. If set to None, the value will be set as the key.
Shape (batch_size, memory_length, value_dim)
mask : Symbol or NDArray or None, default None
Mask of the memory slots. Shape (batch_size, query_length, memory_length)
Only contains 0 or 1 where 0 means that the memory slot will not be used.
Mask of the memory slots. Shape (batch_size, query_length)
Contains length of the valid portion of the input.
If set to None. No mask will be used.

Returns
Expand Down Expand Up @@ -266,7 +246,7 @@ def _compute_weight(self, F, query, key, mask=None):
if mask is not None:
mask = F.broadcast_axis(F.expand_dims(mask, axis=1),
axis=1, size=self._num_heads)\
.reshape(shape=(-1, 0, 0), reverse=True)
.reshape(shape=(-1, 0), reverse=True)
att_weights = self._base_cell._compute_weight(F, query, key, mask)
return att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True)

Expand Down Expand Up @@ -370,7 +350,7 @@ def _compute_weight(self, F, query, key, mask=None):
F.expand_dims(mapped_key, axis=1))
mid_feat = self._act(mid_feat)
att_score = self._attention_score(mid_feat).reshape(shape=(0, 0, 0))
att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype))
att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask))
return att_weights


Expand Down Expand Up @@ -477,5 +457,5 @@ def _compute_weight(self, F, query, key, mask=None):

att_score = F.batch_dot(query, key, transpose_b=True)

att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask, self._dtype))
att_weights = self._dropout_layer(_masked_softmax(F, att_score, mask))
return att_weights
46 changes: 16 additions & 30 deletions src/gluonnlp/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def hybrid_forward(self, F, inputs, mask=None): # pylint: disable=arguments-dif
inputs : Symbol or NDArray
Input sequence. Shape (batch_size, length, C_in)
mask : Symbol or NDArray or None
Mask for inputs. Shape (batch_size, length, length)
Mask for inputs. Shape (batch_size, length)

Returns
-------
Expand Down Expand Up @@ -502,11 +502,10 @@ def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_wei

steps = self._arange_like(F, inputs, axis=1)
if valid_length is not None:
ones = F.ones_like(steps)
mask = F.broadcast_lesser(F.reshape(steps, shape=(1, -1)),
F.reshape(valid_length, shape=(-1, 1)))
mask = F.broadcast_mul(F.expand_dims(mask, axis=1),
F.broadcast_mul(ones, F.reshape(ones, shape=(-1, 1))))
zeros = F.zeros_like(steps)
mask = F.broadcast_add(F.reshape(valid_length, shape=(-1, 1)),
F.reshape(zeros, shape=(1, -1)))
mask = F.cast(mask, dtype='int32')
if states is None:
states = [mask]
else:
Expand Down Expand Up @@ -665,7 +664,7 @@ class TransformerEncoderCell(BaseTransformerEncoderCell):

Inputs:
- **inputs** : input sequence. Shape (batch_size, length, C_in)
- **mask** : mask for inputs. Shape (batch_size, length, length)
- **mask** : mask for inputs. Shape (batch_size, length)

Outputs:
- **outputs**: output tensor of the transformer encoder cell.
Expand Down Expand Up @@ -856,9 +855,9 @@ def hybrid_forward(self, F, inputs, mem_value, mask=None, mem_mask=None): #pyli
mem_value : Symbol or NDArrays
Memory value, i.e. output of the encoder. Shape (batch_size, mem_length, C_in)
mask : Symbol or NDArray or None
Mask for inputs. Shape (batch_size, length, length)
Mask for inputs. Shape (batch_size, length)
mem_mask : Symbol or NDArray or None
Mask for mem_value. Shape (batch_size, length, mem_length)
Mask for mem_value. Shape (batch_size, length)

Returns
-------
Expand Down Expand Up @@ -990,13 +989,8 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
"""
mem_value = encoder_outputs
decoder_states = [mem_value]
mem_length = mem_value.shape[1]
if encoder_valid_length is not None:
dtype = encoder_valid_length.dtype
ctx = encoder_valid_length.context
mem_masks = mx.nd.broadcast_lesser(
mx.nd.arange(mem_length, ctx=ctx, dtype=dtype).reshape((1, -1)),
encoder_valid_length.reshape((-1, 1)))
mem_masks = mx.nd.cast(encoder_valid_length, dtype='int32')
decoder_states.append(mem_masks)
self._encoder_valid_length = encoder_valid_length
return decoder_states
Expand Down Expand Up @@ -1028,17 +1022,12 @@ def decode_seq(self, inputs, states, valid_length=None):
"""
batch_size = inputs.shape[0]
length = inputs.shape[1]
length_array = mx.nd.arange(length, ctx=inputs.context, dtype=inputs.dtype)
mask = mx.nd.broadcast_lesser_equal(
length_array.reshape((1, -1)),
length_array.reshape((-1, 1)))
valid_length_cast = mx.nd.cast(valid_length, dtype='int32')
arange = mx.nd.arange(1, length+1, ctx=inputs.context, dtype='int32')
if valid_length is not None:
arange = mx.nd.arange(length, ctx=valid_length.context, dtype=valid_length.dtype)
batch_mask = mx.nd.broadcast_lesser(
arange.reshape((1, -1)),
valid_length.reshape((-1, 1)))
mask = mx.nd.broadcast_mul(mx.nd.expand_dims(batch_mask, -1),
mx.nd.expand_dims(mask, 0))
mask = mx.nd.broadcast_lesser_equal(arange.reshape(1, -1),
valid_length_cast.reshape(-1, 1))
mask = mask * arange.reshape(1, -1)
else:
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size)
states = [None] + states
Expand Down Expand Up @@ -1106,11 +1095,8 @@ def forward(self, step_input, states, mask=None): #pylint: disable=arguments-di
.broadcast_axes(axis=1, size=step_input.shape[1])
states[-1] = augmented_mem_mask
if mask is None:
length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context,
dtype=step_input.dtype)
mask = mx.nd.broadcast_lesser_equal(
length_array.reshape((1, -1)),
length_array.reshape((-1, 1)))
mask = mx.nd.arange(1, step_input.shape[1]+1, ctx=step_input.context,
dtype='int32')
mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0),
axis=0, size=step_input.shape[0])
steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context)
Expand Down