diff --git a/wenet/transformer/decoder.py b/wenet/transformer/decoder.py index 4efd764f37..e4191bd242 100644 --- a/wenet/transformer/decoder.py +++ b/wenet/transformer/decoder.py @@ -76,14 +76,18 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): super().__init__() attention_dim = encoder_output_size @@ -117,10 +121,15 @@ def __init__( value_bias, use_sdpa), WENET_ATTENTION_CLASSES["crossattn"]( attention_heads, attention_dim, src_attention_dropout_rate, - query_bias, key_bias, value_bias, use_sdpa) - if src_attention else None, - mlp_class(attention_dim, linear_units, dropout_rate, - activation, mlp_bias), + query_bias, key_bias, value_bias, use_sdpa, n_kv_head, + head_dim) if src_attention else None, + mlp_class(attention_dim, + linear_units, + dropout_rate, + activation, + mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated), dropout_rate, normalize_before, layer_norm_type, @@ -325,15 +334,22 @@ def __init__( input_layer: str = "embed", use_output_layer: bool = True, normalize_before: bool = True, + src_attention: bool = True, query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, + activation_type: str = "relu", gradient_checkpointing: bool = False, tie_word_embedding: bool = False, use_sdpa: bool = False, layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): super().__init__() @@ -352,15 +368,22 @@ def __init__( input_layer, use_output_layer, normalize_before, + src_attention=src_attention, query_bias=query_bias, key_bias=key_bias, value_bias=value_bias, + activation_type=activation_type, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa, layer_norm_type=layer_norm_type, norm_eps=norm_eps, - ) + n_kv_head=n_kv_head, + head_dim=head_dim, + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) self.right_decoder = TransformerDecoder( vocab_size, @@ -375,16 +398,22 @@ def __init__( input_layer, use_output_layer, normalize_before, + src_attention=src_attention, query_bias=query_bias, key_bias=key_bias, value_bias=value_bias, - mlp_bias=mlp_bias, + activation_type=activation_type, gradient_checkpointing=gradient_checkpointing, tie_word_embedding=tie_word_embedding, use_sdpa=use_sdpa, layer_norm_type=layer_norm_type, norm_eps=norm_eps, - ) + n_kv_head=n_kv_head, + head_dim=head_dim, + mlp_type=mlp_type, + mlp_bias=mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated) def forward( self, diff --git a/wenet/transformer/encoder.py b/wenet/transformer/encoder.py index be319a89a7..d8ca3bbfc2 100644 --- a/wenet/transformer/encoder.py +++ b/wenet/transformer/encoder.py @@ -14,7 +14,7 @@ # limitations under the License. # Modified from ESPnet(https://github.com/espnet/espnet) """Encoder definition.""" -from typing import Tuple +from typing import Tuple, Optional import torch import torch.utils.checkpoint as ckpt @@ -368,13 +368,18 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, activation_type: str = "relu", gradient_checkpointing: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + selfattention_layer_type: str = "selfattn", + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """ Construct TransformerEncoder @@ -392,13 +397,17 @@ def __init__( self.encoders = torch.nn.ModuleList([ TransformerEncoderLayer( output_size, - WENET_ATTENTION_CLASSES["selfattn"](attention_heads, - output_size, - attention_dropout_rate, - query_bias, key_bias, - value_bias, use_sdpa), - mlp_class(output_size, linear_units, dropout_rate, activation, - mlp_bias), + WENET_ATTENTION_CLASSES[selfattention_layer_type]( + attention_heads, output_size, attention_dropout_rate, + query_bias, key_bias, value_bias, use_sdpa, n_kv_head, + head_dim), + mlp_class(output_size, + linear_units, + dropout_rate, + activation, + mlp_bias, + n_expert=n_expert, + n_expert_activated=n_expert_activated), dropout_rate, normalize_before, layer_norm_type=layer_norm_type, @@ -438,13 +447,17 @@ def __init__( query_bias: bool = True, key_bias: bool = True, value_bias: bool = True, - mlp_bias: bool = True, conv_bias: bool = True, gradient_checkpointing: bool = False, use_sdpa: bool = False, - mlp_type: str = 'position_wise_feed_forward', layer_norm_type: str = 'layer_norm', norm_eps: float = 1e-5, + n_kv_head: Optional[int] = None, + head_dim: Optional[int] = None, + mlp_type: str = 'position_wise_feed_forward', + mlp_bias: bool = True, + n_expert: int = 8, + n_expert_activated: int = 2, ): """Construct ConformerEncoder @@ -489,6 +502,8 @@ def __init__( dropout_rate, activation, mlp_bias, + n_expert, + n_expert_activated, ) # convolution module definition convolution_layer_args = (output_size, cnn_module_kernel, activation, diff --git a/wenet/transformer/positionwise_feed_forward.py b/wenet/transformer/positionwise_feed_forward.py index 7d6ab3251e..e4c38e0f99 100644 --- a/wenet/transformer/positionwise_feed_forward.py +++ b/wenet/transformer/positionwise_feed_forward.py @@ -31,12 +31,14 @@ class PositionwiseFeedForward(torch.nn.Module): """ def __init__( - self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - bias: bool = True, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = True, + *dummy_args, + **dummy_kwargs, ): """Construct a PositionwiseFeedForward object.""" super(PositionwiseFeedForward, self).__init__() @@ -66,7 +68,7 @@ class MoEFFNLayer(torch.nn.Module): https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 Args: n_expert: number of expert. - n_expert_per_token: The actual number of experts used for each frame + n_expert_activated: The actual number of experts used for each frame idim (int): Input dimenstion. hidden_units (int): The number of hidden units. dropout_rate (float): Dropout rate. @@ -74,22 +76,23 @@ class MoEFFNLayer(torch.nn.Module): """ def __init__( - self, - n_expert: int, - n_expert_per_token: int, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - bias: bool = False, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.ReLU(), + bias: bool = False, + n_expert: int = 8, + n_expert_activated: int = 2, ): super(MoEFFNLayer, self).__init__() - bias = False - self.gate = torch.nn.Linear(idim, n_expert, bias=bias) + self.gate = torch.nn.Linear(idim, n_expert, bias=False) self.experts = torch.nn.ModuleList( - PositionwiseFeedForward(idim, hidden_units, dropout_rate, - activation) for _ in range(n_expert)) - self.n_expert_per_token = n_expert_per_token + PositionwiseFeedForward( + idim, hidden_units, dropout_rate, activation, bias=bias) + for _ in range(n_expert)) + self.n_expert = n_expert + self.n_expert_activated = n_expert_activated def forward(self, xs: torch.Tensor) -> torch.Tensor: """Foward function. @@ -103,18 +106,18 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor: ) # batch size, sequence length, embedding dimension (idim) xs = xs.view(-1, D) # (B*L, D) router = self.gate(xs) # (B*L, n_expert) - logits, indices = torch.topk( - router, self.n_expert_per_token - ) # probs:(B*L, n_expert), indices: (B*L, n_expert) + logits, selected_experts = torch.topk( + router, self.n_expert_activated + ) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated) weights = torch.nn.functional.softmax( logits, dim=1, - dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) + dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_activated) output = torch.zeros_like(xs) # (B*L, D) for i, expert in enumerate(self.experts): - mask = indices == i - batch_idx, ith_expert = torch.where(mask) - output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( - xs[batch_idx]) + mask = selected_experts == i + token_ids, ith_expert = torch.where(mask) + output[token_ids] += weights[token_ids, ith_expert, None] * expert( + xs[token_ids]) return output.view(B, L, D) @@ -123,12 +126,14 @@ class GatedVariantsMLP(torch.nn.Module): """ def __init__( - self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.GELU(), - bias: bool = True, + self, + idim: int, + hidden_units: int, + dropout_rate: float, + activation: torch.nn.Module = torch.nn.GELU(), + bias: bool = True, + *dummy_args, + **dummy_kwargs, ): """Construct a PositionwiseFeedForward object.""" super(GatedVariantsMLP, self).__init__() @@ -140,7 +145,7 @@ def __init__( # w_2 as down proj self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias) - def forward(self, x): + def forward(self, x) -> torch.Tensor: """Foward function. Args: xs: input tensor (B, L, D)