Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,7 +1280,7 @@ def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)

@torch.compile(dynamic=True, backend=get_compiler_backend())
def forward_native(
def _forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -1340,6 +1340,27 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass with optional Triton kernel acceleration.
Args:
positions:
[num_tokens,] (text only) or
[3, num_tokens] (T/H/W positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 1 or positions.ndim == 2

if positions.ndim == 2 and self.mrope_section and _is_cuda:
return self._forward_triton(positions, query, key)
else:
return self._forward_native(positions, query, key)

def _forward_triton(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
Expand Down
Loading