Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
self.layer_idx = layer_idx

self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]

# Dual normalization for attention
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -595,10 +594,10 @@ def forward(

position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/afmoe/modular_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
self.layer_idx = layer_idx

self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]

# Dual normalization for attention
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -376,10 +375,10 @@ def forward(

position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,8 +1164,8 @@ def forward(
mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

for decoder_layer in self.layers:
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
for i, decoder_layer in enumerate(self.layers):
layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask

hidden_states, attn_weights = decoder_layer(
hidden_states,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,8 @@ def forward(
mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

for decoder_layer in self.layers:
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
for i, decoder_layer in enumerate(self.layers):
layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask

hidden_states, attn_weights = decoder_layer(
hidden_states,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
self.mlp = Cohere2MLP(config)
self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -413,10 +412,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def forward(
class Cohere2DecoderLayer(CohereDecoderLayer):
def __init__(self, config: Cohere2Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -301,10 +300,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cwm/modeling_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def __init__(self, config: CwmConfig, layer_idx: int):
self.mlp = CwmMLP(config)
self.input_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -407,10 +406,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
position_embeddings=position_embeddings,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cwm/modular_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(self, config: CwmConfig, layer_idx: int):
class CwmDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: CwmConfig, layer_idx: int):
super().__init__(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]
self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)


Expand Down Expand Up @@ -168,10 +167,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
position_embeddings=position_embeddings,
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/dots1/configuration_dots1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_dots1.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -11,8 +17,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from huggingface_hub.dataclasses import strict

from ...configuration_utils import PreTrainedConfig
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ def __init__(self, config: Dots1Config, layer_idx: int):

self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -549,10 +548,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
107 changes: 102 additions & 5 deletions src/transformers/models/dots1/modular_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from huggingface_hub.dataclasses import strict

from ...configuration_utils import PreTrainedConfig
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_rope_utils import RopeParameters
from ...processing_utils import Unpack
from ...utils import logging
from ...utils import auto_docstring, logging
from ..deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3DecoderLayer,
DeepseekV3MLP,
Expand All @@ -31,12 +34,107 @@
Qwen3RotaryEmbedding,
TransformersKwargs,
)
from .configuration_dots1 import Dots1Config


logger = logging.get_logger(__name__)


@auto_docstring(checkpoint="rednote-hilab/dots.llm1.base")
@strict(accept_kwargs=True)
class Dots1Config(PreTrainedConfig):
r"""
n_group (`int`, *optional*, defaults to 1):
Number of groups for routed experts.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers at the beginning of the model before the first MoE layer.

Examples:

```python
>>> from transformers import Dots1Model, Dots1Config
>>> # Initializing a Dots1 style configuration
>>> configuration = Dots1Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""

model_type = "dots1"
keys_to_ignore_at_inference = ["past_key_values"]

base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
attribute_map = {
"num_local_experts": "n_routed_experts",
}

vocab_size: int = 152064
hidden_size: int = 4608
intermediate_size: int = 10944
moe_intermediate_size: int = 1408
num_hidden_layers: int = 62
num_attention_heads: int = 32
num_key_value_heads: int | None = 32
n_shared_experts: int | None = None
n_routed_experts: int | None = None
n_group: int | None = 1
topk_group: int | None = 1
num_experts_per_tok: int | None = None
first_k_dense_replace: int | None = 0
norm_topk_prob: bool | None = False
hidden_act: str = "silu"
max_position_embeddings: int = 2048
initializer_range: float = 0.02
rms_norm_eps: float = 1e-6
use_cache: bool = True
tie_word_embeddings: bool = False
rope_parameters: RopeParameters | dict | None = None
attention_bias: bool = False
attention_dropout: float | int | None = 0.0
routed_scaling_factor: float = 1.0
sliding_window: int | None = 4096
max_window_layers: int | None = 62
layer_types: list[str] | None = None
pad_token_id: int | None = None
bos_token_id: int | None = None
eos_token_id: int | list[int] | None = None

def __post_init__(self, **kwargs):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]

super().__post_init__(**kwargs)


class Dots1RMSNorm(Qwen3RMSNorm):
pass

Expand Down Expand Up @@ -85,9 +183,7 @@ def route_tokens_to_experts(self, router_logits):


class Dots1DecoderLayer(DeepseekV3DecoderLayer):
def __init__(self, config: Dots1Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]
pass


class Dots1PreTrainedModel(DeepseekV3PreTrainedModel):
Expand Down Expand Up @@ -129,6 +225,7 @@ def forward(


__all__ = [
"Dots1Config",
"Dots1PreTrainedModel",
"Dots1Model",
"Dots1ForCausalLM",
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -456,10 +455,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -370,10 +369,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,6 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
self.config = config
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma3Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma3MLP(config)
self.input_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -571,11 +570,11 @@ def forward(
for layer_type in self.config.layer_types:
position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_embeddings=position_embeddings[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
**kwargs,
Expand Down
Loading
Loading