Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/amd/wheel/sglang/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ srt_musa = [
"sglang[runtime_common]",
"torch",
"torch_musa",
"torchada>=0.1.54",
"torchada>=0.1.55",
"mthreads-ml-py",
"mate>=0.2.0",
"deep-gemm>=0.1.3",
Expand Down
2 changes: 1 addition & 1 deletion python/pyproject_other.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ srt_musa = [
"sglang[runtime_common]",
"torch",
"torch_musa",
"torchada>=0.1.54",
"torchada>=0.1.55",
"mthreads-ml-py",
"mate>=0.2.0",
"deep-gemm>=0.1.3",
Expand Down
63 changes: 63 additions & 0 deletions python/sglang/srt/hardware_backend/musa/utils/patch_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import re
from dataclasses import replace as _dataclass_replace

import torch
import torch.fx.graph as fx_graph

_DEVICE_REPR_RE = re.compile(r"\bdevice\(type='([^']+)'(?:,\s*index=(\d+))?\)")


def _replace_device_repr(m: re.Match) -> str:
dev_type = m.group(1)
dev_index = m.group(2)
if dev_index is not None:
return f"torch.device('{dev_type}:{dev_index}')"
return f"torch.device('{dev_type}')"


def patch_fx_custom_device() -> None:
"""
Fix FX codegen serialization for non-standard devices (e.g. torch_musa).

Root cause:
torch.device is registered as a custom builtin named 'device', imported
via 'from torch import device'. repr(torch.device('musa', 0)) produces
"device(type='musa', index=0)", which is syntactically valid but fails
at runtime because torch.device does not recognize 'musa' as a type when
invoked through the standard import path.

Fix:
Post-process the generated src string, replacing all occurrences of
device(type='x', index=N) with torch.device('x:N'), and ensure 'torch'
is present in the graph globals.

Note:
_get_repr is a closure inside _gen_python_code and cannot be patched
directly, so we wrap _gen_python_code and rewrite its output instead.
"""
original = fx_graph.CodeGen._gen_python_code

def patched(self, nodes, root_module, namespace, **kwargs):
result = original(self, nodes, root_module, namespace, **kwargs)
new_src = _DEVICE_REPR_RE.sub(_replace_device_repr, result.src)
if new_src is result.src:
return result
result.globals.setdefault("torch", torch)
if hasattr(result, "_replace"):
return result._replace(src=new_src)
return _dataclass_replace(result, src=new_src)

fx_graph.CodeGen._gen_python_code = patched
12 changes: 8 additions & 4 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@
elif _is_hip:
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
elif _is_musa:
from sgl_kernel import silu_and_mul
from sglang.srt.utils.patch_torch import register_fake_if_exists

@register_fake_if_exists("aten::_fused_swiglu_forward")
def _(x):
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
return torch.empty(output_shape, dtype=x.dtype, device=x.device)


if is_npu():
import torch_npu
Expand Down Expand Up @@ -106,9 +113,6 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return out

def forward_musa(self, x: torch.Tensor) -> torch.Tensor:
if not get_global_server_args().disable_piecewise_cuda_graph:
return self.forward_native(x)

if not hasattr(self, "_musa_swish_glu"):
# XXX (MUSA): nn.SwishGLU seems to have better performance than silu_and_mul on MUSA, we can switch to it for now. We can consider implementing a silu_and_mul kernel for MUSA in the future if needed.
Comment thread
popsiclexu marked this conversation as resolved.
self._musa_swish_glu = nn.SwishGLU()
Expand Down
3 changes: 0 additions & 3 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,6 @@ def forward_musa(
residual: Optional[torch.Tensor] = None,
post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not get_global_server_args().disable_piecewise_cuda_graph:
return self.forward_native(x, residual, post_residual_addition)

if not x.is_contiguous():
x = x.contiguous()

Expand Down
18 changes: 18 additions & 0 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,24 @@
# Fallback: vllm not available, will use native PyTorch implementation
_has_vllm = False

if _is_musa:

@register_fake_if_exists("sgl_kernel::sgl_per_token_group_quant_8bit_v2")
def _(
input,
output_q,
output_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
):
return


logger = logging.getLogger(__name__)


Expand Down
10 changes: 10 additions & 0 deletions python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from sglang.srt.model_executor.input_buffers import ForwardInputBuffers
from sglang.srt.utils import (
get_available_gpu_memory,
is_musa,
is_npu,
log_info_on_rank0,
require_gathered_buffer,
Expand All @@ -73,6 +74,8 @@
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner

_is_musa = is_musa()


@dataclass
class PrefillInputBuffers(ForwardInputBuffers):
Expand Down Expand Up @@ -147,6 +150,13 @@ def set_torch_compile_config():
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024

if _is_musa:
from sglang.srt.hardware_backend.musa.utils.patch_torch import (
patch_fx_custom_device,
)

patch_fx_custom_device()


class PiecewiseCudaGraphRunner:
"""A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/utils/patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
import torch
from torch.multiprocessing import reductions

from sglang.srt.utils.common import is_npu, torch_release
from sglang.srt.utils.common import is_musa, is_npu, torch_release

_is_npu = is_npu()
_is_musa = is_musa()

if _is_npu:
from torch_npu.multiprocessing import reductions as npu_reductions
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject_musa.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [
"setuptools>=75.0",
"scikit-build-core>=0.10",
"torch",
"torchada>=0.1.54",
"torchada>=0.1.55",
"wheel",
]
build-backend = "setuptools.build_meta"
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/python/sgl_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def wrapper(*args, **kwargs):

@cache_once
def is_arch_support_pdl() -> bool:
if bool(torch.version.hip):
if getattr(torch.version, "hip", None) or getattr(torch.version, "musa", None):
return False
try:
device = torch.cuda.current_device()
Expand Down
Loading