Skip to content
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
1e82f7e
Remove low_cpu_mem_usage and _fast_init
Cyrilvallez Mar 25, 2025
cd19480
Update deepspeed.py
Cyrilvallez Mar 25, 2025
148abd7
Update modeling_utils.py
Cyrilvallez Mar 25, 2025
076155b
remove the first 2 tests everywhere
Cyrilvallez Mar 25, 2025
5800645
Update test_modeling_common.py
Cyrilvallez Mar 25, 2025
b9501d7
remove what was remaining about fast_init
Cyrilvallez Mar 25, 2025
6813fcc
fix logic and simplify
Cyrilvallez Mar 25, 2025
ee83ab3
mismatched keys logic update
Cyrilvallez Mar 25, 2025
209d0d2
Update modeling_utils.py
Cyrilvallez Mar 25, 2025
28d8185
Update modeling_utils.py
Cyrilvallez Mar 25, 2025
f7c7490
Update modeling_utils.py
Cyrilvallez Mar 25, 2025
139d2a4
Update modeling_utils.py
Cyrilvallez Mar 25, 2025
5a30f74
fix 2 models init_weights
Cyrilvallez Mar 25, 2025
6878e1e
extend to others
Cyrilvallez Mar 25, 2025
d0918bb
remove grad
Cyrilvallez Mar 25, 2025
5f60e23
Update modeling_fsmt.py
Cyrilvallez Mar 26, 2025
f55fccb
init weights in tests
Cyrilvallez Mar 26, 2025
67baba2
style
Cyrilvallez Mar 26, 2025
b02f1cf
Update test_modeling_fsmt.py
Cyrilvallez Mar 26, 2025
6545e0c
more old models
Cyrilvallez Mar 26, 2025
15315b6
fix more init_weights
Cyrilvallez Mar 26, 2025
a178716
copies
Cyrilvallez Mar 26, 2025
8fe6897
fix
Cyrilvallez Mar 26, 2025
339483e
style
Cyrilvallez Mar 26, 2025
57ec41e
Update modeling_lxmert.py
Cyrilvallez Mar 26, 2025
a231ac5
fix inits
Cyrilvallez Mar 26, 2025
90e4acd
more and more
Cyrilvallez Mar 26, 2025
c1aeaf6
more
Cyrilvallez Mar 26, 2025
212f263
should finalize
Cyrilvallez Mar 27, 2025
faf6078
style
Cyrilvallez Mar 27, 2025
a958296
Update modeling_dinov2_with_registers.py
Cyrilvallez Mar 27, 2025
08af665
fix
Cyrilvallez Mar 27, 2025
1deb90a
Update modeling_encoder_decoder.py
Cyrilvallez Mar 27, 2025
30fc095
fix
Cyrilvallez Mar 27, 2025
fd737fb
style
Cyrilvallez Mar 27, 2025
315fe4d
Update modeling_lxmert.py
Cyrilvallez Mar 27, 2025
2f12c30
post rebase cleanup
Cyrilvallez Mar 27, 2025
9b6b0b7
Update modeling_informer.py
Cyrilvallez Mar 27, 2025
cf903a0
back to start for device
Cyrilvallez Mar 27, 2025
d2d13bb
fix
Cyrilvallez Mar 27, 2025
501b51b
add test to detect all failing cases correctly
Cyrilvallez Mar 27, 2025
f030a5c
Update test_modeling_common.py
Cyrilvallez Mar 27, 2025
ad69a59
fix
Cyrilvallez Mar 27, 2025
e73cf5b
fix
Cyrilvallez Mar 27, 2025
72c0385
sam
Cyrilvallez Mar 27, 2025
59b9685
style
Cyrilvallez Mar 27, 2025
308fa26
Update modeling_maskformer_swin.py
Cyrilvallez Mar 27, 2025
2d9df9a
CIs
Cyrilvallez Mar 27, 2025
3eef3bf
CIs
Cyrilvallez Mar 28, 2025
84548ef
remove test - will add it on separate PR
Cyrilvallez Mar 28, 2025
94e646c
fix
Cyrilvallez Mar 28, 2025
ec0e1af
fix
Cyrilvallez Mar 28, 2025
9942ed7
Update modeling_sam.py
Cyrilvallez Mar 28, 2025
bbc1a33
CIs
Cyrilvallez Mar 28, 2025
78ab3f5
CIs
Cyrilvallez Mar 28, 2025
a38b3b4
CIs
Cyrilvallez Mar 31, 2025
ecc3867
convnext
Cyrilvallez Mar 31, 2025
f3ab843
suggestions
Cyrilvallez Mar 31, 2025
ec10b3a
CIs
Cyrilvallez Mar 31, 2025
17753d7
Merge branch 'main' into remove-low-mem
ydshieh Mar 31, 2025
4d84dad
fix copies after merge
Cyrilvallez Mar 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@
"test_keep_in_fp32_modules",
"test_gradient_checkpointing_backward_compatibility",
"test_gradient_checkpointing_enable_disable",
"test_save_load_fast_init_from_base",
"test_fast_init_context_manager",
"test_fast_init_tied_embeddings",
"test_save_load_fast_init_to_base",
"test_torch_save_load",
"test_initialization",
"test_forward_signature",
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def deepspeed_config():
return None


def _load_state_dict_into_zero3_model(model_to_load, state_dict, assign_to_params_buffers=False):
def _load_state_dict_into_zero3_model(model_to_load, state_dict):
"""
Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
tensor parallelism API.
Expand Down Expand Up @@ -349,10 +349,7 @@ def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=Fals
if child is not None:
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)

load(model_to_load, state_dict, assign_to_params_buffers=assign_to_params_buffers)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
load(model_to_load, state_dict, assign_to_params_buffers=False)

return error_msgs

Expand Down
376 changes: 143 additions & 233 deletions src/transformers/modeling_utils.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ class ASTPreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flash_attn_2 = True

# Copied from transformers.models.deit.modeling_deit.DeiTPreTrainedModel._init_weights
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
Expand All @@ -415,6 +414,10 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ASTEmbeddings):
module.cls_token.data.zero_()
module.position_embeddings.data.zero_()
module.distillation_token.data.zero_()


AUDIO_SPECTROGRAM_TRANSFORMER_START_DOCSTRING = r"""
Expand Down
12 changes: 5 additions & 7 deletions src/transformers/models/autoformer/modeling_autoformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,22 +361,20 @@ class AutoformerSinusoidalPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)

@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
def _init_weight(self):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
n_pos, dim = self.weight.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
out = torch.empty(n_pos, dim, dtype=self.weight.dtype, requires_grad=False)
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
self.weight = nn.Parameter(out, requires_grad=False)

@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
Expand Down Expand Up @@ -903,7 +901,7 @@ def _init_weights(self, module):
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, AutoformerSinusoidalPositionalEmbedding):
module.weight = module._init_weight(module.weight)
module._init_weight()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,18 @@ def _init_weights(self, module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, BeitEmbeddings):
module.cls_token.data.zero_()
if module.mask_token is not None:
module.mask_token.data.zero_()
if module.position_embeddings is not None:
module.position_embeddings.data.zero_()
elif isinstance(module, BeitRelativePositionBias):
module.relative_position_bias_table.data.zero_()
elif isinstance(module, BeitLayer):
if module.lambda_1 is not None:
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
module.lambda_2.data.fill_(self.config.layer_scale_init_value)


BEIT_START_DOCSTRING = r"""
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,8 @@ def _init_weights(self, module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, BertLMPredictionHead):
module.bias.data.zero_()


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/camembert/modeling_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ class CamembertPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_sdpa = True

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->CamembertLMHead
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
Expand All @@ -731,6 +731,8 @@ def _init_weights(self, module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, CamembertLMHead):
module.bias.data.zero_()


CAMEMBERT_INPUTS_DOCSTRING = r"""
Expand Down
5 changes: 4 additions & 1 deletion src/transformers/models/convnext/modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,12 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
elif isinstance(module, (nn.LayerNorm, ConvNextLayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ConvNextLayer):
if module.layer_scale_parameter is not None:
module.layer_scale_parameter.data.fill_(self.config.layer_scale_init_value)


CONVNEXT_START_DOCSTRING = r"""
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/convnextv2/modeling_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def forward(
)


# Copied from transformers.models.convnext.modeling_convnext.ConvNextPreTrainedModel with ConvNext->ConvNextV2, convnext->convnextv2
class ConvNextV2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
Expand All @@ -307,9 +306,12 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
elif isinstance(module, (nn.LayerNorm, ConvNextV2LayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ConvNextV2GRN):
module.weight.data.zero_()
module.bias.data.zero_()


CONVNEXTV2_START_DOCSTRING = r"""
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/models/data2vec/modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,18 @@ def _init_weights(self, module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Data2VecVisionEmbeddings):
module.cls_token.data.zero_()
if module.mask_token is not None:
module.mask_token.data.zero_()
if module.position_embeddings is not None:
module.position_embeddings.data.zero_()
elif isinstance(module, Data2VecVisionRelativePositionBias):
module.relative_position_bias_table.data.zero_()
elif isinstance(module, Data2VecVisionLayer):
if module.lambda_1 is not None:
module.lambda_1.data.fill_(self.config.layer_scale_init_value)
module.lambda_2.data.fill_(self.config.layer_scale_init_value)


DATA2VEC_VISION_START_DOCSTRING = r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1850,30 +1850,20 @@ def __init__(self, config: DeformableDetrConfig):
num_layers=3,
)

prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
self.class_embed.bias.data = torch.ones(config.num_labels) * bias_value
nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)

# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (config.decoder_layers + 1) if config.two_stage else config.decoder_layers
if config.with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
# hack implementation for iterative bounding box refinement
self.model.decoder.bbox_embed = self.bbox_embed
else:
nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)])
self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)])
self.model.decoder.bbox_embed = None
if config.two_stage:
# hack implementation for two-stage
self.model.decoder.class_embed = self.class_embed
for box_embed in self.bbox_embed:
nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0)

# Initialize weights and apply final processing
self.post_init()
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,12 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DeiTEmbeddings):
module.cls_token.data.zero_()
module.position_embeddings.data.zero_()
module.distillation_token.data.zero_()
if module.mask_token is not None:
module.mask_token.data.zero_()


DEIT_START_DOCSTRING = r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional
self.offset = 2
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)

def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
Expand Down Expand Up @@ -399,6 +398,11 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Speech2Text2SinusoidalPositionalEmbedding):
weight = module.get_embedding(*module.weight.shape, module.padding_idx)
weight = nn.Parameter(weight, requires_grad=False)
weight.detach_()
module.weight = weight


SPEECH_TO_TEXT_2_START_DOCSTRING = r"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,12 +516,12 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)

module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)
module.mask_token.data.zero_()


VIT_START_DOCSTRING = r"""
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/dinov2/modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
std=self.config.initializer_range,
).to(module.cls_token.dtype)

if self.config.use_mask_token:
module.mask_token.data.zero_()
elif isinstance(module, Dinov2LayerScale):
module.lambda1.data.fill_(self.config.layerscale_value)


DINOV2_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,11 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No
std=self.config.initializer_range,
).to(module.cls_token.dtype)

module.mask_token.data.zero_()
module.register_tokens.data.zero_()
elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
module.lambda1.data.fill_(self.config.layerscale_value)


_EXPECTED_OUTPUT_SHAPE = [1, 257, 768]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Optional, Union

import torch
import torch.utils.checkpoint
Expand Down Expand Up @@ -277,7 +277,36 @@ class Dinov2WithRegistersEncoder(Dinov2Encoder):


class Dinov2WithRegistersPreTrainedModel(Dinov2PreTrainedModel):
pass
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
# `trunc_normal_cpu` not implemented in `half` issues
module.weight.data = nn.init.trunc_normal_(
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
).to(module.weight.dtype)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, Dinov2WithRegistersEmbeddings):
module.position_embeddings.data = nn.init.trunc_normal_(
module.position_embeddings.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.position_embeddings.dtype)

module.cls_token.data = nn.init.trunc_normal_(
module.cls_token.data.to(torch.float32),
mean=0.0,
std=self.config.initializer_range,
).to(module.cls_token.dtype)

module.mask_token.data.zero_()
module.register_tokens.data.zero_()
elif isinstance(module, Dinov2WithRegistersLayerScale): # noqa: F821
module.lambda1.data.fill_(self.config.layerscale_value)


class Dinov2WithRegistersModel(Dinov2Model):
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,13 @@ def _init_weights(self, module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, DonutSwinEmbeddings):
if module.mask_token is not None:
module.mask_token.data.zero_()
if module.position_embeddings is not None:
module.position_embeddings.data.zero_()
elif isinstance(module, DonutSwinSelfAttention):
module.relative_position_bias_table.data.zero_()


SWIN_START_DOCSTRING = r"""
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,9 @@ def _init_weights(self, module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, (DPTViTEmbeddings, DPTViTHybridEmbeddings)):
module.cls_token.data.zero_()
module.position_embeddings.data.zero_()


DPT_START_DOCSTRING = r"""
Expand Down
1 change: 0 additions & 1 deletion src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,6 @@ class ElectraPreTrainedModel(PreTrainedModel):
base_model_prefix = "electra"
supports_gradient_checkpointing = True

# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
Expand Down
Loading