Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Split up Seq2SeqDecoder into Seq2SeqDecoder and Seq2SeqOneStepDecoder
Browse files Browse the repository at this point in the history
In the current Gluon API, each HybridBlock has to serve one puropse and can only
define a single callable interface. Previous Seq2SeqDecoder interface required
each Seq2SeqDecoder Block to perform two functionalities (multi-step ahead and
single-step ahead decoding). This means neither of the two functionalities can
in practice be hybridized completely. Thus use two separate Blocks for the two
functionalities. They may share parameters.

Update the NMTModel API accordingly.

Further refactor TransformerDecoder to make it completely hybridizable.
TransformerOneStepDecoder still relies on a small hack but can be hybridized
completely when we enable numpy shape semantics.
  • Loading branch information
leezu committed Oct 19, 2019
1 parent 02056ea commit 8a9be69
Show file tree
Hide file tree
Showing 10 changed files with 473 additions and 375 deletions.
10 changes: 5 additions & 5 deletions docs/examples/machine_translation/gnmt.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ feed the encoder and decoder to the `NMTModel` to construct the GNMT model.
`model.hybridize` allows computation to be done using the symbolic backend. To understand what it means to be "hybridized," please refer to [this](https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/hybrid.html) page on MXNet hybridization and its advantages.

```{.python .input}
encoder, decoder = nmt.gnmt.get_gnmt_encoder_decoder(hidden_size=num_hidden,
dropout=dropout,
num_layers=num_layers,
num_bi_layers=num_bi_layers)
encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder(
hidden_size=num_hidden, dropout=dropout, num_layers=num_layers,
num_bi_layers=num_bi_layers)
model = nlp.model.translation.NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder,
decoder=decoder, embed_size=num_hidden, prefix='gnmt_')
decoder=decoder, one_step_ahead_decoder=one_step_ahead_decoder,
embed_size=num_hidden, prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
Expand Down
188 changes: 144 additions & 44 deletions scripts/machine_translation/gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mxnet.gluon import nn, rnn
from mxnet.gluon.block import HybridBlock
from gluonnlp.model.seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, \
_get_attention_cell, _get_cell_type, _nested_sequence_last
Seq2SeqOneStepDecoder, _get_attention_cell, _get_cell_type, _nested_sequence_last


class GNMTEncoder(Seq2SeqEncoder):
Expand Down Expand Up @@ -158,48 +158,14 @@ def forward(self, inputs, states=None, valid_length=None): #pylint: disable=arg
return [outputs, new_states], []


class GNMTDecoder(HybridBlock, Seq2SeqDecoder):
"""Structure of the RNN Encoder similar to that used in the
Google Neural Machine Translation paper.
We use gnmt_v2 strategy in tensorflow/nmt
Parameters
----------
cell_type : str or type
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
hidden_size : int
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
class _BaseGNMTDecoder(HybridBlock):
def __init__(self, cell_type='lstm', attention_cell='scaled_luong',
num_layers=2, hidden_size=128,
dropout=0.0, use_residual=True, output_attention=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
prefix=None, params=None):
super(GNMTDecoder, self).__init__(prefix=prefix, params=params)
super().__init__(prefix=prefix, params=params)
self._cell_type = _get_cell_type(cell_type)
self._num_layers = num_layers
self._hidden_size = hidden_size
Expand Down Expand Up @@ -301,7 +267,8 @@ def decode_seq(self, inputs, states, valid_length=None):
additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
return output, states, additional_outputs

def __call__(self, step_input, states): #pylint: disable=arguments-differ

def forward(self, step_input, states): #pylint: disable=arguments-differ
"""One-step-ahead decoding of the GNMT decoder.
Parameters
Expand All @@ -326,11 +293,7 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ
The attention weights will have shape (batch_size, 1, mem_length) or
(batch_size, num_heads, 1, mem_length)
"""
return super(GNMTDecoder, self).__call__(step_input, states)

def forward(self, step_input, states): #pylint: disable=arguments-differ, missing-docstring
step_output, new_states, step_additional_outputs =\
super(GNMTDecoder, self).forward(step_input, states)
step_output, new_states, step_additional_outputs = super().forward(step_input, states)
# In hybrid_forward, only the rnn_states and attention_vec are calculated.
# We directly append the mem_value and mem_masks in the forward() function.
# We apply this trick because the memory value/mask can be directly appended to the next
Expand Down Expand Up @@ -402,6 +365,134 @@ def hybrid_forward(self, F, step_input, states): #pylint: disable=arguments-dif
return rnn_out, new_states, step_additional_outputs


class GNMTOneStepDecoder(_BaseGNMTDecoder, Seq2SeqOneStepDecoder):
"""RNN Encoder similar to that used in the Google Neural Machine Translation paper.
One-step ahead decoder used during inference.
We use gnmt_v2 strategy in tensorflow/nmt
Parameters
----------
cell_type : str or type
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
hidden_size : int
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""


class GNMTDecoder(_BaseGNMTDecoder, Seq2SeqDecoder):
"""RNN Encoder similar to that used in the Google Neural Machine Translation paper.
Multi-step decoder used during training with teacher forcing.
We use gnmt_v2 strategy in tensorflow/nmt
Parameters
----------
cell_type : str or type
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
hidden_size : int
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""

def forward(self, inputs, states, valid_length=None):
"""Decode the decoder inputs. This function is only used for training.
Parameters
----------
inputs : NDArray, Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of initial decoder states
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)
Returns
-------
output : NDArray, Shape (batch_size, length, C_out)
states : list
The decoder states, includes:
- rnn_states : NDArray
- attention_vec : NDArray
- mem_value : NDArray
- mem_masks : NDArray, optional
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
length = inputs.shape[1]
output = []
additional_outputs = []
inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
rnn_states_l = []
attention_output_l = []
fixed_states = states[2:]
for i in range(length):
ele_output, states, ele_additional_outputs = super().forward(inputs[i], states)
rnn_states_l.append(states[0])
attention_output_l.append(states[1])
output.append(ele_output)
additional_outputs.extend(ele_additional_outputs)
output = mx.nd.stack(*output, axis=1)
if valid_length is not None:
states = [_nested_sequence_last(rnn_states_l, valid_length),
_nested_sequence_last(attention_output_l, valid_length)] + fixed_states
output = mx.nd.SequenceMask(output,
sequence_length=valid_length,
use_sequence_length=True,
axis=1)
if self._output_attention:
additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
return output, states, additional_outputs


def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', num_layers=2,
num_bi_layers=1, hidden_size=128, dropout=0.0, use_residual=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
Expand Down Expand Up @@ -450,4 +541,13 @@ def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', nu
i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer,
prefix=prefix + 'dec_', params=params)
return encoder, decoder
one_step_ahead_decoder = GNMTDecoder(cell_type=cell_type, attention_cell=attention_cell,
num_layers=num_layers,
hidden_size=hidden_size, dropout=dropout,
use_residual=use_residual,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer,
i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer,
prefix=prefix + 'dec_', params=decoder.collect_params())
return encoder, decoder, one_step_ahead_decoder
17 changes: 7 additions & 10 deletions scripts/machine_translation/inference_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,14 @@
else:
tgt_max_len = max_len[1]

encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
hidden_size=args.hidden_size,
dropout=args.dropout,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500),
scaled=args.scaled)
encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder(
units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout,
num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
share_embed=args.dataset != 'TOY', embed_size=args.num_units,
tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_')
one_step_ahead_decoder=one_step_ahead_decoder, share_embed=args.dataset != 'TOY',
embed_size=args.num_units, tie_weights=args.dataset != 'TOY',
embed_initializer=None, prefix='transformer_')

param_name = args.model_parameter
if (not os.path.exists(param_name)):
Expand Down
10 changes: 5 additions & 5 deletions scripts/machine_translation/train_gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@
else:
ctx = mx.gpu(args.gpu)

encoder, decoder = get_gnmt_encoder_decoder(hidden_size=args.num_hidden,
dropout=args.dropout,
num_layers=args.num_layers,
num_bi_layers=args.num_bi_layers)
encoder, decoder, one_step_ahead_decoder = get_gnmt_encoder_decoder(
hidden_size=args.num_hidden, dropout=args.dropout, num_layers=args.num_layers,
num_bi_layers=args.num_bi_layers)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
embed_size=args.num_hidden, prefix='gnmt_')
one_step_ahead_decoder=one_step_ahead_decoder, embed_size=args.num_hidden,
prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
Expand Down
34 changes: 16 additions & 18 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,26 @@
# pylint:disable=redefined-outer-name,logging-format-interpolation

import argparse
import time
import random
import os
import logging
import math
import numpy as np
import os
import random
import time

import mxnet as mx
import numpy as np
from mxnet import gluon
import gluonnlp as nlp

from gluonnlp.loss import MaskedSoftmaxCELoss, LabelSmoothing
import dataprocessor
import gluonnlp as nlp
from bleu import _bpe_to_words, compute_bleu
from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss
from gluonnlp.model.transformer import (ParallelTransformer,
get_transformer_encoder_decoder)
from gluonnlp.model.translation import NMTModel
from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer
from gluonnlp.utils.parallel import Parallel
from translation import BeamSearchTranslator

from utils import logging_config
from bleu import _bpe_to_words, compute_bleu
import dataprocessor

np.random.seed(100)
random.seed(100)
Expand Down Expand Up @@ -174,15 +175,12 @@
tgt_max_len = args.tgt_max_len
else:
tgt_max_len = max_len[1]
encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
hidden_size=args.hidden_size,
dropout=args.dropout,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500),
scaled=args.scaled)
encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder(
units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout,
num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
one_step_ahead_decoder=one_step_ahead_decoder,
share_embed=args.dataset not in ('TOY', 'IWSLT2015'), embed_size=args.num_units,
tie_weights=args.dataset not in ('TOY', 'IWSLT2015'), embed_initializer=None,
prefix='transformer_')
Expand Down
Loading

0 comments on commit 8a9be69

Please sign in to comment.