Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
x = x.view(-1, C) # (B*T, C)
router = self.gate(x) # (B*T, n_expert)
router = F.softmax(router, dim=1, dtype=torch.float)
probs, indices = torch.topk(router, self.config.n_expert_per_token) # (B*T, n_expert_per_token)
probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
probs /= probs.sum(dim=1, keepdim=True)
probs = probs.to(dtype=x.dtype)
masks = indices.unsqueeze(-1) == torch.arange(self.config.n_expert, device=x.device)
masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
y = torch.zeros_like(x) # (B*T, C)
Expand Down
Loading