diff --git a/scripts/bert/finetune_classifier.py b/scripts/bert/finetune_classifier.py index 8c3c17a36a..c1c2a8d56e 100644 --- a/scripts/bert/finetune_classifier.py +++ b/scripts/bert/finetune_classifier.py @@ -203,8 +203,10 @@ args = parser.parse_args() + log = logging.getLogger() log.setLevel(logging.INFO) + logging.captureWarnings(True) fh = logging.FileHandler('log_{0}.txt'.format(args.task_name)) formatter = logging.Formatter(fmt='%(levelname)s:%(name)s:%(asctime)s %(message)s', diff --git a/scripts/tests/test_bert_checkpoints.py b/scripts/tests/test_bert_checkpoints.py new file mode 100644 index 0000000000..16e58216bb --- /dev/null +++ b/scripts/tests/test_bert_checkpoints.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test inference with BERT checkpoints""" +import pytest +import zipfile +import subprocess +import sys +import re +import mxnet as mx + +@pytest.mark.serial +@pytest.mark.gpu +@pytest.mark.remote_required +@pytest.mark.integration +def test_bert_checkpoints(): + script = './scripts/bert/finetune_classifier.py' + param = 'bert_base_uncased_sst-a628b1d4.params' + param_zip = 'bert_base_uncased_sst-a628b1d4.zip' + arguments = ['--log_interval', '1000000', '--model_parameters', param, + '--gpu', '0', '--only_inference', '--task_name', 'SST', + '--epochs', '1'] + url = 'https://apache-mxnet.s3-accelerate.amazonaws.com/gluon/models/' + param_zip + mx.gluon.utils.download(url , path='.') + with zipfile.ZipFile(param_zip) as zf: + zf.extractall('.') + p = subprocess.check_call([sys.executable, script] + arguments) + with open('log_SST.txt', 'r') as f: + x = f.read() + find = re.compile('accuracy:0.[0-9]+').search(str(x)).group(0) + assert float(find[len('accuracy:'):]) > 0.92 diff --git a/src/gluonnlp/model/bert.py b/src/gluonnlp/model/bert.py index c30bab8b7d..022d1103b2 100644 --- a/src/gluonnlp/model/bert.py +++ b/src/gluonnlp/model/bert.py @@ -31,13 +31,225 @@ from ..base import get_home_dir from .block import GELU from .seq2seq_encoder_decoder import Seq2SeqEncoder -from .transformer import TransformerEncoderCell +from .transformer import PositionwiseFFN from .utils import _load_pretrained_params, _load_vocab ############################################################################### # COMPONENTS # ############################################################################### +class DotProductSelfAttentionCell(HybridBlock): + r"""Multi-head Dot Product Self Attention Cell. + + In the DotProductSelfAttentionCell, the input query/key/value will be linearly projected + for `num_heads` times with different projection matrices. Each projected key, value, query + will be used to calculate the attention weights and values. The output of each head will be + concatenated to form the final output. + + This is a more efficient implementation of MultiHeadAttentionCell with + DotProductAttentionCell as the base_cell: + + score = / sqrt(dim_q) + + Parameters + ---------- + units : int + Total number of projected units for query. Must be divided exactly by num_heads. + num_heads : int + Number of parallel attention heads + use_bias : bool, default True + Whether to use bias when projecting the query/key/values + weight_initializer : str or `Initializer` or None, default None + Initializer of the weights. + bias_initializer : str or `Initializer`, default 'zeros' + Initializer of the bias. + prefix : str or None, default None + See document of `Block`. + params : str or None, default None + See document of `Block`. + + Inputs: + - **qkv** : Symbol or NDArray + Query / Key / Value vector. Shape (query_length, batch_size, C_in) + - **valid_len** : Symbol or NDArray or None, default None + Valid length of the query/key/value slots. Shape (batch_size, query_length) + + Outputs: + - **context_vec** : Symbol or NDArray + Shape (query_length, batch_size, context_vec_dim) + - **att_weights** : Symbol or NDArray + Attention weights of multiple heads. + Shape (batch_size, num_heads, query_length, memory_length) + """ + def __init__(self, units, num_heads, dropout=0.0, use_bias=True, + weight_initializer=None, bias_initializer='zeros', + prefix=None, params=None): + super().__init__(prefix=prefix, params=params) + self._num_heads = num_heads + self._use_bias = use_bias + self._dropout = dropout + self.units = units + with self.name_scope(): + if self._use_bias: + self.query_bias = self.params.get('query_bias', shape=(self.units,), + init=bias_initializer) + self.key_bias = self.params.get('key_bias', shape=(self.units,), + init=bias_initializer) + self.value_bias = self.params.get('value_bias', shape=(self.units,), + init=bias_initializer) + weight_shape = (self.units, self.units) + self.query_weight = self.params.get('query_weight', shape=weight_shape, + init=weight_initializer, + allow_deferred_init=True) + self.key_weight = self.params.get('key_weight', shape=weight_shape, + init=weight_initializer, + allow_deferred_init=True) + self.value_weight = self.params.get('value_weight', shape=weight_shape, + init=weight_initializer, + allow_deferred_init=True) + self.dropout_layer = nn.Dropout(self._dropout) + + def _collect_params_with_prefix(self, prefix=''): + # the registered parameter names in v0.8 are the following: + # prefix_proj_query.weight, prefix_proj_query.bias + # prefix_proj_value.weight, prefix_proj_value.bias + # prefix_proj_key.weight, prefix_proj_key.bias + # this is a temporary fix to keep backward compatibility, due to an issue in MXNet: + # https://github.com/apache/incubator-mxnet/issues/17220 + if prefix: + prefix += '.' + ret = {prefix + 'proj_' + k.replace('_', '.') : v for k, v in self._reg_params.items()} + for name, child in self._children.items(): + ret.update(child._collect_params_with_prefix(prefix + name)) + return ret + + # pylint: disable=arguments-differ + def hybrid_forward(self, F, qkv, valid_len, query_bias, key_bias, value_bias, + query_weight, key_weight, value_weight): + # interleaved_matmul_selfatt ops assume the projection is done with interleaving + # weights for query/key/value. The concatenated weight should have shape + # (num_heads, C_out/num_heads * 3, C_in). + query_weight = query_weight.reshape(shape=(self._num_heads, -1, 0), reverse=True) + key_weight = key_weight.reshape(shape=(self._num_heads, -1, 0), reverse=True) + value_weight = value_weight.reshape(shape=(self._num_heads, -1, 0), reverse=True) + in_weight = F.concat(query_weight, key_weight, value_weight, dim=-2) + in_weight = in_weight.reshape(shape=(-1, 0), reverse=True) + in_bias = F.concat(query_bias, key_bias, value_bias, dim=0) + + # qkv_proj shape = (seq_length, batch_size, num_heads * head_dim * 3) + qkv_proj = F.FullyConnected(data=qkv, weight=in_weight, bias=in_bias, + num_hidden=self.units*3, no_bias=False, flatten=False) + att_score = F.contrib.interleaved_matmul_selfatt_qk(qkv_proj, heads=self._num_heads) + if valid_len is not None: + valid_len = F.broadcast_axis(F.expand_dims(valid_len, axis=1), + axis=1, size=self._num_heads) + valid_len = valid_len.reshape(shape=(-1, 0), reverse=True) + att_weights = F.softmax(att_score, length=valid_len, use_length=True, axis=-1) + else: + att_weights = F.softmax(att_score, axis=-1) + # att_weights shape = (batch_size, seq_length, seq_length) + att_weights = self.dropout_layer(att_weights) + context_vec = F.contrib.interleaved_matmul_selfatt_valatt(qkv_proj, att_weights, + heads=self._num_heads) + att_weights = att_weights.reshape(shape=(-1, self._num_heads, 0, 0), reverse=True) + return context_vec, att_weights + + +class BERTEncoderCell(HybridBlock): + """Structure of the BERT Encoder Cell. + + Parameters + ---------- + units : int + Number of units for the output + hidden_size : int + number of units in the hidden layer of position-wise feed-forward networks + num_heads : int + Number of heads in multi-head attention + dropout : float + output_attention: bool + Whether to output the attention weights + attention_use_bias : float, default True + Whether to use bias term in the attention cell + weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + bias_initializer : str or Initializer + Initializer for the bias vector. + prefix : str, default None + Prefix for name of `Block`s. (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. Created if `None`. + activation : str, default 'gelu' + Activation methods in PositionwiseFFN + layer_norm_eps : float, default 1e-5 + Epsilon for layer_norm + + Inputs: + - **inputs** : input sequence. Shape (length, batch_size, C_in) + - **valid_length** : valid length of inputs for attention. Shape (batch_size, length) + + Outputs: + - **outputs**: output tensor of the transformer encoder cell. + Shape (length, batch_size, C_out) + - **additional_outputs**: the additional output of all the BERT encoder cell. + """ + def __init__(self, units=128, hidden_size=512, num_heads=4, + dropout=0.0, output_attention=False, + attention_use_bias=True, + weight_initializer=None, bias_initializer='zeros', + prefix=None, params=None, activation='gelu', + layer_norm_eps=1e-5): + super().__init__(prefix=prefix, params=params) + self._dropout = dropout + self._output_attention = output_attention + with self.name_scope(): + if dropout: + self.dropout_layer = nn.Dropout(rate=dropout) + self.attention_cell = DotProductSelfAttentionCell(units, num_heads, + use_bias=attention_use_bias, + dropout=dropout) + self.proj = nn.Dense(units=units, flatten=False, use_bias=True, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, prefix='proj_') + self.ffn = PositionwiseFFN(units=units, hidden_size=hidden_size, dropout=dropout, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, activation=activation, + layer_norm_eps=layer_norm_eps) + self.layer_norm = nn.LayerNorm(in_channels=units, epsilon=layer_norm_eps) + + + def hybrid_forward(self, F, inputs, valid_len=None): # pylint: disable=arguments-differ + """Transformer Encoder Attention Cell. + + Parameters + ---------- + inputs : Symbol or NDArray + Input sequence. Shape (length, batch_size, C_in) + valid_len : Symbol or NDArray or None + Valid length for inputs. Shape (batch_size, length) + + Returns + ------- + encoder_cell_outputs: list + Outputs of the encoder cell. Contains: + + - outputs of the transformer encoder cell. Shape (length, batch_size, C_out) + - additional_outputs of all the transformer encoder cell + """ + outputs, attention_weights = self.attention_cell(inputs, valid_len) + outputs = self.proj(outputs) + if self._dropout: + outputs = self.dropout_layer(outputs) + # use residual + outputs = outputs + inputs + outputs = self.layer_norm(outputs) + outputs = self.ffn(outputs) + additional_outputs = [] + if self._output_attention: + additional_outputs.append(attention_weights) + return outputs, additional_outputs + class BERTEncoder(HybridBlock, Seq2SeqEncoder): """Structure of the BERT Encoder. @@ -47,9 +259,6 @@ class BERTEncoder(HybridBlock, Seq2SeqEncoder): Parameters ---------- - attention_cell : AttentionCell or str, default 'multi_head' - Arguments of the attention cell. - Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' num_layers : int Number of attention layers. units : int @@ -60,12 +269,8 @@ class BERTEncoder(HybridBlock, Seq2SeqEncoder): Maximum length of the input sequence num_heads : int Number of heads in multi-head attention - scaled : bool - Whether to scale the softmax input by the sqrt of the input dimension - in multi-head attention dropout : float Dropout probability of the attention probabilities and embedding. - use_residual : bool output_attention: bool, default False Whether to output the attention weights output_all_encodings: bool, default False @@ -85,21 +290,21 @@ class BERTEncoder(HybridBlock, Seq2SeqEncoder): Epsilon for layer_norm Inputs: - - **inputs** : input sequence of shape (batch_size, length, C_in) - - **states** : list of tensors for initial states and masks. + - **inputs** : input sequence of shape (length, batch_size, C_in) + - **states** : list of tensors for initial states and valid length for self attention. - **valid_length** : valid lengths of each sequence. Usually used when part of sequence has been padded. Shape is (batch_size, ) Outputs: - - **outputs** : the output of the encoder. Shape is (batch_size, length, C_out) + - **outputs** : the output of the encoder. Shape is (length, batch_size, C_out) - **additional_outputs** : list of tensors. Either be an empty list or contains the attention weights in this step. The attention weights will have shape (batch_size, num_heads, length, mem_length) """ - def __init__(self, *, attention_cell='multi_head', num_layers=2, units=512, hidden_size=2048, - max_length=50, num_heads=4, scaled=True, dropout=0.0, use_residual=True, + def __init__(self, *, num_layers=2, units=512, hidden_size=2048, + max_length=50, num_heads=4, dropout=0.0, output_attention=False, output_all_encodings=False, weight_initializer=None, bias_initializer='zeros', prefix=None, params=None, activation='gelu', layer_norm_eps=1e-12): @@ -122,11 +327,10 @@ def __init__(self, *, attention_cell='multi_head', num_layers=2, units=512, hidd init=weight_initializer) self.transformer_cells = nn.HybridSequential() for i in range(num_layers): - cell = TransformerEncoderCell( + cell = BERTEncoderCell( units=units, hidden_size=hidden_size, num_heads=num_heads, - attention_cell=attention_cell, weight_initializer=weight_initializer, - bias_initializer=bias_initializer, dropout=dropout, use_residual=use_residual, - attention_proj_use_bias=True, attention_use_bias=True, scaled=scaled, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, dropout=dropout, output_attention=output_attention, prefix='transformer%d_' % i, activation=activation, layer_norm_eps=layer_norm_eps) self.transformer_cells.add(cell) @@ -139,7 +343,7 @@ def __call__(self, inputs, states=None, valid_length=None): # pylint: disable=a inputs : NDArray or Symbol Input sequence. Shape (batch_size, length, C_in) states : list of NDArrays or Symbols - Initial states. The list of initial states and masks + Initial states. The list of initial states and valid length for self attention valid_length : NDArray or Symbol Valid lengths of each sequence. This is usually used when part of sequence has been padded. Shape (batch_size,) @@ -161,9 +365,9 @@ def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_wei Parameters ---------- inputs : NDArray or Symbol - Input sequence. Shape (batch_size, length, C_in) + Input sequence. Shape (length, batch_size, C_in) states : list of NDArrays or Symbols - Initial states. The list of initial states and masks + Initial states. The list of initial states and valid length for self attention valid_length : NDArray or Symbol Valid lengths of each sequence. This is usually used when part of sequence has been padded. Shape (batch_size,) @@ -173,26 +377,27 @@ def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_wei outputs : NDArray or Symbol, or List[NDArray] or List[Symbol] If output_all_encodings flag is False, then the output of the last encoder. If output_all_encodings flag is True, then the list of all outputs of all encoders. - In both cases, shape of the tensor(s) is/are (batch_size, length, C_out) + In both cases, shape of the tensor(s) is/are (length, batch_size, C_out) additional_outputs : list Either be an empty list or contains the attention weights in this step. - The attention weights will have shape (batch_size, length, length) or + The attention weights will have shape (batch_size, length) or (batch_size, num_heads, length, length) """ - steps = F.contrib.arange_like(inputs, axis=1) + # axis 0 is for length + steps = F.contrib.arange_like(inputs, axis=0) if valid_length is not None: - ones = F.ones_like(steps) - mask = F.broadcast_lesser(F.reshape(steps, shape=(1, -1)), - F.reshape(valid_length, shape=(-1, 1))) - mask = F.broadcast_mul(F.expand_dims(mask, axis=1), - F.broadcast_mul(ones, F.reshape(ones, shape=(-1, 1)))) + zeros = F.zeros_like(steps) + # valid_length for attention, shape = (batch_size, seq_length) + attn_valid_len = F.broadcast_add(F.reshape(valid_length, shape=(-1, 1)), + F.reshape(zeros, shape=(1, -1))) + attn_valid_len = F.cast(attn_valid_len, dtype='int32') if states is None: - states = [mask] + states = [attn_valid_len] else: - states.append(mask) + states.append(attn_valid_len) else: - mask = None + attn_valid_len = None if states is None: states = [steps] @@ -201,7 +406,7 @@ def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_wei # positional encoding positional_embed = F.Embedding(steps, position_weight, self._max_length, self._units) - inputs = F.broadcast_add(inputs, F.expand_dims(positional_embed, axis=0)) + inputs = F.broadcast_add(inputs, F.expand_dims(positional_embed, axis=1)) if self._dropout: inputs = self.dropout_layer(inputs) @@ -211,12 +416,12 @@ def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_wei all_encodings_outputs = [] additional_outputs = [] for cell in self.transformer_cells: - outputs, attention_weights = cell(inputs, mask) + outputs, attention_weights = cell(inputs, attn_valid_len) inputs = outputs if self._output_all_encodings: if valid_length is not None: outputs = F.SequenceMask(outputs, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, axis=0) all_encodings_outputs.append(outputs) if self._output_attention: @@ -225,7 +430,7 @@ def hybrid_forward(self, F, inputs, states=None, valid_length=None, position_wei if valid_length is not None and not self._output_all_encodings: # if self._output_all_encodings, SequenceMask is already applied above outputs = F.SequenceMask(outputs, sequence_length=valid_length, - use_sequence_length=True, axis=1) + use_sequence_length=True, axis=0) if self._output_all_encodings: return all_encodings_outputs, additional_outputs @@ -426,8 +631,15 @@ def _encode_sequence(self, inputs, token_types, valid_length=None): if self._use_token_type_embed: type_embedding = self.token_type_embed(token_types) embedding = embedding + type_embedding + # (batch, seq_len, C) -> (seq_len, batch, C) + embedding = embedding.transpose((1, 0, 2)) # encoding outputs, additional_outputs = self.encoder(embedding, valid_length=valid_length) + # (seq_len, batch, C) -> (batch, seq_len, C) + if isinstance(outputs, (list, tuple)): + outputs = [o.transpose((1, 0, 2)) for o in outputs] + else: + outputs = outputs.transpose((1, 0, 2)) return outputs, additional_outputs def _apply_pooling(self, sequence): @@ -791,30 +1003,24 @@ def hybrid_forward(self, F, inputs, valid_length=None): ]}) roberta_12_768_12_hparams = { - 'attention_cell': 'multi_head', 'num_layers': 12, 'units': 768, 'hidden_size': 3072, 'max_length': 512, 'num_heads': 12, - 'scaled': True, 'dropout': 0.1, - 'use_residual': True, 'embed_size': 768, 'word_embed': None, 'layer_norm_eps': 1e-5 } roberta_24_1024_16_hparams = { - 'attention_cell': 'multi_head', 'num_layers': 24, 'units': 1024, 'hidden_size': 4096, 'max_length': 512, 'num_heads': 16, - 'scaled': True, 'dropout': 0.1, - 'use_residual': True, 'embed_size': 1024, 'word_embed': None, 'layer_norm_eps': 1e-5 @@ -835,45 +1041,36 @@ def hybrid_forward(self, F, inputs, valid_length=None): } bert_12_768_12_hparams = { - 'attention_cell': 'multi_head', 'num_layers': 12, 'units': 768, 'hidden_size': 3072, 'max_length': 512, 'num_heads': 12, - 'scaled': True, 'dropout': 0.1, - 'use_residual': True, 'embed_size': 768, 'token_type_vocab_size': 2, 'word_embed': None, } bert_24_1024_16_hparams = { - 'attention_cell': 'multi_head', 'num_layers': 24, 'units': 1024, 'hidden_size': 4096, 'max_length': 512, 'num_heads': 16, - 'scaled': True, 'dropout': 0.1, - 'use_residual': True, 'embed_size': 1024, 'token_type_vocab_size': 2, 'word_embed': None, } ernie_12_768_12_hparams = { - 'attention_cell': 'multi_head', 'num_layers': 12, 'units': 768, 'hidden_size': 3072, 'max_length': 513, 'num_heads': 12, - 'scaled': True, 'dropout': 0.1, - 'use_residual': True, 'embed_size': 768, 'token_type_vocab_size': 2, 'word_embed': None, @@ -1307,17 +1504,14 @@ def get_roberta_model(model_name=None, dataset_name=None, vocab=None, pretrained predefined_args.update(kwargs) # encoder - encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], - num_layers=predefined_args['num_layers'], + encoder = BERTEncoder(num_layers=predefined_args['num_layers'], units=predefined_args['units'], hidden_size=predefined_args['hidden_size'], max_length=predefined_args['max_length'], num_heads=predefined_args['num_heads'], - scaled=predefined_args['scaled'], dropout=predefined_args['dropout'], output_attention=output_attention, output_all_encodings=output_all_encodings, - use_residual=predefined_args['use_residual'], activation=predefined_args.get('activation', 'gelu'), layer_norm_eps=predefined_args.get('layer_norm_eps', 1e-5)) @@ -1424,17 +1618,14 @@ def get_bert_model(model_name=None, dataset_name=None, vocab=None, pretrained=Tr 'Cannot override predefined model settings.' predefined_args.update(kwargs) # encoder - encoder = BERTEncoder(attention_cell=predefined_args['attention_cell'], - num_layers=predefined_args['num_layers'], + encoder = BERTEncoder(num_layers=predefined_args['num_layers'], units=predefined_args['units'], hidden_size=predefined_args['hidden_size'], max_length=predefined_args['max_length'], num_heads=predefined_args['num_heads'], - scaled=predefined_args['scaled'], dropout=predefined_args['dropout'], output_attention=output_attention, output_all_encodings=output_all_encodings, - use_residual=predefined_args['use_residual'], activation=predefined_args.get('activation', 'gelu'), layer_norm_eps=predefined_args.get('layer_norm_eps', 1e-12))