diff --git a/model_cards/squeezebert/squeezebert-uncased/README.md b/model_cards/squeezebert/squeezebert-uncased/README.md index 597abcd56604..c1a38cab0b8d 100644 --- a/model_cards/squeezebert/squeezebert-uncased/README.md +++ b/model_cards/squeezebert/squeezebert-uncased/README.md @@ -22,7 +22,7 @@ The authors found that SqueezeBERT is 4.3x faster than `bert-base-uncased` on a The model is pretrained using the Masked Language Model (MLM) and Sentence Order Prediction (SOP) tasks. (Author's note: If you decide to pretrain your own model, and you prefer to train with MLM only, that should work too.) -The SqueezeBERT paper presents 2 approaches to finetuning the model: +From the SqueezeBERT paper: > We pretrain SqueezeBERT from scratch (without distillation) using the [LAMB](https://arxiv.org/abs/1904.00962) optimizer, and we employ the hyperparameters recommended by the LAMB authors: a global batch size of 8192, a learning rate of 2.5e-3, and a warmup proportion of 0.28. Following the LAMB paper's recommendations, we pretrain for 56k steps with a maximum sequence length of 128 and then for 6k steps with a maximum sequence length of 512. ## Finetuning diff --git a/src/transformers/modeling_squeezebert.py b/src/transformers/modeling_squeezebert.py index 620280a31681..875f0fdd4fb8 100644 --- a/src/transformers/modeling_squeezebert.py +++ b/src/transformers/modeling_squeezebert.py @@ -373,6 +373,53 @@ def forward(self, hidden_states): return pooled_output +class SqueezeBertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class SqueezeBertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = SqueezeBertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class SqueezeBertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = SqueezeBertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + class SqueezeBertPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -594,16 +641,19 @@ def forward( @add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING) class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): + + authorized_missing_keys = [r"predictions.decoder.bias"] + def __init__(self, config): super().__init__(config) self.transformer = SqueezeBertModel(config) - self.lm_head = nn.Linear(config.embedding_size, config.vocab_size) + self.cls = SqueezeBertOnlyMLMHead(config) self.init_weights() def get_output_embeddings(self): - return self.lm_head + return self.cls.predictions.decoder @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( @@ -646,7 +696,7 @@ def forward( ) sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) + prediction_scores = self.cls(sequence_output) masked_lm_loss = None if labels is not None: