Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
self.layer_idx = layer_idx

self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]

# Dual normalization for attention
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -595,10 +594,10 @@ def forward(

position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/afmoe/modular_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def __init__(self, config: AfmoeConfig, layer_idx: int):
self.layer_idx = layer_idx

self.self_attn = AfmoeAttention(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]

# Dual normalization for attention
self.input_layernorm = AfmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -376,10 +375,10 @@ def forward(

position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_value=past_key_values,
use_cache=use_cache,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,8 +1164,8 @@ def forward(
mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

for decoder_layer in self.layers:
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
for i, decoder_layer in enumerate(self.layers):
layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask

hidden_states, attn_weights = decoder_layer(
hidden_states,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,8 @@ def forward(
mamba_mask = self._update_mamba_mask(attention_mask, past_key_values)
position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

for decoder_layer in self.layers:
layer_mask = mamba_mask if decoder_layer.layer_type == "mamba" else causal_mask
for i, decoder_layer in enumerate(self.layers):
layer_mask = mamba_mask if self.config.layers_block_type[i] == "mamba" else causal_mask

hidden_states, attn_weights = decoder_layer(
hidden_states,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def __init__(self, config: Cohere2Config, layer_idx: int):
self.self_attn = Cohere2Attention(config=config, layer_idx=layer_idx)
self.mlp = Cohere2MLP(config)
self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -413,10 +412,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ def forward(
class Cohere2DecoderLayer(CohereDecoderLayer):
def __init__(self, config: Cohere2Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -301,10 +300,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers:
for i, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
past_key_values=past_key_values,
use_cache=use_cache,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cwm/modeling_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def __init__(self, config: CwmConfig, layer_idx: int):
self.mlp = CwmMLP(config)
self.input_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = CwmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -407,10 +406,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
position_embeddings=position_embeddings,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/cwm/modular_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def __init__(self, config: CwmConfig, layer_idx: int):
class CwmDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: CwmConfig, layer_idx: int):
super().__init__(config=config, layer_idx=layer_idx)
self.attention_type = config.layer_types[layer_idx]
self.self_attn = CwmAttention(config=config, layer_idx=layer_idx)


Expand Down Expand Up @@ -168,10 +167,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_ids=position_ids,
past_key_values=past_key_values,
position_embeddings=position_embeddings,
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/dots1/configuration_dots1.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# This file was automatically generated from src/transformers/models/dots1/modular_dots1.py.
# Do NOT edit this file manually as any edits will be overwritten by the generation of
# the file from the modular. If any change should be done, please apply the change to the
# modular_dots1.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# Copyright 2025 The rednote-hilab team and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ def __init__(self, config: Dots1Config, layer_idx: int):

self.input_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]

def forward(
self,
Expand Down Expand Up @@ -549,10 +548,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
139 changes: 134 additions & 5 deletions src/transformers/models/dots1/modular_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
import torch

from ...configuration_utils import PreTrainedConfig, layer_type_validation
from ...modeling_outputs import CausalLMOutputWithPast
from ...modeling_rope_utils import RopeParameters
from ...processing_utils import Unpack
from ...utils import logging
from ...utils import auto_docstring, logging
from ..deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3DecoderLayer,
DeepseekV3MLP,
Expand All @@ -31,12 +33,140 @@
Qwen3RotaryEmbedding,
TransformersKwargs,
)
from .configuration_dots1 import Dots1Config


logger = logging.get_logger(__name__)


@auto_docstring(checkpoint="rednote-hilab/dots.llm1.base")
class Dots1Config(PreTrainedConfig):
r"""
n_group (`int`, *optional*, defaults to 1):
Number of groups for routed experts.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers at the beginning of the model before the first MoE layer.

Examples:

```python
>>> from transformers import Dots1Model, Dots1Config
>>> # Initializing a Dots1 style configuration
>>> configuration = Dots1Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```
"""

model_type = "dots1"
keys_to_ignore_at_inference = ["past_key_values"]

base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.self_attn.q_norm": "replicated_with_grad_allreduce",
"layers.*.self_attn.k_norm": "replicated_with_grad_allreduce",
"layers.*.mlp.experts.gate_up_proj": "packed_colwise",
"layers.*.mlp.experts.down_proj": "rowwise",
"layers.*.mlp.experts": "moe_tp_experts",
"layers.*.mlp.shared_experts.gate_proj": "colwise",
"layers.*.mlp.shared_experts.up_proj": "colwise",
"layers.*.mlp.shared_experts.down_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}

base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
attribute_map = {
"num_local_experts": "n_routed_experts",
}

def __init__(
self,
vocab_size: int | None = 152064,
hidden_size: int | None = 4608,
intermediate_size: int | None = 10944,
moe_intermediate_size: int | None = 1408,
num_hidden_layers: int | None = 62,
num_attention_heads: int | None = 32,
num_key_value_heads: int | None = 32,
n_shared_experts: int | None = None,
n_routed_experts: int | None = None,
n_group: int | None = 1,
topk_group: int | None = 1,
num_experts_per_tok: int | None = None,
first_k_dense_replace: int | None = 0,
norm_topk_prob: bool | None = False,
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 2048,
initializer_range: float | None = 0.02,
rms_norm_eps: int | None = 1e-6,
use_cache: bool | None = True,
tie_word_embeddings: bool | None = False,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
attention_bias: bool | None = False,
attention_dropout: float | None = 0.0,
routed_scaling_factor: float | None = 1.0,
sliding_window: int | None = 4096,
max_window_layers: int | None = 62,
layer_types: list[str] | None = None,
pad_token_id: int | None = None,
bos_token_id: int | None = None,
eos_token_id: int | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.num_experts_per_tok = num_experts_per_tok
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.n_group = n_group
self.topk_group = topk_group
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.routed_scaling_factor = routed_scaling_factor
self.sliding_window = sliding_window
self.max_window_layers = max_window_layers

self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention"
if self.sliding_window is not None and i >= self.max_window_layers
else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)

self.tie_word_embeddings = tie_word_embeddings
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.rope_parameters = rope_parameters

super().__init__(**kwargs)


class Dots1RMSNorm(Qwen3RMSNorm):
pass

Expand Down Expand Up @@ -85,9 +215,7 @@ def route_tokens_to_experts(self, router_logits):


class Dots1DecoderLayer(DeepseekV3DecoderLayer):
def __init__(self, config: Dots1Config, layer_idx: int):
super().__init__(config, layer_idx)
self.attention_type = config.layer_types[layer_idx]
pass


class Dots1PreTrainedModel(DeepseekV3PreTrainedModel):
Expand Down Expand Up @@ -129,6 +257,7 @@ def forward(


__all__ = [
"Dots1Config",
"Dots1PreTrainedModel",
"Dots1Model",
"Dots1ForCausalLM",
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -456,10 +455,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.config = config
self.attention_type = config.layer_types[layer_idx]
self.self_attn = Gemma2Attention(config=config, layer_idx=layer_idx)
self.mlp = Gemma2MLP(config)
self.input_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Expand Down Expand Up @@ -370,10 +369,10 @@ def forward(
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for i, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[self.config.layer_types[i]],
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
Expand Down
Loading
Loading