Skip to content

Commit

Permalink
[transformer] Make MoE runnable (#2474)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Apr 14, 2024
1 parent 9e20eb9 commit c906392
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 49 deletions.
37 changes: 29 additions & 8 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,18 @@ def __init__(
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -121,8 +123,13 @@ def __init__(
attention_heads, attention_dim, src_attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim) if src_attention else None,
mlp_class(attention_dim, linear_units, dropout_rate,
activation, mlp_bias),
mlp_class(attention_dim,
linear_units,
dropout_rate,
activation,
mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated),
dropout_rate,
normalize_before,
layer_norm_type,
Expand Down Expand Up @@ -327,17 +334,22 @@ def __init__(
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
src_attention: bool = True,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):

super().__init__()
Expand All @@ -356,17 +368,22 @@ def __init__(
input_layer,
use_output_layer,
normalize_before,
src_attention=src_attention,
query_bias=query_bias,
key_bias=key_bias,
value_bias=value_bias,
activation_type=activation_type,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -381,18 +398,22 @@ def __init__(
input_layer,
use_output_layer,
normalize_before,
src_attention=src_attention,
query_bias=query_bias,
key_bias=key_bias,
value_bias=value_bias,
mlp_bias=mlp_bias,
activation_type=activation_type,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)

def forward(
self,
Expand Down
23 changes: 17 additions & 6 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,18 @@ def __init__(
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
""" Construct TransformerEncoder
Expand All @@ -404,8 +406,13 @@ def __init__(
attention_heads, output_size, attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias),
mlp_class(output_size,
linear_units,
dropout_rate,
activation,
mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
Expand Down Expand Up @@ -445,15 +452,17 @@ def __init__(
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
conv_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
"""Construct ConformerEncoder
Expand Down Expand Up @@ -500,6 +509,8 @@ def __init__(
dropout_rate,
activation,
mlp_bias,
n_expert,
n_expert_activated,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
Expand Down
75 changes: 40 additions & 35 deletions wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ class PositionwiseFeedForward(torch.nn.Module):
"""

def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
*dummy_args,
**dummy_kwargs,
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
Expand Down Expand Up @@ -66,30 +68,31 @@ class MoEFFNLayer(torch.nn.Module):
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
Args:
n_expert: number of expert.
n_expert_per_token: The actual number of experts used for each frame
n_expert_activated: The actual number of experts used for each frame
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""

def __init__(
self,
n_expert: int,
n_expert_per_token: int,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = False,
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = False,
n_expert: int = 8,
n_expert_activated: int = 2,
):
super(MoEFFNLayer, self).__init__()
bias = False
self.gate = torch.nn.Linear(idim, n_expert, bias=bias)
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
self.experts = torch.nn.ModuleList(
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
activation) for _ in range(n_expert))
self.n_expert_per_token = n_expert_per_token
PositionwiseFeedForward(
idim, hidden_units, dropout_rate, activation, bias=bias)
for _ in range(n_expert))
self.n_expert = n_expert
self.n_expert_activated = n_expert_activated

def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Foward function.
Expand All @@ -103,18 +106,18 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor:
) # batch size, sequence length, embedding dimension (idim)
xs = xs.view(-1, D) # (B*L, D)
router = self.gate(xs) # (B*L, n_expert)
logits, indices = torch.topk(
router, self.n_expert_per_token
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
logits, selected_experts = torch.topk(
router, self.n_expert_activated
) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated)
weights = torch.nn.functional.softmax(
logits, dim=1,
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_activated)
output = torch.zeros_like(xs) # (B*L, D)
for i, expert in enumerate(self.experts):
mask = indices == i
batch_idx, ith_expert = torch.where(mask)
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
xs[batch_idx])
mask = selected_experts == i
token_ids, ith_expert = torch.where(mask)
output[token_ids] += weights[token_ids, ith_expert, None] * expert(
xs[token_ids])
return output.view(B, L, D)


Expand All @@ -123,12 +126,14 @@ class GatedVariantsMLP(torch.nn.Module):
"""

def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.GELU(),
bias: bool = True,
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.GELU(),
bias: bool = True,
*dummy_args,
**dummy_kwargs,
):
"""Construct a PositionwiseFeedForward object."""
super(GatedVariantsMLP, self).__init__()
Expand All @@ -140,7 +145,7 @@ def __init__(
# w_2 as down proj
self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias)

def forward(self, x):
def forward(self, x) -> torch.Tensor:
"""Foward function.
Args:
xs: input tensor (B, L, D)
Expand Down

0 comments on commit c906392

Please sign in to comment.