Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/modular-transformers/modeling_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,8 +455,9 @@ def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)

# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
Expand Down
3 changes: 2 additions & 1 deletion examples/modular-transformers/modular_new_task_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def resize_token_embeddings(
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of=None,
mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)

# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
Expand Down
19 changes: 15 additions & 4 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,11 +1189,11 @@ def set_output_embeddings(self, new_output_embeddings):
# one lm_head for each codebook
self.lm_heads = new_output_embeddings

def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
old_embeddings_list = self.get_input_embeddings()
new_embeddings_list = nn.ModuleList(
[
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing)
for old_embeddings in old_embeddings_list
]
)
Expand All @@ -1211,7 +1211,10 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
return self.get_input_embeddings()

def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
self,
new_num_tokens: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
mean_resizing: bool = True,
) -> nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
Expand All @@ -1230,11 +1233,19 @@ def resize_token_embeddings(
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
mean_resizing (`bool`):
Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.

Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html

Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,8 +1577,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2457,8 +2457,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,8 +1232,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1184,8 +1184,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1295,8 +1295,10 @@ def prepare_inputs_for_generation(
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
return self._shift_right(labels)

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/led/modeling_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,8 +2319,10 @@ def get_encoder(self):
def get_decoder(self):
return self.led.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/lxmert/modeling_lxmert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,9 +1072,11 @@ def __init__(self, config):
}
self.visual_losses = visual_losses

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
# Adding the following steps to resize bias to match the shape of resized embeddings
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self.cls.predictions.bias = self._resize_bias(self.cls.predictions.bias, new_num_tokens)
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/marian/modeling_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,8 +1252,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
if self.config.share_encoder_decoder_embeddings:
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,8 +1546,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mvp/modeling_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,8 +1370,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_num_tokens)
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/omdet_turbo/modeling_omdet_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,9 +1658,11 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_backbone.model.set_input_embeddings(value)

def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing: bool = True
) -> nn.Embedding:
model_embeds = self.language_backbone.model.resize_token_embeddings(
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of
new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of, mean_resizing=mean_resizing
)
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/pegasus/modeling_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,8 +1265,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,8 +1274,10 @@ def get_encoder(self):
def get_decoder(self):
return self.model.get_decoder()

def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
def resize_token_embeddings(
self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
) -> nn.Embedding:
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
return new_embeddings

Expand Down