Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 68 additions & 55 deletions src/transformers/models/glm_moe_dsa/configuration_glm_moe_dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,90 +35,91 @@ class GlmMoeDsaConfig(PreTrainedConfig):

Args:
vocab_size (`int`, *optional*, defaults to 154880):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Glm4MoeLiteModel`]
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GlmMoeDsaModel`].
hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 12288):
Dimension of the MLP representations.
Dimension of the dense MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 2048):
Dimension of the MoE representations.
Dimension of the MoE expert representations.
num_hidden_layers (`int`, *optional*, defaults to 78):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*, defaults to 64):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
Number of key-value heads for Grouped Query Attention. If equal to `num_attention_heads`, uses MHA.
n_shared_experts (`int`, *optional*, defaults to 1):
Number of shared experts.
Number of shared experts in MoE layers.
n_routed_experts (`int`, *optional*, defaults to 256):
Number of routed experts.
Number of routed experts in MoE layers.
routed_scaling_factor (`float`, *optional*, defaults to 2.5):
Scaling factor or routed experts.
Scaling factor for routed experts.
kv_lora_rank (`int`, *optional*, defaults to 512):
Rank of the LoRA matrices for key and value projections.
Rank of the LoRA matrices for key and value projections (MLA).
q_lora_rank (`int`, *optional*, defaults to 2048):
Rank of the LoRA matrices for query projections.
Rank of the LoRA matrices for query projections (MLA).
qk_rope_head_dim (`int`, *optional*, defaults to 64):
Dimension of the query/key heads that use rotary position embeddings.
v_head_dim (`int`, *optional*, defaults to 256):
Dimension of the value heads.
qk_nope_head_dim (`int`, *optional*, defaults to 192):
Dimension of the query/key heads that don't use rotary position embeddings.
v_head_dim (`int`, *optional*, defaults to 256):
Dimension of the value heads.
n_group (`int`, *optional*, defaults to 1):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to 1):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
Number of selected groups for each token.
num_experts_per_tok (`int`, *optional*, defaults to 8):
Number of selected experts, None means dense model.
Number of experts selected per token.
norm_topk_prob (`bool`, *optional*, defaults to `True`):
Whether to normalize the weights of the routed experts.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
The non-linear activation function in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 202752):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
Whether or not the model should return the last key/values attentions.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 0):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 1):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
Whether to tie weight embeddings.
rope_parameters (`RopeParameters`, *optional*):
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
with longer `max_position_embeddings`.
Configuration parameters for the RoPE embeddings, including `rope_theta` and optional scaling parameters.
rope_interleave (`bool`, *optional*, defaults to `True`):
Whether to interleave the rotary position embeddings.
mlp_layer_types (`list`, *optional*):
MLP (Moe vs Dense) pattern for each layer.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
MLP type pattern for each layer (`"dense"` or `"sparse"`). Defaults to 3 dense + rest sparse.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
index_topk (`int`, *optional*, defaults to 2048):
Number of top tokens selected by the indexer for retrieval/attention in each step.
Number of top tokens selected by the indexer for sparse attention.
index_head_dim (`int`, *optional*, defaults to 128):
Head dimension for the indexer projections (DSA).
index_n_heads (`int | None`, *optional*, defaults to 32):
Number of heads for the indexer projections (DSA).
indexer_rope_interleave (`bool`, *optional*, defaults to `True`):
Whether the indexer uses interleaved rotary position embeddings.


```python
>>> from transformers import Glm4MoeLiteModel, Glm4MoeLiteConfig
>>> from transformers import GlmMoeDsaConfig, GlmMoeDsaModel

>>> # Initializing a GLM-MOE-DSA style configuration
>>> # Initializing a GLM-MoE-DSA configuration
>>> configuration = GlmMoeDsaConfig()

>>> # Initializing a model from the configuration
>>> model = GlmMoeDsaModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""
Expand Down Expand Up @@ -158,73 +159,85 @@ def __init__(
kv_lora_rank: int | None = 512,
q_lora_rank: int | None = 2048,
qk_rope_head_dim: int | None = 64,
v_head_dim: int | None = 256,
qk_nope_head_dim: int | None = 192,
v_head_dim: int | None = 256,
n_group: int | None = 1,
topk_group: int | None = 1,
num_experts_per_tok: int | None = 8,
norm_topk_prob: bool | None = True,
hidden_act: str | None = "silu",
max_position_embeddings: int | None = 202752,
initializer_range: float | None = 0.02,
rms_norm_eps: int | None = 1e-5,
rms_norm_eps: float | None = 1e-5,
use_cache: bool | None = True,
pad_token_id: int | None = None,
bos_token_id: int | None = 0,
eos_token_id: int | None = 1,
tie_word_embeddings: bool | None = False,
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = None,
rope_interleave: bool | None = True,
mlp_layer_types=None,
attention_bias: bool | None = False,
attention_dropout: float | None = 0.0,
index_topk: int | None = 2048,
index_head_dim: int | None = 128,
index_n_heads: int | None = 32,
**kwargs,
):
# Model dimensions
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.max_position_embeddings = max_position_embeddings

# Default to MoE from the fourth layer and on
if mlp_layer_types is None:
mlp_layer_types = ["dense"] * min(3, self.num_hidden_layers) + ["sparse"] * (self.num_hidden_layers - 3)
layer_type_validation(mlp_layer_types, self.num_hidden_layers, attention=False)
self.mlp_layer_types = mlp_layer_types

self.moe_intermediate_size = moe_intermediate_size
# Attention dimensions (MLA)
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.routed_scaling_factor = routed_scaling_factor
self.num_key_value_heads = num_key_value_heads
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.head_dim = qk_rope_head_dim

# MoE parameters
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.routed_scaling_factor = routed_scaling_factor
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.norm_topk_prob = norm_topk_prob
self.rope_interleave = rope_interleave
self.num_key_value_heads = num_key_value_heads

# MLP layer types: first 3 dense, rest sparse
self.mlp_layer_types = mlp_layer_types
if self.mlp_layer_types is None:
self.mlp_layer_types = ["dense"] * min(3, num_hidden_layers) + ["sparse"] * (num_hidden_layers - 3)
layer_type_validation(self.mlp_layer_types, self.num_hidden_layers, attention=False)

# Indexer (DSA) parameters
self.index_topk = index_topk
self.index_head_dim = index_head_dim
self.index_n_heads = index_n_heads

# General config
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.index_topk = index_topk
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.rope_parameters = rope_parameters
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings

super().__init__(**kwargs)
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


__all__ = ["GlmMoeDsaConfig"]
Loading