Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,8 @@ def _init_weights(self, module):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, T5DenseReluDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,9 @@ def _shift_right(self, input_ids):
class TFT5Model(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)

# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
Expand Down Expand Up @@ -1259,8 +1261,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.model_dim = config.d_model

self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)

# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
Expand Down Expand Up @@ -1600,7 +1603,9 @@ def _reorder_cache(self, past, beam_idx):
class TFT5EncoderModel(TFT5PreTrainedModel):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
self.shared = TFSharedEmbeddings(
config.vocab_size, config.d_model, name="shared", initializer_range=self.config.initializer_factor
)

# retrieve correct absolute scope for embed token wrapper
with tf.compat.v1.variable_scope("shared") as shared_abs_scope_name:
Expand Down