Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/sglang/srt/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,20 @@ def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()

# States for torch.compile
self._original_forward_method = None
self.is_torch_compile = False

def enter_torch_compile(self, num_tokens: int):
# Skip if Op is already entered compile mode.
# NOTE(alcanderian): Some Ops(for example RotaryEmbedding) will be reused
# among layers and `enter_torch_compile` will be called many times.
# We should prevent `self._original_forward_method` from being overridden when
# it is not the first time `enter_torch_compile` called.
if self.is_torch_compile:
return

self._original_forward_method = self._forward_method
# NOTE: Temporarily workaround MoE
if "FusedMoE" in self.__class__.__name__:
if num_tokens == 1:
Expand All @@ -27,7 +40,12 @@ def enter_torch_compile(self, num_tokens: int):
self.is_torch_compile = True

def leave_torch_compile(self):
self._forward_method = self.forward_cuda
# Skip if Op is already exited compile mode.
if not self.is_torch_compile:
return

self._forward_method = self._original_forward_method
self._original_forward_method = None
self.is_torch_compile = False

# Please do not override this method, because `self._forward_method` can change when in torch compile mode
Expand Down
18 changes: 0 additions & 18 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,6 @@ def __init__(
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, *args, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would still need this custom dispatch in this context - over-general abstraction as _is_hip to forward_hip isn't enough - we will need to add forward_aiter specifically in this block.

if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
elif _is_hip:
return self.forward_hip(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)

def forward_cuda(
self,
x: torch.Tensor,
Expand Down Expand Up @@ -117,14 +107,6 @@ def __init__(
self.weight = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)

def forward_native(
self,
x: torch.Tensor,
Expand Down
11 changes: 0 additions & 11 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,17 +650,6 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward_hip(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

def forward(self, *args, **kwargs):
if torch.compiler.is_compiling():
return self.forward_native(*args, **kwargs)
if _is_cuda:
return self.forward_cuda(*args, **kwargs)
else:
return self.forward_native(*args, **kwargs)

def forward_native(
self,
positions: torch.Tensor,
Expand Down
Loading