Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[DOC] Adding documentation to xlnet scripts #985

Merged
merged 7 commits into from
Oct 29, 2019
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions scripts/language_model/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,45 @@ def hybrid_forward(self, F, inputs, pos_emb, mem_value, mask, segments):


class _BaseXLNet(mx.gluon.HybridBlock):
"""
Parameters
----------
vocab_size : int
The size of the vocabulary.
num_layers : int
units : int
hidden_size : int
number of units in the hidden layer of position-wise feed-forward networks
num_heads : int
Number of heads in multi-head attention
activation
Activation function used for the position-wise feed-forward networks
two_stream
If True, use Two-Stream Self-Attention. Typically set to True for
pre-training and False during finetuning.
scaled : bool
Whether to scale the softmax input by the sqrt of the input dimension
in multi-head attention
dropout : float
attention_dropout : float
use_residual : bool
clamp_len : int
Clamp all relative distances larger than clamp_len
use_decoder : bool, default True
Whether to include the decoder for language model prediction.
tie_decoder_weight : bool, default True
Whether to tie the decoder weight with the input embeddings
weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default is None

Prefix for name of `Block`s (and name of weight if params is `None`).
params : Parameter or None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be ParameterDict or None

Container for weight sharing between cells. Created if `None`.

"""
def __init__(self, vocab_size, num_layers=2, units=128, hidden_size=2048, num_heads=4,
activation='gelu', two_stream: bool = False, scaled=True, dropout=0.0,
attention_dropout=0.0, use_residual=True, clamp_len: typing.Optional[int] = None,
Expand Down Expand Up @@ -529,6 +568,32 @@ def __init__(self, vocab_size, num_layers=2, units=128, hidden_size=2048, num_he
params=self.word_embed.params if tie_decoder_weight else None)

def hybrid_forward(self, F, step_input, segments, mask, pos_seq, mems, mask_embed): #pylint: disable=arguments-differ
"""
Parameters
----------
step_input : Symbol or NDArray
Input of shape [batch_size, query_length]
segments : Symbol or NDArray
One-hot vector indicating if a query-key pair is in the same
segment or not. Shape [batch_size, query_length, query_length +
memory_length, 2]. `1` indicates that the pair is not in the same
segment.
mask : Symbol or NDArray
Attention mask of shape (batch_size, length, length + mem_length)
pos_seq : Symbol or NDArray
Relative distances
mems : List of NDArray or Symbol, optional
Memory from previous forward passes containing
`num_layers` `NDArray`s or `Symbol`s each of shape [batch_size,
memory_length, units].

Returns
-------
core_out : NDArray or Symbol
For use_decoder=True, logits. Otherwise output of last layer.
hids : List of NDArray or Symbol
Stacking the output of each layer
"""
if self._clamp_len:
pos_seq = F.clip(pos_seq, a_min=0, a_max=self._clamp_len)

Expand Down Expand Up @@ -635,6 +700,8 @@ def forward(self, step_input, token_types, mems=None, mask=None): # pylint: dis
Optional memory from previous forward passes containing
`num_layers` `NDArray`s or `Symbol`s each of shape [batch_size,
memory_length, units].
mask : Symbol or NDArray
Attention mask of shape (batch_size, length, length + mem_length)

Returns
-------
Expand Down