Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
47c8e54
Update test_modeling_common.py
Cyrilvallez Mar 28, 2025
abe0488
Fix Llama and its modular children
Cyrilvallez Mar 28, 2025
4749c0d
Update test_modeling_common.py
Cyrilvallez Mar 28, 2025
49d7625
qwen3
Cyrilvallez Mar 31, 2025
80e52d7
first try at prioritizing models
Cyrilvallez Mar 31, 2025
155969d
Update test_modeling_common.py
Cyrilvallez Mar 31, 2025
9a15450
Update test_modeling_common.py
Cyrilvallez Mar 31, 2025
62725b4
Update test_modeling_common.py
Cyrilvallez Mar 31, 2025
b5eb6cd
test
Cyrilvallez Apr 1, 2025
d0a016a
fix
Cyrilvallez Apr 1, 2025
dceab88
fix
Cyrilvallez Apr 1, 2025
3251852
more models
Cyrilvallez Apr 1, 2025
15bcf97
more
Cyrilvallez Apr 1, 2025
997ba7e
more
Cyrilvallez Apr 1, 2025
0551911
more
Cyrilvallez Apr 1, 2025
eb4bfd3
smarter init for composite models!
Cyrilvallez Apr 1, 2025
2082647
fix post rebase
Cyrilvallez Apr 2, 2025
9a83073
smol
Cyrilvallez Apr 2, 2025
4fc6898
fix missing args
Cyrilvallez Apr 2, 2025
fa736c9
more
Cyrilvallez Apr 2, 2025
7001d13
typo
Cyrilvallez Apr 2, 2025
b083965
Super elegant and efficient init for submodels
Cyrilvallez Apr 2, 2025
7adb98c
Update modeling_utils.py
Cyrilvallez Apr 2, 2025
bf9b49f
style
Cyrilvallez Apr 2, 2025
e4141c0
last fixes
Cyrilvallez Apr 2, 2025
a04f7d5
cleanup
Cyrilvallez Apr 2, 2025
8ed50ab
finalize cleanup
Cyrilvallez Apr 2, 2025
a9c303e
CIs
Cyrilvallez Apr 2, 2025
23eb8c1
improve docstring
Cyrilvallez Apr 2, 2025
1aa8914
Update modeling_utils.py
Cyrilvallez Apr 2, 2025
28f8657
llama4
Cyrilvallez Apr 8, 2025
ce281b8
style
Cyrilvallez Apr 8, 2025
c135488
CIs
Cyrilvallez Apr 9, 2025
f41f9cc
style
Cyrilvallez Apr 14, 2025
6f6364c
add dpt
Cyrilvallez Apr 14, 2025
ce665b8
granite speech
Cyrilvallez Apr 14, 2025
e3ccb5f
qwen 2.5 omni
Cyrilvallez Apr 14, 2025
1034c8c
better fix
Cyrilvallez Apr 14, 2025
25d47e4
Parse the config file instead
Cyrilvallez Apr 14, 2025
4400c52
CIs
Cyrilvallez Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,37 @@ def _initialize_weights(self, module):
self._init_weights(module)
module._is_hf_initialized = True

@torch.no_grad()
def initialize_weights(self):
"""
This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
is extremely error prone and inefficient.

Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
`torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
`module.weight.data.zero_()`.
"""
if not hasattr(torch.nn.Module, "smart_apply"):
# This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
# to apply as we go down the graph
def smart_apply(self, fn):
for module in self.children():
# We found a sub-model: recursively dispatch its own init function now!
if hasattr(module, "_init_weights"):
module.smart_apply(module._initialize_weights)
else:
module.smart_apply(fn)
fn(self)
return self

torch.nn.Module.smart_apply = smart_apply

# Let the magic happen with this simple call
self.smart_apply(self._initialize_weights)
Comment on lines +2452 to +2481
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the most important change to review @ArthurZucker. It's the most efficient and elegant way to handle it, as we only need to traverse modules once. However, it requires to hot-patch torch.nn.Module, which is a bummer but fine IMO.
Other options to avoid doing so all require to traverse the modules several times (at least 2 times) which is less efficient.


def tie_weights(self):
"""
Tie the weights between the input embeddings and the output embeddings.
Expand Down Expand Up @@ -3074,7 +3105,7 @@ def init_weights(self):

if _init_weights:
# Initialize weights
self.apply(self._initialize_weights)
self.initialize_weights()

# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
Expand Down Expand Up @@ -5286,9 +5317,9 @@ def _initialize_missing_keys(
)
)
with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
self.apply(self._initialize_weights)
self.initialize_weights()
else:
self.apply(self._initialize_weights)
self.initialize_weights()

def get_parameter_or_buffer(self, target: str):
"""
Expand Down
17 changes: 9 additions & 8 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,12 +679,10 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaTextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()


ARIA_TEXT_START_DOCSTRING = r"""
Expand Down Expand Up @@ -724,14 +722,17 @@ class AriaPreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
std = self.config.initializer_range

if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, AriaProjector):
nn.init.trunc_normal_(module.query, std=std)

Expand Down
17 changes: 9 additions & 8 deletions src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,12 +1255,10 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, AriaTextRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, AriaGroupedExpertsGemm):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Conv2d):
module.weight.data.normal_(mean=0.0, std=std)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()


class AriaPreTrainedModel(LlamaPreTrainedModel):
Expand All @@ -1269,14 +1267,17 @@ class AriaPreTrainedModel(LlamaPreTrainedModel):

def _init_weights(self, module):
std = self.config.initializer_range

if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
elif isinstance(module, AriaProjector):
nn.init.trunc_normal_(module.query, std=std)

Expand Down
15 changes: 4 additions & 11 deletions src/transformers/models/aya_vision/modeling_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,19 @@ class AyaVisionPreTrainedModel(PreTrainedModel):
_supports_static_cache = False

def _init_weights(self, module):
# important: this ported version of AyaVision isn't meant for training from scratch - only
# inference and fine-tuning - so the proper init weights code has been removed - the original codebase
# https://github.com/haotian-liu/AyaVision/tree/main/aya_vision should serve for that purpose
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)

if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)

if isinstance(module, (nn.Linear, nn.Conv2d)):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()


@dataclass
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/models/aya_vision/modular_aya_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,21 @@ class AyaVisionPreTrainedModel(LlavaPreTrainedModel):
_supports_quantized_cache = False
_supports_static_cache = False

def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else self.config.text_config.initializer_range
)

if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.weight.data.fill_(1.0)
module.bias.data.zero_()


class AyaVisionCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
pass
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,10 +1052,16 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, BambaMixer):
module.dt_bias.data.fill_(1.0)
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
module.D.data.fill_(1.0)


BAMBA_INPUTS_DOCSTRING = r"""
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,10 +820,16 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (BambaRMSNormGated, BambaRMSNorm)):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, BambaMixer):
module.dt_bias.data.fill_(1.0)
module.A_log.data = torch.log(torch.arange(1, module.num_heads + 1))
module.D.data.fill_(1.0)


BAMBA_INPUTS_DOCSTRING = r"""
Expand Down
30 changes: 19 additions & 11 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,22 +423,30 @@ class Blip2PreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):

if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()

if isinstance(module, Blip2VisionEmbeddings):
if hasattr(self.config, "vision_config") and not isinstance(self.config, Blip2VisionConfig):
factor = self.config.vision_config.initializer_range
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)

elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, Blip2VisionEmbeddings):
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(
module,
(
Blip2Model,
Blip2TextModelWithProjection,
Blip2VisionModelWithProjection,
Blip2ForConditionalGeneration,
Blip2ForImageTextRetrieval,
),
):
module.query_tokens.data.zero_()


BLIP_2_START_DOCSTRING = r"""
Expand Down
22 changes: 7 additions & 15 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,12 +1056,16 @@ class ChameleonPreTrainedModel(PreTrainedModel):

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, ChameleonVQVAE):
module.apply(module._init_weights)
elif isinstance(module, (nn.Linear, nn.Conv2d)):

if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, (nn.GroupNorm, nn.LayerNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, ChameleonRMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
Expand Down Expand Up @@ -1096,18 +1100,6 @@ class ChameleonVQVAE(ChameleonPreTrainedModel):
config_class = ChameleonVQVAEConfig
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]

def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.GroupNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()

def __init__(self, config: ChameleonVQVAEConfig):
super().__init__(config)

Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, CohereLayerNorm):
module.weight.data.fill_(1.0)


COHERE_INPUTS_DOCSTRING = r"""
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/models/cohere/modular_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
LlamaPreTrainedModel,
LlamaRotaryEmbedding,
eager_attention_forward,
)
Expand Down Expand Up @@ -277,6 +278,21 @@ def forward(
return outputs


class CoherePreTrainedModel(LlamaPreTrainedModel):
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, CohereLayerNorm):
module.weight.data.fill_(1.0)


class CohereModel(LlamaModel):
def __init__(self, config: CohereConfig):
super().__init__(config)
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Cohere2LayerNorm):
module.weight.data.fill_(1.0)


COHERE2_INPUTS_DOCSTRING = r"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,10 +557,10 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DeepseekV3RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, DeepseekV3TopkRouter):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Parameter):
module.weight.data.normal_(mean=0.0, std=std)


DEEPSEEK_V3_INPUTS_DOCSTRING = r"""
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deepseek_v3/modular_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,10 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DeepseekV3RMSNorm):
module.weight.data.fill_(1.0)
elif isinstance(module, DeepseekV3TopkRouter):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Parameter):
module.weight.data.normal_(mean=0.0, std=std)


class DeepseekV3Model(LlamaModel):
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,13 @@ def _init_weights(self, module):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, DiffLlamaRMSNorm): # noqa: F821
module.weight.data.fill_(1.0)
elif isinstance(module, DiffLlamaAttention):
module.lambda_q1.data.normal_(0, self.config.lambda_std_dev)
module.lambda_k1.data.normal_(0, self.config.lambda_std_dev)
module.lambda_q2.data.normal_(0, self.config.lambda_std_dev)
module.lambda_k2.data.normal_(0, self.config.lambda_std_dev)


class DiffLlamaRotaryEmbedding(nn.Module):
Expand Down
Loading