Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,8 @@ def __init__(
def forward(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would still need this custom dispatch in this context - over-general abstraction as _is_hip to forward_hip isn't enough - we will need to add forward_aiter specifically in this block.

if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_hip:
return self.forward_hip(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
return self._forward_method(*args, **kwargs)

def forward_cuda(
self,
Expand Down Expand Up @@ -120,10 +116,8 @@ def __init__(
def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)
return self._forward_method(*args, **kwargs)

def forward_native(
self,
Expand Down
14 changes: 6 additions & 8 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
else:
return self._forward_method(*args, **kwargs)

def forward_native(
self,
positions: torch.Tensor,
Expand Down Expand Up @@ -653,14 +659,6 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
def forward_hip(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)

def forward_native(
self,
positions: torch.Tensor,
Expand Down
Loading