Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/transformers/models/flaubert/modeling_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
# Copied from transformers.models.xlm.modeling_xlm.create_sinusoidal_embeddings
def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False


# Copied from transformers.models.xlm.modeling_xlm.get_masks
Expand Down Expand Up @@ -370,6 +370,10 @@ def _init_weights(self, module):
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings:
create_sinusoidal_embeddings(
self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight
)


class FlaubertModel(FlaubertPreTrainedModel):
Expand Down Expand Up @@ -407,8 +411,6 @@ def __init__(self, config): # , dico, is_encoder, with_output):

# embeddings
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/informer/modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
elif isinstance(module, nn.Embedding) and not isinstance(module, InformerSinusoidalPositionalEmbedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/xlm/modeling_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@

def create_sinusoidal_embeddings(n_pos, dim, out):
position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
out.requires_grad = False
out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
out.requires_grad = False


def get_masks(slen, lengths, causal, padding_mask=None):
Expand Down Expand Up @@ -245,6 +245,10 @@ def _init_weights(self, module):
if isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, XLMModel) and self.config.sinusoidal_embeddings:
create_sinusoidal_embeddings(
self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight
)


@dataclass
Expand Down Expand Up @@ -414,8 +418,6 @@ def __init__(self, config):

# embeddings
self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
if config.sinusoidal_embeddings:
create_sinusoidal_embeddings(config.max_position_embeddings, self.dim, out=self.position_embeddings.weight)
if config.n_langs > 1 and config.use_lang_emb:
self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
Expand Down
9 changes: 9 additions & 0 deletions tests/models/flaubert/test_modeling_flaubert.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
FlaubertModel,
FlaubertWithLMHeadModel,
)
from transformers.models.flaubert.modeling_flaubert import create_sinusoidal_embeddings


class FlaubertModelTester(object):
Expand Down Expand Up @@ -431,6 +432,14 @@ def test_flaubert_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_model(*config_and_inputs)

# Copied from tests/models/distilbert/test_modeling_distilbert.py with Distilbert->Flaubert
def test_flaubert_model_with_sinusoidal_encodings(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's use copied from for the test as well!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this solve it?

config = FlaubertConfig(sinusoidal_embeddings=True)
model = FlaubertModel(config=config)
sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.emb_dim), dtype=torch.float32)
create_sinusoidal_embeddings(config.max_position_embeddings, config.emb_dim, sinusoidal_pos_embds)
self.model_tester.parent.assertTrue(torch.equal(model.position_embeddings.weight, sinusoidal_pos_embds))

def test_flaubert_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_flaubert_lm_head(*config_and_inputs)
Expand Down
12 changes: 11 additions & 1 deletion tests/models/informer/test_modeling_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
import torch

from transformers import InformerConfig, InformerForPrediction, InformerModel
from transformers.models.informer.modeling_informer import InformerDecoder, InformerEncoder
from transformers.models.informer.modeling_informer import (
InformerDecoder,
InformerEncoder,
InformerSinusoidalPositionalEmbedding,
)


@require_torch
Expand Down Expand Up @@ -164,6 +168,12 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict):

self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)

embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight))
self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight))

with tempfile.TemporaryDirectory() as tmpdirname:
decoder = model.get_decoder()
decoder.save_pretrained(tmpdirname)
Expand Down
9 changes: 9 additions & 0 deletions tests/models/xlm/test_modeling_xlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
XLMModel,
XLMWithLMHeadModel,
)
from transformers.models.xlm.modeling_xlm import create_sinusoidal_embeddings


class XLMModelTester:
Expand Down Expand Up @@ -432,6 +433,14 @@ def test_xlm_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_model(*config_and_inputs)

# Copied from tests/models/distilbert/test_modeling_distilbert.py with Distilbert->XLM
def test_xlm_model_with_sinusoidal_encodings(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing copied from

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker does this solve it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but we need "with Distlbert->XLM"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed them

config = XLMConfig(sinusoidal_embeddings=True)
model = XLMModel(config=config)
sinusoidal_pos_embds = torch.empty((config.max_position_embeddings, config.emb_dim), dtype=torch.float32)
create_sinusoidal_embeddings(config.max_position_embeddings, config.emb_dim, sinusoidal_pos_embds)
self.model_tester.parent.assertTrue(torch.equal(model.position_embeddings.weight, sinusoidal_pos_embds))

def test_xlm_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_xlm_lm_head(*config_and_inputs)
Expand Down