diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 4b165c6aa..d6962ca5a 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -76,16 +76,18 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): super().__init__() attention_dim = encoder_output_size @@ -121,8 +123,13 @@ def __init__( attention_heads, attention_dim, src_attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim) if src_attention else None, - mlp_class(attention_dim, linear_units, dropout_rate, - activation, mlp_bias), + mlp_class(attention_dim, + linear_units, + dropout_rate, + activation, + mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated), dropout_rate, normalize_before, layer_norm_type, @@ -327,10 +334,11 @@ def __init__( input_layer: str = "embed", use_output_layer: bool = True, normalize_before: bool = True, + src_attention: bool = True, query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, + activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, @@ -338,6 +346,10 @@ def __init__( norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): super().__init__() @@ -356,9 +368,11 @@ def __init__( input_layer, use_output_layer, normalize_before, + src_attention=src_attention, query_bias=query_bias, key_bias=key_bias, value_bias=value_bias, + activation_type=activation_type, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa, @@ -366,7 +380,10 @@ def __init__( norm_eps=norm_eps, n_kv_head=n_kv_head, head_dim=head_dim, - ) + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) self.right_decoder = TransformerDecoder( vocab_size, @@ -381,10 +398,11 @@ def __init__( input_layer, use_output_layer, normalize_before, + src_attention=src_attention, query_bias=query_bias, key_bias=key_bias, value_bias=value_bias, - mlp_bias=mlp_bias, + activation_type=activation_type, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa, @@ -392,7 +410,10 @@ def __init__( norm_eps=norm_eps, n_kv_head=n_kv_head, head_dim=head_dim, - ) + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index 0f5ccef86..6cb4e6abb 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -371,16 +371,18 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, selfattention_layer_type: str = "selfattn", + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """ Construct TransformerEncoder @@ -404,8 +406,13 @@ def __init__( attention_heads, output_size, attention_dropout_rate, query_bias, key_bias, value_bias, use_sdpa, n_kv_head, head_dim), - mlp_class(output_size, linear_units, dropout_rate, activation, - mlp_bias), + mlp_class(output_size, + linear_units, + dropout_rate, + activation, + mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated), dropout_rate, normalize_before, layer_norm_type=layer_norm_type, @@ -445,15 +452,17 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, conv_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, n_kv_head: Optional[int] = None, head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """Construct ConformerEncoder @@ -500,6 +509,8 @@ def __init__( dropout_rate, activation, mlp_bias, + n_expert, + n_expert_activated, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index 7d6ab3251..e4c38e0f9 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -31,12 +31,14 @@ class PositionwiseFeedForward(torch.nn.Module): """ def __init__( - self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - bias: bool = True, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + *dummy_args, + **dummy_kwargs, ): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward, self).__init__() @@ -66,7 +68,7 @@ class MoEFFNLayer(torch.nn.Module): https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 Args: n_expert: number of expert. - n_expert_per_token: The actual number of experts used for each frame + n_expert_activated: The actual number of experts used for each frame idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. @@ -74,22 +76,23 @@ class MoEFFNLayer(torch.nn.Module): """ def __init__( - self, - n_expert: int, - n_expert_per_token: int, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - bias: bool = False, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = False, + n_expert: int = 8, + n_expert_activated: int = 2, ): super(MoEFFNLayer, self).__init__() - bias = False - self.gate = torch.nn.Linear(idim, n_expert, bias=bias) + self.gate = torch.nn.Linear(idim, n_expert, bias=False) self.experts = torch.nn.ModuleList( - PositionwiseFeedForward(idim, hidden_units, dropout_rate, - activation) for _ in range(n_expert)) - self.n_expert_per_token = n_expert_per_token + PositionwiseFeedForward( + idim, hidden_units, dropout_rate, activation, bias=bias) + for _ in range(n_expert)) + self.n_expert = n_expert + self.n_expert_activated = n_expert_activated def forward(self, xs: torch.Tensor) -> torch.Tensor: """Foward function. @@ -103,18 +106,18 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: ) # batch size, sequence length, embedding dimension (idim) xs = xs.view(-1, D) # (B*L, D) router = self.gate(xs) # (B*L, n_expert) - logits, indices = torch.topk( - router, self.n_expert_per_token - ) # probs:(B*L, n_expert), indices: (B*L, n_expert) + logits, selected_experts = torch.topk( + router, self.n_expert_activated + ) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated) weights = torch.nn.functional.softmax( logits, dim=1, - dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) + dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_activated) output = torch.zeros_like(xs) # (B*L, D) for i, expert in enumerate(self.experts): - mask = indices == i - batch_idx, ith_expert = torch.where(mask) - output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( - xs[batch_idx]) + mask = selected_experts == i + token_ids, ith_expert = torch.where(mask) + output[token_ids] += weights[token_ids, ith_expert, None] * expert( + xs[token_ids]) return output.view(B, L, D) @@ -123,12 +126,14 @@ class GatedVariantsMLP(torch.nn.Module): """ def __init__( - self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.GELU(), - bias: bool = True, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.GELU(), + bias: bool = True, + *dummy_args, + **dummy_kwargs, ): """Construct a PositionwiseFeedForward object.""" super(GatedVariantsMLP, self).__init__() @@ -140,7 +145,7 @@ def __init__( # w_2 as down proj self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias) - def forward(self, x): + def forward(self, x) -> torch.Tensor: """Foward function. Args: xs: input tensor (B, L, D)