Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ class Envs:
SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm"))
SGLANG_DG_USE_NVRTC = EnvBool(False)
SGLANG_USE_DEEPGEMM_BMM = EnvBool(False)
SGLANG_DEEPGEMM_SANITY_CHECK = EnvBool(False)

# DeepSeek MHA Optimization
SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD = EnvInt(8192)
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
is_cpu,
is_cuda,
is_hip,
is_musa,
is_npu,
is_xpu,
set_weight_attrs,
)
from sglang.utils import resolve_obj_by_qualname

_is_cuda = is_cuda()
_is_musa = is_musa()
_is_npu = is_npu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
Expand All @@ -53,6 +55,8 @@
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
elif _is_hip:
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
elif _is_musa:
from sgl_kernel import silu_and_mul

if is_npu():
import torch_npu
Expand Down Expand Up @@ -95,6 +99,15 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
silu_and_mul(x, out)
return out

def forward_musa(self, x: torch.Tensor) -> torch.Tensor:
if not get_global_server_args().disable_piecewise_cuda_graph:
return self.forward_native(x)

if not hasattr(self, "_musa_swish_glu"):
# XXX (MUSA): nn.SwishGLU seems to have better performance than silu_and_mul on MUSA, we can switch to it for now. We can consider implementing a silu_and_mul kernel for MUSA in the future if needed.
self._musa_swish_glu = nn.SwishGLU()
return self._musa_swish_glu(x)


class GeluAndMul(MultiPlatformOp):
def __init__(self, approximate="tanh"):
Expand Down
17 changes: 14 additions & 3 deletions python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from enum import IntEnum, auto
from typing import Dict, List, Tuple

Expand All @@ -14,10 +14,12 @@
from sglang.srt.environ import envs
from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ceil_div, get_available_gpu_memory
from sglang.srt.utils import ceil_div, get_available_gpu_memory, is_musa

logger = logging.getLogger(__name__)

_is_musa = is_musa()

if ENABLE_JIT_DEEPGEMM:
import deep_gemm

Expand Down Expand Up @@ -332,9 +334,18 @@ def execute(self, m):
deep_gemm.bf16_gemm_nt(self.lhs[:m], self.rhs, self.out[:m])


@contextmanager
def deep_gemm_execution_hook(
Comment thread
popsiclexu marked this conversation as resolved.
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
):
if _is_musa:
return nullcontext()

return _deep_gemm_execution_hook(m, n, k, num_groups, kernel_type)


@contextmanager
def _deep_gemm_execution_hook(
m: int, n: int, k: int, num_groups: int, kernel_type: DeepGemmKernelType
):
if m > 0:
_maybe_compile_deep_gemm_one_type_all(kernel_type, n, k, num_groups)
Expand Down
13 changes: 11 additions & 2 deletions python/sglang/srt/layers/deep_gemm_wrapper/configurer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
import logging

from sglang.srt.environ import envs
from sglang.srt.utils import get_device_sm, is_blackwell_supported
from sglang.srt.utils import (
get_device_sm,
is_blackwell_supported,
is_cuda,
is_musa,
)

logger = logging.getLogger(__name__)

_is_cuda = is_cuda()
_is_musa = is_musa()


def _compute_enable_deep_gemm():
sm_version = get_device_sm()
if sm_version < 90:
if (_is_cuda and sm_version < 90) or (_is_musa and sm_version < 31):
return False

try:
Expand All @@ -23,3 +31,4 @@ def _compute_enable_deep_gemm():

DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell_supported()
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
DEEPGEMM_NEED_TMA_ALIGNED_SCALES = not (DEEPGEMM_SCALE_UE8M0 or _is_musa)
5 changes: 3 additions & 2 deletions python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@

import torch

from sglang.srt.environ import envs
from sglang.srt.layers.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
DEEPGEMM_BLACKWELL,
DEEPGEMM_NEED_TMA_ALIGNED_SCALES,
DEEPGEMM_SCALE_UE8M0,
ENABLE_JIT_DEEPGEMM,
)
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var

logger = logging.getLogger(__name__)

if ENABLE_JIT_DEEPGEMM:
import deep_gemm
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor # noqa: F401

_SANITY_CHECK = get_bool_env_var("SGLANG_DEEPGEMM_SANITY_CHECK")
_SANITY_CHECK = envs.SGLANG_DEEPGEMM_SANITY_CHECK.get()
Comment thread
popsiclexu marked this conversation as resolved.


# TODO maybe rename these functions
Expand Down
27 changes: 26 additions & 1 deletion python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,23 @@
is_cuda,
is_flashinfer_available,
is_hip,
is_musa,
is_npu,
is_xpu,
)

_is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_hip = is_hip()
_is_musa = is_musa()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_xpu = is_xpu()
_flashinfer_layernorm_available = False

if _is_cuda or _is_xpu:
if _is_cuda or _is_xpu or _is_musa:
if _is_flashinfer_available:
try:
from flashinfer.norm import layernorm
Expand Down Expand Up @@ -284,6 +286,29 @@ def forward_hip(
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out

def forward_musa(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not get_global_server_args().disable_piecewise_cuda_graph:
return self.forward_native(x, residual, post_residual_addition)

if not x.is_contiguous():
x = x.contiguous()

if residual is not None:
if post_residual_addition is not None:
residual = residual + post_residual_addition
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual

out = nn.functional.rms_norm(
x, (self.hidden_size,), self.weight.data, self.variance_epsilon
)
return out

def forward_native(
self,
x: torch.Tensor,
Expand Down
14 changes: 11 additions & 3 deletions python/sglang/srt/layers/moe/ep_moe/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import torch
import triton

from sglang.srt.utils import ceil_div, is_cuda
from sglang.srt.utils import ceil_div, is_cuda, is_musa

logger = logging.getLogger(__name__)

_is_cuda = is_cuda()
if _is_cuda:
_is_musa = is_musa()

if _is_cuda or _is_musa:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
)
Expand Down Expand Up @@ -665,6 +667,8 @@ def _fwd_kernel_ep_scatter_2(
HIDDEN_SIZE_PAD: tl.constexpr,
SCALE_HIDDEN_SIZE: tl.constexpr,
SCALE_HIDDEN_SIZE_PAD: tl.constexpr,
# Platform-specific semaphore for atomic_add performance tuning
ATOMIC_ADD_SEM: tl.constexpr,
):
start_token_id = tl.program_id(0)
grid_num = tl.num_programs(0)
Expand All @@ -689,7 +693,9 @@ def _fwd_kernel_ep_scatter_2(
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if expert_id >= 0:
dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index_int32 = tl.atomic_add(
expert_start_loc + expert_id, 1, sem=ATOMIC_ADD_SEM
Comment thread
popsiclexu marked this conversation as resolved.
)
dest_token_index = dest_token_index_int32.to(tl.int64)

tl.store(
Expand Down Expand Up @@ -783,6 +789,8 @@ def ep_scatter(
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
# XXX (MUSA): Atomic add with "relaxed" semaphore on musa backend for better performance
ATOMIC_ADD_SEM=None if not _is_musa else "relaxed",
)
return

Expand Down
11 changes: 7 additions & 4 deletions python/sglang/srt/layers/moe/moe_runner/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_bool_env_var,
is_cuda,
is_hip,
is_musa,
is_npu,
)
from sglang.srt.utils.offloader import get_offloader
Expand All @@ -42,6 +43,7 @@
_is_npu = is_npu()
_is_cuda = is_cuda()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_musa = is_musa()

if not (_is_npu or _is_hip) and _is_cuda:
from sgl_kernel import silu_and_mul
Expand Down Expand Up @@ -166,8 +168,9 @@ def _run_contiguous_gemm(
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
if deep_gemm_wrapper.DEEPGEMM_NEED_TMA_ALIGNED_SCALES:
hidden_states_scale = tma_align_input_scale(hidden_states_scale)

deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(hidden_states, hidden_states_scale),
w13_weight_fp8,
Expand Down Expand Up @@ -203,7 +206,7 @@ def _run_contiguous_gemm(
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
if deep_gemm_wrapper.DEEPGEMM_NEED_TMA_ALIGNED_SCALES:
down_input_scale = tma_align_input_scale(down_input_scale)

deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
Expand Down Expand Up @@ -251,7 +254,7 @@ def _run_masked_gemm(
hidden_states_scale = _cast_to_e8m0_with_rounding_up(
hidden_states_scale
)
else:
elif deep_gemm_wrapper.DEEPGEMM_NEED_TMA_ALIGNED_SCALES:
hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
hidden_states_scale
)
Expand Down Expand Up @@ -317,7 +320,7 @@ def _run_masked_gemm(
# GroupGemm-1
n = w2_weight.shape[1]

if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
if deep_gemm_wrapper.DEEPGEMM_NEED_TMA_ALIGNED_SCALES:
down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
down_input_scale
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
is_cpu,
is_cuda,
is_hip,
is_musa,
is_xpu,
use_intel_xpu_backend,
)
Expand All @@ -44,6 +45,7 @@
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_xpu = is_xpu()
_use_sgl_xpu = use_intel_xpu_backend()
_is_musa = is_musa()


if _is_cuda:
Expand All @@ -62,6 +64,10 @@
# because the code uses moe_sum_reduce_triton as fallback (line 619)
elif _is_xpu:
from sgl_kernel import moe_sum_reduce, silu_and_mul
elif _is_musa:
from sgl_kernel import moe_sum_reduce

_silu_and_mul_musa = torch.nn.SwishGLU()

# Try to import vllm_ops for non-CUDA/HIP/XPU platforms
_has_vllm_ops = False
Expand Down Expand Up @@ -534,6 +540,8 @@ def _fused_moe_kernel_sequence(
down_moe_use_tma,
activation,
)
elif _is_musa:
intermediate_cache2 = _silu_and_mul_musa(intermediate_cache1.view(-1, N))
else:
if _has_vllm_ops:
vllm_ops.silu_and_mul(
Expand Down Expand Up @@ -647,7 +655,7 @@ def _fused_moe_kernel_sequence(

if no_combine:
pass
elif _is_cuda:
elif _is_cuda or _is_musa:
if use_fused_moe_sum_all_reduce:
if routed_scaling_factor != 1.0:
assert out_slice is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import torch
import triton

from sglang.srt.utils import is_cuda, is_hip, is_xpu
from sglang.srt.utils import is_cuda, is_hip, is_musa, is_xpu

_is_cuda = is_cuda()
_is_hip = is_hip()
_is_xpu = is_xpu()
_is_musa = is_musa()

if _is_cuda or _is_hip or _is_xpu:
if _is_cuda or _is_hip or _is_xpu or _is_musa:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size


Expand Down
Loading
Loading