Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,21 @@ def select_experts(
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native"
assert expert_location_dispatch_info is None
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
device = hidden_states.device
if device == torch.device("cpu") and _is_cpu_amx:
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
elif custom_routing_function is None:
assert (
num_token_non_padded is None
Expand Down
30 changes: 30 additions & 0 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_cpu_amx:
return self.forward_cpu(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)

def forward_native(
self,
positions: torch.Tensor,
Expand Down Expand Up @@ -147,6 +157,26 @@ def forward_native(
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

def forward_cpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
positions = torch.add(positions, offsets) if offsets is not None else positions
if positions.device == torch.device("cpu") and _is_cpu_amx:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
else:
return self.forward_native(positions, query, key, offsets)

def forward_cuda(
self,
positions: torch.Tensor,
Expand Down
Loading