diff --git a/src/gluonnlp/models/bart.py b/src/gluonnlp/models/bart.py index 6b011a7566..fae242c5bc 100644 --- a/src/gluonnlp/models/bart.py +++ b/src/gluonnlp/models/bart.py @@ -80,6 +80,7 @@ def bart_base(): cfg.MODEL.ENCODER.recurrent = False cfg.MODEL.ENCODER.pre_norm = False cfg.MODEL.ENCODER.activation = 'gelu' + cfg.MODEL.ENCODER.use_qkv_bias = True # Parameters for the decoder cfg.MODEL.DECODER = CN() @@ -90,6 +91,7 @@ def bart_base(): cfg.MODEL.DECODER.recurrent = False cfg.MODEL.DECODER.pre_norm = False cfg.MODEL.DECODER.activation = 'gelu' + cfg.MODEL.DECODER.use_qkv_bias = True # Parameters for the initializer cfg.INITIALIZER = CN() diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 8e55111e19..8e30f1048a 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -45,6 +45,7 @@ def transformer_base(): cfg.MODEL.ENCODER.recurrent = False cfg.MODEL.ENCODER.activation = 'relu' cfg.MODEL.ENCODER.pre_norm = False + cfg.MODEL.ENCODER.use_qkv_bias = True # Parameters for the decoder cfg.MODEL.DECODER = CN() @@ -55,6 +56,7 @@ def transformer_base(): cfg.MODEL.DECODER.recurrent = False cfg.MODEL.DECODER.activation = 'relu' cfg.MODEL.DECODER.pre_norm = False + cfg.MODEL.DECODER.use_qkv_bias = False # Parameters for the initializer cfg.INITIALIZER = CN() @@ -161,6 +163,7 @@ def __init__(self, data -> attn -> norm(res(+data)) -> ffn use_qkv_bias + Wether to use bias for self attention weight_initializer bias_initializer activation @@ -265,7 +268,7 @@ def hybrid_forward(self, F, data, attn_mask): class TransformerEncoder(HybridBlock): def __init__(self, num_layers=6, recurrent=False, units=512, hidden_size=2048, num_heads=8, - activation_dropout=0.0, dropout=0.1, + activation_dropout=0.0, dropout=0.1, use_qkv_bias=True, attention_dropout=0.1, layer_norm_eps=1E-5, data_norm=False, pre_norm=False, weight_initializer=None, bias_initializer='zeros', activation='relu', dtype='float32', layout='NT'): @@ -319,6 +322,7 @@ def __init__(self, num_layers=6, recurrent=False, hidden_dropout_prob=dropout, attention_dropout_prob=attention_dropout, activation_dropout_prob=activation_dropout, + use_qkv_bias=use_qkv_bias, layer_norm_eps=layer_norm_eps, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -384,6 +388,7 @@ def __init__(self, units: int = 512, layer_norm_eps: float = 1E-5, activation: str = 'relu', pre_norm: bool = False, + use_qkv_bias: bool = True, weight_initializer=None, bias_initializer='zeros', dtype='float32', @@ -405,6 +410,8 @@ def __init__(self, units: int = 512, activation pre_norm Whether to apply normalization before the attention layer + use_qkv_bias + Wether to use bias for both self attention and contextual attention weight_initializer bias_initializer dtype @@ -431,7 +438,7 @@ def __init__(self, units: int = 512, raise ValueError('In Transformer, units should be divided exactly by the number of ' 'heads. Received units={}, num_heads={}'.format(units, num_heads)) self.attn_in_qkv = nn.Dense(3 * units, in_units=units, - use_bias=True, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -447,19 +454,19 @@ def __init__(self, units: int = 512, dtype=dtype) self.attn_inter_q = nn.Dense(units, in_units=units, - use_bias=True, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype) self.attn_inter_k = nn.Dense(units, in_units=mem_units, - use_bias=True, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, dtype=dtype) self.attn_inter_v = nn.Dense(units, in_units=mem_units, - use_bias=True, + use_bias=use_qkv_bias, flatten=False, weight_initializer=weight_initializer, bias_initializer=bias_initializer, @@ -706,7 +713,7 @@ def incremental_decode(self, F, data, states, mem, mem_valid_length, mem_attn_ma @use_np class TransformerDecoder(HybridBlock): def __init__(self, num_layers=6, recurrent=False, - units=512, mem_units=None, hidden_size=2048, + units=512, mem_units=None, hidden_size=2048, use_qkv_bias=True, num_heads=8, max_shift=None, activation_dropout=0.0, dropout=0.1, attention_dropout=0.1, layer_norm_eps=1E-5, data_norm=False, pre_norm=False, weight_initializer=None, bias_initializer=None, @@ -740,6 +747,7 @@ def __init__(self, num_layers=6, recurrent=False, hidden_size=hidden_size, num_heads=num_heads, activation_dropout=activation_dropout, + use_qkv_bias=use_qkv_bias, dropout=dropout, attention_dropout=attention_dropout, layer_norm_eps=layer_norm_eps, @@ -930,6 +938,7 @@ def __init__(self, src_vocab_size: int, enc_recurrent: bool = False, enc_activation='relu', enc_pre_norm: bool = False, + enc_use_qkv_bias: bool = True, dec_units: int = 512, dec_hidden_size: int = 2048, dec_num_heads: int = 8, @@ -937,6 +946,7 @@ def __init__(self, src_vocab_size: int, dec_recurrent: bool = False, dec_activation='relu', dec_pre_norm: bool = False, + dec_use_qkv_bias: bool = True, embed_initializer=mx.init.Xavier('gaussian', 'in', 1), weight_initializer=mx.init.Xavier('uniform', 'avg', 3), bias_initializer='zeros', @@ -988,6 +998,8 @@ def __init__(self, src_vocab_size: int, Activation of the encoder layer enc_pre_norm Whether to add layer_norm before self-attention in the encoder + enc_use_qkv_bias + Wether to use bias for attention layer in the encoder dec_units Units of the decoder dec_hidden_size @@ -1002,6 +1014,8 @@ def __init__(self, src_vocab_size: int, Activation of the decoder layer dec_pre_norm Whether to add layer_norm before self-attention in the decoder + dec_use_qkv_bias + Wether to use bias for attention layer in the decoder embed_initializer Initializer of the embedding layer weight_initializer @@ -1067,6 +1081,7 @@ def __init__(self, src_vocab_size: int, hidden_size=enc_hidden_size, num_heads=enc_num_heads, activation_dropout=activation_dropout, + use_qkv_bias=enc_use_qkv_bias, dropout=dropout, attention_dropout=attention_dropout, layer_norm_eps=layer_norm_eps, @@ -1084,6 +1099,7 @@ def __init__(self, src_vocab_size: int, hidden_size=dec_hidden_size, num_heads=dec_num_heads, activation_dropout=activation_dropout, + use_qkv_bias=dec_use_qkv_bias, dropout=dropout, attention_dropout=attention_dropout, layer_norm_eps=layer_norm_eps, @@ -1272,6 +1288,7 @@ def from_cfg(cls, cfg, dtype=None): enc_recurrent=cfg.MODEL.ENCODER.recurrent, enc_activation=cfg.MODEL.ENCODER.activation, enc_pre_norm=cfg.MODEL.ENCODER.pre_norm, + enc_use_qkv_bias=cfg.MODEL.ENCODER.use_qkv_bias, dec_num_layers=cfg.MODEL.DECODER.num_layers, dec_units=cfg.MODEL.DECODER.units, dec_num_heads=cfg.MODEL.DECODER.num_heads, @@ -1279,6 +1296,7 @@ def from_cfg(cls, cfg, dtype=None): dec_recurrent=cfg.MODEL.DECODER.recurrent, dec_activation=cfg.MODEL.DECODER.activation, dec_pre_norm=cfg.MODEL.DECODER.pre_norm, + dec_use_qkv_bias=cfg.MODEL.DECODER.use_qkv_bias, layout=cfg.MODEL.layout, embed_initializer=embed_initializer, weight_initializer=weight_initializer,