Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/sglang/srt/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def __init__(
f"the number of experts {self.num_total_experts}."
)
# Split experts equally between ranks
self.expert_indicies = np.array_split(
self.expert_indices = np.array_split(
range(self.num_total_experts), self.tp_size
)[self.rank].tolist()
if not self.expert_indicies:
if not self.expert_indices:
raise ValueError(f"Rank {self.rank} has no experts assigned to it.")

self.experts = nn.ModuleList(
Expand All @@ -131,7 +131,7 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix(f"experts.{idx}", prefix),
)
if idx in self.expert_indicies
if idx in self.expert_indices
else None
)
for idx in range(self.num_total_experts)
Expand All @@ -155,7 +155,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

final_hidden_states = None
for expert_idx in self.expert_indicies:
for expert_idx in self.expert_indices:
expert_layer = self.experts[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
Expand Down
Loading