diff --git a/3rdparty/amd/wheel/sglang/pyproject.toml b/3rdparty/amd/wheel/sglang/pyproject.toml index d04c3f3bb96c..114648c172e1 100644 --- a/3rdparty/amd/wheel/sglang/pyproject.toml +++ b/3rdparty/amd/wheel/sglang/pyproject.toml @@ -123,7 +123,7 @@ srt_musa = [ "sglang[runtime_common]", "torch", "torch_musa", - "torchada>=0.1.54", + "torchada>=0.1.55", "mthreads-ml-py", "mate>=0.2.0", "deep-gemm>=0.1.3", diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml index d6f40474f26f..aba31a8f0078 100755 --- a/python/pyproject_other.toml +++ b/python/pyproject_other.toml @@ -115,7 +115,7 @@ srt_musa = [ "sglang[runtime_common]", "torch", "torch_musa", - "torchada>=0.1.54", + "torchada>=0.1.55", "mthreads-ml-py", "mate>=0.2.0", "deep-gemm>=0.1.3", diff --git a/python/sglang/srt/hardware_backend/musa/utils/patch_torch.py b/python/sglang/srt/hardware_backend/musa/utils/patch_torch.py new file mode 100644 index 000000000000..bfc91cc8fb23 --- /dev/null +++ b/python/sglang/srt/hardware_backend/musa/utils/patch_torch.py @@ -0,0 +1,63 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import re +from dataclasses import replace as _dataclass_replace + +import torch +import torch.fx.graph as fx_graph + +_DEVICE_REPR_RE = re.compile(r"\bdevice\(type='([^']+)'(?:,\s*index=(\d+))?\)") + + +def _replace_device_repr(m: re.Match) -> str: + dev_type = m.group(1) + dev_index = m.group(2) + if dev_index is not None: + return f"torch.device('{dev_type}:{dev_index}')" + return f"torch.device('{dev_type}')" + + +def patch_fx_custom_device() -> None: + """ + Fix FX codegen serialization for non-standard devices (e.g. torch_musa). + + Root cause: + torch.device is registered as a custom builtin named 'device', imported + via 'from torch import device'. repr(torch.device('musa', 0)) produces + "device(type='musa', index=0)", which is syntactically valid but fails + at runtime because torch.device does not recognize 'musa' as a type when + invoked through the standard import path. + + Fix: + Post-process the generated src string, replacing all occurrences of + device(type='x', index=N) with torch.device('x:N'), and ensure 'torch' + is present in the graph globals. + + Note: + _get_repr is a closure inside _gen_python_code and cannot be patched + directly, so we wrap _gen_python_code and rewrite its output instead. + """ + original = fx_graph.CodeGen._gen_python_code + + def patched(self, nodes, root_module, namespace, **kwargs): + result = original(self, nodes, root_module, namespace, **kwargs) + new_src = _DEVICE_REPR_RE.sub(_replace_device_repr, result.src) + if new_src is result.src: + return result + result.globals.setdefault("torch", torch) + if hasattr(result, "_replace"): + return result._replace(src=new_src) + return _dataclass_replace(result, src=new_src) + + fx_graph.CodeGen._gen_python_code = patched diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 2af8fcea0488..216e37a234ae 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -62,7 +62,14 @@ elif _is_hip: from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul elif _is_musa: - from sgl_kernel import silu_and_mul + from sglang.srt.utils.patch_torch import register_fake_if_exists + + @register_fake_if_exists("aten::_fused_swiglu_forward") + def _(x): + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + return torch.empty(output_shape, dtype=x.dtype, device=x.device) + if is_npu(): import torch_npu @@ -106,9 +113,6 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return out def forward_musa(self, x: torch.Tensor) -> torch.Tensor: - if not get_global_server_args().disable_piecewise_cuda_graph: - return self.forward_native(x) - if not hasattr(self, "_musa_swish_glu"): # XXX (MUSA): nn.SwishGLU seems to have better performance than silu_and_mul on MUSA, we can switch to it for now. We can consider implementing a silu_and_mul kernel for MUSA in the future if needed. self._musa_swish_glu = nn.SwishGLU() diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 3ac232dd71db..e9c9e7be8c71 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -344,9 +344,6 @@ def forward_musa( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if not get_global_server_args().disable_piecewise_cuda_graph: - return self.forward_native(x, residual, post_residual_addition) - if not x.is_contiguous(): x = x.contiguous() diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 2a36ea8e6b41..258394b64006 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -94,6 +94,24 @@ # Fallback: vllm not available, will use native PyTorch implementation _has_vllm = False +if _is_musa: + + @register_fake_if_exists("sgl_kernel::sgl_per_token_group_quant_8bit_v2") + def _( + input, + output_q, + output_s, + group_size, + eps, + fp8_min, + fp8_max, + scale_ue8m0, + fuse_silu_and_mul, + masked_m, + ): + return + + logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index da547df15e5e..661d1728c848 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -61,6 +61,7 @@ from sglang.srt.model_executor.input_buffers import ForwardInputBuffers from sglang.srt.utils import ( get_available_gpu_memory, + is_musa, is_npu, log_info_on_rank0, require_gathered_buffer, @@ -73,6 +74,8 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner +_is_musa = is_musa() + @dataclass class PrefillInputBuffers(ForwardInputBuffers): @@ -147,6 +150,13 @@ def set_torch_compile_config(): if hasattr(torch._dynamo.config, "cache_size_limit"): torch._dynamo.config.cache_size_limit = 1024 + if _is_musa: + from sglang.srt.hardware_backend.musa.utils.patch_torch import ( + patch_fx_custom_device, + ) + + patch_fx_custom_device() + class PiecewiseCudaGraphRunner: """A PiecewiseCudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" diff --git a/python/sglang/srt/utils/patch_torch.py b/python/sglang/srt/utils/patch_torch.py index 7546502cd826..68cc94002004 100644 --- a/python/sglang/srt/utils/patch_torch.py +++ b/python/sglang/srt/utils/patch_torch.py @@ -16,9 +16,10 @@ import torch from torch.multiprocessing import reductions -from sglang.srt.utils.common import is_npu, torch_release +from sglang.srt.utils.common import is_musa, is_npu, torch_release _is_npu = is_npu() +_is_musa = is_musa() if _is_npu: from torch_npu.multiprocessing import reductions as npu_reductions diff --git a/sgl-kernel/pyproject_musa.toml b/sgl-kernel/pyproject_musa.toml index 160be2792e7d..614936c2333f 100644 --- a/sgl-kernel/pyproject_musa.toml +++ b/sgl-kernel/pyproject_musa.toml @@ -3,7 +3,7 @@ requires = [ "setuptools>=75.0", "scikit-build-core>=0.10", "torch", - "torchada>=0.1.54", + "torchada>=0.1.55", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/sgl-kernel/python/sgl_kernel/utils.py b/sgl-kernel/python/sgl_kernel/utils.py index 7e98c6cc3d4d..71c56792394d 100644 --- a/sgl-kernel/python/sgl_kernel/utils.py +++ b/sgl-kernel/python/sgl_kernel/utils.py @@ -56,7 +56,7 @@ def wrapper(*args, **kwargs): @cache_once def is_arch_support_pdl() -> bool: - if bool(torch.version.hip): + if getattr(torch.version, "hip", None) or getattr(torch.version, "musa", None): return False try: device = torch.cuda.current_device()