Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,6 @@ class HunYuanMoEV1Config(PretrainedConfig):
The number of experts for moe. If it is a list, it will be used as the number of experts for each layer.
moe_topk (int or List, *optional*, defaults to 1):
Number of experts selected per token (Top-K routing). List form enables layer-wise customization.
moe_drop_tokens (bool, *optional*, defaults to `False`):
Whether to drop tokens exceeding expert capacity instead of padding.
moe_random_routing_dropped_token (bool, *optional*, defaults to `False`):
If True, randomly routes dropped tokens to available experts.
head_dim (`int`, *optional*, defaults to 128):
The attention head dimension.
"""
Expand Down Expand Up @@ -138,9 +134,6 @@ def __init__(
attention_dropout=0.0,
num_experts: Union[int, list] = 1,
moe_topk: Union[int, list] = 1,
# capacity_factor: Union[int, List]=1.0,
moe_drop_tokens=False,
moe_random_routing_dropped_token=False,
head_dim=None,
**kwargs,
):
Expand All @@ -152,9 +145,6 @@ def __init__(
self.num_attention_heads = num_attention_heads
self.num_experts = num_experts
self.moe_topk = moe_topk
# self.capacity_factor = capacity_factor
self.moe_drop_tokens = moe_drop_tokens
self.moe_random_routing_dropped_token = moe_random_routing_dropped_token

self.head_dim = head_dim
# for backward compatibility
Expand Down
179 changes: 44 additions & 135 deletions src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch import nn

from transformers.cache_utils import Cache

Expand Down Expand Up @@ -228,112 +228,11 @@ def forward(
return attn_output, attn_weights


def topkgating(logits: Tensor, topk: int):
if topk == 1:
"""Implements Top1Gating on logits."""
# everything is in fp32 in this function
logits = logits.float()
gates = F.softmax(logits, dim=1)
capacity = gates.shape[0]

# Create a mask for 1st's expert per token
# noisy gating
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)

# gating decisions

top_idx = torch.topk(mask1, k=capacity, dim=0)[1]

new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1

# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)

# Normalize gate probabilities
mask1_float = mask1.float()
gates = gates * mask1_float

locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float() # one hot to float
combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc)

dispatch_mask = combine_weights.bool()
return combine_weights, dispatch_mask

logits = logits.float()
gates = F.softmax(logits, dim=1)
# expert_capacity = topk * gates.shape[0]
expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
num_experts = int(gates.shape[1])
# Top-k router probability and corresponding expert indices for each token.
# Shape: [tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(gates, topk)
expert_mask = F.one_hot(expert_index, num_experts)
# For a given token, determine if it was routed to a given expert.
# Shape: [tokens_per_group, num_experts]

gates_s = torch.clamp(
torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
)
router_probs = gates / gates_s
# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
expert_index = torch.transpose(expert_index, 0, 1)
# Shape: [num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(-1)

# Create mask out of indices.
# Shape: [tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)

# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# Shape: [tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
# Shape: [num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((topk, -1, num_experts))
# Shape: [tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 0, 1)
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# Shape: [tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=1)[0]

# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# Shape: [tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)

# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
# combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
router_probs_expanded = router_probs.unsqueeze(-1)
combine_weights = router_probs_expanded * dispatch_mask
return combine_weights, dispatch_mask


class HunYuanTopKGate(nn.Module):
class HunYuanMoEV1Gate(nn.Module):
def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.moe_topk = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx]
self.drop_tokens = config.moe_drop_tokens
self.random_routing_dropped_token = config.moe_random_routing_dropped_token
num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
self.wg = nn.Linear(config.hidden_size, num_experts, bias=False, dtype=torch.float32)

Expand All @@ -343,62 +242,72 @@ def forward(self, hidden_states):
if self.wg.weight.dtype == torch.float32:
hidden_states = hidden_states.float()
logits = self.wg(hidden_states)
gate_output = topkgating(logits, self.moe_topk)
return logits

return gate_output


class HunYuanMoE(nn.Module):
class HunYuanMoEV1Moe(nn.Module):
def __init__(self, config: HunYuanMoEV1Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.moe_topk = config.moe_topk
self.num_experts = config.num_experts if isinstance(config.num_experts, int) else config.num_experts[layer_idx]
self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True)
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
self.top_k = config.moe_topk if isinstance(config.moe_topk, int) else config.moe_topk[layer_idx]
self.norm_topk_prob = config.norm_topk_prob
self.gate = HunYuanMoEV1Gate(config, layer_idx=layer_idx)
# self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32)
self.experts = nn.ModuleList(
[HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(self.num_experts)]
)

def forward(self, hidden_states):
bsz, seq_len, hidden_size = hidden_states.shape
self.shared_mlp = HunYuanMoEV1MLP(config, layer_idx=layer_idx, is_shared_mlp=True)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states_mlp = self.shared_mlp(hidden_states)
router_logits = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

combine_weights, dispatch_mask = self.gate(hidden_states)

reshaped_input = hidden_states.reshape(-1, hidden_size)

# dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
dispatch_mask_expanded = dispatch_mask.type_as(hidden_states).unsqueeze(3) # (s, e, c, 1)
reshaped_input_expanded = reshaped_input.unsqueeze(1).unsqueeze(1) # (s, 1, 1, m)
dispatched_input = (dispatch_mask_expanded * reshaped_input_expanded).sum(dim=(0)) # (s, m)

chunks = dispatched_input.chunk(self.num_experts, dim=0)
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
expert_outputs.append(expert(chunk))

expert_output = torch.cat(expert_outputs, dim=0)
# combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
combine_exp = combine_weights.type_as(hidden_states).unsqueeze(3) # (s, e, c, 1)
expert_exp = expert_output.unsqueeze(0) # (1, e, c, m)
combined_output = (combine_exp * expert_exp).sum(dim=(1, 2)) # (s, m)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
# Loop over all available experts in the model and perform the computation on each expert
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))

output = hidden_states_mlp + combined_output
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

return output
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states + hidden_states_mlp


class HunYuanMoEV1DecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: HunYuanMoEV1Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = HunYuanMoEV1Attention(config=config, layer_idx=layer_idx)
self.mlp = HunYuanMoE(config, layer_idx=layer_idx)
self.mlp = HunYuanMoEV1Moe(config, layer_idx=layer_idx)
self.input_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = HunYuanMoEV1RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.layer_idx = layer_idx
Expand Down
Loading