diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 372e249a3094..ea6ee3f3e51d 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -500,6 +500,7 @@ class BartPretrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder.version", r"decoder.version"] + _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] def _init_weights(self, module): std = self.config.init_std @@ -712,10 +713,10 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = BartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -801,6 +802,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) @@ -884,10 +886,10 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = BartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -1043,6 +1045,7 @@ def forward( # embed positions positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) @@ -1373,7 +1376,9 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) masked_lm_loss = None if labels is not None: diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 40ed916b0c71..c05370217596 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1595,6 +1595,7 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): config_class = BigBirdPegasusConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] def _init_weights(self, module): std = self.config.init_std @@ -1788,10 +1789,10 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( config.max_position_embeddings, @@ -2082,10 +2083,10 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = BigBirdPegasusLearnedPositionalEmbedding( config.max_position_embeddings, @@ -2240,6 +2241,7 @@ def forward( # embed positions positions = self.embed_positions(input_shape, past_key_values_length) + positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions @@ -2573,7 +2575,9 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) masked_lm_loss = None if labels is not None: diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index fa58563ec42d..f46615c00b15 100755 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -506,6 +506,7 @@ class PLBartPreTrainedModel(PreTrainedModel): config_class = PLBartConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] def _init_weights(self, module): std = self.config.init_std @@ -683,10 +684,10 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = self.max_source_positions = config.max_position_embeddings self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = PLBartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -772,6 +773,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale embed_pos = self.embed_positions(input) + embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos hidden_states = self.layernorm_embedding(hidden_states) @@ -856,10 +858,10 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = self.max_target_positions = config.max_position_embeddings self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + if embed_tokens is not None: - self.embed_tokens = embed_tokens - else: - self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) + self.embed_tokens.weight = embed_tokens.weight self.embed_positions = PLBartLearnedPositionalEmbedding( config.max_position_embeddings, @@ -1015,6 +1017,7 @@ def forward( # embed positions positions = self.embed_positions(input, past_key_values_length) + positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) @@ -1334,7 +1337,8 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) masked_lm_loss = None if labels is not None: