Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] Make MoE runnable #2474

Merged
merged 2 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,18 @@ def __init__(
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
super().__init__()
attention_dim = encoder_output_size
Expand Down Expand Up @@ -121,8 +123,13 @@ def __init__(
attention_heads, attention_dim, src_attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim) if src_attention else None,
mlp_class(attention_dim, linear_units, dropout_rate,
activation, mlp_bias),
mlp_class(attention_dim,
linear_units,
dropout_rate,
activation,
mlp_bias,
n_expert=n_expert,
Mddct marked this conversation as resolved.
Show resolved Hide resolved
n_expert_activated=n_expert_activated),
dropout_rate,
normalize_before,
layer_norm_type,
Expand Down Expand Up @@ -327,17 +334,22 @@ def __init__(
input_layer: str = "embed",
use_output_layer: bool = True,
normalize_before: bool = True,
src_attention: bool = True,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):

super().__init__()
Expand All @@ -356,17 +368,22 @@ def __init__(
input_layer,
use_output_layer,
normalize_before,
src_attention=src_attention,
query_bias=query_bias,
key_bias=key_bias,
value_bias=value_bias,
activation_type=activation_type,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -381,18 +398,22 @@ def __init__(
input_layer,
use_output_layer,
normalize_before,
src_attention=src_attention,
query_bias=query_bias,
key_bias=key_bias,
value_bias=value_bias,
mlp_bias=mlp_bias,
activation_type=activation_type,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa,
layer_norm_type=layer_norm_type,
norm_eps=norm_eps,
n_kv_head=n_kv_head,
head_dim=head_dim,
)
mlp_type=mlp_type,
mlp_bias=mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated)

def forward(
self,
Expand Down
23 changes: 17 additions & 6 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,18 @@ def __init__(
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
selfattention_layer_type: str = "selfattn",
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
""" Construct TransformerEncoder

Expand All @@ -404,8 +406,13 @@ def __init__(
attention_heads, output_size, attention_dropout_rate,
query_bias, key_bias, value_bias, use_sdpa, n_kv_head,
head_dim),
mlp_class(output_size, linear_units, dropout_rate, activation,
mlp_bias),
mlp_class(output_size,
linear_units,
dropout_rate,
activation,
mlp_bias,
n_expert=n_expert,
n_expert_activated=n_expert_activated),
dropout_rate,
normalize_before,
layer_norm_type=layer_norm_type,
Expand Down Expand Up @@ -445,15 +452,17 @@ def __init__(
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
mlp_bias: bool = True,
conv_bias: bool = True,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
mlp_type: str = 'position_wise_feed_forward',
layer_norm_type: str = 'layer_norm',
norm_eps: float = 1e-5,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None,
mlp_type: str = 'position_wise_feed_forward',
mlp_bias: bool = True,
n_expert: int = 8,
n_expert_activated: int = 2,
):
"""Construct ConformerEncoder

Expand Down Expand Up @@ -500,6 +509,8 @@ def __init__(
dropout_rate,
activation,
mlp_bias,
n_expert,
n_expert_activated,
)
# convolution module definition
convolution_layer_args = (output_size, cnn_module_kernel, activation,
Expand Down
75 changes: 40 additions & 35 deletions wenet/transformer/positionwise_feed_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ class PositionwiseFeedForward(torch.nn.Module):
"""

def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = True,
*dummy_args,
**dummy_kwargs,
):
"""Construct a PositionwiseFeedForward object."""
super(PositionwiseFeedForward, self).__init__()
Expand Down Expand Up @@ -66,30 +68,31 @@ class MoEFFNLayer(torch.nn.Module):
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
Args:
n_expert: number of expert.
n_expert_per_token: The actual number of experts used for each frame
n_expert_activated: The actual number of experts used for each frame
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
activation (torch.nn.Module): Activation function
"""

def __init__(
self,
n_expert: int,
n_expert_per_token: int,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = False,
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.ReLU(),
bias: bool = False,
n_expert: int = 8,
n_expert_activated: int = 2,
):
super(MoEFFNLayer, self).__init__()
bias = False
self.gate = torch.nn.Linear(idim, n_expert, bias=bias)
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
self.experts = torch.nn.ModuleList(
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
activation) for _ in range(n_expert))
self.n_expert_per_token = n_expert_per_token
PositionwiseFeedForward(
idim, hidden_units, dropout_rate, activation, bias=bias)
for _ in range(n_expert))
self.n_expert = n_expert
self.n_expert_activated = n_expert_activated

def forward(self, xs: torch.Tensor) -> torch.Tensor:
"""Foward function.
Expand All @@ -103,18 +106,18 @@ def forward(self, xs: torch.Tensor) -> torch.Tensor:
) # batch size, sequence length, embedding dimension (idim)
xs = xs.view(-1, D) # (B*L, D)
router = self.gate(xs) # (B*L, n_expert)
logits, indices = torch.topk(
router, self.n_expert_per_token
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
logits, selected_experts = torch.topk(
router, self.n_expert_activated
) # probs:(B*L, n_expert_activated), selected_exp: (B*L, n_expert_activated)
weights = torch.nn.functional.softmax(
logits, dim=1,
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_activated)
output = torch.zeros_like(xs) # (B*L, D)
for i, expert in enumerate(self.experts):
mask = indices == i
batch_idx, ith_expert = torch.where(mask)
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
xs[batch_idx])
mask = selected_experts == i
token_ids, ith_expert = torch.where(mask)
output[token_ids] += weights[token_ids, ith_expert, None] * expert(
xs[token_ids])
return output.view(B, L, D)


Expand All @@ -123,12 +126,14 @@ class GatedVariantsMLP(torch.nn.Module):
"""

def __init__(
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.GELU(),
bias: bool = True,
self,
idim: int,
hidden_units: int,
dropout_rate: float,
activation: torch.nn.Module = torch.nn.GELU(),
bias: bool = True,
*dummy_args,
**dummy_kwargs,
):
"""Construct a PositionwiseFeedForward object."""
super(GatedVariantsMLP, self).__init__()
Expand All @@ -140,7 +145,7 @@ def __init__(
# w_2 as down proj
self.w_2 = torch.nn.Linear(hidden_units, idim, bias=bias)

def forward(self, x):
def forward(self, x) -> torch.Tensor:
"""Foward function.
Args:
xs: input tensor (B, L, D)
Expand Down
Loading