diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 9b5dffd3517b..d0fb12a634d3 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -648,9 +648,12 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets def compute_bias(self, block_length: int): """Compute binned relative position bias""" - memory_position = torch.arange( - 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device + target_device = ( + self.relative_attention_bias.weight.device + if self.relative_attention_bias.weight.device.type != "meta" + else None ) + memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device) context_position = memory_position[block_length:-block_length] # (block_length, 3 * block_length) @@ -843,9 +846,12 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets def compute_bias(self, block_length: int): """Compute binned relative position bias""" - memory_position = torch.arange( - 3 * block_length, dtype=torch.long, device=self.relative_attention_bias.weight.device + target_device = ( + self.relative_attention_bias.weight.device + if self.relative_attention_bias.weight.device.type != "meta" + else None ) + memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device) context_position = memory_position[block_length:-block_length] # (block_length, 3 * block_length) @@ -1271,6 +1277,7 @@ class LongT5PreTrainedModel(PreTrainedModel): config_class = LongT5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True + _no_split_modules = ["LongT5Block"] @property # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs @@ -1366,7 +1373,9 @@ class LongT5Stack(LongT5PreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) - self.embed_tokens = embed_tokens + self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight self.is_decoder = config.is_decoder self.local_radius = config.local_radius