diff --git a/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py index effcda58141c..fb8cba72bdfc 100644 --- a/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py @@ -101,10 +101,6 @@ class HunYuanMoEV1Config(PretrainedConfig): The number of experts for moe. If it is a list, it will be used as the number of experts for each layer. moe_topk (int or List, *optional*, defaults to 1): Number of experts selected per token (Top-K routing). List form enables layer-wise customization. - moe_drop_tokens (bool, *optional*, defaults to `False`): - Whether to drop tokens exceeding expert capacity instead of padding. - moe_random_routing_dropped_token (bool, *optional*, defaults to `False`): - If True, randomly routes dropped tokens to available experts. head_dim (`int`, *optional*, defaults to 128): The attention head dimension. """ @@ -138,9 +134,6 @@ def __init__( attention_dropout=0.0, num_experts: Union[int, list] = 1, moe_topk: Union[int, list] = 1, - # capacity_factor: Union[int, List]=1.0, - moe_drop_tokens=False, - moe_random_routing_dropped_token=False, head_dim=None, **kwargs, ): @@ -152,9 +145,6 @@ def __init__( self.num_attention_heads = num_attention_heads self.num_experts = num_experts self.moe_topk = moe_topk - # self.capacity_factor = capacity_factor - self.moe_drop_tokens = moe_drop_tokens - self.moe_random_routing_dropped_token = moe_random_routing_dropped_token self.head_dim = head_dim # for backward compatibility diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index 861df2ac529e..8bfdab6d159b 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -23,7 +23,7 @@ import torch import torch.nn.functional as F -from torch import Tensor, nn +from torch import nn from transformers.cache_utils import Cache @@ -228,112 +228,11 @@ def forward( return attn_output, attn_weights -def topkgating(logits: Tensor, topk: int): - if topk == 1: - """Implements Top1Gating on logits.""" - # everything is in fp32 in this function - logits = logits.float() - gates = F.softmax(logits, dim=1) - capacity = gates.shape[0] - - # Create a mask for 1st's expert per token - # noisy gating - indices1_s = torch.argmax(gates, dim=1) - num_experts = int(gates.shape[1]) - mask1 = F.one_hot(indices1_s, num_classes=num_experts) - - # gating decisions - - top_idx = torch.topk(mask1, k=capacity, dim=0)[1] - - new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) - mask1 = new_mask1 - # Compute locations in capacity buffer - locations1 = torch.cumsum(mask1, dim=0) - 1 - - # Store the capacity location for each token - locations1_s = torch.sum(locations1 * mask1, dim=1) - - # Normalize gate probabilities - mask1_float = mask1.float() - gates = gates * mask1_float - - locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float() # one hot to float - combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc) - - dispatch_mask = combine_weights.bool() - return combine_weights, dispatch_mask - - logits = logits.float() - gates = F.softmax(logits, dim=1) - # expert_capacity = topk * gates.shape[0] - expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1]) - num_experts = int(gates.shape[1]) - # Top-k router probability and corresponding expert indices for each token. - # Shape: [tokens_per_group, num_selected_experts]. - expert_gate, expert_index = torch.topk(gates, topk) - expert_mask = F.one_hot(expert_index, num_experts) - # For a given token, determine if it was routed to a given expert. - # Shape: [tokens_per_group, num_experts] - - gates_s = torch.clamp( - torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps - ) - router_probs = gates / gates_s - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = torch.transpose(expert_index, 0, 1) - # Shape: [num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(-1) - - # Create mask out of indices. - # Shape: [tokens_per_group * num_selected_experts, num_experts]. - expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [tokens_per_group * num_selected_experts, num_experts]. - token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 - # Shape: [num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((topk, -1, num_experts)) - # Shape: [tokens_per_group, num_selected_experts, num_experts]. - token_priority = torch.transpose(token_priority, 0, 1) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [tokens_per_group, num_experts]. - token_priority = torch.max(token_priority, dim=1)[0] - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [tokens_per_group, num_experts, expert_capacity]. - valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) - token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) - dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) - valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) - dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - # combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) - router_probs_expanded = router_probs.unsqueeze(-1) - combine_weights = router_probs_expanded * dispatch_mask - return combine_weights, dispatch_mask - - -class HunYuanTopKGate(nn.Module): +class HunYuanMoEV1Gate(nn.Module): def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.moe_topk = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] - self.drop_tokens = config.moe_drop_tokens - self.random_routing_dropped_token = config.moe_random_routing_dropped_token num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32) @@ -343,54 +242,64 @@ def forward(self, hidden_states): if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) - gate_output = topkgating(logits, self.moe_topk) + return logits - return gate_output - -class HunYuanMoE(nn.Module): +class HunYuanMoEV1Moe(nn.Module): def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.moe_topk = config.moe_topk self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] - self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True) - self.gate = HunYuanTopKGate(config, layer_idx=layer_idx) + self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] + self.norm_topk_prob = config.norm_topk_prob + self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx) + # self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32) self.experts = nn.ModuleList( [HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(self.num_experts)] ) - def forward(self, hidden_states): - bsz, seq_len, hidden_size = hidden_states.shape + self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_mlp = self.shared_mlp(hidden_states) + router_logits = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) - combine_weights, dispatch_mask = self.gate(hidden_states) - - reshaped_input = hidden_states.reshape(-1, hidden_size) - - # dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) - dispatch_mask_expanded = dispatch_mask.type_as(hidden_states).unsqueeze(3) # (s, e, c, 1) - reshaped_input_expanded = reshaped_input.unsqueeze(1).unsqueeze(1) # (s, 1, 1, m) - dispatched_input = (dispatch_mask_expanded * reshaped_input_expanded).sum(dim=(0)) # (s, m) - - chunks = dispatched_input.chunk(self.num_experts, dim=0) - expert_outputs = [] - for chunk, expert in zip(chunks, self.experts): - expert_outputs.append(expert(chunk)) - - expert_output = torch.cat(expert_outputs, dim=0) - # combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) - combine_exp = combine_weights.type_as(hidden_states).unsqueeze(3) # (s, e, c, 1) - expert_exp = expert_output.unsqueeze(0) # (1, e, c, m) - combined_output = (combine_exp * expert_exp).sum(dim=(1, 2)) # (s, m) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - combined_output = combined_output.reshape(bsz, seq_len, hidden_size) + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - output = hidden_states_mlp + combined_output + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - return output + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + hidden_states_mlp class HunYuanMoEV1DecoderLayer(GradientCheckpointingLayer): @@ -398,7 +307,7 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = HunYuanMoEV1Attention(config=config, layer_idx=layer_idx) - self.mlp = HunYuanMoE(config, layer_idx=layer_idx) + self.mlp = HunYuanMoEV1Moe(config, layer_idx=layer_idx) self.input_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx diff --git a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py index cfcf22b15441..a8b7f92f9941 100644 --- a/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py @@ -19,7 +19,7 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint -from torch import Tensor, nn +from torch import nn from transformers.cache_utils import Cache from transformers.utils import ( @@ -113,112 +113,11 @@ def forward( return attn_output, attn_weights -def topkgating(logits: Tensor, topk: int): - if topk == 1: - """Implements Top1Gating on logits.""" - # everything is in fp32 in this function - logits = logits.float() - gates = F.softmax(logits, dim=1) - capacity = gates.shape[0] - - # Create a mask for 1st's expert per token - # noisy gating - indices1_s = torch.argmax(gates, dim=1) - num_experts = int(gates.shape[1]) - mask1 = F.one_hot(indices1_s, num_classes=num_experts) - - # gating decisions - - top_idx = torch.topk(mask1, k=capacity, dim=0)[1] - - new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) - mask1 = new_mask1 - # Compute locations in capacity buffer - locations1 = torch.cumsum(mask1, dim=0) - 1 - - # Store the capacity location for each token - locations1_s = torch.sum(locations1 * mask1, dim=1) - - # Normalize gate probabilities - mask1_float = mask1.float() - gates = gates * mask1_float - - locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float() # one hot to float - combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc) - - dispatch_mask = combine_weights.bool() - return combine_weights, dispatch_mask - - logits = logits.float() - gates = F.softmax(logits, dim=1) - # expert_capacity = topk * gates.shape[0] - expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1]) - num_experts = int(gates.shape[1]) - # Top-k router probability and corresponding expert indices for each token. - # Shape: [tokens_per_group, num_selected_experts]. - expert_gate, expert_index = torch.topk(gates, topk) - expert_mask = F.one_hot(expert_index, num_experts) - # For a given token, determine if it was routed to a given expert. - # Shape: [tokens_per_group, num_experts] - - gates_s = torch.clamp( - torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps - ) - router_probs = gates / gates_s - # Make num_selected_experts the leading axis to ensure that top-1 choices - # have priority over top-2 choices, which have priority over top-3 choices, - # etc. - expert_index = torch.transpose(expert_index, 0, 1) - # Shape: [num_selected_experts * tokens_per_group] - expert_index = expert_index.reshape(-1) - - # Create mask out of indices. - # Shape: [tokens_per_group * num_selected_experts, num_experts]. - expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32) - - # Experts have a fixed capacity that we cannot exceed. A token's priority - # within the expert's buffer is given by the masked, cumulative capacity of - # its target expert. - # Shape: [tokens_per_group * num_selected_experts, num_experts]. - token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1 - # Shape: [num_selected_experts, tokens_per_group, num_experts]. - token_priority = token_priority.reshape((topk, -1, num_experts)) - # Shape: [tokens_per_group, num_selected_experts, num_experts]. - token_priority = torch.transpose(token_priority, 0, 1) - # For each token, across all selected experts, select the only non-negative - # (unmasked) priority. Now, for group G routing to expert E, token T has - # non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E - # is its targeted expert. - # Shape: [tokens_per_group, num_experts]. - token_priority = torch.max(token_priority, dim=1)[0] - - # Token T can only be routed to expert E if its priority is positive and - # less than the expert capacity. One-hot matrix will ignore indices outside - # the range [0, expert_capacity). - # Shape: [tokens_per_group, num_experts, expert_capacity]. - valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity) - token_priority = torch.masked_fill(token_priority, ~valid_mask, 0) - dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool) - valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity) - dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0) - - # The combine array will be used for combining expert outputs, scaled by the - # router probabilities. Shape: [num_groups, tokens_per_group, num_experts, - # expert_capacity]. - # combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask) - router_probs_expanded = router_probs.unsqueeze(-1) - combine_weights = router_probs_expanded * dispatch_mask - return combine_weights, dispatch_mask - - -class HunYuanTopKGate(nn.Module): +class HunYuanMoEV1Gate(nn.Module): def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.moe_topk = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] - self.drop_tokens = config.moe_drop_tokens - self.random_routing_dropped_token = config.moe_random_routing_dropped_token num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32) @@ -228,54 +127,64 @@ def forward(self, hidden_states): if self.wg.weight.dtype == torch.float32: hidden_states = hidden_states.float() logits = self.wg(hidden_states) - gate_output = topkgating(logits, self.moe_topk) + return logits - return gate_output - -class HunYuanMoE(nn.Module): +class HunYuanMoEV1Moe(nn.Module): def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx - self.moe_topk = config.moe_topk self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx] - self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True) - self.gate = HunYuanTopKGate(config, layer_idx=layer_idx) + self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx] + self.norm_topk_prob = config.norm_topk_prob + self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx) + # self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32) self.experts = nn.ModuleList( [HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(self.num_experts)] ) - def forward(self, hidden_states): - bsz, seq_len, hidden_size = hidden_states.shape + self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_mlp = self.shared_mlp(hidden_states) + router_logits = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) - combine_weights, dispatch_mask = self.gate(hidden_states) - - reshaped_input = hidden_states.reshape(-1, hidden_size) - - # dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input) - dispatch_mask_expanded = dispatch_mask.type_as(hidden_states).unsqueeze(3) # (s, e, c, 1) - reshaped_input_expanded = reshaped_input.unsqueeze(1).unsqueeze(1) # (s, 1, 1, m) - dispatched_input = (dispatch_mask_expanded * reshaped_input_expanded).sum(dim=(0)) # (s, m) - - chunks = dispatched_input.chunk(self.num_experts, dim=0) - expert_outputs = [] - for chunk, expert in zip(chunks, self.experts): - expert_outputs.append(expert(chunk)) - - expert_output = torch.cat(expert_outputs, dim=0) - # combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output) - combine_exp = combine_weights.type_as(hidden_states).unsqueeze(3) # (s, e, c, 1) - expert_exp = expert_output.unsqueeze(0) # (1, e, c, m) - combined_output = (combine_exp * expert_exp).sum(dim=(1, 2)) # (s, m) + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - combined_output = combined_output.reshape(bsz, seq_len, hidden_size) + # Loop over all available experts in the model and perform the computation on each expert + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - output = hidden_states_mlp + combined_output + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - return output + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + hidden_states_mlp class HunYuanMoEV1DecoderLayer(LlamaDecoderLayer): @@ -283,7 +192,7 @@ def __init__(self, config: HunYuanMoEV1Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = HunYuanMoEV1Attention(config=config, layer_idx=layer_idx) - self.mlp = HunYuanMoE(config, layer_idx=layer_idx) + self.mlp = HunYuanMoEV1Moe(config, layer_idx=layer_idx) self.input_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx