diff --git a/docker/deepseek_v4_b200.Dockerfile b/docker/deepseek_v4_b200.Dockerfile new file mode 100644 index 000000000000..f532684dbc0f --- /dev/null +++ b/docker/deepseek_v4_b200.Dockerfile @@ -0,0 +1,34 @@ +FROM lmsysorg/sglang:v0.5.7 + +# need: cu12.9, x86_64 docker +# Same dependency set as H200 (preset.py treats H200/B200 as one flavor). + +RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \ + git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git + +# tilelang 0.1.8 pinned: mhc.py uses T.gemm(wg_wait=0), removed in 0.1.9. +RUN pip install tilelang==0.1.8 + +RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu129 + +RUN cd /tmp && rm -rf flash-mla && \ + git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && \ + cd flash-mla && git submodule update --init --recursive && \ + pip install --no-build-isolation -v . && \ + cd /tmp && rm -rf flash-mla + +RUN pip install -e /workspace/sglang/python/ + +# DeepGEMM must come after sglang install: sglang pyproject pulls +# cuda-python / sgl-kernel / quack-kernels / nvidia-cutlass-dsl, which +# DeepGEMM depends on at the resolved versions. +RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \ + cd /tmp && rm -rf DeepGEMM && \ + git clone https://github.com/sgl-project/DeepGEMM.git -b release && \ + cd DeepGEMM && git checkout 7f2a70 && \ + git submodule update --init --recursive && \ + bash install.sh + +# DeepGEMM install.sh bumps apache-tvm-ffi to 0.1.10, which breaks tilelang +# 0.1.8 ABI. Re-pin to 0.1.9 (--no-deps so pip does not touch deep-gemm). +RUN pip install --no-deps apache-tvm-ffi==0.1.9 diff --git a/docker/deepseek_v4_b300.Dockerfile b/docker/deepseek_v4_b300.Dockerfile new file mode 100644 index 000000000000..8e1a8bff5ed8 --- /dev/null +++ b/docker/deepseek_v4_b300.Dockerfile @@ -0,0 +1,48 @@ +FROM lmsysorg/sglang:v0.5.7-cu130-runtime + +ENV PIP_BREAK_SYSTEM_PACKAGES=1 + +# tilelang's bundled libtvm.so depends on libz3.so (no version suffix). +# Base image ships nothing matching, and apt's libz3-4 only installs libz3.so.4. +RUN apt-get update && apt-get install -y --no-install-recommends libz3-4 && \ + ln -sf /usr/lib/x86_64-linux-gnu/libz3.so.4 /usr/lib/x86_64-linux-gnu/libz3.so && \ + ldconfig && rm -rf /var/lib/apt/lists/* + +RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \ + git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git + +RUN pip install cuda-python --upgrade +RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu130 + + +RUN pip install https://github.com/sgl-project/whl/releases/download/v0.3.21/sgl_kernel-0.3.21+cu130-cp310-abi3-manylinux2014_x86_64.whl + +RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \ + cd /tmp && rm -rf DeepGEMM && \ + git clone https://github.com/sgl-project/DeepGEMM.git -b release && \ + cd DeepGEMM && git checkout 7f2a70 && \ + git submodule update --init --recursive && \ + ln -sf $(pwd)/third-party/cutlass/include/cutlass $(pwd)/deep_gemm/include/cutlass && \ + ln -sf $(pwd)/third-party/cutlass/include/cute $(pwd)/deep_gemm/include/cute && \ + bash install.sh + +RUN pip install -e /workspace/sglang/python/ + + +RUN pip install --force-reinstall --no-deps tilelang==0.1.8 + +RUN pip install nvidia-cuda-cccl && \ + CCCL_INC=$(find /usr/local/lib -path "*/include/cccl/cuda/std" -type d 2>/dev/null | head -1 | sed 's|/cuda/std$||') && \ + ln -sf $CCCL_INC/cuda /usr/local/cuda/include/cuda && \ + mv /usr/local/cuda/targets/x86_64-linux/include/cccl /usr/local/cuda/targets/x86_64-linux/include/cccl.bak && \ + ln -sf $CCCL_INC /usr/local/cuda/targets/x86_64-linux/include/cccl +# FlashMLA — required by deepseek_v4_backend_radix.py +RUN cd /tmp && rm -rf flash-mla && \ + git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && \ + cd flash-mla && git submodule update --init --recursive && \ + pip install --no-build-isolation . +# fast_hadamard_transform — sgl_kernel 0.3.21 lacks hadamard_transform on B300 +RUN pip install --no-build-isolation git+https://github.com/Dao-AILab/fast-hadamard-transform.git + +# Install mooncake +RUN pip install mooncake-transfer-engine-cuda13 diff --git a/docker/deepseek_v4_grace_blackwell.Dockerfile b/docker/deepseek_v4_grace_blackwell.Dockerfile new file mode 100644 index 000000000000..9a75d2f73ff6 --- /dev/null +++ b/docker/deepseek_v4_grace_blackwell.Dockerfile @@ -0,0 +1,28 @@ +FROM lmsysorg/sglang:v0.5.7-cu130-runtime + +# need: cu13, arm docker +RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \ + git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git + +RUN pip install https://github.com/sgl-project/whl/releases/download/v0.3.21/sgl_kernel-0.3.21+cu130-cp310-abi3-manylinux2014_aarch64.whl +RUN pip install cuda-python --upgrade +RUN cd /tmp && rm -rf flash-mla && git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && cd flash-mla && ln -s /usr/local/cuda/include/cccl/cuda /usr/local/cuda/include/cuda && git submodule update --init --recursive && pip install --no-build-isolation -v . + +RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu130 +RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \ + cd /tmp && rm -rf DeepGEMM && git clone https://github.com/sgl-project/DeepGEMM.git -b release && \ + cd DeepGEMM && git checkout 003ed71 && \ + git submodule update --init --recursive && \ + ln -sf $(pwd)/third-party/cutlass/include/cutlass $(pwd)/deep_gemm/include/cutlass && \ + ln -sf $(pwd)/third-party/cutlass/include/cute $(pwd)/deep_gemm/include/cute && \ + bash install.sh +RUN pip install -e /workspace/sglang/python/ + +# Install TileLang for arm +RUN pip install https://github.com/tile-ai/tilelang/releases/download/v0.1.8/tilelang-0.1.8-cp38-abi3-manylinux_2_34_aarch64.whl + +# Install hadamard transform +RUN pip install --no-build-isolation git+https://github.com/Dao-AILab/fast-hadamard-transform.git + +# Install mooncake +RUN pip install mooncake-transfer-engine-cuda13 diff --git a/docker/deepseek_v4_h200.Dockerfile b/docker/deepseek_v4_h200.Dockerfile new file mode 100644 index 000000000000..17205bb3b082 --- /dev/null +++ b/docker/deepseek_v4_h200.Dockerfile @@ -0,0 +1,36 @@ +FROM lmsysorg/sglang:v0.5.7 + +# need: cu12.9, x86_64 docker + +RUN mkdir -p /workspace && cd /workspace && rm -rf sglang && \ + git clone -b deepseek_v4 https://github.com/sgl-project/sglang.git + +# tilelang 0.1.8 pinned: mhc.py uses T.gemm(wg_wait=0), removed in 0.1.9. +RUN pip install tilelang==0.1.8 + +RUN pip install flashinfer-jit-cache==0.6.8 --index-url https://flashinfer.ai/whl/cu129 + +RUN cd /tmp && rm -rf flash-mla && \ + git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla && \ + cd flash-mla && git submodule update --init --recursive && \ + pip install --no-build-isolation -v . && \ + cd /tmp && rm -rf flash-mla + +RUN pip install -e /workspace/sglang/python/ + +# Build kernel for w4a16 marlin +RUN cd /workspace/sglang/sgl-kernel && make build + +# DeepGEMM must come after sglang install: sglang pyproject pulls +# cuda-python / sgl-kernel / quack-kernels / nvidia-cutlass-dsl, which +# DeepGEMM depends on at the resolved versions. +RUN pip uninstall -y deep-gemm deep_gemm 2>/dev/null; \ + cd /tmp && rm -rf DeepGEMM && \ + git clone https://github.com/sgl-project/DeepGEMM.git -b release && \ + cd DeepGEMM && git checkout 7f2a70 && \ + git submodule update --init --recursive && \ + bash install.sh + +# DeepGEMM install.sh bumps apache-tvm-ffi to 0.1.10, which breaks tilelang +# 0.1.8 ABI. Re-pin to 0.1.9 (--no-deps so pip does not touch deep-gemm). +RUN pip install --no-deps apache-tvm-ffi==0.1.9 diff --git a/python/pyproject.toml b/python/pyproject.toml index ad2faf1888d3..56a670529111 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -28,8 +28,8 @@ dependencies = [ "datasets", "einops", "fastapi", - "flashinfer_python==0.6.2", # keep it aligned with jit-cache version in Dockerfile - "flashinfer_cubin==0.6.2", + "flashinfer_python==0.6.8", # keep it aligned with jit-cache version in Dockerfile + "flashinfer_cubin==0.6.8", "gguf", "hf_transfer", "huggingface_hub", @@ -55,7 +55,7 @@ dependencies = [ "pydantic", "python-multipart", "pyzmq>=25.1.2", - "quack-kernels==0.2.4", + # "quack-kernels==0.2.4", # conflicts with flashinfer 0.6.8 on nvidia-cutlass-dsl (<4.4.0 vs >=4.4.2); not used by current bench flows "requests", "scipy", "sentencepiece", diff --git a/python/sglang/jit_kernel/.clang-format b/python/sglang/jit_kernel/.clang-format index 56acfb8b8f5c..690cc3fea0d7 100644 --- a/python/sglang/jit_kernel/.clang-format +++ b/python/sglang/jit_kernel/.clang-format @@ -17,7 +17,7 @@ PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name IncludeCategories: - Regex: '^$' Priority: 0 - - Regex: '^$' + - Regex: '^$' Priority: 2 - Regex: '^$' Priority: 1 diff --git a/python/sglang/jit_kernel/all_reduce.py b/python/sglang/jit_kernel/all_reduce.py new file mode 100644 index 000000000000..ef808a871c78 --- /dev/null +++ b/python/sglang/jit_kernel/all_reduce.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import enum +from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, cast + +import torch +import tvm_ffi +from tvm_ffi import Module + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) + + +class ConfigResult(NamedTuple): + num_blocks: int + num_threads: int + + +class AllReduceAlgo(enum.Enum): + ONE_SHOT_PUSH = enum.auto() + ONE_SHOT_PULL = enum.auto() + TWO_SHOT_PULL = enum.auto() + + def is_push(self) -> bool: + return self == AllReduceAlgo.ONE_SHOT_PUSH + + @property + def shot(self) -> int: + return 2 if self == AllReduceAlgo.TWO_SHOT_PULL else 1 + + +if TYPE_CHECKING: + CUSTOM_AR_HANDLE = List[int] + CUSTOM_AR_PAIR = Tuple[int, CUSTOM_AR_HANDLE] + + class CustomAllReduceObj: + def __init__( + self, + rank: int, + world_size: int, + pull_buffer_bytes: int, + push_buffer_bytes: int, + graph_input_count: int, + *, + max_pull_blocks: Optional[int] = None, + max_push_blocks: Optional[int] = None, + ) -> None: + """ + Create a CustomAllReduceObj instance. + + :param rank: The rank of the current process. + :param world_size: The total number of processes in the group. + :param pull_buffer_bytes: The size of the buffer (in bytes) used for pull-based all-reduce. + :param push_buffer_bytes: The size of the buffer (in bytes) used for push-based all-reduce. + :param graph_input_count: The maximum number of inputs in all CUDA graphs. + :param max_pull_blocks: The maximum number of thread blocks to launch for pull-based all-reduce. + If None, it will be determined by the implementation. + :param max_push_blocks: The maximum number of thread blocks to launch for push-based all-reduce. + If None, it will be determined by the implementation. + """ + + @property + def world_size(self) -> int: ... + def share_storage(self) -> CUSTOM_AR_HANDLE: ... + def share_graph_inputs(self) -> List[CUSTOM_AR_PAIR]: ... + def post_init(self, handles: List[CUSTOM_AR_HANDLE]) -> None: ... + def register_inputs(self, handles: List[List[CUSTOM_AR_PAIR]]) -> None: ... + def set_cuda_graph_capture(self, is_capturing: bool) -> None: ... + def free(self, tp_cpu_group: torch.distributed.ProcessGroup) -> None: ... + def all_reduce( + self, input: torch.Tensor, algo: AllReduceAlgo + ) -> tvm_ffi.Tensor: ... + def config_pull( + self, num_blocks: int = -1, num_threads: int = -1 + ) -> ConfigResult: + """ + Configure the CUDA kernel's grid and block dimensions. + This provides only the upper bound of the configuration, + and the actual launch configuration may be determined by implementation. + Note that push-based all-reduce can not be configured currently. + + :param num_blocks: The maximum number of thread blocks to launch. -1 means no limit. + :param num_threads: The maximum number of threads per block. -1 means no limit. + + :return: The previous configuration as a ConfigResult named tuple. + """ + ... + + +@cache_once +def _jit_custom_all_reduce_pull_module(dtype: torch.dtype, world_size: int) -> Module: + args = make_cpp_args(dtype, world_size, is_arch_support_pdl()) + return load_jit( + "custom_all_reduce_pull", + *args, + extra_ldflags=["-lcuda"], + cuda_files=["distributed/custom_all_reduce_pull.cuh"], + cuda_wrappers=[("all_reduce", f"custom_all_reduce<{args}>")], + ) + + +@cache_once +def _jit_custom_all_reduce_push_module(dtype: torch.dtype, world_size: int) -> Module: + args = make_cpp_args(dtype, world_size, is_arch_support_pdl()) + return load_jit( + "custom_all_reduce_push", + *args, + extra_ldflags=["-lcuda"], + cuda_files=["distributed/custom_all_reduce_push.cuh"], + cuda_wrappers=[("all_reduce", f"custom_all_reduce<{args}>")], + ) + + +@cache_once +def get_custom_all_reduce_cls() -> type[CustomAllReduceObj]: + module = load_jit( + "custom_all_reduce_base", + extra_ldflags=["-lcuda"], + cuda_files=["distributed/custom_all_reduce_base.cuh"], + cuda_wrappers=[("register_once", "register_custom_all_reduce")], + ) + module.register_once() + device = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device) + NUM_CTA = props.multi_processor_count + MAX_THREADS = 512 + + @tvm_ffi.register_object("sgl.CustomAllReduce") + class CustomAllReduceObjReal(tvm_ffi.Object): + __slots__ = ("__dict__",) + + def __init__( + self, + rank: int, + world_size: int, + pull_buffer_bytes: int, + push_buffer_bytes: int, + graph_input_count: int, + *, + max_pull_blocks: Optional[int] = None, + max_push_blocks: Optional[int] = None, + ) -> None: + max_pull_blocks = NUM_CTA if max_pull_blocks is None else max_pull_blocks + max_push_blocks = NUM_CTA if max_push_blocks is None else max_push_blocks + self.__ffi_init__( + rank, + world_size, + max_pull_blocks, + max_push_blocks, + pull_buffer_bytes, + push_buffer_bytes, + graph_input_count, + ) + self._world_size = world_size + self._pull_config = ConfigResult(min(NUM_CTA, max_pull_blocks), MAX_THREADS) + if max_pull_blocks > 0: # special case: cannot configure 0 blocks + self.configure_pull(*self._pull_config) # type: ignore + + @property + def world_size(self) -> int: + return self._world_size + + def all_reduce( + self, + input: torch.Tensor, + algo: AllReduceAlgo, + ) -> tvm_ffi.Tensor: + compile_fn = ( + _jit_custom_all_reduce_push_module + if algo.is_push() + else _jit_custom_all_reduce_pull_module + ) + module = compile_fn(input.dtype, self._world_size) + return module.all_reduce(self, input, algo.shot) + + def config_pull( + self, num_blocks: int = -1, num_threads: int = -1 + ) -> ConfigResult: + old_config = self._pull_config + num_blocks = num_blocks if num_blocks != -1 else old_config.num_blocks + num_threads = num_threads if num_threads != -1 else old_config.num_threads + new_config = ConfigResult(num_blocks, num_threads) + if new_config != old_config: + result = ConfigResult(*self.configure_pull(*new_config)) # type: ignore + assert result == self._pull_config + self._pull_config = new_config + return old_config + + def free(self, tp_cpu_group: torch.distributed.ProcessGroup) -> None: + self.free_ipc_handles() # type: ignore + torch.distributed.barrier(group=tp_cpu_group) + self.free_storage() # type: ignore + + return cast(type["CustomAllReduceObj"], CustomAllReduceObjReal) diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh new file mode 100644 index 000000000000..3a89e8114ce5 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh @@ -0,0 +1,522 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan128 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) + +struct Compress128DecodeParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const int32_t* __restrict__ load_indices; + /** \brief The following part is plan info. */ + const Plan128* __restrict__ compress_plan; + const Plan128* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos, + const uint32_t lane_id) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 2; + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv | [1] = score | + Storage kv_score[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t window_len, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + const auto element_size = head_dim * 2; + const auto score_offset = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + const InFloat* src; + __builtin_assume(j < 128); + if (j < window_len) { + src = kv_score_buf + j * element_size; + } else { + /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = j - 127; + src = kv_score_src + k * element_size; + } + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, batch_size // decode info + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % kNumSplit; // split id + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: the write must be visible to the subsequent c128_forward, + /// so only the last warp can write to HBM + /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127` + if (warp_id == kNumWarps - 1) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id); + } + if (seq_len % 128 == 0) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, load_indices, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + uint32_t global_id; + if constexpr (kWrite) { + // for write kernel, we use global warp_id to dispatch work + global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads; + } else { + // for compress kernel, we use block id to dispatch work + global_id = blockIdx.x; // block id + } + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const auto indices_ptr = kWrite ? indices : load_indices; + + const int64_t split_offset = global_sid * kTileDim; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + const int32_t index = indices_ptr[global_bid]; + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + PDLWaitPrimary(); + + // only responsible for the compress part + if constexpr (kWrite) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id); + } else { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + template + static constexpr auto prefill_kernel = flash_c128_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWriteBlockSize = 128; + static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + // might be needed for prefill write + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}) // [read_positions] + .with_dtype() + .with_device(device_) + .verify(load_indices); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + + constexpr auto kBlockSize_C = kBlockSize; + constexpr auto kBlockSize_W = kWriteBlockSize; + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c128_online.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_online.cuh new file mode 100644 index 000000000000..b497470606cf --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_online.cuh @@ -0,0 +1,726 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace device::compress { + +/// \brief Plan entry for online compress 128 prefill. +/// Each entry describes a contiguous segment of tokens that lies inside a +/// single 128-chunk. Multiple segments can map to the same batch id when the +/// extend tokens span chunk boundaries. +/// +/// **Layout compatibility:** the field order/types match `PrefillPlan` so that +/// downstream kernels (e.g. `fused_norm_rope` in `CompressExtend` mode) can +/// consume the compress_plan tensor as-if it were a `PrefillPlan` tensor -- +/// they only read `ragged_id` and `position`, both of which carry identical +/// semantics here (the LAST token of the segment in q-ragged and global +/// coordinates respectively). +/// +/// Note that `window_len` here means "number of real tokens in this segment" +/// (1..128), which differs from `PrefillPlan::window_len`. Downstream kernels +/// that share the tensor MUST NOT read it under that name. +struct alignas(16) OnlinePrefillPlan { + /// \brief Ragged-q position of the LAST token in this segment. + /// Equal to `segment_start_ragged + window_len - 1`. + uint32_t ragged_id; + /// \brief Index into the `indices` / `load_indices` arrays. + uint32_t batch_id; + /// \brief Global position of the LAST token in this segment. + /// For compress plans, `position % 128 == 127` (chunk-closing); for write + /// plans, `position % 128 < 127`. + uint32_t position; + /// \brief Number of real tokens in this segment (1..128). + /// The first segment token sits at `position - window_len + 1` (global) and + /// at `ragged_id - window_len + 1` (ragged). + uint32_t window_len; +}; + +static_assert(alignof(OnlinePrefillPlan) == alignof(PrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == sizeof(PrefillPlan)); + +} // namespace device::compress + +namespace host::compress { + +using device::compress::OnlinePrefillPlan; +using OnlinePrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kOnlinePrefillPlanDim = 16; + +static_assert(alignof(OnlinePrefillPlan) == sizeof(OnlinePrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == kOnlinePrefillPlanDim * sizeof(OnlinePrefillPlanTensorDtype)); + +} // namespace host::compress + +namespace { + +using OnlinePlan = device::compress::OnlinePrefillPlan; +using IndiceT = int32_t; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlineDecodeParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlinePrefillParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ load_indices; + /// \brief Plan for segments that close a chunk (write to `kv_compressed_output`). + /// Shape: `[num_compress, 16]` (uint8). + const OnlinePlan* __restrict__ compress_plan; + /// \brief Plan for the trailing partial segment of each batch (write back to + /// `kv_score_buffer`). Shape: `[num_write, 16]` (uint8). + const OnlinePlan* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +// 4 elements per thread, kHeadDim / 4 threads per block +template +__global__ void flash_c128_online_decode(const __grid_constant__ Compress128OnlineDecodeParams params) { + using namespace device; + constexpr uint32_t kVecSize = 4; + constexpr uint32_t kBlockSize = kHeadDim / kVecSize; + using Vec = AlignedVector; + const auto gmem = tile::Memory::cta(kBlockSize); + const auto batch_id = blockIdx.x; + const auto index = params.indices[batch_id]; + const auto seq_len = params.seq_lens[batch_id]; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kHeadDim * 3); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2); + + /// NOTE: kv_score_buffer layout is [max, sum, kv] (slot 0 / 1 / 2). Reads, + /// writes, and the prefill kernel must all agree on this order. + const auto max_score_vec = gmem.load(kv_buf, 0); + const auto sum_score_vec = gmem.load(kv_buf, 1); + const auto old_kv_vec = gmem.load(kv_buf, 2); + + /// NOTE: kv_score_input layout is | kv | score | (head_dim each), matching + /// the offline c128 kernel and the online prefill kernel. + const auto new_kv_vec = gmem.load(kv_src, 0); + const auto new_score_raw_vec = gmem.load(kv_src, 1); + + /// NOTE: the new token sits at global position `seq_len - 1`, so its + /// position inside the 128-chunk is `(seq_len - 1) % 128`. The previous + /// `seq_len % 128` was off by one (`bias[127]` vs `bias[0]`, etc.). + const auto pos_in_chunk = (seq_len - 1) % 128; + const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk); + + Vec out_kv_vec; + Vec out_max_vec; + Vec out_sum_vec; + if (pos_in_chunk != 0) { + // Mid-chunk: combine prior partial state with the new token via online softmax. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto old_max = max_score_vec[i]; + const auto old_kv = old_kv_vec[i]; + const auto new_score = new_score_raw_vec[i] + bias_vec[i]; + const auto new_kv = new_kv_vec[i]; + const auto new_max = fmax(old_max, new_score); + const auto old_sum = sum_score_vec[i] * expf(old_max - new_max); + const auto new_exp = expf(new_score - new_max); + const auto new_sum = old_sum + new_exp; + out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum; + out_max_vec[i] = new_max; + out_sum_vec[i] = new_sum; + } + } else { + // First token of a new 128-chunk: initialize state with this token alone. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + out_kv_vec[i] = new_kv_vec[i]; + out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i]; + out_sum_vec[i] = 1.0f; // exp(score - max) with max == score + } + } + + if (pos_in_chunk == 127) { + // Chunk just closed: emit the compressed kv. No need to update the buffer + // -- the next chunk's first token will overwrite it. + const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim; + gmem.store(kv_out, out_kv_vec); + } else { + // Otherwise persist the running [max, sum, kv] state for the next step. + gmem.store(kv_buf, out_max_vec, 0); + gmem.store(kv_buf, out_sum_vec, 1); + gmem.store(kv_buf, out_kv_vec, 2); + } +} + +constexpr int32_t kTileElements = 2; // split (along head-dim) +/// \brief Each warp will handle this many elements (split along softmax-128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps; +using PrefillStorage = device::AlignedVector; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_prefill_forward( + const PrefillStorage (&kv)[kElementsPerWarp], + const PrefillStorage (&score)[kElementsPerWarp], + float* kv_out, + float* max_out, + float* sum_out, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = score[j][i]; + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = global_product; + if constexpr (kNeedData) { + max_out[local_elem_id] = global_val_max; + sum_out[local_elem_id] = global_exp_sum; + } + } + if constexpr (kNeedData) __syncthreads(); +} + +/// \brief Sentinel score for padded positions in a 128-segment. +/// Must be finite so that `score - max` never produces NaN even when an +/// entire warp has only padded positions. +constexpr float kPadScore = -FLT_MAX; + +/// \brief Online compress 128 prefill. Two passes share this body: +/// - `kWrite=false` (compress pass): handles segments that close a chunk. +/// May load prior partial state from the buffer, but never writes to it, +/// so concurrent blocks can read the same slot without racing. +/// - `kWrite=true` (write pass): handles the trailing partial segment of each +/// batch. Each batch contributes at most one such plan, so concurrent blocks +/// touch disjoint buffer slots. +/// +/// The two passes MUST run as separate kernel launches (in stream order) so +/// that all reads in pass 1 finish before any writes in pass 2 start. +template +__global__ __launch_bounds__(kPrefillBlockSize, 2) // + void flash_c128_online_prefill(const __grid_constant__ Compress128OnlinePrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + /// NOTE: the compiler folds the if-else at compile time. + const auto num_plans = kWrite ? params.num_write : params.num_compress; + const auto plan_ptr = kWrite ? params.write_plan : params.compress_plan; + const uint32_t global_id = blockIdx.x; + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + if (global_pid >= num_plans) return; + const auto [ragged_id, batch_id, position, window_len] = plan_ptr[global_pid]; + if (ragged_id == 0xFFFFFFFFu) [[unlikely]] + return; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + const int32_t split_offset = global_sid * kTileDim; // int32 is enough + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_compressed_output = static_cast(params.kv_compressed_output); + const auto score_bias_base = static_cast(params.score_bias); + + constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score | + const uint32_t chunk_offset = (position % 128u) + 1u - window_len; + const uint32_t window_end = chunk_offset + window_len; // exclusive, in [1, 128] + const int32_t segment_start = ragged_id - (position % 128u); // can be negative, but safe + const int32_t load_index = chunk_offset != 0 ? params.load_indices[batch_id] : -1; + const int32_t store_index = kWrite ? params.indices[batch_id] : -1; + + PDLWaitPrimary(); + + // 2 * 8 = 16 register per elem. in theory we should consume 48 register here + PrefillStorage kv[kElementsPerWarp]; + PrefillStorage score[kElementsPerWarp]; + PrefillStorage bias[kElementsPerWarp]; + const auto warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + if (j >= chunk_offset && j < window_end) { + const auto kv_src_ptr = kv_score_input + (segment_start + j) * kElementSize + split_offset; + const auto score_src_ptr = kv_src_ptr + kHeadDim; + const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset; + kv[i].load(kv_src_ptr, lane_id); + score[i].load(score_src_ptr, lane_id); + bias[i].load(bias_src_ptr, lane_id); + } + } + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + const bool is_valid = (j >= chunk_offset && j < window_end); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore; + /// NOTE: must zero out kv on padded slots -- `c128_prefill_forward` + /// computes `kv * exp_score` where `exp_score = expf(-FLT_MAX - max) ??? 0`, + /// and IEEE-754 makes `NaN * 0 = NaN` / `+-inf * 0 = NaN`. An + /// uninitialized register can hold a NaN/inf bit pattern, so without + /// this reset a single padded warp can poison the whole softmax. + kv[i][ii] = is_valid ? kv[i][ii] : 0.0f; + } + } + + __shared__ alignas(16) float seg_kv[kTileDim]; + __shared__ alignas(16) float seg_max[kTileDim]; + __shared__ alignas(16) float seg_sum[kTileDim]; + + c128_prefill_forward(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id); + + PDLTriggerSecondary(); + + if (warp_id == 0) { + PrefillStorage out_kv_vec, out_max_vec, out_sum_vec; + out_kv_vec.load(seg_kv, lane_id); + out_max_vec.load(seg_max, lane_id); + out_sum_vec.load(seg_sum, lane_id); + if (chunk_offset != 0) { + /// NOTE: load (max, sum, kv) of the in-progress chunk for this index. + /// `load_indices` may differ from `indices` when the prior partial state + /// lives on a different slot than the slot we ultimately write to. + const auto buf_load = kv_score_buffer + load_index * (kHeadDim * 3) + split_offset; + PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec; + buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id); + buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id); + buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + const float m1 = buf_max_vec[ii]; + const float s1 = buf_sum_vec[ii]; + const float k1 = buf_kv_vec[ii]; + const float m2 = out_max_vec[ii]; + const float s2 = out_sum_vec[ii]; + const float k2 = out_kv_vec[ii]; + const float new_max = fmaxf(m1, m2); + const float new_s1 = s1 * expf(m1 - new_max); + const float new_s2 = s2 * expf(m2 - new_max); + const float new_sum = new_s1 + new_s2; + const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum; + out_max_vec[ii] = new_max; + out_sum_vec[ii] = new_sum; + out_kv_vec[ii] = new_kv; + } + } + + if constexpr (kWrite) { + const auto buf_store = kv_score_buffer + store_index * (kHeadDim * 3) + split_offset; + reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec; + reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec; + reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec; + } else { + const auto out_ptr = kv_compressed_output + ragged_id * kHeadDim + split_offset; + reinterpret_cast(out_ptr)[lane_id] = out_kv_vec; + } + } +} + +template +struct FlashCompress128OnlineKernel { + static constexpr auto decode_kernel = flash_c128_online_decode; + template + static constexpr auto prefill_kernel = flash_c128_online_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}).with_dtype().with_device(device).verify(indices); + TensorMatcher({B}).with_dtype().with_device(device).verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128OnlineDecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + LaunchKernel(batch_size, kDecodeBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + using host::compress::kOnlinePrefillPlanDim; + using host::compress::OnlinePrefillPlanTensorDtype; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) ??? 2D + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, kOnlinePrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, kOnlinePrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + /// NOTE: `extra` is `load_indices`. When the previous partial state lives + /// on a slot different from the destination slot (e.g. paged buffers), the + /// caller must supply this; otherwise it defaults to `indices`. + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}).with_dtype().with_device(device_).verify(load_indices); + + const auto device = device_.unwrap(); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128OnlinePrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + + /// NOTE: pass 1 reads the buffer (for the first segment of each batch + /// that started mid-chunk) and writes only to `kv_compressed_output`. + /// Pass 2 then writes the trailing partial state of each batch back to + /// the buffer. Stream serialization between the two launches enforces + /// read-before-write on shared buffer slots. + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = num_w * kNumSplit) { + LaunchKernel(num_w_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace + +namespace host::compress { + +using OnlinePlanResult = tvm::ffi::Tuple; + +struct OnlinePrefillCompressParams { + OnlinePrefillPlan* __restrict__ compress_plan; + OnlinePrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; +}; + +/// \brief Build the compress + write plans for online compress 128 prefill. +/// +/// Each batch's `[prefix_len, prefix_len + extend_len)` range is split at +/// 128-aligned boundaries. Every resulting segment falls into one of: +/// - **compress**: closes a 128-chunk (`chunk_offset + window_len == 128`). +/// These plans only read the buffer (when starting mid-chunk) and write the +/// compressed kv to `kv_compressed_output`. +/// - **write**: trailing partial of the batch (`chunk_offset + window_len < 128`). +/// May read the buffer and always writes the new partial state back to it. +/// Each batch produces at most one such plan. +/// +/// The two plans MUST be dispatched as separate kernel launches in stream +/// order so that pass-1 reads of a buffer slot complete before any pass-2 +/// write of the same slot. +inline OnlinePlanResult plan_online_prefill_host(const OnlinePrefillCompressParams& params, const bool use_cuda_graph) { + const auto& [compress_plan, write_plan, seq_lens, extend_lens, batch_size, num_tokens] = params; + + uint32_t counter = 0; + uint32_t compress_count = 0; + uint32_t write_count = 0; + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = static_cast(seq_lens[i]); + const uint32_t extend_len = static_cast(extend_lens[i]); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t end_pos = prefix_len + extend_len; + /// NOTE: split the extend range into per-128-chunk segments. Each segment + /// stays inside one chunk, so the kernel can decide load/store from + /// `chunk_offset` and `window_len` alone. + uint32_t pos = prefix_len; + while (pos < end_pos) { + const uint32_t chunk_start = (pos / 128u) * 128u; + const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive + const uint32_t seg_len = seg_end - pos; + const uint32_t chunk_off = pos - chunk_start; + /// NOTE: store last-token coordinates so that downstream consumers + /// (e.g. `fused_norm_rope`) can read `ragged_id` and `position` with the + /// same semantics as `PrefillPlan`. The segment start is recoverable as + /// `ragged_id - window_len + 1` and `position - window_len + 1`. + const uint32_t last_pos = seg_end - 1; + const uint32_t last_ragged = counter + (last_pos - prefix_len); + const auto plan = OnlinePrefillPlan{ + .ragged_id = last_ragged, + .batch_id = i, + .position = last_pos, + .window_len = seg_len, + }; + if (chunk_off + seg_len == 128u) { + // full chunk, must be complete, maybe read the buffer, no write + RuntimeCheck(compress_count < num_tokens); + compress_plan[compress_count++] = plan; + } else { + // last chunk, must be incomplete, maybe read the buffer, must write + RuntimeCheck(write_count < num_tokens); + write_plan[write_count++] = plan; + } + pos = seg_end; + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return OnlinePlanResult{compress_count, write_count}; + /// NOTE: pad both plans with sentinel entries so cuda-graph runs always see + /// the same number of blocks. The kernel skips plans whose `ragged_id` is -1. + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = OnlinePrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_count, num_tokens)) { + compress_plan[i] = kInvalidPlan; + } + for (const auto i : irange(write_count, num_tokens)) { + write_plan[i] = kInvalidPlan; + } + return OnlinePlanResult{num_tokens, num_tokens}; +} + +inline OnlinePlanResult plan_online_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + /// NOTE: only host (CPU/cuda-host) planning is implemented for now. The + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kOnlinePrefillPlanDim}) // + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + const auto params = OnlinePrefillCompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + }; + return plan_online_prefill_host(params, use_cuda_graph); +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_online_prefill = host::compress::plan_online_prefill; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c128_v2.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_v2.cuh new file mode 100644 index 000000000000..9cdc3dbeb6a3 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_v2.cuh @@ -0,0 +1,540 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan128 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) + +struct Compress128DecodeParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const int32_t* __restrict__ load_indices; + /** \brief The following part is plan info. */ + + const Plan128* __restrict__ compress_plan; + const Plan128* __restrict__ write_plan; + + uint32_t num_compress; + uint32_t num_write; + + uint32_t num_q_tokens; + uint32_t batch_size; + uint32_t num_indices; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos, + const uint32_t lane_id) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 2; + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv | [1] = score | + Storage kv_score[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t window_len, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + const auto element_size = head_dim * 2; + const auto score_offset = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + const InFloat* src; + __builtin_assume(j < 128); + if (j < window_len) { + src = kv_score_buf + j * element_size; + } else { + /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = j - 127; + src = kv_score_src + k * element_size; + } + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, batch_size // decode info + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % kNumSplit; // split id + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: the write must be visible to the subsequent c128_forward, + /// so only the last warp can write to HBM + /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127` + if (warp_id == kNumWarps - 1) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id); + } + if (seq_len % 128 == 0) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, load_indices, compress_plan, write_plan, num_compress, num_write, // prefill plan + _num_q_tokens, _batch_size, _num_indices + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + uint32_t global_id; + if constexpr (kWrite) { + // for write kernel, we use global warp_id to dispatch work + global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads; + } else { + // for compress kernel, we use block id to dispatch work + global_id = blockIdx.x; // block id + } + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const auto indices_ptr = kWrite ? indices : load_indices; + + const int64_t split_offset = global_sid * kTileDim; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + if (ragged_id >= _num_q_tokens) [[unlikely]] return; + if (global_bid >= _batch_size) [[unlikely]] return; + + const int32_t index = indices_ptr[global_bid]; + + if (index < 0 || static_cast(index) >= _num_indices) [[unlikely]] return; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + PDLWaitPrimary(); + + // only responsible for the compress part + if constexpr (kWrite) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id); + } else { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + template + static constexpr auto prefill_kernel = flash_c128_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWriteBlockSize = 128; + static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto K = SymbolicSize{"num_indices"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({K, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + // might be needed for prefill write + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}) // [read_positions] + .with_dtype() + .with_device(device_) + .verify(load_indices); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto num_indices = static_cast(K.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + .num_q_tokens = num_q_tokens, + .batch_size = batch_size, + .num_indices = num_indices, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + + constexpr auto kBlockSize_C = kBlockSize; + constexpr auto kBlockSize_W = kWriteBlockSize; + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh new file mode 100644 index 000000000000..145ab1fb081e --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh @@ -0,0 +1,549 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan4 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int kTileElements = 4; + +/// \brief Need to improve register usage to reduce latency +#define C4_KERNEL __global__ __launch_bounds__(128, 4) + +enum class PageMode { + RingBuffer = 8, + Page4Align = 4, +}; + +struct alignas(16) C4IndexBundle { + int32_t load_first_page; + int32_t load_second_page; + int32_t write_first_page; + int32_t last_position; +}; + +struct Compress4DecodeParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \brief Shape: `[batch_size, 1]` */ + const int32_t* __restrict__ extra; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress4PrefillParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, 4]` */ + const C4IndexBundle* __restrict__ extra; + /** \brief The following part is plan info. */ + + const Plan4* __restrict__ compress_plan; + const Plan4* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +template +SGL_DEVICE void c4_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 4; + const auto gmem = tile::Memory::warp(); + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv overlap | [1] = kv | [2] = score overlap | [3] = score | + Storage kv_score[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c4_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t seq_len, + const int32_t window_len, + [[maybe_unused]] const InFloat* kv_score_overlap_buf = nullptr) { + using namespace device; + + const auto element_size = head_dim * 4; + const auto score_offset = head_dim * 2; + const auto overlap_stride = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory::warp(); + StorageIn kv[8]; + StorageIn score[8]; + StorageIn bias[8]; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + bias[i] = gmem_in.load(score_bias + i * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const bool is_overlap = i < 4; + const InFloat* src; + if (i < window_len) { + /// NOTE: `seq_len` must be a multiple of 4 here + if constexpr (kPaged) { + const auto kv_score_ptr = is_overlap ? kv_score_overlap_buf : kv_score_buf; + const int32_t k = i % 4; + src = kv_score_ptr + k * element_size; + } else { + const int32_t k = (seq_len + i) % 8; + src = kv_score_buf + k * element_size; + } + } else { + /// NOTE: k in [-7, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = i - 7; + src = kv_score_src + k * element_size; + } + src += (is_overlap ? 0 : overlap_stride); + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + if (seq_len == 4) { + [[unlikely]]; + constexpr float kFloatNegInf = -1e9f; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv[i].fill(cast(0.0f)); + score[i].fill(cast(kFloatNegInf)); + } + } + + /// NOTE: part 2: safe online softmax + weighted sum + using StorageOut = AlignedVector; + const auto gmem_out = tile::Memory::warp(); + StorageOut result; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[8]; + +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + result[i] = cast(sum_product / sum_exp_value); + } + + gmem_out.store(kv_out, result); +} + +template +C4_KERNEL void flash_c4_decode(const __grid_constant__ Compress4DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, extra, batch_size // decode info + ] = params; + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_bid = global_wid / kNumSplit; // batch id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + page_size - 1` + if constexpr (kMode == PageMode::Page4Align) { + const auto index_prev = extra[global_bid]; + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 3) % 4); + if (seq_len % 4 == 0) { + const auto kv_overlap = kv_buf + (index_prev - index) * (kElementSize * 4); + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, 8, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 7) % 8); + if (seq_len % 4 == 0) { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, /*window_size=*/8); + } + } + + PDLTriggerSecondary(); +} + +template +C4_KERNEL void flash_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, extra, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / kNumSplit; // plan id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + const auto seq_len = position + 1; + const int32_t index = indices[global_bid]; + + PDLWaitPrimary(); + + if constexpr (kMode == PageMode::Page4Align) { + const auto write_second_page = index; + const auto [load_first_page, load_second_page, write_first_page, last_pos] = extra[global_bid]; + if constexpr (kWrite) { + int32_t index; + if (position < static_cast(last_pos)) { + index = write_first_page; + } else { + index = write_second_page; + } + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 4); + } else { + int32_t index_overlap, index_normal; + if (window_len <= 4) { + index_overlap = load_second_page; + index_normal = load_second_page; // not used + } else { + index_overlap = load_first_page; + index_normal = load_second_page; + } + const auto kv_buf = kv_score_buffer + index_normal * (kElementSize * 4) + split_offset; + const auto kv_overlap = kv_score_buffer + index_overlap * (kElementSize * 4) + split_offset; + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + if constexpr (kWrite) { + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 8); + } else { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len); + } + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress4Kernel { + template + static constexpr auto decode_kernel = flash_c4_decode; + template + static constexpr auto prefill_kernel = flash_c4_prefill; + template + static constexpr auto prefill_c_kernel = prefill_kernel; + template + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr uint32_t kBlockSize = 128; + static constexpr uint32_t kTileDim = kTileElements * device::kWarpThreads; + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWarpsPerBlock = kBlockSize / device::kWarpThreads; + + using Self = FlashCompress4Kernel; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional extra) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device_) + .verify(seq_lens); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress4DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extra = static_cast(extra_ptr), + .batch_size = batch_size, + }; + const auto kernel = extra_ptr != nullptr ? decode_kernel // + : decode_kernel; + const uint32_t num_blocks = div_ceil(batch_size * kNumSplit, kWarpsPerBlock); + LaunchKernel(num_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra, /*is_prefill=*/true); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress4PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .extra = static_cast(extra_ptr), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + if (const auto num_c_blocks = div_ceil(num_c * kNumSplit, kWarpsPerBlock)) { + const auto c_kernel = extra_ptr != nullptr ? prefill_c_kernel // + : prefill_c_kernel; + LaunchKernel(num_c_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerBlock)) { + const auto w_kernel = extra_ptr != nullptr ? prefill_w_kernel // + : prefill_w_kernel; + LaunchKernel(num_w_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(w_kernel, params); + } + } + + // some auxiliary functions + private: + static const void* _get_extra_pointer( + host::SymbolicSize& B, // batch_size + host::SymbolicDevice& device, + const tvm::ffi::Optional& extra, + bool is_prefill = false) { + // only have value when using page-aligned mode + if (!extra.has_value()) return nullptr; + const auto& extra_tensor = extra.value(); + /// NOTE: the metadata layout is different for prefill and decode: + /// for prefill, last 4 are: + /// load overlap | load normal | write overlap | last written page + /// for decode, last 1 is the write (also load) overlap + host::TensorMatcher({B, is_prefill ? 4 : 1}) // extra tensor + .with_dtype() + .with_device(device) + .verify(extra_tensor); + const auto data_ptr = extra_tensor.data_ptr(); + host::RuntimeCheck(data_ptr != nullptr, "extra tensor data ptr is null"); + if (is_prefill) { + static_assert(alignof(C4IndexBundle) == 16); + host::RuntimeCheck(std::bit_cast(data_ptr) % 16 == 0, "extra tensor is not properly aligned"); + } + return data_ptr; + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh new file mode 100644 index 000000000000..46acaa9c46b3 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh @@ -0,0 +1,208 @@ +#include +#include + +#include + +#include + +namespace host::compress { + +using PlanResult = tvm::ffi::Tuple; + +struct CompressParams { + PrefillPlan* __restrict__ compress_plan; + PrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; + uint32_t compress_ratio; + bool is_overlap; +}; + +inline constexpr uint32_t kBlockSize = 1024; + +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) inline + +PLAN_KERNEL void plan_prefill_cuda(const __grid_constant__ CompressParams params) { + const auto &[ + compress_plan, write_plan, seq_lens, extend_lens, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + __shared__ uint32_t compress_counter; + __shared__ uint32_t write_counter; + + uint32_t batch_id = 0; + uint32_t counter = 0; + uint32_t extend_len = extend_lens[0]; + + const auto tid = threadIdx.x; + if (tid == 0) { + compress_counter = 0; + write_counter = 0; + } + __syncthreads(); + + for (uint32_t i = tid; i < num_tokens; i += blockDim.x) { + const uint32_t ragged_id = i; + uint32_t j = ragged_id - counter; + while (j >= extend_len) { + j -= extend_len; + batch_id += 1; + if (batch_id >= batch_size) [[unlikely]] + break; + counter += extend_len; + extend_len = extend_lens[batch_id]; + } + if (batch_id >= batch_size) [[unlikely]] + break; + const uint32_t seq_len = seq_lens[batch_id]; + const uint32_t extend_len = extend_lens[batch_id]; + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t ratio = compress_ratio * (1 + is_overlap); + const uint32_t window_len = j + 1 < ratio ? ratio - (j + 1) : 0; + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = ragged_id, + .batch_id = batch_id, + .position = position, + .window_len = window_len, + }; + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + if ((position + 1) % compress_ratio == 0) { + const auto write_pos = atomicAdd(&compress_counter, 1); + compress_plan[write_pos] = plan; + } + if (position >= start_write_pos) { + const auto write_pos = atomicAdd(&write_counter, 1); + write_plan[write_pos] = plan; + } + } + __syncthreads(); + constexpr auto kInvalid = static_cast(-1); + const auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + const auto compress_count = compress_counter; + const auto write_count = write_counter; + for (uint32_t i = compress_count + tid; i < num_tokens; i += blockDim.x) { + compress_plan[i] = kInvalidPlan; + } + for (uint32_t i = write_count + tid; i < num_tokens; i += blockDim.x) { + write_plan[i] = kInvalidPlan; + } +} + +inline PlanResult plan_prefill_host(const CompressParams& params, const bool use_cuda_graph) { + const auto &[ + compress_ptr, write_ptr, seq_lens_ptr, extend_lens_ptr, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + uint32_t counter = 0; + uint32_t compress_counter = 0; + uint32_t write_counter = 0; + const auto ratio = compress_ratio * (1 + is_overlap); + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = seq_lens_ptr[i]; + const uint32_t extend_len = extend_lens_ptr[i]; + const uint32_t prefix_len = seq_len - extend_len; + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + /// NOTE: `start_write_pos` must be a multiple of `compress_ratio` + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + /// NOTE: to avoid unsigned integer underflow, don't use `pos - compress_ratio` + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + /// NOTE: `position` is within [prefix_len, seq_len) + for (const auto j : irange(extend_len)) { + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = counter + j, + .batch_id = i, + .position = position, + .window_len = ratio - std::min(j + 1, ratio), + }; + RuntimeCheck(plan.is_valid(compress_ratio, is_overlap), "Internal error!"); + if ((position + 1) % compress_ratio == 0) { + compress_ptr[compress_counter++] = plan; + } + if (position >= start_write_pos) { + write_ptr[write_counter++] = plan; + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return PlanResult{compress_counter, write_counter}; + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_counter, num_tokens)) { + compress_ptr[i] = kInvalidPlan; + } + for (const auto i : irange(write_counter, num_tokens)) { + write_ptr[i] = kInvalidPlan; + } + return PlanResult{num_tokens, num_tokens}; +} + +inline PlanResult plan_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const uint32_t compress_ratio, + const bool is_overlap, // for overlap transform, we have to keep 1 more extra window + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + const bool is_cuda = [&] { + if (extend_lens.device().device_type == kDLCUDA) { + device.set_options(); + return true; + } else { + device.set_options(); + return false; + } + }(); + TensorMatcher({N}) // extend_lens and seq_lens + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kPrefillPlanDim}) // compress_plan and write_plan + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + + const auto params = CompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + .compress_ratio = compress_ratio, + .is_overlap = is_overlap, + }; + + if (!is_cuda) return plan_prefill_host(params, use_cuda_graph); + /// NOTE: cuda kernel plan is naturally compatible with cuda graph + LaunchKernel(1, kBlockSize, device.unwrap())(plan_prefill_cuda, params); + return PlanResult{params.num_tokens, params.num_tokens}; +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_prefill = host::compress::plan_prefill; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh new file mode 100644 index 000000000000..d3953578b925 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh @@ -0,0 +1,254 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace { + +using Plan = device::compress::PrefillPlan; + +/// \brief common block size for memory-bound kernel +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedNormRopeParams { + void* __restrict__ input; + const void* __restrict__ weight; + float eps; + uint32_t num_works; + const void* __restrict__ handle; + const float* __restrict__ freqs_cis; + uint32_t compress_ratio; +}; + +enum class ForwardMode { + CompressExtend = 0, + CompressDecode = 1, + DefaultForward = 2, +}; + +template +__global__ void fused_norm_rope(const __grid_constant__ FusedNormRopeParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kMaxVecSize = 16 / sizeof(DType); + constexpr int64_t kVecSize = std::min(kMaxVecSize, kHeadDim / kWarpThreads); + constexpr int64_t kLocalSize = kHeadDim / (kWarpThreads * kVecSize); + constexpr int64_t kRopeVecSize = kRopeDim / (kWarpThreads * 2); + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + static_assert(kLocalSize * kVecSize * kWarpThreads == kHeadDim); + static_assert(kRopeDim % (kWarpThreads * 2) == 0); + static_assert(kRopeDim % (kVecSize * kLocalSize) == 0); + static_assert(kRopeSize <= kWarpThreads); + static_assert(kRopeVecSize == 1, "only support rope dim = 64"); + + const auto& [ + _input, _weight, eps, num_works, // norm + handle, freqs_cis, compress_ratio // rope + ] = params; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + + if (work_id >= num_works) return; + + DType* input; + int32_t position; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(handle)[work_id]; + input = static_cast(_input) + plan.ragged_id * kHeadDim; + position = plan.position + 1 - compress_ratio; + if (plan.ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + } else if constexpr (kMode == CompressDecode) { + input = static_cast(_input) + work_id * kHeadDim; + const auto seq_len = static_cast(handle)[work_id]; + if (seq_len % compress_ratio != 0) return; + position = seq_len - compress_ratio; + } else if constexpr (kMode == DefaultForward) { + input = static_cast(_input) + work_id * kHeadDim; + position = static_cast(handle)[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + + using Storage = AlignedVector; + __shared__ Storage s_rope_input[kNumWarps][kRopeSize]; + + // prefetch freq + const auto mem_freq = tile::Memory::warp(); + const auto freq = mem_freq.load(freqs_cis + position * kRopeDim); + + PDLWaitPrimary(); + + // part 1: norm + { + const auto gmem = tile::Memory::warp(); + Storage input_vec[kLocalSize]; + Storage weight_vec[kLocalSize]; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + input_vec[i] = gmem.load(input, i); + } + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + weight_vec[i] = gmem.load(_weight, i); + } + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + sum_of_squares += fp32_input * fp32_input; + } + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + eps); + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + const auto fp32_weight = cast(weight_vec[i][j]); + input_vec[i][j] = cast(fp32_input * norm_factor * fp32_weight); + } + } + + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + if (i == kLocalSize - 1 && is_rope_lane) { + const auto rope_id = lane_id - (kWarpThreads - kRopeSize); + s_rope_input[warp_id][rope_id] = input_vec[i]; + } else { + gmem.store(input, input_vec[i], i); + } + } + + __syncwarp(); + } + + // part 2: rope + { + // mem elem = DType x 2 + using DTypex2_t = packed_t; + const auto mem_elem = tile::Memory::warp(); + const auto elem = mem_elem.load(s_rope_input[warp_id]); + const auto [x_real, x_imag] = cast(elem); + const auto [freq_real, freq_imag] = freq; + const fp32x2_t output = { + x_real * freq_real - x_imag * freq_imag, + x_real * freq_imag + x_imag * freq_real, + }; + mem_elem.store(input + (kHeadDim - kRopeDim), cast(output)); + } + + PDLTriggerSecondary(); +} + +template +struct FusedNormRopeKernel { + template + static constexpr auto fused_kernel = fused_norm_rope; + + static void forward( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView handle, + const tvm::ffi::TensorView freqs_cis, + int32_t _mode, + float eps, + uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + const auto mode = static_cast(_mode); + + auto B = SymbolicSize{"num_q_tokens"}; + auto N = SymbolicSize{"num_compress_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, kHeadDim}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({kHeadDim}) // weight + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({-1, kRopeDim}) // freqs_cis + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + switch (mode) { + case CompressExtend: + TensorMatcher({N, compress::kPrefillPlanDim}) // plan + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case CompressDecode: + TensorMatcher({N}) // seq_len + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case DefaultForward: + TensorMatcher({N}) // position + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio == 0); + break; + default: + Panic("unsupported forward mode: ", static_cast(mode)); + } + + // launch kernel + const auto num_compress_tokens = static_cast(N.unwrap()); + if (num_compress_tokens == 0) return; + const auto params = FusedNormRopeParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .eps = eps, + .num_works = num_compress_tokens, + .handle = handle.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .compress_ratio = compress_ratio, + }; + const auto num_blocks = div_ceil(num_compress_tokens, kNumWarps); + using KernelType = std::decay_t)>; + static constexpr KernelType kernel_table[3] = { + [static_cast(CompressExtend)] = fused_kernel, + [static_cast(CompressDecode)] = fused_kernel, + [static_cast(DefaultForward)] = fused_kernel, + }; + const auto kernel = kernel_table[static_cast(mode)]; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh new file mode 100644 index 000000000000..8c65954d1356 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh @@ -0,0 +1,213 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +[[maybe_unused]] +SGL_DEVICE float act_sqrt_softplus(float x) { + const float softplus = fmaxf(x, 0.0f) + log1pf(expf(-fabsf(x))); + return sqrtf(softplus); +} + +struct MoEHashTopKParams { + const float* __restrict__ router_logits; + const int64_t* __restrict__ input_id; + const int32_t* __restrict__ tid2eid; + int32_t* __restrict__ topk_ids; + float* __restrict__ topk_weights; + uint32_t num_tokens; + uint32_t topk; + uint32_t num_routed_experts; + uint32_t num_shared_experts; + float routed_scaling_factor; +}; + +template +__global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ params) { + using namespace device; + const auto& [ + router_logits, input_id, tid2eid, topk_ids, topk_weights, // pointers + num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor] = + params; + + const uint32_t topk_fused = topk + num_shared_experts; + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / kWarpThreads; + const uint32_t lane_id = tid % kWarpThreads; + if (warp_id >= num_tokens) return; + // we can safely prefetch the token id + const auto token_id = input_id[warp_id]; + + PDLWaitPrimary(); + + float routed_weight = 0.0f; + int32_t expert_id = 0; + if (lane_id < topk) { + expert_id = tid2eid[token_id * topk + lane_id]; + routed_weight = Fn(router_logits[warp_id * num_routed_experts + expert_id]); + } + + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk_fused) { + const bool is_shared = lane_id >= topk; + const auto output_offset = warp_id * topk_fused + lane_id; + topk_ids[output_offset] = is_shared ? num_routed_experts + lane_id - topk : expert_id; + topk_weights[output_offset] = is_shared ? 1.0f / routed_scaling_factor : routed_weight / routed_sum; + } + + PDLTriggerSecondary(); +} + +struct TopKParams { + int32_t* __restrict__ topk_ids; + // Exactly one is active: ntn_ptr == nullptr means use ntn_value. + const int32_t* __restrict__ ntn_ptr; + int32_t ntn_value; + int64_t stride; + uint32_t topk; + uint32_t num_tokens; +}; + +__global__ void mask_topk_ids_padded_region(const TopKParams __grid_constant__ params) { + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / device::kWarpThreads; + const uint32_t lane_id = tid % device::kWarpThreads; + if (warp_id >= params.num_tokens || lane_id >= params.topk) return; + device::PDLWaitPrimary(); + const uint32_t num = (params.ntn_ptr != nullptr) // + ? static_cast(params.ntn_ptr[0]) + : static_cast(params.ntn_value); + if (warp_id >= num) params.topk_ids[warp_id * params.stride + lane_id] = -1; + device::PDLTriggerSecondary(); +} + +template +struct HashTopKKernel { + static constexpr auto kernel = moe_hash_topk_fused; + + static void + run(const tvm::ffi::TensorView router_logits, + const tvm::ffi::TensorView input_id, + const tvm::ffi::TensorView tid2eid, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView topk_ids, + float routed_scaling_factor) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto E = SymbolicSize{"num_routed_experts"}; + auto K = SymbolicSize{"topk_fused"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, E}) // + .with_dtype() + .with_device(device) + .verify(router_logits); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(input_id); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device) + .verify(tid2eid); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_ids); + + const auto num_tokens = static_cast(N.unwrap()); + const auto topk_fused = static_cast(K.unwrap()); + const auto topk = static_cast(tid2eid.size(1)); + const auto shared_experts = topk_fused - topk; + RuntimeCheck(topk <= topk_fused, "HashTopKKernel requires topk <= topk_fused"); + RuntimeCheck(topk_fused <= device::kWarpThreads, "HashTopKKernel requires topk_fused <= warp size"); + + const auto params = MoEHashTopKParams{ + .router_logits = static_cast(router_logits.data_ptr()), + .input_id = static_cast(input_id.data_ptr()), + .tid2eid = static_cast(tid2eid.data_ptr()), + .topk_ids = static_cast(topk_ids.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .num_tokens = num_tokens, + .topk = topk, + .num_routed_experts = static_cast(E.unwrap()), + .num_shared_experts = shared_experts, + .routed_scaling_factor = routed_scaling_factor, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +struct MaskKernel { + static constexpr auto kernel = mask_topk_ids_padded_region; + + static void run(tvm::ffi::TensorView topk_ids, tvm::ffi::TensorView num_token_non_padded) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto K = SymbolicSize{"topk"}; + auto D = SymbolicSize{"stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N, K}) // + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(topk_ids); + RuntimeCheck(num_token_non_padded.numel() == 1, "num_token_non_padded should be a scalar"); + RuntimeCheck(K.unwrap() <= device::kWarpThreads, "MaskKernel requires topk <= warp size"); + const int32_t* ntn_ptr = nullptr; + int32_t ntn_value = 0; + const auto ntn_dev = num_token_non_padded.device().device_type; + if (ntn_dev == kDLCUDA) { + RuntimeCheck(is_type(num_token_non_padded.dtype()), "num_token_non_padded on CUDA must be int32"); + ntn_ptr = static_cast(num_token_non_padded.data_ptr()); + } else if (ntn_dev == kDLCPU) { + if (is_type(num_token_non_padded.dtype())) { + ntn_value = *static_cast(num_token_non_padded.data_ptr()); + } else if (is_type(num_token_non_padded.dtype())) { + ntn_value = static_cast(*static_cast(num_token_non_padded.data_ptr())); + } else { + RuntimeCheck(false, "num_token_non_padded on CPU must be int32 or int64"); + } + } else { + RuntimeCheck(false, "num_token_non_padded must be on CPU or CUDA"); + } + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = TopKParams{ + .topk_ids = static_cast(topk_ids.data_ptr()), + .ntn_ptr = ntn_ptr, + .ntn_value = ntn_value, + .stride = static_cast(D.unwrap()), + .topk = static_cast(K.unwrap()), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(true)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh new file mode 100644 index 000000000000..aefec24372a8 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh @@ -0,0 +1,82 @@ +#include +#include + +#include + +#include + +#include +#include + +#include + +namespace { + +/// NOTE: for offload to cpu kernel, we use persistent kernel +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kBlockQuota = 4; + +#define OFFLOAD_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct OffloadParams { + void** gpu_caches; + void** cpu_caches; + const int64_t* gpu_indices; + const int64_t* cpu_indices; + uint32_t num_items; + uint32_t num_layers; +}; + +OFFLOAD_KERNEL void offload_to_cpu(const __grid_constant__ OffloadParams params) { + using namespace device::hisparse; + const auto [gpu_caches, cpu_caches, gpu_indices, cpu_indices, num_items, num_layers] = params; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + constexpr auto kNumWarps = (kBlockSize / 32) * kBlockQuota; + for (auto i = global_tid / 32; i < num_items; i += kNumWarps) { + const int32_t gpu_index = gpu_indices[i]; + const int32_t cpu_index = cpu_indices[i]; + for (auto j = 0u; j < num_layers; ++j) { + const auto gpu_cache = gpu_caches[j]; + const auto cpu_cache = cpu_caches[j]; + transfer_item( + /*dst_cache=*/cpu_cache, + /*src_cache=*/gpu_cache, + /*dst_index=*/cpu_index, + /*src_index=*/gpu_index); + } + } +} + +[[maybe_unused]] +void hisparse_transfer( + tvm::ffi::TensorView gpu_ptrs, + tvm::ffi::TensorView cpu_ptrs, + tvm::ffi::TensorView gpu_indices, + tvm::ffi::TensorView cpu_indices) { + using namespace host; + auto N = SymbolicSize{"num_items"}; + auto L = SymbolicSize{"num_layers"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({L}) // 1D cache pointers + .with_dtype() + .with_device(device_) + .verify(gpu_ptrs) + .verify(cpu_ptrs); + TensorMatcher({N}) // 1D indices + .with_dtype() + .with_device(device_) + .verify(gpu_indices) + .verify(cpu_indices); + const auto params = OffloadParams{ + .gpu_caches = static_cast(gpu_ptrs.data_ptr()), + .cpu_caches = static_cast(cpu_ptrs.data_ptr()), + .gpu_indices = static_cast(gpu_indices.data_ptr()), + .cpu_indices = static_cast(cpu_indices.data_ptr()), + .num_items = static_cast(N.unwrap()), + .num_layers = static_cast(L.unwrap()), + }; + LaunchKernel(kBlockQuota, kBlockSize, device_.unwrap())(offload_to_cpu, params); +} + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh new file mode 100644 index 000000000000..a35d3dbf1092 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh @@ -0,0 +1,219 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct MegaMoEPreDispatchParams { + const bf16_t* __restrict__ x; // [num_tokens, hidden] + const int32_t* __restrict__ topk_idx; // [num_tokens, top_k] + const float* __restrict__ topk_weights; // [num_tokens, top_k] + + fp8_e4m3_t* __restrict__ buf_x; // [padded_max, hidden] + int32_t* __restrict__ buf_x_sf; // contiguous int32 [P, G/4]; see layout comment + int64_t* __restrict__ buf_topk_idx; // [padded_max, top_k] + float* __restrict__ buf_topk_weights; // [padded_max, top_k] + + uint32_t num_tokens; + uint32_t padded_max; + uint32_t hidden; + uint32_t num_groups; // hidden / group_size + uint32_t top_k; +}; + +// kGroupSize must match sglang_per_token_group_quant_fp8_ue8m0(group_size=). +template +__global__ __launch_bounds__(1024, 2) void // + mega_moe_pre_dispatch_kernel(const MegaMoEPreDispatchParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kVecElems = 8; // 8 bf16 = 16B load per thread + static_assert(kGroupSize % kVecElems == 0, "group_size must be a multiple of 8"); + constexpr uint32_t kThreadsPerGroup = kGroupSize / kVecElems; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + + PDLWaitPrimary(); + if (bid < params.num_tokens) { + // ---- Quantize path: one CTA per valid token ---- + + const uint32_t token_id = bid; + const auto token_in = params.x + static_cast(token_id) * params.hidden; + const auto token_out = params.buf_x + static_cast(token_id) * params.hidden; + + InputVec in_vec; + in_vec.load(token_in, tid); + + float local_max = 0.0f; + float vals[kVecElems]; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + const auto [v0, v1] = cast(in_vec[i]); + vals[2 * i + 0] = v0; + vals[2 * i + 1] = v1; + local_max = fmaxf(local_max, fmaxf(fabsf(v0), fabsf(v1))); + } + + // Absmax across the kThreadsPerGroup threads that cover one group. + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + const float raw_scale = absmax / math::FP8_E4M3_MAX; + const uint32_t ue8m0_exp = cast_to_ue8m0(raw_scale); + // 2^-ue8m0_exp as fp32 (equivalent to 1 / __uint_as_float(ue8m0 << 23)). + const float inv_scale = __uint_as_float((127u + 127u - ue8m0_exp) << 23); + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + out_vec[i] = pack_fp8(vals[2 * i + 0] * inv_scale, vals[2 * i + 1] * inv_scale); + } + out_vec.store(token_out, tid); + + // One thread per group writes its UE8M0 byte into the contiguous + // row-major int32-packed layout: byte address = t*num_groups + g + // (see layout comment at the top of the file). + const uint32_t group_id = tid / kThreadsPerGroup; + const uint32_t within_group_id = tid % kThreadsPerGroup; + if (within_group_id == 0 && group_id < params.num_groups) { + const uint32_t byte_off = token_id * params.num_groups + group_id; + reinterpret_cast(params.buf_x_sf)[byte_off] = static_cast(ue8m0_exp); + } + + // Copy this token's topk row (no alignment assumptions; top_k is small). + if (tid < params.top_k) { + const uint32_t off = token_id * params.top_k + tid; + params.buf_topk_idx[off] = params.topk_idx[off]; + params.buf_topk_weights[off] = params.topk_weights[off]; + } + } else { + // ---- Pad path: trailing blocks fill [num_tokens, padded_max) with (-1, 0) ---- + const uint32_t copy_bid = bid - params.num_tokens; + const uint32_t pad_base = params.num_tokens * params.top_k; + const uint32_t slot = pad_base + copy_bid * blockDim.x + tid; + const uint32_t total_slots = params.padded_max * params.top_k; + + if (slot < total_slots) { + params.buf_topk_idx[slot] = -1; + params.buf_topk_weights[slot] = 0.0f; + } + } + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct MegaMoEPreDispatchKernel { + static_assert(kGroupSize == 32 || kGroupSize == 64 || kGroupSize == 128, "unsupported group_size"); + static constexpr auto kernel = mega_moe_pre_dispatch_kernel(kGroupSize), kUsePDL>; + + static void + run(const tvm::ffi::TensorView x, + const tvm::ffi::TensorView topk_idx, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView buf_x, + const tvm::ffi::TensorView buf_x_sf, + const tvm::ffi::TensorView buf_topk_idx, + const tvm::ffi::TensorView buf_topk_weights) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto P = SymbolicSize{"padded_max"}; + auto H = SymbolicSize{"hidden"}; + auto K = SymbolicSize{"top_k"}; + auto G4 = SymbolicSize{"num_groups_div_4"}; + device.set_options(); + + TensorMatcher({M, H}) // input x + .with_dtype() + .with_device(device) + .verify(x); + TensorMatcher({M, K}) // topk_idx + .with_dtype() + .with_device(device) + .verify(topk_idx); + TensorMatcher({M, K}) // topk_weights + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({P, H}) // buf.x + .with_dtype() + .with_device(device) + .verify(buf_x); + // buf.x_sf is the contiguous row-major int32 view from DeepGEMM's mega + // symm buffer (DeepGEMM/csrc/apis/mega.hpp): shape (P, G/4), strides + // (G/4, 1). No explicit strides required -> TensorMatcher enforces + // is_contiguous(). + TensorMatcher({P, G4}) // buf_x_sf + .with_dtype() + .with_device(device) + .verify(buf_x_sf); + TensorMatcher({P, K}) // buf.topk_idx + .with_dtype() + .with_device(device) + .verify(buf_topk_idx); + TensorMatcher({P, K}) // buf.topk_weights + .with_dtype() + .with_device(device) + .verify(buf_topk_weights); + + const auto num_tokens = static_cast(M.unwrap()); + const auto padded_max = static_cast(P.unwrap()); + const auto hidden = static_cast(H.unwrap()); + const auto top_k = static_cast(K.unwrap()); + const auto num_groups_div_4 = static_cast(G4.unwrap()); + + RuntimeCheck(num_tokens <= padded_max, "num_tokens must not exceed padded_max"); + RuntimeCheck(hidden % kGroupSize == 0, "hidden must be a multiple of group_size"); + const auto num_groups = hidden / static_cast(kGroupSize); + RuntimeCheck(num_groups == num_groups_div_4 * 4u, "num_groups must be a multiple of 4"); + RuntimeCheck(hidden % 8u == 0, "hidden must be a multiple of 8 (16B bf16 loads)"); + const auto num_threads = hidden / 8u; + RuntimeCheck(num_threads <= 1024, "hidden too large for single-block-per-row quant"); + RuntimeCheck(num_threads >= top_k, "top_k must fit into one quant CTA"); + + const auto pad_slots = (padded_max - num_tokens) * top_k; + const uint32_t num_pad_blocks = pad_slots == 0 ? 0u : ((pad_slots + num_threads - 1u) / num_threads); + const auto num_total_blocks = num_tokens + num_pad_blocks; + + const auto params = MegaMoEPreDispatchParams{ + .x = static_cast(x.data_ptr()), + .topk_idx = static_cast(topk_idx.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .buf_x = static_cast(buf_x.data_ptr()), + .buf_x_sf = static_cast(buf_x_sf.data_ptr()), + .buf_topk_idx = static_cast(buf_topk_idx.data_ptr()), + .buf_topk_weights = static_cast(buf_topk_weights.data_ptr()), + .num_tokens = num_tokens, + .padded_max = padded_max, + .hidden = hidden, + .num_groups = num_groups, + .top_k = top_k, + }; + + if (num_total_blocks == 0) return; + LaunchKernel(num_total_blocks, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh new file mode 100644 index 000000000000..38be97555853 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh @@ -0,0 +1,119 @@ +#include +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kBlockSize = 1024; +constexpr uint32_t kSplitKV = 256; // const for both SM90 and SM100 + +struct MetadataParams { + /// NOTE: batch_size > 0 + uint32_t batch_size; + uint32_t num_sm; + const uint32_t* __restrict__ context_lens; + uint32_t* __restrict__ schedule_metadata; + bool use_smem = true; +}; + +__global__ __launch_bounds__(kBlockSize, 1) // + void smxx_paged_mqa_logits_metadata(const MetadataParams params) { + using namespace device; + extern __shared__ uint32_t s_length[]; + static constexpr auto kNumWarps = kBlockSize / kWarpThreads; + static_assert(kNumWarps == kWarpThreads); + + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + __shared__ uint32_t s_warp_sum[kNumWarps]; + + uint32_t local_sum = 0; + for (uint32_t i = tx; i < params.batch_size; i += kBlockSize) { + const auto length = params.context_lens[i]; + local_sum += (length + kSplitKV - 1) / kSplitKV; + if (params.use_smem) s_length[i] = length; + } + + s_warp_sum[warp_id] = warp::reduce_sum(local_sum); + __syncthreads(); + + const auto global_sum = warp::reduce_sum(s_warp_sum[lane_id]); + if (lane_id != 0) return; + + const auto length_ptr = params.use_smem ? s_length : params.context_lens; + + const auto avg = global_sum / params.num_sm; + const auto ret = global_sum % params.num_sm; + uint32_t q = 0; + uint32_t num_work = (length_ptr[0] + kSplitKV - 1) / kSplitKV; + uint32_t sum_work = num_work; + for (auto i = warp_id; i <= params.num_sm; i += kNumWarps) { + const auto target = i * avg + min(i, ret); + while (sum_work <= target) { + if (++q >= params.batch_size) break; + num_work = (length_ptr[q] + kSplitKV - 1) / kSplitKV; + sum_work += num_work; + } + if (q >= params.batch_size) { + params.schedule_metadata[2 * i + 0] = params.batch_size; + params.schedule_metadata[2 * i + 1] = 0; + } else { + // sum > target && (sum - length) <= target + params.schedule_metadata[2 * i + 0] = q; + params.schedule_metadata[2 * i + 1] = target - (sum_work - num_work); + } + } +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +struct IndexerMetadataKernel { + static constexpr auto kMaxBatchSizeInSmem = 16384 * 2; // 128 KB smeme + static void run(tvm::ffi::TensorView seq_lens, tvm::ffi::TensorView metadata) { + using namespace host; + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_sm"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({M, 2}) // + .with_dtype() + .with_device(device) + .verify(metadata); + const auto batch_size = static_cast(N.unwrap()); + const auto num_sm = static_cast(M.unwrap()) - 1; + RuntimeCheck(num_sm <= 1024); + const auto use_smem = batch_size <= kMaxBatchSizeInSmem; + const auto params = MetadataParams{ + .batch_size = batch_size, + .num_sm = num_sm, + .context_lens = static_cast(seq_lens.data_ptr()), + .schedule_metadata = static_cast(metadata.data_ptr()), + .use_smem = use_smem, + }; + constexpr auto kernel = smxx_paged_mqa_logits_metadata; + setup_kernel_smem_once(); + const auto smem = use_smem ? (batch_size + 1) * sizeof(uint32_t) : 0; + LaunchKernel(1, kBlockSize, device.unwrap(), smem)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/rmsnorm.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/rmsnorm.cuh new file mode 100644 index 000000000000..f9407ec84db0 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/rmsnorm.cuh @@ -0,0 +1,133 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct RMSNormSelfParams { + const void* __restrict__ input; + void* __restrict__ output; + int64_t stride_batch_bytes; + int64_t stride_head_bytes; + uint32_t batch_size; + uint32_t num_head; + float eps; +}; + +template +__global__ __launch_bounds__(kBlockSize, 20) // + void rmsnorm_self(const __grid_constant__ RMSNormSelfParams params) { + using namespace device; + constexpr int64_t kVecSize = 16 / sizeof(DType); + constexpr uint32_t kNumLoop = kHeadDim / (kVecSize * kWarpThreads); + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + using DType2 = packed_t; + using Vec = AlignedVector; + + const auto warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads; + const auto batch_id = warp_id / params.num_head; + const auto head_id = warp_id % params.num_head; + const auto gmem = tile::Memory::warp(); + if (batch_id >= params.batch_size) return; + const auto input_ptr = pointer::offset( // + params.input, + batch_id * params.stride_batch_bytes, + head_id * params.stride_head_bytes); + // use contiguous layout + const auto output_ptr = pointer::offset( // + params.output, + warp_id * kHeadDim * sizeof(DType)); + PDLWaitPrimary(); // wait for primary kernel + + Vec inputs[kNumLoop]; +#pragma unroll + for (uint32_t i = 0; i < kNumLoop; ++i) { + inputs[i] = gmem.load(input_ptr, i); + } + + // compute sum of squares + float local_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < kNumLoop; ++i) { +#pragma unroll + for (uint32_t j = 0; j < kVecSize / 2; ++j) { + const auto [x, y] = cast(inputs[i][j]); + local_sum += x * x + y * y; + } + } + + const auto sum_of_squares = warp::reduce_sum(local_sum); + const auto factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + + // weight must be identity (null, not used) +#pragma unroll + for (uint32_t i = 0; i < kNumLoop; ++i) { +#pragma unroll + for (uint32_t j = 0; j < kVecSize / 2; ++j) { + const auto [x, y] = cast(inputs[i][j]); + inputs[i][j] = cast(fp32x2_t{x * factor, y * factor}); + } + gmem.store(output_ptr, inputs[i], i); + } + + PDLTriggerSecondary(); // launch secondary kernel +} + +template +struct RMSNormKernel { + static constexpr auto kernel_self = rmsnorm_self; + + static void run_self(tvm::ffi::TensorView input, tvm::ffi::TensorView output, float eps) { + using namespace host; + + auto N = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_heads"}; + auto Dn = SymbolicSize{"stride_head"}; + auto Dh = SymbolicSize{"stride_batch"}; + constexpr auto D = kHeadDim; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, H, D}) // input + .with_strides({Dh, Dn, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({N, H, D}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(output); + + const auto batch_size = static_cast(N.unwrap()); + const auto num_head = static_cast(H.unwrap()); + const auto stride_head_bytes = static_cast(Dn.unwrap() * sizeof(DType)); + const auto stride_batch_bytes = static_cast(Dh.unwrap() * sizeof(DType)); + const auto params = RMSNormSelfParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .stride_batch_bytes = stride_batch_bytes, + .stride_head_bytes = stride_head_bytes, + .batch_size = batch_size, + .num_head = num_head, + .eps = eps, + }; + if (batch_size == 0 || num_head == 0) return; + const auto needed_warps = batch_size * num_head; + const auto num_blocks = div_ceil(needed_warps, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel_self, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh new file mode 100644 index 000000000000..2239d3972d64 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh @@ -0,0 +1,169 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +namespace { + +using DType = bf16_t; +constexpr int64_t kRopeDim = 64; +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedQKRopeParams { + void* __restrict__ q; + void* __restrict__ k; + const float* __restrict__ freqs_cis; + const void* __restrict__ positions; + int64_t q_stride_batch; + int64_t k_stride_batch; + int64_t q_stride_head; + int64_t k_stride_head; + uint32_t num_q_heads; + uint32_t num_k_heads; + uint32_t batch_size; +}; + +template +__global__ __launch_bounds__(kBlockSize, 16) // + void deepseek_rope_kernel(const __grid_constant__ FusedQKRopeParams param) { + using namespace device; + using DType2 = packed_t; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto global_warp_id = blockIdx.x * kNumWarps + warp_id; + + const auto& [ + q, k, freqs_cis, positions, // + q_stride_batch, k_stride_batch, q_stride_head, k_stride_head, // + num_q_heads, num_k_heads, batch_size + ] = param; + + const auto num_total_heads = num_q_heads + num_k_heads; + const auto head_id = global_warp_id % num_total_heads; + const auto batch_id = global_warp_id / num_total_heads; + if (batch_id >= batch_size) return; + + const auto position = static_cast(positions)[batch_id]; + const auto is_q = head_id < num_q_heads; + const auto local_head = is_q ? head_id : (head_id - num_q_heads); + const auto stride_batch = is_q ? q_stride_batch : k_stride_batch; + const auto stride_head = is_q ? q_stride_head : k_stride_head; + const auto base_ptr = is_q ? q : k; + const auto input = static_cast(pointer::offset(base_ptr, batch_id * stride_batch, local_head * stride_head)); + + const auto freq_ptr = reinterpret_cast(freqs_cis + position * kRopeDim); + const auto [f_real, f_imag] = freq_ptr[lane_id]; + PDLWaitPrimary(); + + const auto data = input[lane_id]; + const auto [x_real, x_imag] = cast(data); + fp32x2_t output; + if constexpr (kInverse) { + // (a + bi) * (c - di) = (ac + bd) + (bc - ad)i + output = { + x_real * f_real + x_imag * f_imag, + x_imag * f_real - x_real * f_imag, + }; + } else { + // (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + output = { + x_real * f_real - x_imag * f_imag, + x_real * f_imag + x_imag * f_real, + }; + } + input[lane_id] = cast(output); + + PDLTriggerSecondary(); +} + +template +struct FusedQKRopeKernel { + // 4 kernel variants: {forward, inverse} x {int32, int64} + static constexpr auto kernel_fwd_i32 = deepseek_rope_kernel; + static constexpr auto kernel_fwd_i64 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i32 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i64 = deepseek_rope_kernel; + + static void forward( + const tvm::ffi::TensorView q, + const tvm::ffi::Optional k, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + bool inverse) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto Q = SymbolicSize{"num_q_heads"}; + auto K = SymbolicSize{"num_k_heads"}; + constexpr auto D = kRopeDim; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, Q, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(q); + if (k.has_value()) { + TensorMatcher({B, K, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(k.value()); + } else { + K.set_value(0); + } + TensorMatcher({-1, D}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + const bool pos_i32 = pos_dtype.is_type(); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto num_q_heads = static_cast(Q.unwrap()); + const auto num_k_heads = static_cast(K.unwrap()); + const auto num_total_heads = num_q_heads + num_k_heads; + const auto total_warps = batch_size * num_total_heads; + const auto num_blocks = div_ceil(total_warps, kNumWarps); + + const auto elem_size = static_cast(sizeof(DType)); + const auto params = FusedQKRopeParams{ + .q = q.data_ptr(), + .k = k ? k.value().data_ptr() : nullptr, + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .q_stride_batch = q.stride(0) * elem_size, + .k_stride_batch = k ? k.value().stride(0) * elem_size : 0, + .q_stride_head = q.stride(1) * elem_size, + .k_stride_head = k ? k.value().stride(1) * elem_size : 0, + .num_q_heads = num_q_heads, + .num_k_heads = num_k_heads, + .batch_size = batch_size, + }; + + // dispatch: {inverse} x {pos_i32} + using KernelType = decltype(kernel_fwd_i32); + const KernelType kernel = + inverse ? (pos_i32 ? kernel_inv_i32 : kernel_inv_i64) : (pos_i32 ? kernel_fwd_i32 : kernel_fwd_i64); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh new file mode 100644 index 000000000000..be0e759445f9 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh @@ -0,0 +1,540 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct SiluMulQuantVarlenParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + const int32_t* __restrict__ masked_m; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t num_experts; +}; + +constexpr uint32_t kMaxExperts = 256; + +struct alignas(16) CTAWork { + uint32_t expert_id; + uint32_t expert_token_id; + bool valid; +}; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +template +SGL_DEVICE fp32x2_t silu_and_mul(DType2 gate, DType2 up, float limit) { + using namespace device; + // refer to as implementation. TL;DR: must clamp in bf16 + // https://github.com/deepseek-ai/DeepGEMM/blob/7f2a703ed51ac1f7af07f5e1453b2d3267d37d50/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh#L984-L997 + if constexpr (kApplySwigluLimit) { + static_assert(std::is_same_v); + gate = __hmin2(gate, {limit, limit}); + up = __hmax2(up, {-limit, -limit}); + up = __hmin2(up, {limit, limit}); + } + const auto [g0, g1] = cast(gate); + const auto [u0, u1] = cast(up); + const auto silu0 = g0 / (1.0f + __expf(-g0)); + const auto silu1 = g1 / (1.0f + __expf(-g1)); + const float val0 = silu0 * u0; + const float val1 = silu1 * u1; + if constexpr (kPrecise) { // I don't know if we should enable this? + return {val0, val1}; + } else { + return cast(cast(fp32x2_t{val0, val1})); + } +} + +[[maybe_unused]] +SGL_DEVICE CTAWork get_work(const SiluMulQuantVarlenParams& params) { + // Preconditions: + // 1. blockDim.x >= params.num_experts + // 2. params.num_experts <= kMaxExperts + using namespace device; + static_assert(kWarpThreads == 32); + + static __shared__ uint32_t s_warp_sum[32]; + static __shared__ CTAWork result; + + result.valid = false; + + const uint32_t tx = threadIdx.x; + const uint32_t lane_id = tx % kWarpThreads; + const uint32_t warp_id = tx / kWarpThreads; + + const uint32_t val = tx < params.num_experts ? params.masked_m[tx] : 0u; + + // Per-warp inclusive scan of masked_m. + const uint32_t warp_inclusive = warp_inclusive_sum(lane_id, val); + const uint32_t warp_exclusive = warp_inclusive - val; + + // Write each warp total. + if (lane_id == kWarpThreads - 1) s_warp_sum[warp_id] = warp_inclusive; + __syncthreads(); + const auto tmp_val = lane_id < warp_id ? s_warp_sum[lane_id] : 0u; + const auto prefix_exclusive = warp::reduce_sum(tmp_val) + warp_exclusive; + const auto bx = blockIdx.x; + if (prefix_exclusive <= bx && bx < prefix_exclusive + val) { + result = {tx, bx - prefix_exclusive, true}; + } + __syncthreads(); + return result; +} + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_varlen_kernel(const SiluMulQuantVarlenParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + // each thread will handle 8 elements + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto [expert_id, token_id, valid] = get_work(params); + + if (!valid) return; + + const auto work_id = threadIdx.x / kWorkThreads; + + const auto offset = expert_id * params.num_tokens + token_id; + const auto input = params.input + offset * params.hidden_dim * 2; + const auto output = params.output + offset * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + const auto base = reinterpret_cast(params.output_scale); + // Physical layout is [E, G//4, N] int32. Each int32 packs 4 consecutive + // group scales for the same token, so the byte address is: + // expert_offset + (group/4)*N*4 + token*4 + group%4 + return base + expert_id * num_groups * params.num_tokens + (work_id / 4u) * (params.num_tokens * 4u) + + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + offset * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + // gran=8 interleaved: every 16-element chunk on the N axis is + // [gate[0..7], up[0..7]]. Each thread handles 8 consecutive output + // elements, so its gate chunk lives at vec index 2*threadIdx.x and its + // up chunk at 2*threadIdx.x+1. + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +struct SiluAndMulClampParams { + const void* __restrict__ input; + void* __restrict__ output; + float swiglu_limit; +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_clamp_kernel(const SiluAndMulClampParams __grid_constant__ params) { + using namespace device; + static_assert(sizeof(DType) == 2, "only fp16/bf16 supported"); + using DType2 = packed_t; + constexpr auto kVecSize = 16 / sizeof(DType); + static_assert(kVecSize % 2 == 0 && kVecSize > 0); + using Vec = AlignedVector; + const auto bid = blockIdx.x; + const auto tile = tile::Memory::cta(); + const float limit = params.swiglu_limit; + + PDLWaitPrimary(); + const auto gate = tile.load(params.input, bid * 2 + 0); + const auto up = tile.load(params.input, bid * 2 + 1); + Vec out; + +#pragma unroll + for (uint32_t i = 0; i < kVecSize / 2; ++i) { + out[i] = cast(silu_and_mul(cast(gate[i]), cast(up[i]), limit)); + } + + tile.store(params.output, out, bid); + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct SiluAndMulMaskedPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_varlen_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_varlen_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const tvm::ffi::TensorView masked_m, + const uint32_t topk, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto E = SymbolicSize{"num_experts"}; + auto T = SymbolicSize{"num_tokens_padded"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({E, T, D}) // input + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({E, T, N}) // output + .with_dtype() + .with_device(device) + .verify(output); + if (!transposed) { + TensorMatcher({E, T, G}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + auto G_ = SymbolicSize{"G // 4"}; + TensorMatcher({E, G_, T}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + G.set_value(G_.unwrap() * 4); + } + TensorMatcher({E}) // + .with_dtype() + .with_device(device) + .verify(masked_m); + + const auto num_experts = static_cast(E.unwrap()); + const auto num_tokens = static_cast(T.unwrap()); + const auto num_groups = static_cast(G.unwrap()); + const auto hidden_dim = N.unwrap(); + + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum (256)"); + RuntimeCheck(num_groups * kGroupSize == hidden_dim, "invalid num_groups"); + + const auto params = SiluMulQuantVarlenParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .masked_m = static_cast(masked_m.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .num_experts = num_experts, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + RuntimeCheck(num_threads >= num_experts); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens * topk, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct SiluAndMulClampKernel { + static constexpr auto kernel = silu_mul_clamp_kernel; + + static void run(const tvm::ffi::TensorView input, const tvm::ffi::TensorView output, const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"gate_up_dim"}; // 2 * out_dim + auto H = SymbolicSize{"out_dim"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate || up) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, H}) // output + .with_dtype() + .with_device(device) + .verify(output); + RuntimeCheck(D.unwrap() == 2 * H.unwrap(), "input last dim must be 2 * output last dim"); + + constexpr uint32_t kVecSize = 16 / sizeof(DType); + const auto out_dim = static_cast(H.unwrap()); + const auto num_tokens = static_cast(M.unwrap()); + RuntimeCheck(out_dim % kVecSize == 0, "out_dim must be divisible by vector size"); + const auto num_threads = out_dim / kVecSize; + RuntimeCheck(num_threads <= 1024, "out_dim too large for single-block-per-row launch"); + + const auto params = SiluAndMulClampParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .swiglu_limit = static_cast(swiglu_limit), + }; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +struct SiluMulQuantContigParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t scale_row_stride_int32; // only used when kTransposed=true +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_contig_kernel(const SiluMulQuantContigParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto token_id = blockIdx.x; + const auto work_id = threadIdx.x / kWorkThreads; + + const auto input = params.input + token_id * params.hidden_dim * 2; + const auto output = params.output + token_id * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + // Physical layout is (G//4_pad, M_pad) int32; each int32 packs 4 + // consecutive UE8M0 exponents for the same token. Byte address: + // (work_id / 4) * M_pad * 4 + token * 4 + (work_id % 4). + const auto base = reinterpret_cast(params.output_scale); + return base + (work_id / 4u) * (params.scale_row_stride_int32 * 4u) + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + token_id * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +template +struct SiluAndMulContigPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_contig_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_contig_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate/up, natural or gran=8 interleaved on last dim) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, N}) // fp8 output + .with_dtype() + .with_device(device) + .verify(output); + + const auto hidden_dim = N.unwrap(); + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + const auto num_groups = static_cast(hidden_dim / kGroupSize); + + uint32_t scale_row_stride_int32 = 0; + if (!transposed) { + G.set_value(num_groups); + TensorMatcher({M, G}) // (M, G) fp32 natural row-major + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + RuntimeCheck(num_groups % 4 == 0, "transposed layout requires num_groups % 4 == 0"); + auto G_ = SymbolicSize{"G // 4"}; + G_.set_value(num_groups / 4); + auto M_pad = SymbolicSize{"M padded"}; + TensorMatcher({M, G_}) // `.transpose(-1,-2)[:M,:]` view of (G//4_pad, M_pad) int32 + .with_strides({int64_t{1}, M_pad}) // col-major transposed + .with_dtype() + .with_device(device) + .verify(output_scale); + scale_row_stride_int32 = static_cast(M_pad.unwrap()); + } + + const auto num_tokens = static_cast(M.unwrap()); + + const auto params = SiluMulQuantContigParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .scale_row_stride_int32 = scale_row_stride_int32, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh new file mode 100644 index 000000000000..3e2bd92589b7 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh @@ -0,0 +1,371 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct SiluMulQuantParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + const int32_t* __restrict__ masked_m; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t num_experts; +}; + +constexpr uint32_t kMaxExperts = 256; + +struct alignas(16) CTAWork { + uint32_t expert_id; + uint32_t expert_token_id; + bool valid; +}; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +[[maybe_unused]] +SGL_DEVICE CTAWork get_work(const SiluMulQuantParams& params) { + // Preconditions: + // 1. blockDim.x >= params.num_experts + // 2. params.num_experts <= kMaxExperts + using namespace device; + static_assert(kWarpThreads == 32); + + static __shared__ uint32_t s_warp_sum[32]; + static __shared__ CTAWork result; + + result.valid = false; + + const uint32_t tx = threadIdx.x; + const uint32_t lane_id = tx % kWarpThreads; + const uint32_t warp_id = tx / kWarpThreads; + + const uint32_t val = tx < params.num_experts ? params.masked_m[tx] : 0u; + + // Per-warp inclusive scan of masked_m. + const uint32_t warp_inclusive = warp_inclusive_sum(lane_id, val); + const uint32_t warp_exclusive = warp_inclusive - val; + + // Write each warp total. + if (lane_id == kWarpThreads - 1) s_warp_sum[warp_id] = warp_inclusive; + __syncthreads(); + const auto tmp_val = lane_id < warp_id ? s_warp_sum[lane_id] : 0u; + const auto prefix_exclusive = warp::reduce_sum(tmp_val) + warp_exclusive; + const auto bx = blockIdx.x; + if (prefix_exclusive <= bx && bx < prefix_exclusive + val) { + result = {tx, bx - prefix_exclusive, true}; + } + __syncthreads(); + return result; +} + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_kernel(const SiluMulQuantParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + // each thread will handle 8 elements + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto [expert_id, token_id, valid] = get_work(params); + + if (!valid) return; + + const auto work_id = threadIdx.x / kWorkThreads; + + const auto offset = expert_id * params.num_tokens + token_id; + const auto input = params.input + offset * params.hidden_dim * 2; + const auto output = params.output + offset * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + const auto base = reinterpret_cast(params.output_scale); + // Physical layout is [E, G//4, N] int32. Each int32 packs 4 consecutive + // group scales for the same token, so the byte address is: + // expert_offset + (group/4)*N*4 + token*4 + group%4 + return base + expert_id * num_groups * params.num_tokens + (work_id / 4u) * (params.num_tokens * 4u) + + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + offset * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + if constexpr (kApplySwigluLimit) { + // Fused fp32 path: bf16 load ??? fp32 clamp ??? fp32 silu ??? fp32 mul ??? fp32 result. + // Avoids the silu???bf16???mul???fp32 round-trip of the non-fused path since we already + // have gate/up in fp32 registers after clamp. + const float limit = params.swiglu_limit; + + const auto [g0_raw, g1_raw] = cast(gate_vec[i]); + const float g0 = fminf(g0_raw, limit); + const float g1 = fminf(g1_raw, limit); + + const float silu0 = g0 / (1.0f + expf(-g0)); + const float silu1 = g1 / (1.0f + expf(-g1)); + + const auto [u0_raw, u1_raw] = cast(up_vec[i]); + const float u0 = fmaxf(fminf(u0_raw, limit), -limit); + const float u1 = fmaxf(fminf(u1_raw, limit), -limit); + + const float val0 = u0 * silu0; + const float val1 = u1 * silu1; + results[2 * i + 0] = val0; + results[2 * i + 1] = val1; + local_max = fmaxf(local_max, fmaxf(fabsf(val0), fabsf(val1))); + } else { + // original code path ??? must stay byte-equal to pre-fusion kernel. + const auto [g0, g1] = cast(gate_vec[i]); + + float silu0 = g0 / (1.0f + expf(-g0)); + float silu1 = g1 / (1.0f + expf(-g1)); + + bf16x2_t silu_d = cast(fp32x2_t{silu0, silu1}); + auto [val0, val1] = cast(up_vec[i] * silu_d); + results[2 * i + 0] = val0; + results[2 * i + 1] = val1; + local_max = fmaxf(local_max, fmaxf(fabsf(val0), fabsf(val1))); + } + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +struct SiluAndMulClampParams { + const void* __restrict__ input; + void* __restrict__ output; + float swiglu_limit; +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_clamp_kernel(const SiluAndMulClampParams __grid_constant__ params) { + using namespace device; + static_assert(sizeof(DType) == 2, "only fp16/bf16 supported"); + using DType2 = packed_t; + constexpr auto kVecSize = 16 / sizeof(DType); + static_assert(kVecSize % 2 == 0 && kVecSize > 0); + using Vec = AlignedVector; + const auto bid = blockIdx.x; + const auto tile = tile::Memory::cta(); + const float limit = params.swiglu_limit; + + PDLWaitPrimary(); + const auto gate = tile.load(params.input, bid * 2 + 0); + const auto up = tile.load(params.input, bid * 2 + 1); + Vec out; + +#pragma unroll + for (uint32_t i = 0; i < kVecSize / 2; ++i) { + const auto [g0_raw, g1_raw] = cast(gate[i]); + const float g0 = fminf(g0_raw, limit); + const float g1 = fminf(g1_raw, limit); + const float silu0 = g0 / (1.0f + expf(-g0)); + const float silu1 = g1 / (1.0f + expf(-g1)); + const auto [u0_raw, u1_raw] = cast(up[i]); + const float u0 = fmaxf(fminf(u0_raw, limit), -limit); + const float u1 = fmaxf(fminf(u1_raw, limit), -limit); + const float val0 = u0 * silu0; + const float val1 = u1 * silu1; + out[i] = cast(fp32x2_t{val0, val1}); + } + + tile.store(params.output, out, bid); + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct SiluAndMulMaskedPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = silu_mul_quant_kernel; + static constexpr auto kernel_transposed = silu_mul_quant_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const tvm::ffi::TensorView masked_m, + const uint32_t topk, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto E = SymbolicSize{"num_experts"}; + auto T = SymbolicSize{"num_tokens_padded"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({E, T, D}) // input + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({E, T, N}) // output + .with_dtype() + .with_device(device) + .verify(output); + if (!transposed) { + TensorMatcher({E, T, G}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + auto G_ = SymbolicSize{"G // 4"}; + TensorMatcher({E, G_, T}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + G.set_value(G_.unwrap() * 4); + } + TensorMatcher({E}) // + .with_dtype() + .with_device(device) + .verify(masked_m); + + const auto num_experts = static_cast(E.unwrap()); + const auto num_tokens = static_cast(T.unwrap()); + const auto num_groups = static_cast(G.unwrap()); + const auto hidden_dim = N.unwrap(); + + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum (256)"); + RuntimeCheck(num_groups * kGroupSize == hidden_dim, "invalid num_groups"); + + const auto params = SiluMulQuantParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .masked_m = static_cast(masked_m.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .num_experts = num_experts, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + RuntimeCheck(num_threads >= num_experts); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens * topk, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct SiluAndMulClampKernel { + static constexpr auto kernel = silu_mul_clamp_kernel; + + static void run(const tvm::ffi::TensorView input, const tvm::ffi::TensorView output, const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"gate_up_dim"}; // 2 * out_dim + auto H = SymbolicSize{"out_dim"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate || up) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, H}) // output + .with_dtype() + .with_device(device) + .verify(output); + RuntimeCheck(D.unwrap() == 2 * H.unwrap(), "input last dim must be 2 * output last dim"); + + constexpr uint32_t kVecSize = 16 / sizeof(DType); + const auto out_dim = static_cast(H.unwrap()); + const auto num_tokens = static_cast(M.unwrap()); + RuntimeCheck(out_dim % kVecSize == 0, "out_dim must be divisible by vector size"); + const auto num_threads = out_dim / kVecSize; + RuntimeCheck(num_threads <= 1024, "out_dim too large for single-block-per-row launch"); + + const auto params = SiluAndMulClampParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .swiglu_limit = static_cast(swiglu_limit), + }; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh new file mode 100644 index 000000000000..49f6f5596377 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh @@ -0,0 +1,205 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct FusedStoreCacheParam { + const void* __restrict__ input; + void* __restrict__ cache; + const void* __restrict__ indices; + uint32_t num_tokens; +}; + +template +__global__ void fused_store_flashmla_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 584 = 576 + 8 + constexpr int64_t kPageBytes = host::div_ceil(584 << kPageBits, 576) * 576; + + // each warp handles 64 elements, 8 warps, each block handles 1 row + const auto& [input, cache, indices, num_tokens] = param; + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + const uint32_t wid = tid / 32; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[bid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + const auto elems = static_cast(input)[tid + bid * 256]; + if (wid != 7) { + const auto [x, y] = cast(elems); + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576); + const auto scale_ptr = pointer::offset(page_ptr, 576 << kPageBits, offset * 8); + static_cast(value_ptr)[tid] = result; + static_cast(scale_ptr)[wid] = scale_ue8m0; + } else { + const auto result = cast(elems); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576, 448); + static_cast(value_ptr)[tid - 7 * 32] = result; + } + + PDLTriggerSecondary(); +} + +template +__global__ void fused_store_indexer_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 132 = 128 + 4 + constexpr int64_t kPageBytes = 132 << kPageBits; + + // each warp handles 128 elements, 1 warp, each block handles multiple rows + const auto& [input, cache, indices, num_tokens] = param; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto global_wid = global_tid / 32; + const auto lane_id = threadIdx.x % 32; + + if (global_wid >= num_tokens) return; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[global_wid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + using InStorage = AlignedVector; + using OutStorage = AlignedVector; + const auto elems = static_cast(input)[global_tid]; + const auto [x0, x1] = cast(elems[0]); + const auto [y0, y1] = cast(elems[1]); + const auto local_max = fmaxf(fmaxf(fabs(x0), fabs(x1)), fmaxf(fabs(y0), fabs(y1))); + const auto abs_max = warp::reduce_max(local_max); + // use normal fp32 scale + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 128); + const auto scale_ptr = pointer::offset(page_ptr, 128 << kPageBits, offset * 4); + OutStorage result; + result[0] = pack_fp8(x0 * inv_scale, x1 * inv_scale); + result[1] = pack_fp8(y0 * inv_scale, y1 * inv_scale); + static_cast(value_ptr)[lane_id] = result; + static_cast(scale_ptr)[0] = scale; + + PDLTriggerSecondary(); +} + +template +struct FusedStoreCacheFlashMLAKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = host::div_ceil(584 * kPageSize, 576) * 576; + static constexpr auto kernel = fused_store_flashmla_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 512}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 256; + const auto num_blocks = num_tokens; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct FusedStoreCacheIndexerKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = 132 * kPageSize; + static constexpr auto kernel = fused_store_indexer_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 128}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128; + const auto num_blocks = div_ceil(num_tokens * 32, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh new file mode 100644 index 000000000000..ef2be43c07e2 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh @@ -0,0 +1,336 @@ +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kTopK = 512; +constexpr uint32_t kTopKBlockSize = 512; +constexpr uint32_t kSMEM = 16 * 1024 * sizeof(uint32_t); // 64KB (bytes) + +struct TopK512Params { + const float* __restrict__ scores; + const int32_t* __restrict__ seq_lens; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int32_t* __restrict__ raw_indices; // optional: output raw abs position indices before page transform + const int64_t score_stride; + const int64_t page_table_stride; + uint32_t page_bits; +}; + +SGL_DEVICE uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +SGL_DEVICE uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + int32_t* __restrict__ raw_indices, // optional: output raw abs position indices + const uint32_t length, + const uint32_t page_bits) { + static_assert(kTopK <= kTopKBlockSize); + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + if (raw_indices != nullptr) { + raw_indices[tx] = tx; + } + } else if (kTopK == kTopKBlockSize || tx < kTopK) { + indices[tx] = -1; // fill invalid indices to -1 + if (raw_indices != nullptr) { + raw_indices[tx] = -1; + } + } +} + +[[maybe_unused]] +SGL_DEVICE void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, const uint32_t length) { + constexpr uint32_t RADIX = 256; + constexpr uint32_t BLOCK_SIZE = kTopKBlockSize; + constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t)); + + alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32]; + alignas(128) __shared__ uint32_t s_counter; + alignas(128) __shared__ uint32_t s_threshold_bin_id; + alignas(128) __shared__ uint32_t s_num_input[2]; + alignas(128) __shared__ int32_t s_last_remain; + + extern __shared__ uint32_t s_input_idx[][kSMEM / (2 * sizeof(int32_t))]; + + const uint32_t tx = threadIdx.x; + uint32_t remain_topk = kTopK; + auto& s_histogram = _s_histogram_buf[0]; + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int32_t i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = _s_histogram_buf[k][tx]; + if (tx + j < RADIX) { + value += _s_histogram_buf[k][tx + j]; + } + _s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + if (remain_topk == 0) { + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const uint32_t bin = convert_to_uint8(input[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw_input = input[idx]; + const uint32_t bin = convert_to_uint8(raw_input); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (pos < SMEM_INPUT_SIZE) { + [[likely]] s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto raw_num_input = s_num_input[r_idx]; + const auto num_input = raw_num_input < SMEM_INPUT_SIZE ? raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = remain_topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + + if (remain_topk == 0) { + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + output[kTopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < SMEM_INPUT_SIZE) { + /// NOTE: (dark) fuse the histogram computation here + [[likely]] s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void topk_512_transform(const __grid_constant__ TopK512Params params) { + const auto &[ + scores, seq_lens, page_table, page_indices, raw_indices, // pointers + score_stride, page_table_stride, page_bits // sizes + ] = params; + const uint32_t work_id = blockIdx.x; + + /// NOTE: dangerous prefetch seq_len before PDL wait + const uint32_t seq_len = seq_lens[work_id]; + const auto score_ptr = scores + work_id * score_stride; + const auto page_ptr = page_table + work_id * page_table_stride; + const auto indices_ptr = page_indices + work_id * kTopK; + const auto raw_indices_ptr = raw_indices != nullptr ? raw_indices + work_id * kTopK : nullptr; + + device::PDLWaitPrimary(); + + if (seq_len <= kTopK) { + naive_transform(score_ptr, page_ptr, indices_ptr, raw_indices_ptr, seq_len, page_bits); + } else { + __shared__ int32_t s_topk_indices[kTopK]; + radix_topk(score_ptr, s_topk_indices, seq_len); + static_assert(kTopK <= kTopKBlockSize); + const auto tx = threadIdx.x; + if (kTopK == kTopKBlockSize || tx < kTopK) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], page_bits); + if (raw_indices_ptr != nullptr) { + raw_indices_ptr[tx] = s_topk_indices[tx]; + } + } + } + + device::PDLTriggerSecondary(); +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +template +struct TopK512Kernel { + static constexpr auto kernel = topk_512_transform; + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional raw_indices) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, 512}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + int32_t* raw_indices_ptr = nullptr; + if (raw_indices.has_value()) { + TensorMatcher({B, 512}) // optional raw indices output, must be contiguous + .with_dtype() + .with_device(device) + .verify(raw_indices.value()); + raw_indices_ptr = static_cast(raw_indices.value().data_ptr()); + } + + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopK512Params{ + .scores = static_cast(scores.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .raw_indices = raw_indices_ptr, + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + constexpr auto kSMEM_ = kSMEM + sizeof(int32_t); // align up a little + setup_kernel_smem_once(); + LaunchKernel(batch_size, kTopKBlockSize, device.unwrap(), kSMEM_).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk_1024.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_1024.cuh new file mode 100644 index 000000000000..6774734ec187 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_1024.cuh @@ -0,0 +1,336 @@ +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kTopK = 1024; +constexpr uint32_t kTopKBlockSize = 1024; +constexpr uint32_t kSMEM = 16 * 1024 * sizeof(uint32_t); // 64KB (bytes) + +struct TopK1024Params { + const float* __restrict__ scores; + const int32_t* __restrict__ seq_lens; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int32_t* __restrict__ raw_indices; // optional: output raw abs position indices before page transform + const int64_t score_stride; + const int64_t page_table_stride; + uint32_t page_bits; +}; + +SGL_DEVICE uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +SGL_DEVICE uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + int32_t* __restrict__ raw_indices, // optional: output raw abs position indices + const uint32_t length, + const uint32_t page_bits) { + static_assert(kTopK <= kTopKBlockSize); + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + if (raw_indices != nullptr) { + raw_indices[tx] = tx; + } + } else if (kTopK == kTopKBlockSize || tx < kTopK) { + indices[tx] = -1; // fill invalid indices to -1 + if (raw_indices != nullptr) { + raw_indices[tx] = -1; + } + } +} + +[[maybe_unused]] +SGL_DEVICE void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, const uint32_t length) { + constexpr uint32_t RADIX = 256; + constexpr uint32_t BLOCK_SIZE = kTopKBlockSize; + constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t)); + + alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32]; + alignas(128) __shared__ uint32_t s_counter; + alignas(128) __shared__ uint32_t s_threshold_bin_id; + alignas(128) __shared__ uint32_t s_num_input[2]; + alignas(128) __shared__ int32_t s_last_remain; + + extern __shared__ uint32_t s_input_idx[][kSMEM / (2 * sizeof(int32_t))]; + + const uint32_t tx = threadIdx.x; + uint32_t remain_topk = kTopK; + auto& s_histogram = _s_histogram_buf[0]; + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int32_t i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = _s_histogram_buf[k][tx]; + if (tx + j < RADIX) { + value += _s_histogram_buf[k][tx + j]; + } + _s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + if (remain_topk == 0) { + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const uint32_t bin = convert_to_uint8(input[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw_input = input[idx]; + const uint32_t bin = convert_to_uint8(raw_input); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (pos < SMEM_INPUT_SIZE) { + [[likely]] s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto raw_num_input = s_num_input[r_idx]; + const auto num_input = raw_num_input < SMEM_INPUT_SIZE ? raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = remain_topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + + if (remain_topk == 0) { + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + output[kTopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < SMEM_INPUT_SIZE) { + /// NOTE: (dark) fuse the histogram computation here + [[likely]] s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void topk_1024_transform(const __grid_constant__ TopK1024Params params) { + const auto &[ + scores, seq_lens, page_table, page_indices, raw_indices, // pointers + score_stride, page_table_stride, page_bits // sizes + ] = params; + const uint32_t work_id = blockIdx.x; + + /// NOTE: dangerous prefetch seq_len before PDL wait + const uint32_t seq_len = seq_lens[work_id]; + const auto score_ptr = scores + work_id * score_stride; + const auto page_ptr = page_table + work_id * page_table_stride; + const auto indices_ptr = page_indices + work_id * kTopK; + const auto raw_indices_ptr = raw_indices != nullptr ? raw_indices + work_id * kTopK : nullptr; + + device::PDLWaitPrimary(); + + if (seq_len <= kTopK) { + naive_transform(score_ptr, page_ptr, indices_ptr, raw_indices_ptr, seq_len, page_bits); + } else { + __shared__ int32_t s_topk_indices[kTopK]; + radix_topk(score_ptr, s_topk_indices, seq_len); + static_assert(kTopK <= kTopKBlockSize); + const auto tx = threadIdx.x; + if (kTopK == kTopKBlockSize || tx < kTopK) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], page_bits); + if (raw_indices_ptr != nullptr) { + raw_indices_ptr[tx] = s_topk_indices[tx]; + } + } + } + + device::PDLTriggerSecondary(); +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +template +struct TopK1024Kernel { + static constexpr auto kernel = topk_1024_transform; + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional raw_indices) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, 1024}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + int32_t* raw_indices_ptr = nullptr; + if (raw_indices.has_value()) { + TensorMatcher({B, 1024}) // optional raw indices output, must be contiguous + .with_dtype() + .with_device(device) + .verify(raw_indices.value()); + raw_indices_ptr = static_cast(raw_indices.value().data_ptr()); + } + + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopK1024Params{ + .scores = static_cast(scores.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .raw_indices = raw_indices_ptr, + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + constexpr auto kSMEM_ = kSMEM + sizeof(int32_t); // align up a little + setup_kernel_smem_once(); + LaunchKernel(batch_size, kTopKBlockSize, device.unwrap(), kSMEM_).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh new file mode 100644 index 000000000000..8c4a526575ea --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh @@ -0,0 +1,493 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace { + +#ifndef SGL_TOPK +#define SGL_TOPK 512 +#endif + +inline constexpr uint32_t K = SGL_TOPK; + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +namespace impl = device::top512; +using Large = impl::ClusterTopK; +using Medium = impl::StreamingTopK; +using Small = impl::RegisterTopK; + +using Metadata = Large::Metadata; +constexpr uint32_t kBlockSize = impl::kBlockSize; +constexpr uint32_t kNumClusters = 15; // based on hardware limits +constexpr uint32_t kClusterSize = Large::kClusterSize; +constexpr uint32_t kMax2PassLength = Small::kMax2PassLength; +constexpr uint32_t kMaxSupportedLength = Large::kMaxLength; + +/// Common metadata lives at metadata[0] (first row of the [batch_size+1, 4] tensor). +/// Per-item metadata starts at metadata[1..batch_size]. The plan kernel writes both. +struct alignas(16) GlobalMetadata { + uint32_t cluster_threshold; // decided per-batch in plan kernel + uint32_t num_cluster_items; // N = number of items routed to the cluster path + uint32_t reserved[2]; +}; +static_assert(sizeof(GlobalMetadata) == sizeof(Metadata), "layout: row 0 must occupy one Metadata-sized slot"); + +// optimize occupancy for prefill +#define SMALL_TOPK_KERNEL __global__ __launch_bounds__(kBlockSize, 2) +// cluster at y dim +#define LARGE_CLUSTER __cluster_dims__(1, kClusterSize, 1) +// stage-1 is persistent cluster, and shared memory usage is huge (can not 2) +#define LARGE_TOPK_STAGE_1 __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// stage-2 is non-persistent non-cluster, with less shared memory and higher occupancy +#define LARGE_TOPK_STAGE_2 __global__ __launch_bounds__(kBlockSize, 2) +// fused into 1 stage when batch-size <= kNumPersistentClusters +#define FUSED_COMBINE_KERNEL __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// plan runs once as a single block before the combine kernels +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct TopKParams { + const uint32_t* __restrict__ seq_lens; + const float* __restrict__ scores; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int64_t score_stride; + int64_t page_table_stride; + uint8_t* __restrict__ workspace; // [batch, kWorkspaceBytes] -- internally allocated + /// Pointer to the full metadata tensor: metadata[0] is GlobalMetadata, metadata[1..] + /// are per-item entries (at most kNumClusters * rounds of them). + const Metadata* __restrict__ metadata = nullptr; + int64_t workspace_stride; // bytes per batch + uint32_t batch_size; + uint32_t page_bits; + + SGL_DEVICE const float* get_scores(const uint32_t batch_id) const { + return scores + batch_id * score_stride; + } + SGL_DEVICE impl::TransformParams get_transform(const uint32_t batch_id, int32_t* indices) const { + return { + .page_table = page_table + batch_id * page_table_stride, + .indices_in = indices, + .indices_out = page_indices + batch_id * K, + .page_bits = page_bits, + }; + } + SGL_DEVICE const GlobalMetadata& get_global_metadata() const { + return *reinterpret_cast(metadata); + } + SGL_DEVICE const Metadata& get_item_metadata(uint32_t work_id) const { + return metadata[1 + work_id]; // +1 to skip the GlobalMetadata row + } +}; + +SGL_DEVICE uint2 partition_work(uint32_t length, uint32_t rank) { + constexpr uint32_t kTMAAlign = 4; + const auto total_units = (length + kTMAAlign - 1) / kTMAAlign; + const auto base = total_units / kClusterSize; + const auto extra = total_units % kClusterSize; + const auto local_units = base + (rank < extra ? 1u : 0u); + const auto offset_units = rank * base + min(rank, extra); + const auto offset = offset_units * kTMAAlign; + const auto finish = min(offset + local_units * kTMAAlign, length); + return {offset, finish - offset}; +} + +/// Persistent scheduler. A single block: +/// 1. Decides a cluster_threshold from the real seq_lens distribution (or +/// uses the caller-supplied `static_cluster_threshold` when non-zero). +/// 2. Writes that threshold + N into metadata[0] (the GlobalMetadata row). +/// 3. Compacts items with seq_len > threshold into metadata[1..N+1), laid out +/// to match the persistent consumer's round-robin stride (kNumClusters). +/// Entries for clusters that get no work are zero-filled. +PLAN_KERNEL void topk_plan( + const uint32_t* __restrict__ seq_lens, + Metadata* __restrict__ metadata, + const uint32_t batch_size, + const uint32_t static_cluster_threshold) { + // Candidate thresholds, strictly increasing. Picked to give the auto-heuristic + // reasonable granularity without needing a full sort. Must all be >= kMax2PassLength. + + struct Pair { + uint32_t threshold; + uint32_t max_batch_size; + }; + /// NOTE: only tuned on B200 + constexpr Pair kCandidates[] = { + {32768, 30}, + {40960, 45}, + {49152, 45}, + {65536, 60}, + {98304, 60}, + {131072, 75}, + {196608, 90}, + {262144, 105}, + }; + constexpr uint32_t kNumCandidates = std::size(kCandidates); + constexpr uint32_t kMinBatchSize = kCandidates[0].max_batch_size; + static_assert(kCandidates[0].threshold == kMax2PassLength); + static_assert(kCandidates[kNumCandidates - 1].threshold == kMaxSupportedLength); + + __shared__ uint32_t s_count; // final N after compaction + __shared__ uint32_t s_counts[kNumCandidates]; + __shared__ uint32_t s_threshold; + + const auto tx = threadIdx.x; + if (tx == 0) s_count = 0; + if (tx < kNumCandidates) s_counts[tx] = 0; + __syncthreads(); + + // --- Phase 1: decide threshold ------------------------------------------ + if (static_cluster_threshold > 0) { + if (tx == 0) s_threshold = static_cluster_threshold; + } else if (batch_size <= kMinBatchSize) { + if (tx == 0) s_threshold = kMax2PassLength; // always prefer cluster + } else { + // Count items above each candidate threshold. Monotonically non-increasing in T. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + assert(sl <= kMaxSupportedLength); + uint32_t count = 0; +#pragma unroll + for (uint32_t j = 0; j < kNumCandidates; ++j) { + count += (sl > kCandidates[j].threshold ? 1 : 0); + } + if (count > 0) { + atomicAdd(&s_counts[count - 1], 1); + } + } + __syncthreads(); + if (tx == 0) { + uint32_t accum = 0; + uint32_t chosen = kMaxSupportedLength; +#pragma unroll + for (uint32_t i = 0; i < kNumCandidates; ++i) { + const auto j = kNumCandidates - 1 - i; + accum += s_counts[j]; + /// NOTE: `accum` increasing, while `max_batch_size` decreasing + if (accum > kCandidates[j].max_batch_size) break; + chosen = kCandidates[j].threshold; + } + s_threshold = chosen; + } + } + __syncthreads(); + // sanity check: below 2 pass threshold, must fits in small path + const auto cluster_threshold = max(s_threshold, kMax2PassLength); + + // --- Phase 2: compact items with seq_len > threshold into metadata[1..] - + // Per-item rows live at metadata[1 + pos]; metadata[0] is the GlobalMetadata row. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + if (sl > cluster_threshold) { + const auto pos = atomicAdd(&s_count, 1); + metadata[1 + pos] = {i, sl, false}; + } + } + __syncthreads(); + const auto N = s_count; + + // --- Phase 3: has_next + sentinels + GlobalMetadata --------------------- + for (uint32_t i = tx; i < N; i += kBlockSize) { + if (i + kNumClusters < N) metadata[1 + i].has_next = true; + } + // Zero-fill the first kNumClusters sentinel slots that got no valid entry. + if (tx < kNumClusters && tx >= N) metadata[1 + tx] = {0, 0, false}; + // Write global metadata (row 0). + if (tx == 0) { + auto* g = reinterpret_cast(metadata); + *g = { + .cluster_threshold = cluster_threshold, + .num_cluster_items = N, + .reserved = {0, 0}, + }; + } +} + +SMALL_TOPK_KERNEL void // short context +topk_short_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + // trivial case + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + device::PDLTriggerSecondary(); + Small::transform(transform); + } +} + +LARGE_TOPK_STAGE_1 void // long context, middle to large batch size +topk_combine_preprocess(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + uint32_t work_id = blockIdx.x; + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + uint32_t length; + uint32_t offset; + const auto cluster_rank = blockIdx.y; + + const auto prefetch_metadata = [&] { + const auto metadata = params.get_item_metadata(work_id); + batch_id = metadata.batch_id; + seq_len = metadata.seq_len; + has_next = metadata.has_next; + work_id += kNumClusters; // advance to the next item for this cluster + }; + const auto launch_prologue = [&] { + const auto partition = partition_work(seq_len, cluster_rank); + offset = partition.x; + length = partition.y; + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + }; + + device::PDLWaitPrimary(); + device::PDLTriggerSecondary(); + + prefetch_metadata(); + if (seq_len == 0) return; + Large::stage1_init(smem); + launch_prologue(); + while (true) { + const auto this_length = length; + const auto this_offset = offset; + const auto need_prefetch = has_next; + const auto transform = params.get_transform(batch_id, s_topk_indices); + const auto ws = params.workspace + batch_id * params.workspace_stride; + if (need_prefetch) prefetch_metadata(); + Large::stage1(s_topk_indices, this_length, smem, /*reuse=*/true); + if (need_prefetch) launch_prologue(); + Large::stage1_epilogue(transform, this_offset, ws, smem); + if (!need_prefetch) break; + } +} + +LARGE_TOPK_STAGE_2 void // long context, middle to large batch size +topk_combine_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto cluster_threshold = params.get_global_metadata().cluster_threshold; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= kMax2PassLength) { + if (seq_len <= Small::kMax1PassLength) { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } else { + __syncwarp(); + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } + Small::transform(transform); + } else if (seq_len <= cluster_threshold) { + Medium::run(params.get_scores(batch_id), seq_len, s_topk_indices, smem); + Medium::transform(transform, smem); + } else { + const auto ws = params.workspace + batch_id * params.workspace_stride; + device::PDLWaitPrimary(); + Large::transform(transform, ws, smem); + } +} + +FUSED_COMBINE_KERNEL void // long context, small batch size +topk_fused_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto cluster_rank = blockIdx.y; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + if (cluster_rank != 0) return; // only first rank work + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= Small::kMax1PassLength) { + if (cluster_rank != 0) return; // only first rank work + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + Small::transform(transform); + } else { + const auto [offset, length] = partition_work(seq_len, cluster_rank); + const auto ws = params.workspace + batch_id * params.workspace_stride; + Large::stage1_init(smem); + device::PDLWaitPrimary(); + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + Large::stage1(s_topk_indices, length, smem); + Large::stage1_epilogue(transform, offset, ws, smem); + cooperative_groups::this_cluster().sync(); + if (cluster_rank != 0) return; // only first rank do the stage-2 + Large::transform(transform, ws, smem); + } +} + +struct CombinedTopKKernel { + static constexpr auto kStage1SMEM = sizeof(Large::Smem) + 128; + static constexpr auto kStage2SMEM = std::max(sizeof(Small::Smem), sizeof(Medium::Smem)) + 128; + + static void plan( // + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView metadata, + const uint32_t static_cluster_threshold) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto batch_size = static_cast(B.unwrap()); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1); + if (batch_size <= kNumClusters) return; // metadata unused in fused path + + const auto device = device_.unwrap(); + constexpr auto kernel = topk_plan; + LaunchKernel(1, kBlockSize, device)( // + kernel, + static_cast(seq_lens.data_ptr()), + static_cast(metadata.data_ptr()), + batch_size, + static_cluster_threshold); + } + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::TensorView workspace, + const tvm::ffi::TensorView metadata) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto L = SymbolicSize{"max_seq_len"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto W = SymbolicSize{"workspace_stride"}; + constexpr auto D = Large::kWorkspaceInts; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, L}) // + .with_strides({S, 1}) + .with_dtype() + .with_device(device_) + .verify(scores); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({B, -1}) // + .with_strides({P, 1}) + .with_dtype() + .with_device(device_) + .verify(page_table); + TensorMatcher({B, K}) // + .with_dtype() + .with_device(device_) + .verify(page_indices); + TensorMatcher({B, D}) // + .with_strides({W, 1}) + .with_dtype() + .with_device(device_) + .verify(workspace); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto max_seq_len = static_cast(L.unwrap()); + const auto device = device_.unwrap(); + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + RuntimeCheck(S.unwrap() % 4 == 0, "score_stride must be a multiple of 4 (TMA 16-byte alignment)"); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1, "invalid metadata shape"); + + // NOTE: this should be fixed later + // RuntimeCheck(max_seq_len <= kMaxSupportedLength, max_seq_len, " exceeds the maximum supported length"); + + const auto params = TopKParams{ + .seq_lens = static_cast(seq_lens.data_ptr()), + .scores = static_cast(scores.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .workspace = static_cast(workspace.data_ptr()), + .metadata = static_cast(metadata.data_ptr()), + .workspace_stride = W.unwrap() * static_cast(sizeof(int32_t)), + .batch_size = batch_size, + .page_bits = page_bits, + }; + + if (max_seq_len <= Small::kMax1PassLength) { + // All items fit in the short path -- no stage-1 needed + constexpr auto kernel = topk_short_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel, params); + } else { + // Some items may be large -- launch stage-1 + main + if (batch_size <= kNumClusters) { + // can fuse into 1 stage + constexpr auto kernel = topk_fused_transform; + constexpr auto kSMEM = std::max(kStage1SMEM, kStage2SMEM); + setup_kernel_smem_once(); + LaunchKernel({batch_size, kClusterSize}, kBlockSize, device, kSMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel, params); + } else { + // stage 1 + stage 2 + constexpr auto kernel_stage_1 = topk_combine_preprocess; + setup_kernel_smem_once(); + const auto num_clusters = std::min(batch_size, kNumClusters); + LaunchKernel({num_clusters, kClusterSize}, kBlockSize, device, kStage1SMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel_stage_1, params); + constexpr auto kernel_stage_2 = topk_combine_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel_stage_2, params); + } + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_base.cuh b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_base.cuh new file mode 100644 index 000000000000..dc5f5beea347 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_base.cuh @@ -0,0 +1,27 @@ +#include +#include +#include + +#include +#include +#include + +#include + +#include +#include + +inline void register_custom_all_reduce() { + namespace refl = tvm::ffi::reflection; + using Class = host::distributed::CustomAllReduceBase; + refl::ObjectDef() + .def(refl::init(), "__init__") + .def("share_storage", &Class::share_storage) + .def("share_graph_inputs", &Class::share_graph_inputs) + .def("post_init", &Class::post_init) + .def("register_inputs", &Class::register_inputs) + .def("set_cuda_graph_capture", &Class::set_cuda_graph_capture) + .def("free_ipc_handles", &Class::free_ipc_handles) + .def("free_storage", &Class::free_storage) + .def("configure_pull", &Class::configure_pull); +} diff --git a/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh new file mode 100644 index 000000000000..e8837af4cd34 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_pull.cuh @@ -0,0 +1,205 @@ +// Partially migrated from AOT kernel: +// https://github.com/sgl-project/sglang/blob/v0.5.9/sgl-kernel/csrc/allreduce/custom_all_reduce.cu +// Which was originally adapted from: +// https://github.com/vllm-project/vllm/blob/v0.8.2/csrc/custom_all_reduce.cu +// We redesign the controller interface to minimize control plane traffic, +// and fuse the reduce-scatter and broadcast in the 2-shot all reduce +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace { + +using device::distributed::PullController; +using host::distributed::AllReduceData; +using host::distributed::CustomAllReduceBase, host::distributed::CustomAllReduceRef; + +struct AllReduceParams { + void* __restrict__ output; + uint32_t rank; + uint32_t num_items; // NOTE: support at most 4G, but that's too much +}; + +[[maybe_unused]] +SGL_DEVICE void prefetch_uniform_ptr(const void* ptr) { + asm volatile("prefetchu.L1 [%0];" ::"l"(ptr) : "memory"); +} + +#define CUSTOM_AR_KERNEL __global__ __launch_bounds__(1024, 1) + +template +SGL_DEVICE void all_reduce_impl(const AllReduceParams& params, DType* (&input)[kNumGPU]) { + using namespace device; + + constexpr uint32_t kVecSize = 16 / (sizeof(DType) * 2); + using DType2 = packed_t; + using Storage = AlignedVector; + const auto& [output, rank, num_items] = params; + + for (auto i = blockIdx.x;; i += gridDim.x) { + const auto offset = i * blockDim.x + threadIdx.x; + if (offset * kVecSize * 2 >= num_items) break; + Storage storage[kNumGPU]; + +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) { + storage[i].load(input[i], offset); + } + const Storage result = distributed::reduce_impl(storage); + if constexpr (kBroadcast) { +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) { + result.store(input[i], offset); + } + } else { + result.store(output, offset); + } + } +} + +template +CUSTOM_AR_KERNEL void all_reduce_one_shot_kernel( + const AllReduceData* __restrict__ data, + const AllReduceParams __grid_constant__ params, + const PullController __grid_constant__ ctrl) { + /// NOTE: we assume the data array is ready before the previous kernel + DType* input[kNumGPU]; + prefetch_uniform_ptr(data); +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) + input[i] = static_cast(data->input[i]); + device::PDLWaitPrimary(); + + ctrl.sync(params.rank, kNumGPU); + all_reduce_impl(params, input); + + device::PDLTriggerSecondary(); + ctrl.sync(params.rank, kNumGPU); +} + +template +CUSTOM_AR_KERNEL void all_reduce_two_shot_kernel( + const AllReduceData* __restrict__ data, + const AllReduceParams __grid_constant__ params, + const PullController __grid_constant__ ctrl) { + // get the range of this rank + using device::kWarpThreads, device::div_ceil; + + prefetch_uniform_ptr(data); + DType* input[kNumGPU]; +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) + input[i] = static_cast(data->input[i]); + + constexpr uint32_t kVecSize = 16 / (sizeof(DType) * 2); + const uint32_t num_items = params.num_items; + const uint32_t total_vec = num_items / (kVecSize * 2); // must be divisible here + const uint32_t vec_per_rank = div_ceil(div_ceil(total_vec, kNumGPU), kWarpThreads) * kWarpThreads; + const uint32_t local_vec_start = min(params.rank * vec_per_rank, total_vec); + const uint32_t local_vec_finish = min(local_vec_start + vec_per_rank, total_vec); + const uint32_t local_start = local_vec_start * kVecSize * 2; + const uint32_t local_length = (local_vec_finish - local_vec_start) * kVecSize * 2; + const auto local_params = AllReduceParams{ + .output = nullptr, // this is not used for 2-shot all reduce + .rank = params.rank, + .num_items = local_length, + }; + +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) + input[i] += local_start; + + device::PDLWaitPrimary(); + + ctrl.sync(params.rank, kNumGPU); + all_reduce_impl(local_params, input); + + device::PDLTriggerSecondary(); + ctrl.sync(params.rank, kNumGPU); +} + +template +struct CustomAllReducePull : public CustomAllReduceBase { + static constexpr uint32_t kVecSize = 16 / (sizeof(DType) * 2); + static constexpr auto one_shot_kernel = all_reduce_one_shot_kernel; + static constexpr auto two_shot_kernel = all_reduce_two_shot_kernel; + static_assert(kNumGPU <= device::distributed::kMaxNumGPU, "kNumGPU exceeds the maximum supported GPUs"); + + tvm::ffi::Tensor all_reduce(tvm::ffi::Tensor input, int shot) { + using namespace host; + const bool use_2shot = (shot == 2); + const auto device = input.device(); + const auto input_ptr = input.data_ptr(); + const auto buffer_ptr = get_pull_buffer(m_storage); + const auto num_items_int64 = input.numel(); + const auto num_items = static_cast(num_items_int64); + const auto items_per_block = m_cta_size * kVecSize * 2; + const auto needed_blocks = div_ceil(num_items, items_per_block); + const auto num_blocks = std::min(needed_blocks, m_num_cta); + const auto kernel = use_2shot ? two_shot_kernel : one_shot_kernel; + // only 1-shot + graph capture need extra output buffer + const auto output = (m_is_graph_capturing && !use_2shot) ? ffi::empty_like(input) : input; + const auto params = AllReduceParams{ + .output = use_2shot ? nullptr : output.data_ptr(), + .rank = m_rank, + .num_items = num_items, + }; + + RuntimeCheck(input.IsContiguous(), "Input tensor must be contiguous"); + RuntimeCheck(m_num_gpu == kNumGPU, "Mismatch GPU count"); + RuntimeCheck(shot == 1 || shot == 2, "Invalid shot count: ", shot); + RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported"); + RuntimeCheck(is_type(input.dtype()), "Input dtype mismatch"); + RuntimeCheck(std::bit_cast(input_ptr) % 16 == 0, "Input pointer is not properly aligned"); + RuntimeCheck(m_pull_ctrl.has_value(), "Controller is not initialized"); + RuntimeCheck(static_cast(num_items) == num_items_int64, "Number of items exceeds 4G limit"); + + const auto& ctrl = *m_pull_ctrl; + const auto stream = LaunchKernel::resolve_device(device); + auto launch = LaunchKernel{num_blocks, m_cta_size, stream}; + launch.enable_pdl(kUsePDL); + const auto check_capturing = [&] { + if (!m_is_graph_capturing) return false; // override to avoid cudaRT call overhead + cudaStreamCaptureStatus status; + RuntimeDeviceCheck(cudaStreamIsCapturing(stream, &status)); + return status == cudaStreamCaptureStatusActive; + }; + if (check_capturing()) { + // no-op if not really capturing, we're in a dummy run + const auto data_ptr = allocate_graph_capture_input(input_ptr); + /// NOTE: we assume when the graph is replayed, the data_ptr should be ready + launch(kernel, data_ptr, params, ctrl); + } else { + // 1.copy the input to the buffer + const auto input_bytes = static_cast(sizeof(DType) * num_items); + RuntimeCheck(input_bytes <= m_pull_buffer_bytes, "Input is too large, num items: ", num_items); + RuntimeDeviceCheck(cudaMemcpyAsync(buffer_ptr, input_ptr, input_bytes, cudaMemcpyDeviceToDevice, stream)); + // 2. launch the all reduce kernel + const auto data_ptr = get_data_ptr(); // use default buffer + launch(kernel, data_ptr, params, ctrl); + if (use_2shot) { // 3. copy the reduced result back to the output, because 2-shot doesn't write to output + RuntimeDeviceCheck(cudaMemcpyAsync(input_ptr, buffer_ptr, input_bytes, cudaMemcpyDeviceToDevice, stream)); + } + } + return output; + } +}; + +template +tvm::ffi::Tensor custom_all_reduce(CustomAllReduceRef obj, tvm::ffi::Tensor input, int shot) { + using Impl = CustomAllReducePull; + return static_cast(*obj.get()).all_reduce(input, shot); +} + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh new file mode 100644 index 000000000000..c4523c27eec3 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/distributed/custom_all_reduce_push.cuh @@ -0,0 +1,253 @@ +// Partially adapted from: +// https://github.com/flashinfer-ai/flashinfer/blob/v0.6.4/include/flashinfer/comm/trtllm_allreduce_fusion.cuh +// We simplify the lamport design and minimize the ring buffer count (from 3 -> 2) +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include +#include + +namespace { + +using device::distributed::PushController; +using host::distributed::CustomAllReduceBase, host::distributed::CustomAllReduceRef; + +struct AllReducePushData { + void* __restrict__ buffer[device::distributed::kMaxNumGPU]; + const void* input; + void* output; + uint32_t rank; + uint32_t num_items; + uint32_t buffer_bytes; + uint32_t epoch_bytes; +}; + +#define CUSTOM_AR_KERNEL __global__ __launch_bounds__(1024, 1) + +template +struct fp_trait {}; + +// TODO: support more dtypes +template <> +struct fp_trait { + using type = uint16_t; + [[maybe_unused]] + static constexpr uint16_t pos_zero = 0x0000u; + [[maybe_unused]] + static constexpr uint16_t neg_zero = 0x8000u; +}; + +template <> +struct fp_trait { + using type = uint16_t; + [[maybe_unused]] + static constexpr uint16_t pos_zero = 0x0000u; + [[maybe_unused]] + static constexpr uint16_t neg_zero = 0x8000u; +}; + +template <> +struct fp_trait { + using type = uint32_t; + [[maybe_unused]] + static constexpr uint32_t pos_zero = 0x00000000u; + [[maybe_unused]] + static constexpr uint32_t neg_zero = 0x80000000u; +}; + +template +SGL_DEVICE void clear_pos_zero(DType& val) { + using Trait = fp_trait; + const auto ptr = reinterpret_cast(&val); + if (*ptr == Trait::pos_zero) *ptr = Trait::neg_zero; +} + +template +SGL_DEVICE bool is_pos_zero(const DType& val) { + using Trait = fp_trait; + const auto ptr = reinterpret_cast(&val); + return *ptr == Trait::pos_zero; +} + +template +SGL_DEVICE DType get_pos_zero() { + using Trait = fp_trait; + const auto value = Trait::pos_zero; + return *reinterpret_cast(&value); +} + +template +SGL_DEVICE void ld_global_volatile_16B(T& x, const void* addr, int64_t offset) { + static_assert(alignof(T) == 16 && sizeof(T) == 16); + addr = device::pointer::offset(addr, offset); + uint4 val; + asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(addr)); + x = *reinterpret_cast(&val); +} + +template +SGL_DEVICE void st_global_volatile_16B(const T& x, void* addr, int64_t offset) { + static_assert(alignof(T) == 16 && sizeof(T) == 16); + const uint4 val = *reinterpret_cast(&x); + addr = device::pointer::offset(addr, offset); + asm volatile( + "st.volatile.global.v4.b32 [%4], {%0, %1, %2, %3};" ::"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w), "l"(addr)); +} + +template +SGL_DEVICE void push_impl(DType* (&push_buf)[kNumGPU], const void* data, uint32_t num_items) { + using namespace device; + constexpr uint32_t kVecSize = 16 / (sizeof(DType) * 2); + using Storage = AlignedVector, kVecSize>; + + for (auto i = blockIdx.x;; i += gridDim.x) { + const auto offset = i * blockDim.x + threadIdx.x; + if (offset * kVecSize * 2 >= num_items) break; + Storage vec; + vec.load(data, offset); +#pragma unroll + for (uint32_t j = 0; j < kVecSize; ++j) { + clear_pos_zero(vec[j].x); + clear_pos_zero(vec[j].y); + } +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) { + st_global_volatile_16B(vec, push_buf[i], offset); + } + } +} + +template +SGL_DEVICE void poll_impl(DType* (&poll_buf)[kNumGPU], void* data, uint32_t num_items) { + using namespace device; + constexpr uint32_t kVecSize = 16 / (sizeof(DType) * 2); + using Storage = AlignedVector, kVecSize>; + + for (auto i = blockIdx.x;; i += gridDim.x) { + const auto offset = i * blockDim.x + threadIdx.x; + if (offset * kVecSize * 2 >= num_items) break; + Storage storage[kNumGPU]; + + while (true) { + bool has_pos_zero = false; +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) { + ld_global_volatile_16B(storage[i], poll_buf[i], offset); +#pragma unroll + for (auto j = 0; j < kVecSize; ++j) { + has_pos_zero |= is_pos_zero(storage[i][j].x); + has_pos_zero |= is_pos_zero(storage[i][j].y); + } + } + if (!has_pos_zero) break; + } + + const Storage result = distributed::reduce_impl(storage); + result.store(data, offset); + + Storage pos_zeros; + pos_zeros.fill({get_pos_zero(), get_pos_zero()}); +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) { + pos_zeros.store(poll_buf[i], offset); + } + } +} + +template +CUSTOM_AR_KERNEL void all_reduce_one_shot_push_kernel( + const AllReducePushData __grid_constant__ params, // + const PushController __grid_constant__ ctrl) { + using namespace device; + + const auto [buffer, input, output, rank, num_items, buffer_bytes, epoch_bytes] = params; + + PDLWaitPrimary(); + + // Phase 1: Push data from input to all ranks' buffers + const auto epoch_offset = ctrl.epoch() * epoch_bytes; + DType* push_buf[kNumGPU]; +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) { + push_buf[i] = static_cast(pointer::offset(buffer[i], rank * buffer_bytes, epoch_offset)); + } + push_impl(push_buf, input, num_items); + + PDLTriggerSecondary(); + + // Phase 2: Poll local data + DType* poll_buf[kNumGPU]; +#pragma unroll + for (uint32_t i = 0; i < kNumGPU; ++i) { + poll_buf[i] = static_cast(pointer::offset(buffer[rank], i * buffer_bytes, epoch_offset)); + } + poll_impl(poll_buf, output, num_items); + ctrl.exit(); +} + +template +struct CustomAllReducePush : public CustomAllReduceBase { + static constexpr uint32_t kVecSize = 16 / (sizeof(DType) * 2); + static_assert(kNumGPU <= device::distributed::kMaxNumGPU, "kNumGPU exceeds the maximum supported GPUs"); + + tvm::ffi::Tensor all_reduce(tvm::ffi::Tensor input, int shot) { + using namespace host; + const auto device = input.device(); + const auto input_ptr = input.data_ptr(); + const auto num_items_int64 = input.numel(); + const auto num_items = static_cast(num_items_int64); + const auto num_blocks = m_max_num_cta_push; // must be constant to ensure correctness + const auto num_threads = [&] { + for (const auto t : {128u, 256u, 512u}) { + if (t * num_blocks * 2 * kVecSize >= num_items) return t; + } + return 1024u; + }(); + const auto output = input; + AllReducePushData params; + for (uint32_t i = 0; i < kNumGPU; ++i) { + params.buffer[i] = get_push_buffer(m_peer_storage[i]); + } + params.input = input_ptr; + params.output = input_ptr; + params.rank = m_rank; + params.num_items = num_items; + params.buffer_bytes = m_push_buffer_bytes; + params.epoch_bytes = kNumGPU * params.buffer_bytes; + + RuntimeCheck(input.IsContiguous(), "Input must be contiguous"); + RuntimeCheck(m_num_gpu == kNumGPU, "Number of GPUs mismatch"); + RuntimeCheck(device.device_type == kDLCUDA, "Only CUDA device is supported"); + RuntimeCheck(is_type(input.dtype()), "Input dtype mismatch"); + RuntimeCheck(std::bit_cast(input_ptr) % 16 == 0, "Input pointer is not properly aligned"); + RuntimeCheck(m_push_ctrl.has_value(), "Controller is not initialized"); + RuntimeCheck(shot == 1, "Push all-reduce only supports 1-shot, got: ", shot); + RuntimeCheck(static_cast(num_items) == num_items_int64, "Number of items exceeds 4G limit"); + + const auto input_bytes = static_cast(sizeof(DType) * num_items_int64); + RuntimeCheck(input_bytes <= m_push_buffer_bytes, "Input is too large, num items: ", num_items); + + const auto kernel = all_reduce_one_shot_push_kernel; + LaunchKernel(num_blocks, num_threads, device) // + .enable_pdl(kUsePDL)(kernel, params, *m_push_ctrl); + return output; + } +}; + +template +tvm::ffi::Tensor custom_all_reduce(CustomAllReduceRef obj, tvm::ffi::Tensor input, int shot) { + using Impl = CustomAllReducePush; + return static_cast(*obj.get()).all_reduce(input, shot); +} + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/hisparse.cuh b/python/sglang/jit_kernel/csrc/hisparse.cuh new file mode 100644 index 000000000000..7482b5cd28c4 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/hisparse.cuh @@ -0,0 +1,472 @@ +#include +#include + +#include + +#include + +#include +#include + +#include +#include +#include +#include + +namespace { + +constexpr int WARP_SIZE = 32; +constexpr int32_t TOKEN_HIT = 0xFFFFFFFF; +constexpr int32_t HASH_EMPTY = -1; + +// Knuth multiplicative hash for open-addressing table of size hash_size. +__device__ __forceinline__ int hash_slot(int32_t key, int hash_size) { + return ((uint32_t)key * 2654435761u) % (uint32_t)hash_size; +} + +__device__ __forceinline__ void +transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { + // 128-bit bulk transfer via paired 64-bit loads (avoids alignment issues with uint4) + const int total_pairs = item_size_bytes / 16; // number of 16-byte chunks + { + const uint64_t* __restrict__ src = static_cast(src_addr); + uint64_t* __restrict__ dst = static_cast(dst_addr); + for (int j = lane_id; j < total_pairs; j += WARP_SIZE) { + uint64_t lo, hi; + const uint64_t* s = src + j * 2; + asm volatile("ld.global.nc.v2.b64 {%0,%1},[%2];" : "=l"(lo), "=l"(hi) : "l"(s) : "memory"); + uint64_t* d = dst + j * 2; + asm volatile("st.global.cg.v2.b64 [%0],{%1,%2};" ::"l"(d), "l"(lo), "l"(hi) : "memory"); + } + } + + // Tail: 64-bit for remaining 8-byte chunk (if item_size not multiple of 16) + const int tail_8B = (item_size_bytes - total_pairs * 16) / 8; + if (tail_8B > 0 && lane_id < tail_8B) { + const uint64_t* __restrict__ src8 = + reinterpret_cast(static_cast(src_addr) + total_pairs * 16); + uint64_t* __restrict__ dst8 = reinterpret_cast(static_cast(dst_addr) + total_pairs * 16); + uint64_t tmp; + asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src8 + lane_id) : "memory"); + asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst8 + lane_id), "l"(tmp) : "memory"); + } +} + +__device__ __forceinline__ int warp_inclusive_scan(int* s_data, int lane_id, int offset, int count, int accumulator) { + int idx = lane_id + offset; + int val = (idx < count) ? s_data[idx] : 0; + +#pragma unroll + for (int i = 1; i < 32; i *= 2) { + int n = __shfl_up_sync(0xffffffff, val, i); + if (lane_id >= i) val += n; + } + val += accumulator; + if (idx < count) { + s_data[idx] = val; + } + accumulator = __shfl_sync(0xffffffff, val, 31); + return accumulator; +} + +// Shared memory size calculation for dynamic allocation. +// Layout: int32_t region (4-byte aligned) followed by int16_t region (2-byte aligned). +template +struct SmemLayout { + static constexpr int HASH_SIZE = NUM_TOP_K * 2; + static constexpr int NUM_BUFFER_CHUNKS = (HOT_BUFFER_SIZE + WARP_SIZE - 1) / WARP_SIZE; + // int32_t region: top_k_tokens + chunk_offset + evict_chunk_offset + hash_keys + total_hits + newest_hit + static constexpr int TOTAL_INT32 = NUM_TOP_K + (NUM_BUFFER_CHUNKS + 1) + (NUM_BUFFER_CHUNKS + 1) + HASH_SIZE + 2; + // int16_t region: lru_slots_out + hash_vals + static constexpr int TOTAL_INT16 = HOT_BUFFER_SIZE + HASH_SIZE; + static constexpr size_t BYTES = TOTAL_INT32 * sizeof(int32_t) + TOTAL_INT16 * sizeof(int16_t); +}; + +// Each block processes one request +// req_pool_indices and seq_lens can each be int32_t or int64_t +// Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token] +// newest_slot is at HOT_BUFFER_SIZE (first position of extra page) +template +__global__ void load_cache_to_device_buffer_kernel( + const int32_t* __restrict__ top_k_tokens, + int32_t* __restrict__ device_buffer_tokens, + const int64_t* __restrict__ host_cache_locs, + const int32_t* __restrict__ device_buffer_locs, + const void* __restrict__ host_cache_k, + const void* __restrict__ host_cache_v, + void* __restrict__ device_buffer_k, + void* __restrict__ device_buffer_v, + int32_t* __restrict__ top_k_device_locs, + const ReqPoolIndicesT* __restrict__ req_pool_indices, + const SeqLensT* __restrict__ seq_lens, + int16_t* __restrict__ lru_slots, + const int32_t* __restrict__ num_real_reqs, + int64_t buffer_stride_0, + int64_t host_stride, + int64_t lru_slot_stride_0, + int64_t top_k_tokens_stride, + int64_t top_k_device_locs_stride, + int64_t page_size, + int64_t item_size_bytes) { + // todo hisparse: support page wise sparsity + constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE; + constexpr int NUM_TOKEN_CHUNKS = (NUM_TOP_K + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_BUFFER_CHUNKS = (HOT_BUFFER_SIZE + WARP_SIZE - 1) / WARP_SIZE; + + const int bid = blockIdx.x; + // Early exit for padded blocks (CUDA graph pads batch to a captured size) + if (bid >= num_real_reqs[0]) return; + + const int tid = threadIdx.x; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + const unsigned int lanes_before = ((unsigned int)1 << lane_id) - 1; + + const int64_t rid = req_pool_indices[bid]; + const int64_t seq_len = seq_lens[bid]; + + // Calculate offsets for this request + const int32_t* req_top_k_tokens = top_k_tokens + bid * top_k_tokens_stride; + int32_t* req_top_k_device_locs = top_k_device_locs + bid * top_k_device_locs_stride; + + const int64_t buffer_offset = rid * buffer_stride_0; + int32_t* req_device_buffer_tokens = device_buffer_tokens + buffer_offset; + const int32_t* req_device_buffer_locs = device_buffer_locs + buffer_offset; + const int64_t* req_host_cache_locs = host_cache_locs + rid * host_stride; + int16_t* req_lru_slots = lru_slots + rid * lru_slot_stride_0; + + // Fast path: short sequences have all tokens in the device buffer in order. + if (seq_len <= HOT_BUFFER_SIZE) { + const int count = (seq_len < NUM_TOP_K) ? static_cast(seq_len) : NUM_TOP_K; + for (int i = tid; i < count; i += BLOCK_SIZE) { + int32_t token_pos = req_top_k_tokens[i]; + if (token_pos >= 0) { + req_top_k_device_locs[i] = req_device_buffer_locs[token_pos]; + } + } + return; + } + + // Dynamic shared memory layout: int32_t arrays first, then int16_t arrays. + extern __shared__ char smem_raw[]; + using Layout = SmemLayout; + constexpr int HASH_SIZE = Layout::HASH_SIZE; + + int32_t* smem_i32 = reinterpret_cast(smem_raw); + // Top-k token positions; reused as miss-token scratch in the copy phase + int32_t* s_top_k_tokens = smem_i32; + // Prefix-sum offsets for hit counting and miss counting + int32_t* s_chunk_offset = s_top_k_tokens + NUM_TOP_K; + // Prefix-sum offsets for evictable counting + int32_t* s_evict_chunk_offset = s_chunk_offset + (NUM_BUFFER_CHUNKS + 1); + // Open-addressing hash table: top-k token_id -> top-k index (keys) + int32_t* s_hash_keys = s_evict_chunk_offset + (NUM_BUFFER_CHUNKS + 1); + // Scalar counters + int32_t& s_total_hits = s_hash_keys[HASH_SIZE]; + int32_t& s_newest_hit = s_hash_keys[HASH_SIZE + 1]; + + int16_t* smem_i16 = reinterpret_cast(smem_i32 + Layout::TOTAL_INT32); + // Compacted slot ordering: [hits fwd-> ... <-evictables bwd] + int16_t* s_lru_slots_out = smem_i16; + // Open-addressing hash table: top-k token_id -> top-k index (values) + int16_t* s_hash_vals = s_lru_slots_out + HOT_BUFFER_SIZE; + + // Initialize shared memory: counters, hash table, prefix-sum offsets. + if (tid == 0) { + s_total_hits = 0; + s_newest_hit = 0; + } + for (int i = tid; i < HASH_SIZE; i += BLOCK_SIZE) { + s_hash_keys[i] = HASH_EMPTY; + } + for (int i = tid; i < NUM_BUFFER_CHUNKS + 1; i += BLOCK_SIZE) { + s_chunk_offset[i] = 0; + s_evict_chunk_offset[i] = 0; + } + __syncthreads(); + + const int newest_slot = HOT_BUFFER_SIZE; + const int32_t newest_token = seq_len - 1; + + // Insert top-k tokens into shared-memory hash table. + for (int i = tid; i < NUM_TOP_K; i += BLOCK_SIZE) { + int32_t token_idx = req_top_k_tokens[i]; + if (token_idx == newest_token) { + // If topk includes the latest token, bind its canonical occurrence to newest_slot (at HOT_BUFFER_SIZE) and mark + // it as a hit. newest_slot is at the first position of the extra page, excluded from LRU tracking. + s_top_k_tokens[i] = TOKEN_HIT; + req_top_k_device_locs[i] = req_device_buffer_locs[newest_slot]; + s_newest_hit = 1; + } else { + int slot = hash_slot(token_idx, HASH_SIZE); + while (true) { + int32_t old = atomicCAS(&s_hash_keys[slot], HASH_EMPTY, token_idx); + if (old == HASH_EMPTY || old == token_idx) { + s_hash_vals[slot] = static_cast(i); + break; + } + slot = (slot + 1) % HASH_SIZE; + } + s_top_k_tokens[i] = token_idx; + } + } + __syncthreads(); + + constexpr int ITERATIONS_PER_WARP_BUFFER = (NUM_BUFFER_CHUNKS + NUM_WARPS - 1) / NUM_WARPS; + int total_hit_count = 0; + int total_evict_count = 0; + for (int iter = 0; iter < ITERATIONS_PER_WARP_BUFFER; iter++) { + int chunk_idx = warp_id + iter * NUM_WARPS; + bool has_valid_chunk = chunk_idx < NUM_BUFFER_CHUNKS; + + const int slot_idx = chunk_idx * WARP_SIZE + lane_id; + const bool has_valid_slot = has_valid_chunk && (slot_idx < HOT_BUFFER_SIZE); + const int16_t buf_slot = has_valid_slot ? req_lru_slots[slot_idx] : -1; + int32_t my_buffer_token = (buf_slot >= 0) ? req_device_buffer_tokens[buf_slot] : -1; + int my_found_top_k_idx = -1; + if (my_buffer_token >= 0) { + int h = hash_slot(my_buffer_token, HASH_SIZE); + while (true) { + int32_t k = s_hash_keys[h]; + if (k == my_buffer_token) { + my_found_top_k_idx = static_cast(s_hash_vals[h]); + break; + } + if (k == HASH_EMPTY) break; + h = (h + 1) % HASH_SIZE; + } + } + bool is_hit = my_found_top_k_idx >= 0; + bool is_evictable = has_valid_slot && !is_hit; + + // Record hits + if (is_hit) { + s_top_k_tokens[my_found_top_k_idx] = TOKEN_HIT; + req_top_k_device_locs[my_found_top_k_idx] = req_device_buffer_locs[buf_slot]; + } + + int local_hit_offset = 0; + int local_evict_offset = 0; + if (has_valid_chunk) { + const unsigned int hit_mask = __ballot_sync(0xFFFFFFFF, is_hit); + const unsigned int evict_mask = __ballot_sync(0xFFFFFFFF, is_evictable); + local_hit_offset = __popc(hit_mask & lanes_before); + local_evict_offset = __popc(evict_mask & lanes_before); + if (lane_id == 0) { + s_chunk_offset[chunk_idx + 1] = __popc(hit_mask); + s_evict_chunk_offset[chunk_idx + 1] = __popc(evict_mask); + } + } + __syncthreads(); + + if (warp_id == 0) { + total_hit_count = + warp_inclusive_scan(s_chunk_offset, lane_id, chunk_idx + 1, NUM_BUFFER_CHUNKS + 1, total_hit_count); + total_evict_count = + warp_inclusive_scan(s_evict_chunk_offset, lane_id, chunk_idx + 1, NUM_BUFFER_CHUNKS + 1, total_evict_count); + if (tid == 0) { + s_total_hits = total_hit_count; + } + } + __syncthreads(); + + // Hits grow forward from index 0 + if (is_hit) { + int hit_offset = s_chunk_offset[chunk_idx] + local_hit_offset; + s_lru_slots_out[hit_offset] = buf_slot; + } + // Evictables grow backward from HOT_BUFFER_SIZE - 1 + if (is_evictable) { + int evict_offset = s_evict_chunk_offset[chunk_idx] + local_evict_offset; + s_lru_slots_out[HOT_BUFFER_SIZE - 1 - evict_offset] = buf_slot; + } + } + __syncthreads(); + + // Reset offsets for the miss counting phase (only NUM_TOKEN_CHUNKS + 1 entries needed). + for (int i = tid; i < NUM_TOKEN_CHUNKS + 1; i += BLOCK_SIZE) { + s_chunk_offset[i] = 0; + } + __syncthreads(); + + // Third pass to identify misses and their evictable slots + int total_misses = 0; + constexpr int ITERATIONS_PER_WARP_TOKEN = (NUM_TOKEN_CHUNKS + NUM_WARPS - 1) / NUM_WARPS; + for (int iter = 0; iter < ITERATIONS_PER_WARP_TOKEN; iter++) { + int chunk_idx = warp_id + iter * NUM_WARPS; + bool has_valid_chunk = chunk_idx < NUM_TOKEN_CHUNKS; + + const int chunk_token_start = chunk_idx * WARP_SIZE; + const int my_token_idx = chunk_token_start + lane_id; + const bool has_valid_token = has_valid_chunk && (my_token_idx < NUM_TOP_K); + + int32_t my_token = 0; + bool is_miss = false; + int local_miss_offset = 0; + + if (has_valid_token) { + is_miss = s_top_k_tokens[my_token_idx] != TOKEN_HIT; + if (is_miss) { + my_token = s_top_k_tokens[my_token_idx]; + } + } + + if (has_valid_chunk) { + const unsigned int miss_mask = __ballot_sync(0xFFFFFFFF, is_miss); + local_miss_offset = __popc(miss_mask & lanes_before); + const int warp_miss_count = __popc(miss_mask); + if (lane_id == 0) { + s_chunk_offset[chunk_idx + 1] = warp_miss_count; + } + } + __syncthreads(); + + if (warp_id == 0) { + total_misses = warp_inclusive_scan(s_chunk_offset, lane_id, chunk_idx + 1, NUM_TOKEN_CHUNKS + 1, total_misses); + } + __syncthreads(); + + if (is_miss) { + int miss_offset = s_chunk_offset[chunk_idx] + local_miss_offset; + int16_t evict_slot = s_lru_slots_out[HOT_BUFFER_SIZE - 1 - miss_offset]; + // Reuse s_top_k_tokens as miss scratch: miss_offset < my_token_idx always + // holds (hits are skipped), so compacted writes never overrun pending reads. + s_top_k_tokens[miss_offset] = my_token; + req_top_k_device_locs[my_token_idx] = req_device_buffer_locs[evict_slot]; + req_device_buffer_tokens[evict_slot] = my_token; + } + } + __syncthreads(); + + total_misses = NUM_TOP_K - s_total_hits - s_newest_hit; + // Write back LRU order: evictables at front (LRU), hits at back (MRU). + { + const int total_evictable = HOT_BUFFER_SIZE - s_total_hits; + for (int i = tid; i < HOT_BUFFER_SIZE; i += BLOCK_SIZE) { + if (i < total_misses) { + // Misses: just loaded from host, place right before hits + req_lru_slots[total_evictable - total_misses + i] = s_lru_slots_out[HOT_BUFFER_SIZE - 1 - i]; + } else if (i < total_evictable) { + // Remaining evictables: truly stale, dest at LRU front + req_lru_slots[i - total_misses] = s_lru_slots_out[HOT_BUFFER_SIZE - 1 - i]; + } else { + // Hits: source at forward end, dest at MRU back + req_lru_slots[i] = s_lru_slots_out[i - total_evictable]; + } + } + } + + // each warp copies one miss directly, can be separated into a new kernel if parallelism is a concern + for (int miss_idx = warp_id; miss_idx < total_misses; miss_idx += NUM_WARPS) { + const int32_t miss_token = s_top_k_tokens[miss_idx]; + const int16_t evict_slot = s_lru_slots_out[HOT_BUFFER_SIZE - 1 - miss_idx]; + + const int64_t src_loc = req_host_cache_locs[miss_token]; + const int64_t dst_loc = static_cast(req_device_buffer_locs[evict_slot]); + + if constexpr (IsMLA) { + // MLA path: page-padded device layout, linear CPU layout (see kvcacheio.cuh). + // transfer_item handles both sides' pointer math and uses 8-byte-aligned loads. + device::hisparse::transfer_item( + /*dst_cache=*/device_buffer_k, + /*src_cache=*/const_cast(host_cache_k), + /*dst_index=*/static_cast(dst_loc), + /*src_index=*/static_cast(src_loc)); + } else { + const auto src_k = static_cast(host_cache_k) + src_loc * item_size_bytes; + auto dst_k = static_cast(device_buffer_k) + dst_loc * item_size_bytes; + transfer_item_warp(lane_id, src_k, dst_k, item_size_bytes); + const auto src_v = static_cast(host_cache_v) + src_loc * item_size_bytes; + auto dst_v = static_cast(device_buffer_v) + dst_loc * item_size_bytes; + transfer_item_warp(lane_id, src_v, dst_v, item_size_bytes); + } + } +} + +template +void load_cache_to_device_buffer( + tvm::ffi::TensorView top_k_tokens, + tvm::ffi::TensorView device_buffer_tokens, + tvm::ffi::TensorView host_cache_locs, + tvm::ffi::TensorView device_buffer_locs, + tvm::ffi::TensorView host_cache_k, + tvm::ffi::TensorView host_cache_v, + tvm::ffi::TensorView device_buffer_k, + tvm::ffi::TensorView device_buffer_v, + tvm::ffi::TensorView top_k_device_locs, + tvm::ffi::TensorView req_pool_indices, + tvm::ffi::TensorView seq_lens, + tvm::ffi::TensorView lru_slots, + tvm::ffi::TensorView num_real_reqs, + int64_t page_size, + int64_t item_size_bytes) { + using namespace host; + + const int64_t bs = top_k_tokens.shape()[0]; + const int64_t host_stride = host_cache_locs.shape()[1]; + const int64_t buffer_stride_0 = device_buffer_tokens.strides()[0]; + const int64_t lru_slot_stride_0 = lru_slots.strides()[0]; + const int64_t top_k_tokens_stride = top_k_tokens.strides()[0]; + const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0]; + const auto device = LaunchKernel::resolve_device(top_k_tokens.device()); + + // Generic lambda: int32/int64 kernel variants are compiled for both + // seq_lens and req_pool_indices; the correct combo is selected at runtime. + auto launch = [&](auto kernel_fn, const auto* seq_lens_ptr, const auto* req_pool_indices_ptr) { + constexpr size_t smem_bytes = SmemLayout::BYTES; + if constexpr (smem_bytes > 48u * 1024u) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + LaunchKernel(bs, BLOCK_SIZE, device, smem_bytes)( + kernel_fn, + static_cast(top_k_tokens.data_ptr()), + static_cast(device_buffer_tokens.data_ptr()), + static_cast(host_cache_locs.data_ptr()), + static_cast(device_buffer_locs.data_ptr()), + host_cache_k.data_ptr(), + (IsMLA || host_cache_v.ndim() == 0) ? (const void*)nullptr : host_cache_v.data_ptr(), + device_buffer_k.data_ptr(), + (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), + static_cast(top_k_device_locs.data_ptr()), + req_pool_indices_ptr, + seq_lens_ptr, + static_cast(lru_slots.data_ptr()), + static_cast(num_real_reqs.data_ptr()), + buffer_stride_0, + host_stride, + lru_slot_stride_0, + top_k_tokens_stride, + top_k_device_locs_stride, + page_size, + item_size_bytes); + }; + + const auto seq_dtype = seq_lens.dtype(); + const auto rpi_dtype = req_pool_indices.dtype(); + const bool seq_is_i64 = (seq_dtype.code == kDLInt && seq_dtype.bits == 64); + const bool rpi_is_i64 = (rpi_dtype.code == kDLInt && rpi_dtype.bits == 64); + + if (seq_is_i64 && rpi_is_i64) { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); + } else if (seq_is_i64 && !rpi_is_i64) { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); + } else if (!seq_is_i64 && rpi_is_i64) { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); + } else { + launch( + load_cache_to_device_buffer_kernel, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); + } +} + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh b/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh new file mode 100644 index 000000000000..6476a3be232f --- /dev/null +++ b/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh @@ -0,0 +1,363 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +constexpr uint32_t kWarpSize = 32; +constexpr uint32_t kWarpsPerCTA = 6; +constexpr uint32_t kSmallTokenThreshold = 512; +constexpr uint32_t kMaxExperts = 512; +constexpr uint32_t kMaxTopK = 16; + +enum class ScoringFunc : uint32_t { + kSigmoid = 0, + kSqrtSoftplus = 1, +}; + +struct MoEFusedGateParams { + const float* __restrict__ input; + const float* __restrict__ bias; + float* __restrict__ output; + int32_t* __restrict__ indices; + uint32_t num_rows; + uint32_t num_experts; + uint32_t topk; + uint32_t num_fused_shared_experts; + bool renormalize; + float routed_scaling_factor; + bool apply_routed_scaling_factor_on_output; +}; + +template +__device__ __forceinline__ float compute_score(float x) { + if constexpr (kScoringFunc == ScoringFunc::kSigmoid) { + // sigmoid(x) = 1 / (1 + exp(-x)) + return 1.0f / (1.0f + expf(-x)); + } else { + // sqrt(softplus(x)) = sqrt(log(1 + exp(x))) + float softplus = log1pf(expf(x)); + return sqrtf(softplus); + } +} + +template +__global__ void moe_fused_gate_kernel_small_token(const MoEFusedGateParams __grid_constant__ params) { + const auto& [input, bias, output, indices, num_rows, num_experts, topk, num_fused_shared_experts, renormalize, routed_scaling_factor, apply_routed_scaling_factor_on_output] = + params; + + uint32_t row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + // number of routed experts to select (excluding fused shared experts) + const uint32_t topk_routed = topk - num_fused_shared_experts; + + uint32_t tid = threadIdx.x; + uint32_t warp_id = tid / kWarpSize; + uint32_t lane_id = tid % kWarpSize; + + extern __shared__ float shared_mem[]; + float* shared_scores = shared_mem; + float* shared_original_scores = shared_mem + num_experts; + + // For warp-level reduction + __shared__ float warp_maxs[kWarpsPerToken]; + __shared__ int warp_experts[kWarpsPerToken]; + __shared__ int selected_experts[kMaxTopK]; + + for (uint32_t e = tid; e < num_experts; e += blockDim.x) { + float input_val = input[row_idx * num_experts + e]; + float bias_val = bias[e]; + float score_val = compute_score(input_val); + float biased_val = score_val + bias_val; + shared_scores[e] = biased_val; + shared_original_scores[e] = score_val; + } + + __syncthreads(); + + // only select topk_routed experts (excluding shared experts) + for (uint32_t k = 0; k < topk_routed; k++) { + float my_val = -FLT_MAX; + int my_expert = -1; + for (uint32_t e = tid; e < num_experts; e += blockDim.x) { + if (shared_scores[e] > my_val) { + my_val = shared_scores[e]; + my_expert = e; + } + } + + float warp_max_val = my_val; + int warp_max_expert = my_expert; + +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, warp_max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, warp_max_expert, offset); + if (other_val > warp_max_val) { + warp_max_val = other_val; + warp_max_expert = other_expert; + } + } + + if (lane_id == 0 && warp_id < kWarpsPerToken) { + warp_maxs[warp_id] = warp_max_val; + warp_experts[warp_id] = warp_max_expert; + } + + __syncthreads(); + + if (warp_id == 0) { + float final_max = (lane_id < kWarpsPerToken) ? warp_maxs[lane_id] : -FLT_MAX; + int final_expert = (lane_id < kWarpsPerToken) ? warp_experts[lane_id] : -1; + +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, final_max, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, final_expert, offset); + if (other_val > final_max) { + final_max = other_val; + final_expert = other_expert; + } + } + + if (lane_id == 0) { + selected_experts[k] = final_expert; + } + } + + __syncthreads(); + + int selected = selected_experts[k]; + if (selected >= 0 && tid == 0) { + shared_scores[selected] = -FLT_MAX; + } + + __syncthreads(); + } + + static_assert(kMaxTopK <= device::kWarpThreads); + if (tid >= device::kWarpThreads) return; + + // only use the first warp to perform write to global operation + float routed_weight = 0.0f; + int32_t selected_expert = 0; + if (tid < topk_routed) { + int expert_id = selected_experts[tid]; + float score = shared_original_scores[expert_id]; + if (expert_id >= 0 && expert_id < static_cast(num_experts)) { + routed_weight = score; + selected_expert = expert_id; + } + } + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (tid < topk) { + const bool is_shared = tid >= topk_routed; + const auto output_offset = row_idx * topk + tid; + const auto weight = is_shared ? (routed_sum / routed_scaling_factor) : routed_weight; + const auto expert_id = is_shared ? (num_experts + tid - topk_routed) : selected_expert; + const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f; + const auto norm = renormalize && routed_sum > 0.0f ? routed_sum : 1.0f; + output[output_offset] = weight / norm * scale; + indices[output_offset] = expert_id; + } +} + +template +__global__ void moe_fused_gate_kernel(const MoEFusedGateParams __grid_constant__ params) { + const auto& [input, bias, output, indices, num_rows, num_experts, topk, num_fused_shared_experts, renormalize, routed_scaling_factor, apply_routed_scaling_factor_on_output] = + params; + + uint32_t row_idx = blockIdx.x * kWarpsPerCTA + threadIdx.y; + if (row_idx >= num_rows) return; + + // number of routed experts to select (excluding fused shared experts) + const uint32_t topk_routed = topk - num_fused_shared_experts; + + uint32_t lane_id = threadIdx.x; + uint32_t warp_id = threadIdx.y; + + extern __shared__ float shared_mem[]; + float* shared_scores = shared_mem + warp_id * num_experts * 2; + float* shared_original_scores = shared_scores + num_experts; + __shared__ int selected_experts[kWarpsPerCTA][kMaxTopK]; + int* warp_selected_experts = selected_experts[warp_id]; + + for (uint32_t e = lane_id; e < num_experts; e += kWarpSize) { + float input_val = input[row_idx * num_experts + e]; + float bias_val = bias[e]; + float score_val = compute_score(input_val); + float biased_val = score_val + bias_val; + shared_scores[e] = biased_val; + shared_original_scores[e] = score_val; + } + + __syncwarp(); + + // only select topk_routed experts + for (uint32_t k = 0; k < topk_routed; k++) { + float max_val = -FLT_MAX; + int max_expert = -1; + + for (uint32_t expert = lane_id; expert < num_experts; expert += kWarpSize) { + if (shared_scores[expert] > max_val) { + max_val = shared_scores[expert]; + max_expert = expert; + } + } + + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, max_expert, offset); + + if (other_val > max_val || (other_val == max_val && other_expert < max_expert)) { + max_val = other_val; + max_expert = other_expert; + } + } + + if (lane_id == 0) { + warp_selected_experts[k] = max_expert; + if (max_expert != -1) { + shared_scores[max_expert] = -FLT_MAX; + } + } + + __syncwarp(); + } + + static_assert(kMaxTopK <= device::kWarpThreads); + + float routed_weight = 0.0f; + int32_t selected_expert = 0; + if (lane_id < topk_routed) { + int expert_id = warp_selected_experts[lane_id]; + if (expert_id >= 0 && expert_id < static_cast(num_experts)) { + routed_weight = shared_original_scores[expert_id]; + selected_expert = expert_id; + } + } + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk) { + const bool is_shared = lane_id >= topk_routed; + const auto output_idx = row_idx * topk + lane_id; + const auto weight = is_shared ? (routed_sum / routed_scaling_factor) : routed_weight; + const auto expert_id = is_shared ? (num_experts + lane_id - topk_routed) : selected_expert; + const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f; + const auto norm = renormalize && routed_sum > 0.0f ? routed_sum : 1.0f; + output[output_idx] = weight / norm * scale; + indices[output_idx] = expert_id; + } +} + +template +void dispatch_small_token_kernel( + uint32_t num_rows, + uint32_t threads_per_block, + uint32_t warps_per_token, + DLDevice device, + size_t smem_per_row, + const MoEFusedGateParams& params) { + using namespace host; + if (warps_per_token <= 8) { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<8, kScoringFunc>, params); + } else if (warps_per_token <= 12) { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<12, kScoringFunc>, params); + } else { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<16, kScoringFunc>, params); + } +} + +struct MoEFusedGateKernel { + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView bias, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView indices, + uint32_t topk, + uint32_t scoring_func, // 0 = sigmoid, 1 = sqrtsoftplus + uint32_t num_fused_shared_experts, + bool renormalize, + float routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { + using namespace host; + + auto N = SymbolicSize{"num_rows"}; + auto E = SymbolicSize{"num_experts"}; + auto K = SymbolicSize{"topk"}; + auto device = SymbolicDevice{}; + K.set_value(topk); + device.set_options(); + + TensorMatcher({N, E}).with_dtype().with_device(device).verify(input); + TensorMatcher({E}).with_dtype().with_device(device).verify(bias); + TensorMatcher({N, K}).with_dtype().with_device(device).verify(output); + TensorMatcher({N, K}).with_dtype().with_device(device).verify(indices); + + const auto num_rows = static_cast(N.unwrap()); + const auto num_experts = static_cast(E.unwrap()); + + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum supported value"); + RuntimeCheck(scoring_func <= 1, "scoring_func must be 0 (sigmoid) or 1 (sqrtsoftplus)"); + RuntimeCheck(topk > num_fused_shared_experts, "topk must be greater than num_fused_shared_experts"); + + const auto params = MoEFusedGateParams{ + .input = static_cast(input.data_ptr()), + .bias = static_cast(bias.data_ptr()), + .output = static_cast(output.data_ptr()), + .indices = static_cast(indices.data_ptr()), + .num_rows = num_rows, + .num_experts = num_experts, + .topk = topk, + .num_fused_shared_experts = num_fused_shared_experts, + .renormalize = renormalize, + .routed_scaling_factor = routed_scaling_factor, + .apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output, + }; + + const size_t smem_per_row = 2 * num_experts * sizeof(float); + + bool use_small_token_kernel = num_rows <= kSmallTokenThreshold; + + if (use_small_token_kernel) { + // 1 token per block + uint32_t warps_per_token = div_ceil(num_experts, kWarpSize); + warps_per_token = std::min(warps_per_token, 16u); + uint32_t threads_per_block = warps_per_token * kWarpSize; + + if (scoring_func == 0) { + dispatch_small_token_kernel( + num_rows, threads_per_block, warps_per_token, device.unwrap(), smem_per_row, params); + } else { + dispatch_small_token_kernel( + num_rows, threads_per_block, warps_per_token, device.unwrap(), smem_per_row, params); + } + } else { + // multiple tokens per block + uint32_t num_blocks = div_ceil(num_rows, kWarpsPerCTA); + dim3 block_dim(kWarpSize, kWarpsPerCTA); + size_t large_smem = smem_per_row * kWarpsPerCTA; + + if (scoring_func == 0) { + LaunchKernel(num_blocks, block_dim, device.unwrap(), large_smem)( + moe_fused_gate_kernel, params); + } else { + LaunchKernel(num_blocks, block_dim, device.unwrap(), large_smem)( + moe_fused_gate_kernel, params); + } + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/deepseek_v4.py b/python/sglang/jit_kernel/deepseek_v4.py new file mode 100644 index 000000000000..0878f45e9c56 --- /dev/null +++ b/python/sglang/jit_kernel/deepseek_v4.py @@ -0,0 +1,1304 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Tuple, Union + +import torch +import triton +import triton.language as tl + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) +from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, +) +from sglang.srt.environ import envs + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +def make_name(name: str) -> str: + return f"dpsk_v4_{name}" + + +@cache_once +def _jit_common_module() -> Module: + return load_jit( + make_name(f"common"), + cuda_files=[f"deepseek_v4/common.cuh"], + cuda_wrappers=[("plan_compress_prefill", "plan_compress_prefill")], + ) + + +@cache_once +def _jit_compress_128_online_plan_module() -> Module: + """Host-side plan generator for online compress 128 (no template args).""" + return load_jit( + make_name("compress_128_online_plan"), + cuda_files=["deepseek_v4/c128_online.cuh"], + cuda_wrappers=[ + ("plan_compress_online_prefill", "plan_compress_online_prefill"), + ], + ) + + +@cache_once +def _jit_compress_128_online_module(head_dim: int) -> Module: + """Online compress 128 kernel: ring_size=1, per-index (max, sum, kv) state.""" + args = make_cpp_args(head_dim, is_arch_support_pdl()) + kernel_class = f"FlashCompress128OnlineKernel<{args}>" + return load_jit( + make_name("compress_128_online"), + *args, + cuda_files=["deepseek_v4/c128_online.cuh"], + cuda_wrappers=[ + ("decode", f"{kernel_class}::run_decode"), + ("prefill", f"{kernel_class}::run_prefill"), + ], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_topk_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("topk"), + *args, + cuda_files=["deepseek_v4/topk.cuh"], + cuda_wrappers=[("topk_transform", f"TopK512Kernel<{args}>::transform")], + ) + + +@cache_once +def _jit_topk1024_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("topk1024"), + *args, + cuda_files=["deepseek_v4/topk_1024.cuh"], + cuda_wrappers=[("topk_transform", f"TopK1024Kernel<{args}>::transform")], + ) + + +@cache_once +def _jit_topk_v2_module(topk: int) -> Module: + return load_jit( + make_name("topk_v2"), + str(topk), + cuda_files=["deepseek_v4/topk_v2.cuh"], + cuda_wrappers=[ + ("topk_transform", "CombinedTopKKernel::transform"), + ("topk_plan", "CombinedTopKKernel::plan"), + ], + extra_cuda_cflags=[f"-DSGL_TOPK={topk}"], + ) + + +@cache_once +def _jit_mask_topk_module() -> Module: + return load_jit( + make_name("mask_topk"), + cuda_files=["deepseek_v4/hash_topk.cuh"], + cuda_wrappers=[("run", "MaskKernel::run")], + ) + + +@cache_once +def _jit_hash_topk_module() -> Module: + args = make_cpp_args("act_sqrt_softplus", is_arch_support_pdl()) + return load_jit( + make_name("hash_topk"), + *args, + cuda_files=["deepseek_v4/hash_topk.cuh"], + cuda_wrappers=[("hash_topk", f"HashTopKKernel<{args}>::run")], + ) + + +@cache_once +def _jit_compress_module( + head_dim: int, + dtype_in: torch.dtype, + dtype_out: torch.dtype, + ratio: Literal[4, 128], +) -> Module: + args = make_cpp_args(head_dim, dtype_in, dtype_out, is_arch_support_pdl()) + kernel_class = f"FlashCompress{ratio}Kernel<{args}>" + return load_jit( + make_name(f"compress_{ratio}"), + *args, + cuda_files=[f"deepseek_v4/c{ratio}.cuh"], + cuda_wrappers=[ + ("decode", f"{kernel_class}::run_decode"), + ("prefill", f"{kernel_class}::run_prefill"), + ], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_compress_module_v2_defensive( + head_dim: int, + dtype_in: torch.dtype, + dtype_out: torch.dtype, +) -> Module: + args = make_cpp_args(head_dim, dtype_in, dtype_out, is_arch_support_pdl()) + kernel_class = f"FlashCompress128Kernel<{args}>" + return load_jit( + make_name("compress_128_v2_defensive"), + *args, + cuda_files=["deepseek_v4/c128_v2.cuh"], + cuda_wrappers=[ + ("prefill", f"{kernel_class}::run_prefill"), + ], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_rmsnorm_head_module(head_dim: int, dtype: torch.dtype): + args = make_cpp_args(head_dim, dtype, is_arch_support_pdl()) + kernel_class = f"RMSNormKernel<{args}>" + return load_jit( + make_name("rmsnorm_head"), + *args, + cuda_files=["deepseek_v4/rmsnorm.cuh"], + cuda_wrappers=[("run_self", f"{kernel_class}::run_self")], + ) + + +@cache_once +def _jit_fused_rope_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("fused_rope"), + *args, + cuda_files=["deepseek_v4/rope.cuh"], + cuda_wrappers=[("forward", f"FusedQKRopeKernel<{args}>::forward")], + ) + + +@cache_once +def _jit_norm_rope_module( + dtype: torch.dtype, + head_dim: int, + rope_dim: int, +) -> Module: + args = make_cpp_args(dtype, head_dim, rope_dim, is_arch_support_pdl()) + return load_jit( + make_name(f"fused_norm_rope"), + *args, + cuda_files=[f"deepseek_v4/fused_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedNormRopeKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_fused_store_module( + name: Literal["flashmla", "indexer"], + input_dtype: torch.dtype, + index_dtype: torch.dtype, + page_size: int, +) -> Module: + args = make_cpp_args(input_dtype, index_dtype, page_size, is_arch_support_pdl()) + cname = "FlashMLA" if name == "flashmla" else "Indexer" + kernel_class = f"FusedStoreCache{cname}Kernel<{args}>" + return load_jit( + make_name("store_" + name), + *args, + cuda_files=["deepseek_v4/store.cuh"], + cuda_wrappers=[("run", f"{kernel_class}::run")], + ) + + +@cache_once +def _jit_metadata_module(): + return load_jit( + make_name("metadata"), + cuda_files=["deepseek_v4/paged_mqa_metadata.cuh"], + cuda_wrappers=[("run", "IndexerMetadataKernel::run")], + ) + + +@cache_once +def _jit_silu_mul_quant_varlen_module( + quant_group_size: int, + scale_ue8m0: bool, + swizzle: bool, + apply_swiglu_limit: bool, +) -> Module: + args = make_cpp_args( + quant_group_size, + scale_ue8m0, + swizzle, + is_arch_support_pdl(), + apply_swiglu_limit, + ) + return load_jit( + make_name("silu_mul_quant_varlen"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant.cuh"], + cuda_wrappers=[("run", f"SiluAndMulMaskedPostQuantKernel<{args}>::run")], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_silu_mul_quant_contig_module( + quant_group_size: int, + scale_ue8m0: bool, + swizzle: bool, + apply_swiglu_limit: bool, +) -> Module: + args = make_cpp_args( + quant_group_size, + scale_ue8m0, + swizzle, + is_arch_support_pdl(), + apply_swiglu_limit, + ) + return load_jit( + make_name("silu_mul_quant_contig"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant.cuh"], + cuda_wrappers=[("run", f"SiluAndMulContigPostQuantKernel<{args}>::run")], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_silu_and_mul_clamp_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype, is_arch_support_pdl()) + return load_jit( + make_name("silu_and_mul_clamp"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant.cuh"], + cuda_wrappers=[("run", f"SiluAndMulClampKernel<{args}>::run")], + extra_cuda_cflags=["-use_fast_math"], + ) + + +# --------------------------------------------------------------------------- +# Byte-equal fallbacks: when SGLANG_OPT_FIX_MEGA_MOE_MEMORY is off, route +# silu_and_mul_masked_post_quant / silu_and_mul_clamp through these _tmp +# modules, which load a copy of the optimize-branch kernel (different +# precision behavior ??? bf16 silu roundtrip, expf, fp32 clamp). +# --------------------------------------------------------------------------- + + +@cache_once +def _jit_silu_mul_quant_tmp_module( + quant_group_size: int, scale_ue8m0: bool, apply_swiglu_limit: bool +) -> Module: + args = make_cpp_args( + quant_group_size, scale_ue8m0, is_arch_support_pdl(), apply_swiglu_limit + ) + return load_jit( + make_name("silu_mul_quant_tmp"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh"], + cuda_wrappers=[("run", f"SiluAndMulMaskedPostQuantKernel<{args}>::run")], + ) + + +@cache_once +def _jit_silu_and_mul_clamp_tmp_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype, is_arch_support_pdl()) + return load_jit( + make_name("silu_and_mul_clamp_tmp"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh"], + cuda_wrappers=[("run", f"SiluAndMulClampKernel<{args}>::run")], + ) + + +@cache_once +def _jit_mega_moe_pre_dispatch_module(quant_group_size: int) -> Module: + args = make_cpp_args(quant_group_size, is_arch_support_pdl()) + return load_jit( + make_name("mega_moe_pre_dispatch"), + *args, + cuda_files=["deepseek_v4/mega_moe_pre_dispatch.cuh"], + cuda_wrappers=[("run", f"MegaMoEPreDispatchKernel<{args}>::run")], + ) + + +@cache_once +def _jit_hisparse_transfer_module() -> Module: + return load_jit( + make_name("hisparse_transfer"), + cuda_files=["deepseek_v4/hisparse_transfer.cuh"], + cuda_wrappers=[("hisparse_transfer", "hisparse_transfer")], + ) + + +def hisparse_offload_to_host( + gpu_ptrs: torch.Tensor, + cpu_ptrs: torch.Tensor, + gpu_indices: torch.Tensor, + cpu_indices: torch.Tensor, +) -> None: + module = _jit_hisparse_transfer_module() + module.hisparse_transfer(gpu_ptrs, cpu_ptrs, gpu_indices, cpu_indices) + + +def topk_transform_512( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, +) -> None: + if out_page_indices.shape[1] == 512: + module = _jit_topk_module() + else: + module = _jit_topk1024_module() + module.topk_transform( + scores, seq_lens, page_tables, out_page_indices, page_size, out_raw_indices + ) + + +_WORKSPACE_INTS_PER_BATCH = 2 + 1024 * 2 +_PLAN_METADATA_INTS_PER_BATCH = 4 + + +def plan_topk_v2(seq_lens: torch.Tensor, static_threshold: int = 0) -> torch.Tensor: + module = _jit_topk_v2_module(512) # does not matter + bs = seq_lens.shape[0] + metadata = seq_lens.new_empty(bs + 1, _PLAN_METADATA_INTS_PER_BATCH) + module.topk_plan(seq_lens, metadata, static_threshold) + return metadata + + +def topk_transform_512_v2( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + metadata: torch.Tensor, +) -> None: + module = _jit_topk_v2_module(out_page_indices.shape[1]) + bs = scores.shape[0] + workspace = seq_lens.new_empty(bs, _WORKSPACE_INTS_PER_BATCH) + module.topk_transform( + scores, + seq_lens, + page_tables, + out_page_indices, + page_size, + workspace, + metadata, + ) + + +def hash_topk( + router_logits: torch.Tensor, + input_ids: torch.Tensor, + tid2eid: torch.Tensor, + num_fused_shared_experts: int = 0, + routed_scaling_factor: float = 1.0, + scoring_func: str = "sqrtsoftplus", +) -> Tuple[torch.Tensor, torch.Tensor]: + assert scoring_func == "sqrtsoftplus" + num_tokens = router_logits.size(0) + topk_routed = tid2eid.size(1) + topk_fused = topk_routed + num_fused_shared_experts + topk_ids = torch.empty( + (num_tokens, topk_fused), dtype=torch.int32, device=router_logits.device + ) + topk_weights = torch.empty( + (num_tokens, topk_fused), dtype=torch.float32, device=router_logits.device + ) + module = _jit_hash_topk_module() + module.hash_topk( + router_logits, + input_ids, + tid2eid, + topk_weights, + topk_ids, + routed_scaling_factor, + ) + return topk_weights, topk_ids + + +def mask_topk_ids(topk_ids: torch.Tensor, num_token_non_padded: torch.Tensor): + return _jit_mask_topk_module().run(topk_ids, num_token_non_padded) + + +class CompressorPrefillPlan(NamedTuple): + compress_ratio: int + compress_plan: torch.Tensor + write_plan: torch.Tensor + + def copy_(self, other: CompressorPrefillPlan) -> None: + assert self.compress_ratio == other.compress_ratio + self.compress_plan.copy_(other.compress_plan) + self.write_plan.copy_(other.write_plan) + + @staticmethod + def generate( + compress_ratio: Literal[4, 128], + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: torch.Tensor, + device: torch.device, + use_cuda_graph: bool = False, + ) -> CompressorPrefillPlan: + from sglang.srt.environ import envs + + # Online c128 keeps the same NamedTuple shape (compress_plan, write_plan) + # so call sites that splat `*plan[1:]` continue to work, but the C++ + # plan struct semantics differ (last-token coords + window_len). + if compress_ratio == 128 and envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + return CompressorPrefillPlan._generate_online( + num_q_tokens=num_q_tokens, + seq_lens=seq_lens, + extend_lens=extend_lens, + device=device, + use_cuda_graph=use_cuda_graph, + ) + assert seq_lens.device == extend_lens.device + seq_lens = seq_lens.to(torch.int64) + extend_lens = extend_lens.to(torch.int64) + plan_tensor = torch.empty( + (2, num_q_tokens, 16), + dtype=torch.uint8, + device=seq_lens.device, + pin_memory=seq_lens.is_cpu, + ) + module = _jit_common_module() + is_overlap = compress_ratio == 4 + plan_lens = module.plan_compress_prefill( + extend_lens, + seq_lens, + plan_tensor[0], + plan_tensor[1], + compress_ratio, + is_overlap, + use_cuda_graph, + ) + return CompressorPrefillPlan( + compress_ratio, + plan_tensor[0, : plan_lens[0]].to(device, non_blocking=True), + plan_tensor[1, : plan_lens[1]].to(device, non_blocking=True), + ) + + @staticmethod + def _generate_online( + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: torch.Tensor, + device: torch.device, + use_cuda_graph: bool, + ) -> CompressorPrefillPlan: + # Online plan host-side path: only CPU/cuda-host implemented today. + # Move inputs to CPU pinned memory then bounce the result to device. + seq_lens_cpu = seq_lens.detach().to(torch.int64).cpu() + extend_lens_cpu = extend_lens.detach().to(torch.int64).cpu() + plan_tensor = torch.empty( + (2, num_q_tokens, 16), + dtype=torch.uint8, + device="cpu", + pin_memory=True, + ) + module = _jit_compress_128_online_plan_module() + plan_lens = module.plan_compress_online_prefill( + extend_lens_cpu, + seq_lens_cpu, + plan_tensor[0], + plan_tensor[1], + use_cuda_graph, + ) + return CompressorPrefillPlan( + 128, + plan_tensor[0, : plan_lens[0]].to(device, non_blocking=True), + plan_tensor[1, : plan_lens[1]].to(device, non_blocking=True), + ) + + @property + def is_decode(self) -> bool: + return False + + +class CompressorDecodePlan(NamedTuple): + compress_ratio: int + seq_lens: torch.Tensor + + def copy_(self, other: CompressorDecodePlan) -> None: + assert self.compress_ratio == other.compress_ratio + self.seq_lens.copy_(other.seq_lens) + + @property + def is_decode(self) -> bool: + return True + + +def compress_plan( + compress_ratio: Literal[4, 128], + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: Optional[torch.Tensor], + device: torch.device, +) -> Union[CompressorDecodePlan, CompressorPrefillPlan]: + if extend_lens is not None: + return CompressorPrefillPlan.generate( + compress_ratio, + num_q_tokens, + seq_lens, + extend_lens, + device, + ) + else: + assert num_q_tokens == len(seq_lens) + seq_lens = seq_lens.to(device, non_blocking=True) + return CompressorDecodePlan(compress_ratio, seq_lens) + + +def compress_forward( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + plan: Union[CompressorDecodePlan, CompressorPrefillPlan, None] = None, + extra_data: Optional[torch.Tensor] = None, + *, + head_dim: int, + compress_ratio: Literal[4, 128], + out: Optional[torch.Tensor] = None, + seq_lens: Optional[torch.Tensor] = None, + extend_lens: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert head_dim % 128 == 0 + num_q_tokens = kv_score_input.shape[0] + if out is None: + out = kv_score_input.new_empty((num_q_tokens, head_dim)) + if plan is None: + assert seq_lens is not None + plan = compress_plan( + compress_ratio, + num_q_tokens, + seq_lens, + extend_lens, + kv_score_input.device, + ) + assert plan.compress_ratio == compress_ratio, "Mismatched compress ratio in plan!" + # Online c128: separate JIT module, fp32 state, no compile-time dtypes. + if compress_ratio == 128 and envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + online_module = _jit_compress_128_online_module(head_dim=head_dim) + F = online_module.decode if plan.is_decode else online_module.prefill + F(kv_score_buffer, kv_score_input, out, ape, indices, *plan[1:], extra_data) + return out + module = _jit_compress_module( + head_dim, + kv_score_input.dtype, + out.dtype, + compress_ratio, + ) + if plan.is_decode: + F = module.decode + elif compress_ratio == 128 and _should_use_c128_prefill_defensive(): + F = _jit_compress_module_v2_defensive( + head_dim, + kv_score_input.dtype, + out.dtype, + ).prefill + else: + F = module.prefill + F(kv_score_buffer, kv_score_input, out, ape, indices, *plan[1:], extra_data) + return out + + +def _should_use_c128_prefill_defensive() -> bool: + from sglang.srt.environ import envs + + return envs.SGLANG_HANDLE_C128_PREFILL_KERNEL.get() + + +def compress_fused_norm_rope_inplace( + kv: torch.Tensor, + weight: torch.Tensor, + eps: float, + freq_cis: torch.Tensor, + plan: Union[CompressorDecodePlan, CompressorPrefillPlan], +) -> None: + freq_cis = torch.view_as_real(freq_cis).flatten(-2) + module = _jit_norm_rope_module(kv.dtype, kv.shape[-1], freq_cis.shape[-1]) + module.forward( + kv, + weight, + plan[1], + freq_cis, + int(plan.is_decode), + eps, + plan.compress_ratio, + ) + + +def fused_norm_rope_inplace( + kv: torch.Tensor, + weight: torch.Tensor, + eps: float, + freq_cis: torch.Tensor, + positions: torch.Tensor, +) -> None: + freq_cis = torch.view_as_real(freq_cis).flatten(-2) + module = _jit_norm_rope_module(kv.dtype, kv.shape[-1], freq_cis.shape[-1]) + module.forward( + kv, + weight, + positions, + freq_cis, + 2, + eps, + 0, + ) + + +def fused_rope( + q: torch.Tensor, + k: Optional[torch.Tensor], + freqs_cis: torch.Tensor, + positions: torch.Tensor, + inverse: bool = False, +) -> None: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2).contiguous() + module = _jit_fused_rope_module() + module.forward(q, k, freqs_real, positions, inverse) + + +@cache_once +def _tilelang_make_swa_indices_kernel(swa_window_size: int, threads: int = 128) -> Any: + import tilelang + import tilelang.language as T + + batch_size = T.dynamic("batch_size") + batch_size_plus_1 = T.dynamic("batch_size_plus_1") + num_q_tokens = T.dynamic("num_q_tokens") + num_warps = threads // 32 + assert swa_window_size % 32 == 0 + + @tilelang.jit + def make_swa_prefill_indices( + seq_lens_k: T.Tensor[(batch_size,), T.int32], + seq_lens_q: T.Tensor[(batch_size,), T.int32], + cu_seqlens_q: T.Tensor[(batch_size_plus_1,), T.int32], + swa_indices: T.Tensor[(num_q_tokens, swa_window_size), T.int32], + ): + _ = batch_size_plus_1 + with T.Kernel(T.ceildiv(num_q_tokens, num_warps), threads=threads) as bx: + tx = T.get_thread_binding() + warp_id = tx // 32 + lane_id = tx % 32 + s_batch_id = T.alloc_shared((num_warps,), dtype=T.int32) + + token_id = warp_id + bx * num_warps + if token_id >= num_q_tokens: + return + for i in T.serial(0, batch_size, step=32): + j = i + lane_id + if cu_seqlens_q[j] <= token_id < cu_seqlens_q[j + 1]: + s_batch_id[warp_id] = j + T.sync_warp() + + seq_idx = s_batch_id[warp_id] + kv_len = seq_lens_k[seq_idx] + qo_len = seq_lens_q[seq_idx] + cum_qo_len = cu_seqlens_q[seq_idx] + prefix_len = kv_len - qo_len + curr_seq_qo_idx = token_id - cum_qo_len + end_abs_pos = prefix_len + curr_seq_qo_idx + 1 + start_abs_pos = T.max(end_abs_pos - swa_window_size, 0) + old_kv_start = seq_idx * swa_window_size + new_kv_start = batch_size * swa_window_size + cum_qo_len + + for i in T.unroll(0, swa_window_size, step=32): + j = i + lane_id + abs_pos = start_abs_pos + j + swa_indices[token_id, j] = T.if_then_else( + abs_pos < end_abs_pos, + T.if_then_else( + abs_pos < prefix_len, + old_kv_start + abs_pos % swa_window_size, + new_kv_start + (abs_pos - prefix_len), + ), + -1, + ) + + return make_swa_prefill_indices + + +def tilelang_make_swa_prefill_indices( + seq_lens_k: torch.Tensor, + seq_lens_q: torch.Tensor, + swa_indices: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if cu_seqlens_q is None: + cu_seqlens_q = torch.cumsum(seq_lens_q, dim=0, dtype=torch.int32) + cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0), value=0) + swa_window_size = swa_indices.shape[1] + kernel = _tilelang_make_swa_indices_kernel(swa_window_size) + kernel(seq_lens_k, seq_lens_q, cu_seqlens_q, swa_indices) + return swa_indices + + +@triton.jit +def create_paged_compress_data_kernel( + req_pool_indices_ptr, + seq_lens_ptr, + extend_seq_lens_ptr, + req_to_token_ptr, + full_to_swa_index_mapping_ptr, + out_0_ptr, + out_1_ptr, + batch_size, + stride_req_to_token_0, + stride_req_to_token_1: tl.constexpr, + stride_out_1_0, + stride_out_1_1: tl.constexpr, + compress_ratio: tl.constexpr, + is_overlap: tl.constexpr, + swa_page_size: tl.constexpr, + ring_size: tl.constexpr, + BLOCK: tl.constexpr, +) -> None: + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < batch_size + + rid = tl.load(req_pool_indices_ptr + offs, mask=mask, other=0).to(tl.int32) + seq_len = tl.load(seq_lens_ptr + offs, mask=mask, other=0).to(tl.int32) + extend_len = tl.load(extend_seq_lens_ptr + offs, mask=mask, other=0).to(tl.int32) + prefix_len = seq_len - extend_len + + cr = compress_ratio + write_pos = ((seq_len - 1) // cr) * cr + load_pos = ((prefix_len - 1) // cr) * cr + write_overlap_pos = write_pos - cr + load_overlap_pos = load_pos - cr + v0 = tl.zeros([BLOCK], tl.int32) + v1 = tl.zeros([BLOCK], tl.int32) + v2 = tl.zeros([BLOCK], tl.int32) + v3 = tl.zeros([BLOCK], tl.int32) + + for i in tl.static_range(4): + if i == 0: + pos = load_pos + elif i == 1: + pos = write_pos + elif i == 2: + pos = load_overlap_pos + else: + pos = write_overlap_pos + pos = tl.maximum(pos, 0) + loc = tl.load( + req_to_token_ptr + + rid.to(tl.int64) * stride_req_to_token_0 + + pos.to(tl.int64) * stride_req_to_token_1, + mask=mask, + other=0, + ).to(tl.int32) + swa_loc = tl.load(full_to_swa_index_mapping_ptr + loc, mask=mask, other=0).to( + tl.int32 + ) + swa_page = swa_loc // swa_page_size + state_loc = swa_page * ring_size + (swa_loc % ring_size) + state_loc = state_loc // cr + if i == 0: + v0 = state_loc + elif i == 1: + v1 = state_loc + elif i == 2: + v2 = state_loc + else: + v3 = state_loc + + tl.store(out_0_ptr + offs, v1, mask=mask) + + if is_overlap: + base = out_1_ptr + offs * stride_out_1_0 + tl.store(base + 0 * stride_out_1_1, v2, mask=mask) + tl.store(base + 1 * stride_out_1_1, v0, mask=mask) + tl.store(base + 2 * stride_out_1_1, v3, mask=mask) + tl.store(base + 3 * stride_out_1_1, write_pos.to(tl.int32), mask=mask) + else: + base = out_1_ptr + offs * stride_out_1_0 + tl.store(base + 0 * stride_out_1_1, v0, mask=mask) + + +_mmap_dumper = None + + +def _get_mmap_dumper(): + global _mmap_dumper + if _mmap_dumper is None: + from sglang.srt.debug_utils.mmap_dumper import MmapDumper + from sglang.srt.environ import envs + + dump_dir = envs.SGLANG_HACK_DEBUG_DUMP_CREATE_PAGED_COMPRESS_DATA.get() + _mmap_dumper = MmapDumper(dump_dir or None) + return _mmap_dumper + + +_dumped_static_meta_once = False + + +def _maybe_dump_create_paged_compress_data_inputs( + *, + compress_ratio: int, + is_overlap: bool, + swa_page_size: int, + ring_size: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + req_to_token: torch.Tensor, + full_to_swa_index_mapping: torch.Tensor, + block: int, +) -> None: + d = _get_mmap_dumper() + if not d.is_active(): + return + + # Print static config (constant after server init) once per process. + global _dumped_static_meta_once + if not _dumped_static_meta_once: + print( + f"[c128_dump_static] swa_page_size={swa_page_size} ring_size={ring_size} " + f"block={block} req_to_token_shape={tuple(req_to_token.shape)} " + f"full_to_swa_shape={tuple(full_to_swa_index_mapping.shape)}", + flush=True, + ) + _dumped_static_meta_once = True + + # Per-ratio dump (small): req_pool_indices / seq_lens / extend_seq_lens. + # These are forward_batch fields, identical across c4 and c128 within the + # same forward — but small (KB-level) so dumping twice is cheap. + p = f"c{compress_ratio}_plan" + d.dump( + { + f"{p}_compress_ratio": compress_ratio, + f"{p}_is_overlap": is_overlap, + f"{p}_req_pool_indices": req_pool_indices, + f"{p}_seq_lens": seq_lens, + f"{p}_extend_seq_lens": extend_seq_lens, + } + ) + + # Global tensors shared between c4 and c128 (multi-MB-GB). Only dump on + # the first call per forward to avoid 2x GPU->CPU copy (~184 MB + ~33 MB). + # Backends call create_paged_compressor_data with c4 first, then c128 + # (deepseek_v4_backend_radix.py:457-458), so dump on c4 only. + if compress_ratio == 4: + cols = min(10000, req_to_token.shape[1]) + req_to_token_partial = req_to_token[:, :cols].contiguous() + d.dump( + { + "global_req_to_token_dumped_cols": cols, + "global_req_to_token_partial": req_to_token_partial, + "global_full_to_swa_index_mapping": full_to_swa_index_mapping, + } + ) + + +def _maybe_dump_create_paged_compress_data_outputs( + *, + compress_ratio: int, + out_0: torch.Tensor, + out_1: torch.Tensor, +) -> None: + d = _get_mmap_dumper() + if not d.is_active(): + return + p = f"c{compress_ratio}_plan" + d.dump( + { + f"{p}_out_0": out_0, + f"{p}_out_1": out_1, + f"{p}_out_0_shape": list(out_0.shape), + f"{p}_out_1_shape": list(out_1.shape), + } + ) + + +_printed_buffer_shape_once: dict = {} + + +def maybe_dump_compress_metadata_extras( + *, + compress_ratio: int, + kv_score_buffer_shape: Tuple[int, ...], + kv_score_buffer_dtype: torch.dtype, + plan_compress_plan: torch.Tensor, + plan_write_plan: torch.Tensor, +) -> None: + """Public helper to be called from compressor.py at metadata-prepare time + (once per forward per ratio, not per layer). Dumps the prefill kernel's + real bound (kv_score_buffer.shape) plus the actual plan tensors that get + fed to flash_c{ratio}_prefill. + """ + d = _get_mmap_dumper() + if not d.is_active(): + return + + # Print kv_score_buffer.shape once per ratio (constant after init). + if compress_ratio not in _printed_buffer_shape_once: + print( + f"[c128_dump_static] c{compress_ratio} " + f"kv_score_buffer_shape={tuple(kv_score_buffer_shape)} " + f"dtype={kv_score_buffer_dtype}", + flush=True, + ) + _printed_buffer_shape_once[compress_ratio] = True + + p = f"c{compress_ratio}_meta" + d.dump( + { + f"{p}_plan_compress_plan": plan_compress_plan, + f"{p}_plan_write_plan": plan_write_plan, + f"{p}_plan_compress_count": int(plan_compress_plan.shape[0]), + f"{p}_plan_write_count": int(plan_write_plan.shape[0]), + } + ) + + +def triton_create_paged_compress_data( + *, + compress_ratio: int, + is_overlap: bool, + swa_page_size: int, + ring_size: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + req_to_token: torch.Tensor, + full_to_swa_index_mapping: torch.Tensor, + block: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + _should_dump = bool(envs.SGLANG_HACK_DEBUG_DUMP_CREATE_PAGED_COMPRESS_DATA.get()) + if _should_dump: + torch.cuda.synchronize() + _maybe_dump_create_paged_compress_data_inputs( + compress_ratio=compress_ratio, + is_overlap=is_overlap, + swa_page_size=swa_page_size, + ring_size=ring_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_seq_lens=extend_seq_lens, + req_to_token=req_to_token, + full_to_swa_index_mapping=full_to_swa_index_mapping, + block=block, + ) + + batch_size = req_pool_indices.shape[0] + out_dim = 4 if is_overlap else 1 + device_args: dict = dict(device=req_pool_indices.device, dtype=torch.int32) + out_0 = torch.empty((batch_size,), **device_args) + out_1 = torch.empty((batch_size, out_dim), **device_args) + grid = (triton.cdiv(batch_size, block),) + create_paged_compress_data_kernel[grid]( + req_pool_indices, + seq_lens, + extend_seq_lens, + req_to_token, + full_to_swa_index_mapping, + out_0, + out_1, + batch_size=batch_size, + stride_req_to_token_0=req_to_token.stride(0), + stride_req_to_token_1=req_to_token.stride(1), + stride_out_1_0=out_1.stride(0), + stride_out_1_1=out_1.stride(1), + compress_ratio=compress_ratio, + is_overlap=1 if is_overlap else 0, + swa_page_size=swa_page_size, + ring_size=ring_size, + BLOCK=block, + ) + + if _should_dump: + torch.cuda.synchronize() + _maybe_dump_create_paged_compress_data_outputs( + compress_ratio=compress_ratio, out_0=out_0, out_1=out_1 + ) + + if not is_overlap: + out_1.squeeze_(1) + return out_0, out_1 + + +def fused_store_cache( + input: torch.Tensor, + cache: torch.Tensor, + indices: torch.Tensor, + *, + page_size: int, + type: Literal["flashmla", "indexer"], +) -> None: + module = _jit_fused_store_module( + name=type, + input_dtype=input.dtype, + index_dtype=indices.dtype, + page_size=page_size, + ) + module.run(input, cache, indices) + + +def silu_and_mul_clamp( + input: torch.Tensor, + output: torch.Tensor, + swiglu_limit: float, +) -> None: + # Fallback path is hacky on purpose: when the mega-moe-memory flag is off + # we must be bitwise-identical to the optimize branch, which used the + # pre-refactor kernel. + from sglang.srt.environ import envs + + deepseek_v4_moe_code_path_checker.observed += 1 + if envs.SGLANG_OPT_FIX_MEGA_MOE_MEMORY.get(): + module = _jit_silu_and_mul_clamp_module(input.dtype) + else: + module = _jit_silu_and_mul_clamp_tmp_module(input.dtype) + module.run(input, output, float(swiglu_limit)) + + +def silu_and_mul_masked_post_quant( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + scale_ue8m0: bool = False, + topk: int = 8, + transposed: bool = False, + swiglu_limit: Optional[float] = None, + swizzle: bool = False, +) -> None: + apply_swiglu_limit = swiglu_limit is not None + if apply_swiglu_limit: + deepseek_v4_moe_code_path_checker.observed += 1 + if swizzle: + module = _jit_silu_mul_quant_varlen_module( + quant_group_size, scale_ue8m0, swizzle, apply_swiglu_limit + ) + else: + module = _jit_silu_mul_quant_tmp_module( + quant_group_size, scale_ue8m0, apply_swiglu_limit + ) + module.run( + input, + output, + output_scale, + masked_m, + topk, + transposed, + float(swiglu_limit) if apply_swiglu_limit else 0.0, + ) + + +def silu_and_mul_contig_post_quant( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + scale_ue8m0: bool = False, + transposed: bool = False, + swiglu_limit: Optional[float] = None, + swizzle: bool = False, +) -> None: + apply_swiglu_limit = swiglu_limit is not None + if apply_swiglu_limit: + deepseek_v4_moe_code_path_checker.observed += 1 + module = _jit_silu_mul_quant_contig_module( + quant_group_size, scale_ue8m0, swizzle, apply_swiglu_limit + ) + module.run( + input, + output, + output_scale, + transposed, + float(swiglu_limit) if apply_swiglu_limit else 0.0, + ) + + +def mega_moe_pre_dispatch( + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + buf_x: torch.Tensor, + buf_x_sf: torch.Tensor, + buf_topk_idx: torch.Tensor, + buf_topk_weights: torch.Tensor, + quant_group_size: int = 32, +) -> None: + module = _jit_mega_moe_pre_dispatch_module(quant_group_size) + module.run( + x, + topk_idx, + topk_weights, + buf_x, + buf_x_sf, + buf_topk_idx, + buf_topk_weights, + ) + + +def get_paged_mqa_logits_metadata(seq_lens: torch.Tensor, page_size: int, num_sm: int): + assert page_size == 64 + seq_lens = seq_lens.view(-1).to(torch.int32) + metadata = seq_lens.new_empty(num_sm + 1, 2) + module = _jit_metadata_module() + module.run(seq_lens, metadata) + return metadata + + +def rmsnorm_self(q: torch.Tensor, eps: float) -> torch.Tensor: + module = _jit_rmsnorm_head_module(q.shape[-1], q.dtype) + out = q.new_empty(q.shape) + module.run_self(q, out, eps) + return out + + + +def linear_bf16_fp32(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + from sglang.srt.environ import envs + + algo = envs.SGLANG_OPT_BF16_FP32_GEMM_ALGO.get() + + if algo == "auto": + from sglang.srt.layers.linear_bf16_fp32.selector import pick_backend + + algo = pick_backend(m=x.size(0), n=y.size(0), k=x.size(1)) + + return _dispatch_bf16_fp32_backend(x, y, algo=algo) + + +def _dispatch_bf16_fp32_backend( + x: torch.Tensor, y: torch.Tensor, *, algo: str +) -> torch.Tensor: + if algo == "cublas": + # cuBLAS BF16xBF16 -> FP32 GEMM via PyTorch native API (torch >= 2.9). + # Bit-exact and matches the previous JIT cublasGemmEx kernel. + return torch.mm(x, y.t(), out_dtype=torch.float32) + elif algo == "deep_gemm": + import deep_gemm + + z = x.new_empty(x.size(0), y.size(0), dtype=torch.float32) + deep_gemm.bf16_gemm_nt(x, y, z) + return z + else: + return torch.nn.functional.linear(x.float(), y.float()) + + +def _compile_one(*input_tuple) -> None: + name, job_fn, *args = input_tuple + print(f"Compiling {name}...", flush=True) + job_fn(*args) + print(f"Finished compiling {name}.", flush=True) + + +def compile_aot(): + c_dtype = torch.float32 + jobs = [ + ("common", _jit_common_module), + ("mask_topk", _jit_mask_topk_module), + ("topk", _jit_topk_module), + ("topk_v2", _jit_topk_v2_module), + ("hash_topk", _jit_hash_topk_module), + ("rope", _jit_fused_rope_module), + ("metadata", _jit_metadata_module), + ( + "compress_128_4", + _jit_compress_module, + 128, + c_dtype, + c_dtype, + 4, + ), + ( + "compress_512_4", + _jit_compress_module, + 512, + c_dtype, + c_dtype, + 4, + ), + ( + "compress_512_128", + _jit_compress_module, + 512, + c_dtype, + c_dtype, + 128, + ), + ( + "norm_rope_128_64", + _jit_norm_rope_module, + c_dtype, + 128, + 64, + ), + ( + "norm_rope_512_64", + _jit_norm_rope_module, + c_dtype, + 512, + 64, + ), + ( + "store_flashmla_bf16_swa_256", + _jit_fused_store_module, + "flashmla", + torch.bfloat16, + torch.int32, + 256, + ), + ( + "store_flashmla_fp32_c4_64", + _jit_fused_store_module, + "flashmla", + torch.float32, + torch.int32, + 64, + ), + ( + "store_flashmla_fp32_c128_2", + _jit_fused_store_module, + "flashmla", + torch.float32, + torch.int32, + 2, + ), + ( + "store_indexer_fp32_c4_64", + _jit_fused_store_module, + "indexer", + torch.float32, + torch.int32, + 64, + ), + ( + "rmsnorm_head_512_bf16", + _jit_rmsnorm_head_module, + 512, + torch.bfloat16, + ), + ] + import multiprocessing + + max_parallel_jobs = min(len(jobs), multiprocessing.cpu_count()) + with multiprocessing.Pool(processes=max_parallel_jobs) as pool: + pool.starmap(_compile_one, jobs) + + +if __name__ == "__main__": + compile_aot() diff --git a/python/sglang/jit_kernel/hisparse.py b/python/sglang/jit_kernel/hisparse.py new file mode 100644 index 000000000000..25db57b0b6ae --- /dev/null +++ b/python/sglang/jit_kernel/hisparse.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@functools.cache +def _jit_sparse_module( + item_size_bytes: int, + block_size: int, + num_top_k: int, + hot_buffer_size: int, + is_mla: bool = False, +) -> Module: + template_args = make_cpp_args(block_size, num_top_k, hot_buffer_size, is_mla) + cache_args = make_cpp_args( + item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla + ) + return load_jit( + "sparse_cache", + *cache_args, + cuda_files=["hisparse.cuh"], + cuda_wrappers=[ + ( + "load_cache_to_device_buffer", + f"load_cache_to_device_buffer<{template_args}>", + ) + ], + ) + + +def load_cache_to_device_buffer_mla( + top_k_tokens: torch.Tensor, + device_buffer_tokens: torch.Tensor, + host_cache_locs: torch.Tensor, + device_buffer_locs: torch.Tensor, + host_cache: torch.Tensor, + device_buffer: torch.Tensor, + top_k_device_locs: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + lru_slots: torch.Tensor, + item_size_bytes: int, + num_top_k: int, + hot_buffer_size: int, + page_size: int = 1, + block_size: int = 256, + num_real_reqs: torch.Tensor | None = None, +) -> None: + assert ( + hot_buffer_size >= num_top_k + ), f"hot_buffer_size ({hot_buffer_size}) must be >= num_top_k ({num_top_k})" + + module = _jit_sparse_module( + item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla=True + ) + + empty = torch.empty(0) + + if num_real_reqs is None: + num_real_reqs = torch.tensor( + [top_k_tokens.size(0)], dtype=torch.int32, device=top_k_tokens.device + ) + + module.load_cache_to_device_buffer( + top_k_tokens, + device_buffer_tokens, + host_cache_locs, + device_buffer_locs, + host_cache, + empty, + device_buffer, + empty, + top_k_device_locs, + req_pool_indices, + seq_lens, + lru_slots, + num_real_reqs, + page_size, + item_size_bytes, + ) diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh new file mode 100644 index 000000000000..02b166d01c73 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace device::compress { + +struct alignas(16) PrefillPlan { + uint32_t ragged_id; + uint32_t batch_id; + uint32_t position; + uint32_t window_len; // must be in `[0, compress_ratio * (1 + is_overlap))` + + bool is_valid(const uint32_t ratio, const bool is_overlap) const { + const uint32_t max_window_len = ratio * (1 + is_overlap); + return window_len < max_window_len; + } +}; + +} // namespace device::compress + +namespace host::compress { + +using device::compress::PrefillPlan; +using PrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kPrefillPlanDim = 16; + +static_assert(alignof(PrefillPlan) == sizeof(PrefillPlan)); +static_assert(sizeof(PrefillPlan) == kPrefillPlanDim * sizeof(PrefillPlanTensorDtype)); + +} // namespace host::compress diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/fp8_utils.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/fp8_utils.cuh new file mode 100644 index 000000000000..4fdbb062c3cd --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/fp8_utils.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include + +#include +#include + +// Small helpers shared by the DeepSeek-V4 FP8/UE8M0 quantization kernels +// (silu_and_mul_masked_post_quant, store, mega_moe_pre_dispatch, ...). +// All functions are `SGL_DEVICE` (= `__forceinline__ __device__`) so +// including this header in multiple translation units is ODR-safe. + +namespace deepseek_v4::fp8 { + +// Round `x` to the nearest representable UE8M0 value. Returns the raw +// 8-bit biased exponent; the actual fp32 scale is `2^(exp - 127)` +// (i.e. `__uint_as_float(exp << 23)`). +SGL_DEVICE int32_t cast_to_ue8m0(float x) { + uint32_t u = __float_as_uint(x); + int32_t exp = int32_t((u >> 23) & 0xFF); + uint32_t mant = u & 0x7FFFFF; + return exp + (mant != 0); +} + +// 1 / 2^(exp - 127) as fp32. Equivalent to `1.0f / __uint_as_float(exp << 23)`. +SGL_DEVICE float inv_scale_ue8m0(int32_t exp) { + return __uint_as_float((127 + 127 - exp) << 23); +} + +// Clamp to [-FP8_E4M3_MAX, FP8_E4M3_MAX]. +SGL_DEVICE float fp8_e4m3_clip(float val) { + namespace math = device::math; + return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX); +} + +// Pack two fp32 values into a single fp8x2_e4m3 with clamping. +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}}; +} + +} // namespace deepseek_v4::fp8 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh new file mode 100644 index 000000000000..0a3acc47734a --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh @@ -0,0 +1,96 @@ +#include +#include + +#include + +#include + +namespace device::hisparse { + +/// NOTE: We call nope+rope as a "value" here. +/// GPU Cache layout: +/// VALUE 0, VALUE 1, ..., VALUE 63, +/// SCALE 0, SCALE 1, ..., SCALE 63, +/// [Padding to align to 576 bytes] +/// CPU Cache follow a trivial linear layout without any padding. +inline constexpr int64_t kGPUPageSize = 64; +inline constexpr int64_t kGPUPageBits = 6; // log2(kGPUPageSize) +inline constexpr int64_t kValueBytes = 576; +inline constexpr int64_t kScaleBytes = 8; +/// NOTE: FlashMLA requires each page to be aligned to 576 bytes +inline constexpr int64_t kCPUItemBytes = kValueBytes + kScaleBytes; +inline constexpr int64_t kGPUPageBytes = host::div_ceil(kCPUItemBytes * kGPUPageSize, 576) * 576; +inline constexpr int64_t kGPUScaleOffset = kValueBytes * kGPUPageSize; + +struct PointerInfo { + int64_t* value_ptr; + int64_t* scale_ptr; +}; + +SGL_DEVICE PointerInfo get_pointer_gpu(void* cache, int32_t index) { + using namespace device; + static_assert(1 << kGPUPageBits == kGPUPageSize); + const int32_t page_num = index >> kGPUPageBits; + const int32_t page_offset = index & (kGPUPageSize - 1); + const auto page_ptr = pointer::offset(cache, page_num * kGPUPageBytes); + const auto value_ptr = pointer::offset(page_ptr, page_offset * kValueBytes); + const auto scale_ptr = pointer::offset(page_ptr, kGPUScaleOffset + page_offset * kScaleBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +SGL_DEVICE PointerInfo get_pointer_cpu(void* cache, int32_t index) { + using namespace device; + const auto value_ptr = pointer::offset(cache, index * kCPUItemBytes); + const auto scale_ptr = pointer::offset(value_ptr, kValueBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +enum class TransferDirection { + DeviceToDevice = 0, + DeviceToHost = 1, + HostToDevice = 2, +}; + +template +SGL_DEVICE void transfer_item(void* dst_cache, void* src_cache, const int32_t dst_index, const int32_t src_index) { + constexpr bool is_dst_device = (direction != TransferDirection::DeviceToHost); + constexpr bool is_src_device = (direction != TransferDirection::HostToDevice); + constexpr auto dst_fn = is_dst_device ? get_pointer_gpu : get_pointer_cpu; + constexpr auto src_fn = is_src_device ? get_pointer_gpu : get_pointer_cpu; + + const auto [dst_value_ptr, dst_scale_ptr] = dst_fn(dst_cache, dst_index); + const auto [src_value_ptr, src_scale_ptr] = src_fn(src_cache, src_index); + + int64_t local_items[2]; + const int64_t* tail_src_ptr; + int64_t* tail_dst_ptr; + + const int32_t lane_id = threadIdx.x % 32; + + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + local_items[i] = src_value_ptr[j]; + } + + if (lane_id < 8) { // handle the tail element safely + const auto last_id = 64 + lane_id; + tail_src_ptr = src_value_ptr + last_id; + tail_dst_ptr = dst_value_ptr + last_id; + } else { // broadcast load/store is safe + tail_src_ptr = src_scale_ptr; + tail_dst_ptr = dst_scale_ptr; + } + + const auto tail_item = *tail_src_ptr; + + // store first 512 bytes of value + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + dst_value_ptr[j] = local_items[i]; + } + + // store the tail element + *tail_dst_ptr = tail_item; +} + +} // namespace device::hisparse diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/cluster.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/cluster.cuh new file mode 100644 index 000000000000..e58214c95148 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/cluster.cuh @@ -0,0 +1,257 @@ +#pragma once +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct ClusterTopK { + static constexpr uint32_t kClusterSize = 8; + static constexpr uint32_t kHistBits = 10; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 4; + static constexpr uint32_t kMaxLength = kClusterSize * kNumStages * kSizePerStage; + static constexpr uint32_t kStoreLane = kBlockSize - 1; + static constexpr uint32_t kAboveBits = 11; + + // --------------------------------------------------------------------------- + // Shared memory layouts + // --------------------------------------------------------------------------- + + struct Smem { + uint64_t barrier[kNumStages]; + uint32_t local_above_equal[kClusterSize]; + uint32_t prefix_above_equal; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + uint32_t histogram[kHistBins]; + alignas(128) float score_buffer[kNumStages][kSizePerStage]; + Tie tie_buffer[kMaxTies]; + }; + + struct alignas(16) Metadata { + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + }; + + struct WorkSpace { + uint2 metadata; // {num_above, num_ties} + Tie ties[kMaxTies]; + }; + + static constexpr uint32_t kWorkspaceInts = sizeof(WorkSpace) / sizeof(uint32_t); + + // --------------------------------------------------------------------------- + // Stage 1: histogram + cluster reduce + find threshold + scatter + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_init(void* _smem) { + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto smem = static_cast(_smem); + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < kNumStages) ptx::mbarrier_init(&smem->barrier[tx], 1); + __syncthreads(); + } + + SGL_DEVICE static void stage1_prologue(const float* scores, uint32_t length, void* _smem) { + if (threadIdx.x == 0) { + const auto smem = static_cast(_smem); + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + const auto length_aligned = (length + 3u) & ~3u; // align to 4 for TMA +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + if (stage >= num_stages) break; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length_aligned - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[stage]; + ptx::tma_load(smem->score_buffer[stage], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + } + } + + SGL_DEVICE static void stage1(int32_t* indices, uint32_t length, void* _smem, bool reuse = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram, counters, and barriers +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; + const auto size = min(kSizePerStage, length - offset); + if (lane_id == 0) ptx::mbarrier_wait(&smem->barrier[stage], 0); + __syncwarp(); +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto idx = tx + i * kBlockSize; + if (idx >= size) break; + const auto score = smem->score_buffer[stage][idx]; + const auto bin = extract_coarse_bin(score); + atomicAdd(&smem->histogram[bin], 1); + } + } + + static_assert(kHistBins <= kBlockSize); + + // 2-shot all-reduce + { + auto cluster = cooperative_groups::this_cluster(); + cluster.sync(); + const auto cluster_rank = blockIdx.y; + const auto kLocalSize = kHistBins / kClusterSize; + const auto offset = kLocalSize * cluster_rank; + + const auto src_tx = tx / kClusterSize; + const auto src_rank = tx % kClusterSize; + + if (tx < kHistBins) { + const auto addr = &smem->histogram[offset + src_tx]; + const auto src_addr = cluster.map_shared_rank(addr, src_rank); + *src_addr = warp::reduce_sum(*src_addr); + } + cluster.sync(); + } + + // now each block holds the whole histogram, find the threshold bin + { + const auto value = tx < kHistBins ? smem->histogram[tx] : 0; + const auto warp_inc = warp_inclusive_sum(lane_id, value); + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + const auto tmp = smem->warp_sum[lane_id]; + // total_length = sum of all bins in the globally-reduced histogram + // (problem.length is block-local; after cluster reduction we need the global total) + const auto total_length = warp::reduce_sum(tmp); + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_inc; + const auto above = total_length - prefix_sum; + if (tx < kHistBins && above < K && above + value >= K) { + smem->counter_gt = smem->counter_eq = 0; + smem->match = { + .bin = tx, + .above_count = above, + .equal_count = value, + }; + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // write above and equal results to global memory +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto buf_idx = tx + i * kBlockSize; + const auto global_idx = offset + buf_idx; + if (global_idx >= length) break; + const auto score = smem->score_buffer[stage][buf_idx]; + const auto bin = extract_coarse_bin(score); + if (bin > thr_bin) { + indices[atomicAdd(&smem->counter_gt, 1)] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } + } + if (reuse) { + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < num_stages) ptx::mbarrier_arrive(&smem->barrier[tx]); + } + __syncthreads(); + } + + // --------------------------------------------------------------------------- + // Stage 1 epilogue: cross-block prefix sum + page translate + tie store + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_epilogue(const TransformParams params, const uint32_t offset, void* _ws, void* _smem) { + auto cluster = cooperative_groups::this_cluster(); + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto local_above = smem->counter_gt; + const auto local_equal = smem->counter_eq; + const auto cluster_rank = blockIdx.y; + + constexpr uint32_t kAboveMask = (1 << kAboveBits) - 1; + static_assert(kAboveMask >= K); + + // Pack local counts -- NO alignment rounding (contiguous layout) + static_assert(kMaxTies <= kBlockSize); + const auto idx_above = tx < local_above ? params.indices_in[tx] : 0; + const auto tie_value = tx < local_equal ? smem->tie_buffer[tx] : Tie{0, 0.0f}; + + // push to remote shared memory, can reduce latency of reading remote + if (tx < kClusterSize) { + const auto value = (local_equal << kAboveBits) | local_above; + const auto dst_addr = cluster.map_shared_rank(smem->local_above_equal, tx); + dst_addr[cluster_rank] = value; + } + // after this last sync, only read local shared memory + // so that it is safe when peer rank has already exited the kernel + cluster.sync(); + if (tx < kClusterSize) { + const auto value = tx < cluster_rank ? smem->local_above_equal[tx] : 0; + const auto kActiveMask = (1u << kClusterSize) - 1; + smem->prefix_above_equal = warp::reduce_sum(value, kActiveMask); + } + __syncthreads(); + + const auto prefix_packed = smem->prefix_above_equal; + const auto prefix_above = prefix_packed & kAboveMask; + const auto prefix_equal = prefix_packed >> kAboveBits; + + // Page-translate above elements + if (tx < local_above) { + params.write(tx + prefix_above, idx_above + offset); + } + // Contiguous tie store via regular global writes (no TMA, no gaps) + const auto ws = static_cast(_ws); + if (tx < local_equal && tx + prefix_equal < kMaxTies) { + ws->ties[tx + prefix_equal] = {tie_value.idx + offset, tie_value.score}; + } + // Block 0 writes global metadata {num_above, num_ties} + if (cluster_rank == kClusterSize - 1 && tx == 0) { + const auto sum_above = prefix_above + local_above; + const auto sum_equal = prefix_equal + local_equal; + ws->metadata = make_uint2(sum_above, sum_equal); + } + } + + SGL_DEVICE static void transform(const TransformParams params, const void* _ws, void* _smem) { + const auto ws = static_cast(_ws); + const auto meta = &ws->metadata; + const auto [num_above, num_equal] = *meta; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(ws->ties, clamped_ties, num_above, K, params, _smem); + } +}; + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/common.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/common.cuh new file mode 100644 index 000000000000..d553032d799a --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/common.cuh @@ -0,0 +1,176 @@ +#pragma once +#include +#include +#include +#include + +#include + +namespace device::top512 { + +inline constexpr uint32_t kMaxTopK = 1024; +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kNumWarps = kBlockSize / kWarpThreads; +inline constexpr uint32_t kMaxTies = 1024; // == kBlockSize: 1 element per thread in stage2 +static constexpr uint32_t kRadixBins = 256; +static_assert(kMaxTopK <= kBlockSize && kMaxTies <= kBlockSize); + +// always use float4 to load from global memory +using Vec4 = AlignedVector; + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +struct TransformParams { + const int32_t* __restrict__ page_table; + const int32_t* __restrict__ indices_in; + int32_t* __restrict__ indices_out; + uint32_t page_bits; + + SGL_DEVICE void transform(const uint32_t idx) const { + indices_out[idx] = page_to_indices(page_table, indices_in[idx], page_bits); + } + SGL_DEVICE void write(const uint32_t dst, const uint32_t src) const { + indices_out[dst] = page_to_indices(page_table, src, page_bits); + } +}; + +struct alignas(16) MatchBin { + uint32_t bin; + uint32_t above_count; + uint32_t equal_count; +}; + +struct alignas(8) Tie { + uint32_t idx; + float score; +}; + +struct TieHandleSmem { + alignas(128) uint32_t counter; // output position counter + alignas(128) MatchBin match; + uint32_t histogram[kRadixBins]; // 256-bin radix histogram + uint32_t warp_sum[kNumWarps]; // for 2-pass prefix sum +}; + +template +SGL_DEVICE uint32_t extract_coarse_bin(float x) { + static_assert(0 < kBits && kBits < 15); + const auto hx = cast(x); + const uint16_t bits = *reinterpret_cast(&hx); + const uint16_t key = (bits & 0x8000) ? ~bits : bits | 0x8000; + return key >> (16 - kBits); +} + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +/// Order-preserving float32 -> uint32 for radix select +SGL_DEVICE uint32_t extract_exact_bin(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE void trivial_transform(const TransformParams& params, uint32_t length, uint32_t K) { + if (const auto tx = threadIdx.x; tx < length) { + params.write(tx, tx); + } else if (tx < K) { + params.indices_out[tx] = -1; + } +} + +SGL_DEVICE void tie_handle_transform( + const Tie* __restrict__ ties, // + const uint32_t num_ties, + const uint32_t num_above, + const uint32_t K, + const TransformParams params, + void* _smem) { + auto* smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Each thread loads one element (or becomes inactive) + const bool has_elem = tx < num_ties; + const auto tie = has_elem ? ties[tx] : Tie{0, 0.0f}; + const uint32_t key = extract_exact_bin(tie.score); + const uint32_t idx = tie.idx; + bool active = has_elem; + uint32_t topk_remain = K - num_above; + uint32_t write_pos = K; + + smem->counter = 0; + __syncthreads(); + + // Number of warps covering the 256-bin histogram (256/32 = 8) + constexpr uint32_t kRadixWarps = kRadixBins / kWarpThreads; + +#pragma unroll + for (int round = 0; round < 4; round++) { + const uint32_t shift = 24 - round * 8; + const uint32_t bin = (key >> shift) & 0xFFu; + + // 1. Build histogram + if (tx < kRadixBins) smem->histogram[tx] = 0; + __syncthreads(); + if (active) atomicAdd(&smem->histogram[bin], 1); + __syncthreads(); + + // 2. v2-style 2-pass prefix sum on 256 bins + // Only first 256 threads (8 warps) carry histogram bins. + // Other threads get hist_val=0 and harmless prefix results. + uint32_t hist_val = 0; + uint32_t warp_inc = 0; + if (tx < kRadixBins) { + hist_val = smem->histogram[tx]; + warp_inc = warp_inclusive_sum(lane_id, hist_val); + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + } + __syncthreads(); + if (tx < kRadixBins) { + // Inter-warp prefix (only first kHistWarps warp totals matter) + const auto tmp = (lane_id < kRadixWarps) ? smem->warp_sum[lane_id] : 0; + const auto total = warp::reduce_sum(tmp); + const auto inter = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + const auto prefix = inter + warp_inc; // inclusive prefix through this bin + const auto above = total - prefix; // elements in bins ABOVE this one + // 3. Find threshold bin + if (above < topk_remain && above + hist_val >= topk_remain) { + smem->match = {tx, above, topk_remain - above}; + } + } + __syncthreads(); + + const auto [thr, n_above, _] = smem->match; + + // 4. Scatter + if (active) { + if (bin > thr) { + write_pos = num_above + atomicAdd(&smem->counter, 1); + active = false; + } else if (bin < thr) { + active = false; + } else if (round == 3) { + write_pos = K - atomicAdd(&smem->match.equal_count, -1u); + } + // my_bin == thr && round < 3: stay active for next round + } + + topk_remain -= n_above; + if (topk_remain == 0) break; + } + + if (write_pos < K) params.write(write_pos, idx); +} + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/ptx.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/ptx.cuh new file mode 100644 index 000000000000..73eef555f4db --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/ptx.cuh @@ -0,0 +1,54 @@ +#pragma once +#include + +#include + +#include + +namespace device::top512 { + +namespace ptx { + +SGL_DEVICE void mbarrier_wait(uint64_t* addr, uint32_t phase) { + while (!cuda::ptx::mbarrier_try_wait_parity(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, addr, phase)) + ; +} + +SGL_DEVICE void mbarrier_init(uint64_t* addr, uint32_t arrives) { + cuda::ptx::mbarrier_init(addr, arrives); +} + +SGL_DEVICE void mbarrier_arrive_expect_tx(uint64_t* addr, uint32_t tx) { + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr, tx); +} + +SGL_DEVICE void mbarrier_arrive(uint64_t* addr) { + cuda::ptx::mbarrier_arrive(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr); +} + +SGL_DEVICE void tma_load(void* dst, const void* src, uint32_t num_bytes, uint64_t* mbar) { + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, cuda::ptx::space_global, dst, src, num_bytes, mbar); +} + +SGL_DEVICE uint32_t elect_sync() { + uint32_t pred = 0; + asm volatile( + "{\n\t" + ".reg .pred %%px;\n\t" + "elect.sync _|%%px, %1;\n\t" + "@%%px mov.s32 %0, 1;\n\t" + "}" + : "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} + +SGL_DEVICE bool elect_sync_cta(uint32_t tx) { + const auto warp_id = tx / 32; + const auto uniform_warp_id = __shfl_sync(0xFFFFFFFF, warp_id, 0); + return (uniform_warp_id == 0 && elect_sync()); +} + +} // namespace ptx + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/register.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/register.cuh new file mode 100644 index 000000000000..77d7361ee871 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/register.cuh @@ -0,0 +1,302 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct RegisterTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kVecsPerThread = 4; + static constexpr uint32_t kMaxTolerance = 0; + static constexpr uint32_t kMax1PassLength = kVecsPerThread * 4 * kBlockSize; + static constexpr uint32_t kMaxExtraLength = kMax1PassLength; + static constexpr uint32_t kMax2PassLength = kMax1PassLength + kMaxExtraLength; + + struct Smem { + using HistVec = AlignedVector; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + uint64_t mbarrier; // for cp.async + MatchBin match; + uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + alignas(16) float score_buffer[kMaxExtraLength]; + }; + + template + SGL_DEVICE static void + run(const float* scores, // + int32_t* indices, + const uint32_t length, + void* _smem, + const bool use_pdl = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram + { + typename Smem::HistVec hist_vec; + hist_vec.fill(0); + smem->histogram_vec[tx] = hist_vec; + if (tx == 0) { + smem->counter_gt = smem->counter_eq = 0; + if constexpr (kIs2Pass) { + ptx::mbarrier_init(&smem->mbarrier, 1); + } + } + __syncthreads(); + } + + if (use_pdl) device::PDLWaitPrimary(); + + // Load scores into registers + Vec4 local[kVecsPerThread]; +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { + const uint32_t base = (tx + v * kBlockSize) * 4; + if (base >= length) break; + local[v].load(scores, tx + v * kBlockSize); + } + + // Fetch the next chunk of scores + if constexpr (kIs2Pass) { + if (ptx::elect_sync_cta(tx)) { + const auto length_aligned = (length + 3u - kMax1PassLength) & ~3u; + const auto size_bytes = length_aligned * sizeof(float); + ptx::tma_load(smem->score_buffer, scores + kMax1PassLength, size_bytes, &smem->mbarrier); + ptx::mbarrier_arrive_expect_tx(&smem->mbarrier, size_bytes); + } + __syncwarp(); // avoid warp divergence on + } + + // Accumulate histogram via shared-memory atomics +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + if constexpr (!kIs2Pass) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if (idx >= length) goto LABEL_ACC_FINISH; + } + atomicAdd(&smem->histogram[extract_coarse_bin(local[v][e])], 1); + } + } + if constexpr (kIs2Pass) { + // 16K ~ 32K. `i` is a float4 index + if (lane_id == 0) ptx::mbarrier_wait(&smem->mbarrier, 0); + __syncwarp(); + for (uint32_t i = tx; i + kMax1PassLength < length; i += kBlockSize) { + const auto val = smem->score_buffer[i]; + atomicAdd(&smem->histogram[extract_coarse_bin(val)], 1); + } + } + [[maybe_unused]] LABEL_ACC_FINISH: + __syncthreads(); + + // Phase 2: Exclusive prefix scan -> find threshold bin + { + constexpr uint32_t kItems = kHistBins / kBlockSize; + uint32_t orig[kItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t tmp_local_sum = 0; + +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + orig[i] = hist_vec[i]; + tmp_local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, tmp_local_sum); + const auto warp_exc = warp_inc - tmp_local_sum; + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + // Exactly one bin satisfies: above < K && above + count >= K + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // Phase 3: Scatter + // Elements strictly above threshold go directly to output. + // Tied elements: simple path admits first-come; tiebreak path collects into tie_buffer. + const bool need_tiebreak = (num_equal + num_above > K + kMaxTolerance); + const auto topk_indices = indices; + const auto tie_buffer = smem->tie_buffer; + +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if constexpr (!kIs2Pass) { + if (idx >= length) goto LABEL_SCATTER_DONE; + } + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + // prefetch the next scores + if constexpr (kIs2Pass) { + local[v].load(smem->score_buffer, tx + v * kBlockSize); + } + } + + // 16K ~ 32K, already in registers: similar loop as above but read from smem->score_buffer + if constexpr (kIs2Pass) { +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e + kMax1PassLength; + if (idx >= length) goto LABEL_SCATTER_DONE; + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + } + } + + [[maybe_unused]] LABEL_SCATTER_DONE: + if (!need_tiebreak) return; + + // Phase 4: Tie-breaking within the threshold bin. + // Assume num_ties <= kBlockSize (at most 1 block of ties). + // Each thread takes one tied element, computes its rank (number of + // elements with strictly higher score, breaking exact float ties by + // original index), and writes to output if rank < topk_remain. + __syncthreads(); + static_assert(kMaxTies <= kBlockSize); + + const uint32_t num_ties = min(num_equal, kMaxTies); + const uint32_t topk_remain = K - num_above; + + const auto is_greater = [](const Tie& a, const Tie& b) { + return (a.score > b.score) || (a.score == b.score && a.idx < b.idx); + }; + + if (num_ties <= kWarpThreads) { + static_assert(kWarpThreads <= kNumWarps); + if (lane_id >= num_ties || warp_id >= num_ties) return; // some threads are idle + /// NOTE: use long long to avoid mask overflow when num_ties == 32 + const uint32_t mask = (1ull << num_ties) - 1u; + const auto tie = tie_buffer[lane_id]; + const auto target_tie = tie_buffer[warp_id]; + const bool pred = is_greater(tie, target_tie); + const auto rank = static_cast(__popc(__ballot_sync(mask, pred))); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } else if (num_ties <= kWarpThreads * 2) { + // 64 x 64 topk implementation: each thread takes 2 elements + const auto lane_id_1 = lane_id + kWarpThreads; + const auto warp_id_1 = warp_id + kWarpThreads; + const auto invalid = Tie{.idx = 0xFFFFFFFF, .score = -FLT_MAX}; + const auto tie_0 = tie_buffer[lane_id]; + const auto tie_1 = lane_id_1 < num_ties ? tie_buffer[lane_id_1] : invalid; + if (true) { + const auto target = tie_buffer[warp_id]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + if (warp_id_1 < num_ties) { + const auto target = tie_buffer[warp_id_1]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + } else { + /// NOTE: Based on my observation, this path is very rarely reached + [[unlikely]]; + // Block-level: each thread reads from tie_buffer in shared memory + for (auto i = warp_id; i < num_ties; i += kNumWarps) { + const auto target_tie = tie_buffer[i]; + uint32_t local_rank = 0; + for (auto j = lane_id; j < num_ties; j += kWarpThreads) { + const auto tie = tie_buffer[j]; + if (is_greater(tie, target_tie)) local_rank++; + } + // sum the rank across the warp + const auto rank = warp::reduce_sum(local_rank); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } + } + } + + SGL_DEVICE static void transform(const TransformParams params) { + __syncthreads(); + if (const auto tx = threadIdx.x; tx < K) params.transform(tx); + } +}; + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/streaming.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/streaming.cuh new file mode 100644 index 000000000000..4462b89a1930 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/streaming.cuh @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct StreamingTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 2; // double buffer + + static constexpr uint32_t kHistItems = kHistBins / kBlockSize; // 4 + static_assert(kHistItems * kBlockSize == kHistBins); + using HistVec = AlignedVector; + + struct Smem { + uint64_t barrier[2][kNumStages]; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + union { + float score_buffer[kNumStages][kSizePerStage]; + TieHandleSmem stage2; // reuse smem for tie handling in phase D + }; + }; + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /// NOTE: length must be 4-aligned since we load 4 floats/thread. Caller should round up. + template + SGL_DEVICE static void issue_tma(const float* scores, uint32_t stage, uint32_t length, Smem* smem) { + const auto buf_idx = stage % kNumStages; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[kIsScatter][buf_idx]; + ptx::tma_load(smem->score_buffer[buf_idx], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + + // --------------------------------------------------------------------------- + // Unified streaming pass. Used for both phase A (kIsScatter=false) and + // phase C (kIsScatter=true). Each buffer is reused across iterations via the + // reuse-arrive trick (same pattern as ClusterTopKImpl::stage1). + // --------------------------------------------------------------------------- + + template + SGL_DEVICE static void stream_pass( + const float* scores, + const uint32_t length, + const uint32_t thr_bin, // ignored when !kIsScatter + int32_t* s_topk_indices, // ignored when !kIsScatter + Smem* smem) { + const auto tx = threadIdx.x; + const auto num_iters = (length + kSizePerStage - 1) / kSizePerStage; + const auto lane_id = tx % kWarpThreads; + + // Initial double-buffer TMA prologue. + const auto length_aligned = (length + 3u) & ~3u; + if (tx == 0) { +#pragma unroll + for (uint32_t i = 0; i < kNumStages; i++) { + if (i >= num_iters) break; + issue_tma(scores, i, length_aligned, smem); + } + } + + for (uint32_t iter = 0; iter < num_iters; iter++) { + const auto buf_idx = iter % kNumStages; + const auto offset = iter * kSizePerStage; + const auto this_size = min(kSizePerStage, length - offset); + + if (lane_id == 1) { + const auto phase_bit = (iter / kNumStages) & 1; + ptx::mbarrier_wait(&smem->barrier[kIsScatter][buf_idx], phase_bit); + } + __syncwarp(); + +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; i++) { + const auto local_idx = tx + i * kBlockSize; + if (local_idx >= this_size) break; + const auto score = smem->score_buffer[buf_idx][local_idx]; + const auto bin = extract_coarse_bin(score); + if constexpr (kIsScatter) { + const auto global_idx = offset + local_idx; + if (bin > thr_bin) { + const auto pos = atomicAdd(&smem->counter_gt, 1); + if (pos < K) s_topk_indices[pos] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } else { + atomicAdd(&smem->histogram[bin], 1); + } + } + + __syncthreads(); + if (tx == 0) { + if (const auto next_iter = iter + kNumStages; next_iter < num_iters) { + issue_tma(scores, next_iter, length_aligned, smem); + } + } + } + } + + // --------------------------------------------------------------------------- + // Phase B: find the threshold bin via a warp-level prefix scan. + // Same structure as SmallTopKImpl's phase 2 (4 bins/thread, warp_sum relay). + // --------------------------------------------------------------------------- + + SGL_DEVICE static void find_threshold(uint32_t length, Smem* smem) { + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + uint32_t orig[kHistItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t local_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + orig[i] = hist_vec[i]; + local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, local_sum); + const auto warp_exc = warp_inc - local_sum; + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kHistItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + SGL_DEVICE static void run(const float* scores, const uint32_t length, int32_t* topk_indices, void* _smem) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + + // Init histogram, barriers, counters. + { + HistVec zero; + zero.fill(0); + smem->histogram_vec[tx] = zero; + if (tx < 2 * kNumStages) { + const auto base_barrier = &smem->barrier[0][0]; + ptx::mbarrier_init(&base_barrier[tx], 1); + } + if (tx == 0) { + smem->counter_gt = 0; + smem->counter_eq = 0; + } + __syncthreads(); + } + + // Phase A: histogram pass (pipelined TMA stream). + stream_pass(scores, length, 0, nullptr, smem); + + // Phase B: locate threshold bin & re-init barriers + find_threshold(length, smem); + + // Phase C: scatter pass. + stream_pass(scores, length, smem->match.bin, topk_indices, smem); + } + + SGL_DEVICE static void transform(const TransformParams params, void* _smem) { + // Phase D: page-translate above entries, then refine ties. + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto num_above = smem->match.above_count; + if (tx < num_above) params.transform(tx); + const auto num_equal = smem->counter_eq; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(smem->tie_buffer, clamped_ties, num_above, K, params, &smem->stage2); + } +}; + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/distributed/common.cuh b/python/sglang/jit_kernel/include/sgl_kernel/distributed/common.cuh new file mode 100644 index 000000000000..e0ce2dc086c1 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/distributed/common.cuh @@ -0,0 +1,120 @@ +#pragma once +#include + +namespace device::distributed { + +inline constexpr uint32_t kMaxNumGPU = 8; + +struct alignas(128) Semaphore { + public: + constexpr Semaphore() : m_flag(0), m_counter(0) {} + + template + SGL_DEVICE uint32_t get() const { + uint32_t val; + if constexpr (kFence) { + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(val) : "l"(&m_flag)); + } else { + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(val) : "l"(&m_flag)); + } + return val; + } + + template + SGL_DEVICE uint32_t add(uint32_t val) { + uint32_t old_val; + if constexpr (kFence) { + asm volatile("atom.release.sys.global.add.u32 %0, [%1], %2;" : "=r"(old_val) : "l"(&m_flag), "r"(val)); + } else { + asm volatile("atom.global.add.u32 %0, [%1], %2;" : "=r"(old_val) : "l"(&m_flag), "r"(val)); + } + return old_val; + } + + // Only called by the owning GPU - plain load is sufficient + SGL_DEVICE uint32_t get_counter() const { + return m_counter; + } + + // Only called by the owning GPU - plain store is sufficient + SGL_DEVICE void set_counter(uint32_t val) { + m_counter = val; + } + + private: + uint32_t m_flag; + uint32_t m_counter; +}; + +struct PullController { + public: + using SignalType = Semaphore; + + PullController(void** signals, uint32_t num_gpu) { + for (uint32_t i = 0; i < num_gpu; ++i) { + m_signals[i] = static_cast(signals[i]); + } + } + + /// Synchronize all GPUs. + /// When kFence is true, establishes happens-before across GPUs using + /// release/acquire semantics, ensuring prior writes are visible system-wide. + template + SGL_DEVICE void sync(uint32_t rank, uint32_t num_gpu) const { + // For fenced sync: ensure all threads in this block have completed their writes, + // so the signaling thread's release carries them transitively. + static_assert(!(kFence && kStart), "Start stage does not need to wait fence"); + if constexpr (kFence || !kStart) __syncthreads(); + constexpr auto kStage = kStart ? 1 : 2; + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + if (lane_id == 0 && warp_id < num_gpu) { + auto& signal = m_signals[warp_id][blockIdx.x]; + signal.add(1); + if (warp_id == rank) { + const auto target = num_gpu * kStage; + /// NOTE: correctness here: + /// - base is only read/updated locally by the owning GPU + const auto base = signal.get_counter(); + while (signal.get() - base < target) + ; + if constexpr (!kStart) { + signal.set_counter(base + target); + } + } + } + if constexpr (kStart) __syncthreads(); + } + + private: + Semaphore* __restrict__ m_signals[kMaxNumGPU]; +}; + +struct PushController { + public: + using SignalType = uint32_t; + static constexpr int64_t kNumStages = 2; + + PushController(void* ptr) : m_local_signal(static_cast(ptr)) {} + + SGL_DEVICE SignalType epoch() const { + return m_local_signal[blockIdx.x]; + } + + SGL_DEVICE void exit() const { + __syncthreads(); + if (threadIdx.x == 0) { + this->exit_unsafe(blockIdx.x); + } + } + + SGL_DEVICE void exit_unsafe(uint32_t which) const { + auto& signal = m_local_signal[which]; + signal = (signal + 1) % kNumStages; + } + + private: + SignalType* m_local_signal; +}; + +} // namespace device::distributed diff --git a/python/sglang/jit_kernel/include/sgl_kernel/distributed/custom_all_reduce.cuh b/python/sglang/jit_kernel/include/sgl_kernel/distributed/custom_all_reduce.cuh new file mode 100644 index 000000000000..239fac71a198 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/distributed/custom_all_reduce.cuh @@ -0,0 +1,354 @@ +#pragma once +#include + +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace host::distributed { + +using device::distributed::PullController, device::distributed::PushController; + +struct AllReduceData { + constexpr AllReduceData() {} + void* __restrict__ input[device::distributed::kMaxNumGPU]; +}; + +using ExternHandle = tvm::ffi::Array; + +inline ExternHandle to_extern_handle(void* ptr) { + ExternHandle array; + cudaIpcMemHandle_t handle; + RuntimeDeviceCheck(cudaIpcGetMemHandle(&handle, ptr)); + for (size_t i = 0; i < sizeof(handle); ++i) { + array.push_back(handle.reserved[i]); + } + return array; +} + +inline void* from_extern_handle(const ExternHandle& array) { + cudaIpcMemHandle_t handle; + RuntimeCheck(array.size() == sizeof(handle), "Invalid IPC handle size: ", array.size()); + for (size_t i = 0; i < sizeof(handle); ++i) { + handle.reserved[i] = array[i]; + } + void* ptr; + RuntimeDeviceCheck(cudaIpcOpenMemHandle(&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); + return ptr; +} + +struct HandleHash { + std::size_t operator()(const cudaIpcMemHandle_t& handle) const { + return std::hash{}({handle.reserved, sizeof(handle.reserved)}); + } +}; + +struct HandleEqual { + bool operator()(const cudaIpcMemHandle_t& a, const cudaIpcMemHandle_t& b) const { + return std::memcmp(a.reserved, b.reserved, sizeof(a.reserved)) == 0; + } +}; + +/** + * \brief The control plane of the custom all-reduce implementation. + * It manages the internal state and synchronization of the participating GPUs. + */ +struct CustomAllReduceBase : public tvm::ffi::Object { + public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("sgl.CustomAllReduce", CustomAllReduceBase, tvm::ffi::Object); + + static constexpr bool _type_mutable = true; + using InputPair = tvm::ffi::Tuple; // (offset, ipc handle) + + CustomAllReduceBase( + uint32_t rank, + uint32_t num_gpu, + uint32_t max_num_cta_pull, + uint32_t max_num_cta_push, + int64_t pull_buffer_size, + int64_t push_buffer_size, + int64_t graph_buffer_count) + : m_pull_buffer_bytes(pull_buffer_size), + m_push_buffer_bytes(push_buffer_size), + m_graph_buffer_count(graph_buffer_count), + m_rank(rank), + m_num_gpu(num_gpu), + m_max_num_cta_pull(max_num_cta_pull), + m_max_num_cta_push(max_num_cta_push), + // default config for pull kernel, can be updated by `configure()` + m_num_cta(max_num_cta_pull), + m_cta_size(256) { + RuntimeCheck(pull_buffer_size % 128 == 0, "Pull buffer size should be aligned to 128 bytes"); + RuntimeCheck(push_buffer_size % 128 == 0, "Push buffer size should be aligned to 128 bytes"); + RuntimeCheck(rank < num_gpu, "Invalid rank: ", rank); + const int64_t kU32Max = static_cast(std::numeric_limits::max()); + const int64_t push_buffer_size_all = push_all_ranks_bytes(); + RuntimeCheck(pull_buffer_size <= kU32Max, "Pull buffer size is too large: ", pull_buffer_size); + RuntimeCheck(push_buffer_size_all <= kU32Max, "Push buffer size is too large: ", push_buffer_size_all); + RuntimeDeviceCheck(cudaMalloc(&m_storage, storage_bytes())); + } + + ExternHandle share_storage() { + return to_extern_handle(m_storage); + } + + tvm::ffi::Array share_graph_inputs() { + tvm::ffi::Array result; + const auto new_inputs_count = registered_count() - m_cum_registered_count; + RuntimeCheck(new_inputs_count >= 0, "Invalid new count: ", new_inputs_count); + result.reserve(new_inputs_count); + std::unordered_map ipc_cache; + const auto get_handle = [&](void* ptr) -> ExternHandle { + const auto it = ipc_cache.find(ptr); + if (it != ipc_cache.end()) return it->second; + const auto handle = to_extern_handle(ptr); + ipc_cache.try_emplace(ptr, handle); + return handle; + }; + for (const auto ptr : std::span(m_graph_capture_inputs).subspan(m_cum_registered_count)) { + // note: must share the base address of each allocation, or we get wrong address + void* base_ptr; + const auto cu_result = cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr); + RuntimeCheck(cu_result == CUDA_SUCCESS, "failed to get pointer attr"); + const auto offset = reinterpret_cast(ptr) - reinterpret_cast(base_ptr); + result.push_back(InputPair{offset, get_handle(base_ptr)}); + } + return result; + } + + void post_init(tvm::ffi::Array ipc_storages) { + RuntimeCheck(ipc_storages.size() == m_num_gpu, "Invalid array size: ", ipc_storages.size()); + m_peer_storage.resize(m_num_gpu); + for (const auto i : irange(m_num_gpu)) { + if (i == m_rank) { + m_peer_storage[i] = m_storage; + } else { + m_peer_storage[i] = from_extern_handle(ipc_storages[i]); + } + } + + // set signal buffer to zero + const auto pull_signal = get_pull_signal(m_storage); + RuntimeDeviceCheck(cudaMemset(pull_signal, 0, pull_signal_bytes())); + + // update the pull controller and data pointer + RuntimeCheck(!m_pull_ctrl.has_value(), "Controller is already initialized"); + m_pull_ctrl.emplace(m_peer_storage.data(), m_num_gpu); + AllReduceData data; + for (const auto i : irange(m_num_gpu)) { + data.input[i] = get_pull_buffer(m_peer_storage[i]); + } + const auto default_data_ptr = get_data_ptr(); + RuntimeDeviceCheck(cudaMemcpy(default_data_ptr, &data, sizeof(AllReduceData), cudaMemcpyHostToDevice)); + + // update the push controller and data pointer + RuntimeCheck(!m_push_ctrl.has_value(), "Controller is already initialized"); + const auto push_signal = get_push_signal(m_storage); + RuntimeDeviceCheck(cudaMemset(push_signal, 0, push_signal_bytes())); + m_push_ctrl.emplace(push_signal); + const auto push_buffer = get_push_buffer(m_storage); + RuntimeDeviceCheck(cudaMemset(push_buffer, 0, push_all_ranks_bytes())); + } + + void register_inputs(tvm::ffi::Array> ipc_graph_inputs) { + RuntimeCheck(ipc_graph_inputs.size() == m_num_gpu); + const auto new_registered_count = registered_count() - m_cum_registered_count; + RuntimeCheck(new_registered_count >= 0, "Invalid registered count: ", new_registered_count); + if (new_registered_count == 0) return; // avoid `m_get_data_ptr()` out-of-bounds + std::vector data; + data.resize(new_registered_count); + const auto open_cached = [&](const ExternHandle& h) -> void* { + RuntimeCheck(h.size() == sizeof(cudaIpcMemHandle_t), "Invalid IPC handle size: ", h.size()); + cudaIpcMemHandle_t handle; + for (size_t i = 0; i < sizeof(handle); ++i) + handle.reserved[i] = h[i]; + const auto [it, success] = m_ipc_cache.try_emplace(handle, nullptr); + if (success) { + void* ptr; + RuntimeDeviceCheck(cudaIpcOpenMemHandle(&ptr, handle, cudaIpcMemLazyEnablePeerAccess)); + it->second = ptr; + } + return it->second; + }; + for (const auto i : irange(ipc_graph_inputs.size())) { + const auto& array = ipc_graph_inputs[i]; + RuntimeCheck(int64_t(array.size()) == new_registered_count); + if (i == m_rank) { + for (const auto j : irange(new_registered_count)) { + data[j].input[i] = m_graph_capture_inputs[m_cum_registered_count + j]; + } + } else { + for (const auto j : irange(new_registered_count)) { + /// NOTE: structural binding will cause intern compiler error... + const auto elem = array[j]; + const auto offset = elem.get<0>(); + const auto ipc_handle = elem.get<1>(); + data[j].input[i] = pointer::offset(open_cached(ipc_handle), offset); + } + } + } + + const auto new_registered_bytes = sizeof(AllReduceData) * new_registered_count; + const auto dst_ptr = get_data_ptr(m_cum_registered_count); + m_cum_registered_count += new_registered_count; + RuntimeDeviceCheck(cudaMemcpy(dst_ptr, data.data(), new_registered_bytes, cudaMemcpyHostToDevice)); + } + + void set_cuda_graph_capture(bool enabled) { + m_is_graph_capturing = enabled; + } + + void free_ipc_handles() { + for (const auto& pair : m_ipc_cache) { + host::RuntimeDeviceCheck(cudaIpcCloseMemHandle(pair.second)); + } + m_ipc_cache.clear(); + } + + void free_storage() { + host::RuntimeDeviceCheck(cudaFree(m_storage)); + m_storage = nullptr; + } + + tvm::ffi::Tuple configure_pull(uint32_t num_cta, uint32_t cta_size) { + using host::RuntimeCheck; + const auto min_cta_size = m_num_gpu * device::kWarpThreads; + RuntimeCheck(num_cta > 0 && num_cta <= m_max_num_cta_pull, "Invalid number of CTAs: ", num_cta); + RuntimeCheck(cta_size >= min_cta_size, "Block size must be at least ", min_cta_size); + const auto old_num_cta = m_num_cta; + const auto old_block_size = m_cta_size; + m_num_cta = num_cta; + m_cta_size = cta_size; + return tvm::ffi::Tuple{old_num_cta, old_block_size}; + } + + protected: + AllReduceData* allocate_graph_capture_input(void* data_ptr) { + const auto count = registered_count(); + RuntimeCheck(count < m_graph_buffer_count, "Graph buffer overflow, increase `graph_buffer_count`!"); + m_graph_capture_inputs.push_back(data_ptr); + return get_data_ptr(count); + } + AllReduceData* get_data_ptr(int64_t which = -1) { + const auto count = registered_count(); + RuntimeCheck(which >= -1 && which < count, "Invalid graph buffer index: ", which, ", count: ", count); + const auto start = get_pull_params(m_storage); + return static_cast(start) + (1 + which); + } + int64_t registered_count() const { + return static_cast(m_graph_capture_inputs.size()); + } + int64_t pull_signal_bytes() const { + return _align_bytes(sizeof(PullController::SignalType) * m_max_num_cta_pull); + } + int64_t push_signal_bytes() const { + return _align_bytes(sizeof(PushController::SignalType) * m_max_num_cta_push); + } + int64_t graph_param_bytes() const { + return _align_bytes(sizeof(AllReduceData) * (1 + m_graph_buffer_count)); // 1 for default + } + int64_t push_all_ranks_bytes() const { + return _align_bytes(PushController::kNumStages * m_num_gpu * m_push_buffer_bytes); + } + int64_t storage_bytes() const { + return _get_offset_impl(5); + } + void* get_pull_signal(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(0)); + } + void* get_push_signal(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(1)); + } + void* get_pull_params(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(2)); + } + void* get_pull_buffer(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(3)); + } + void* get_push_buffer(void* ptr) const { + return pointer::offset(ptr, _get_offset_impl(4)); + } + int64_t _get_offset_impl(int64_t which) const { + // | SignalArray (pull + push) | GraphBuffers (pull params) | Buffers (pull + push) | + const int64_t offset_map[5] = { + /*[0]=*/pull_signal_bytes(), + /*[1]=*/push_signal_bytes(), + /*[2]=*/graph_param_bytes(), + /*[3]=*/m_pull_buffer_bytes, + /*[4]=*/push_all_ranks_bytes(), + }; + RuntimeCheck(which >= 0 && which <= 5, "Invalid offset index: ", which); + return std::accumulate(offset_map, offset_map + which, int64_t(0)); + } + static int64_t _align_bytes(int64_t size) { + return div_ceil(size, 128) * 128; + } + + const int64_t m_pull_buffer_bytes; + const int64_t m_push_buffer_bytes; + const int64_t m_graph_buffer_count; + const uint32_t m_rank; + const uint32_t m_num_gpu; + const uint32_t m_max_num_cta_pull; + const uint32_t m_max_num_cta_push; + // these 2 config should only affect pull kernel + uint32_t m_num_cta; + uint32_t m_cta_size; + // other states + bool m_is_graph_capturing = false; + int64_t m_cum_registered_count = 0; + std::optional m_pull_ctrl; + std::optional m_push_ctrl; + void* m_storage = nullptr; + std::vector m_graph_capture_inputs; + std::vector m_peer_storage; + std::unordered_map m_ipc_cache; +}; + +struct CustomAllReduceRef : public tvm::ffi::ObjectRef { + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(CustomAllReduceRef, tvm::ffi::ObjectRef, CustomAllReduceBase); +}; + +} // namespace host::distributed + +namespace device::distributed { + +template +SGL_DEVICE auto reduce_impl(AlignedVector (&storage)[M]) -> AlignedVector { + fp32x2_t acc[N] = {}; +#pragma unroll // unroll num gpu + for (uint32_t i = 0; i < M; ++i) { +#pragma unroll // unroll vec + for (uint32_t j = 0; j < N; ++j) { + const auto [x, y] = cast(storage[i][j]); + auto& [x_acc, y_acc] = acc[j]; + x_acc += x; + y_acc += y; + } + } + + AlignedVector result; +#pragma unroll + for (uint32_t j = 0; j < N; ++j) { + result[j] = cast(acc[j]); + } + + return result; +} + +} // namespace device::distributed diff --git a/python/sglang/jit_kernel/include/sgl_kernel/ffi.h b/python/sglang/jit_kernel/include/sgl_kernel/ffi.h new file mode 100644 index 000000000000..17d9048d4c42 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/ffi.h @@ -0,0 +1,104 @@ +#pragma once +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace host::ffi { + +using tvm::ffi::Tensor, tvm::ffi::TensorView, tvm::ffi::ShapeView; + +inline Tensor empty(ShapeView shape, DLDataType dtype, DLDevice device) { + return Tensor::FromEnvAlloc(::TVMFFIEnvTensorAlloc, shape, dtype, device); +} + +inline Tensor empty_like(TensorView tensor) { + return empty(tensor.shape(), tensor.dtype(), tensor.device()); +} + +struct _dummy_deleter { + void operator()(void*) const {} +}; + +// template + +template +struct FromBlobContext { + [[no_unique_address]] Fn deleter; + int64_t dimension; + int64_t* get_shape() { + return reinterpret_cast(this + 1); + } + int64_t* get_stride() { + return this->get_shape() + dimension; + } +}; + +template +inline Tensor from_blob( + void* data, + ShapeView shape, + DLDataType dtype, + DLDevice device, + Fn&& deleter = {}, + std::optional stride = {}, + uint64_t byte_offset = 0) { + using Context = FromBlobContext>; + const auto ndim = shape.size(); + const auto ctx = [&] { + auto ptr = std::malloc(sizeof(Context) + sizeof(int64_t) * ndim * 2); + auto ctx = static_cast(ptr); + std::construct_at(ctx, std::forward(deleter), static_cast(ndim)); + stdr::copy_n(shape.data(), ndim, ctx->get_shape()); + if (stride.has_value()) { + RuntimeCheck(stride->size() == ndim, "Stride ndim mismatch!"); + stdr::copy_n(stride->data(), ndim, ctx->get_stride()); + } else { + int64_t stride_val = 1; + for (const auto i : irange(ndim)) { + const auto j = ndim - 1 - i; + ctx->get_stride()[j] = stride_val; + stride_val *= shape[j]; + } + } + return ctx; + }(); + const auto tensor = DLTensor{ + .data = data, + .device = device, + .ndim = static_cast(ndim), + .dtype = dtype, + .shape = ctx->get_shape(), + .strides = ctx->get_stride(), + .byte_offset = byte_offset, + }; + const auto blob_deleter = [](DLManagedTensor* self) { + auto ctx = static_cast(self->manager_ctx); + ctx->deleter(self->dl_tensor.data); + std::destroy_at(ctx); + std::free(ctx); + }; + auto managed_tensor = DLManagedTensor{tensor, ctx, blob_deleter}; + return Tensor::FromDLPack(&managed_tensor); +} + +template +inline Tensor from_blob_like( + void* data, + TensorView t, + Fn&& deleter = {}, + bool is_contiguous = false, // if override to true, the stride will be ignored + uint64_t byte_offset = 0) { + const auto stride = is_contiguous ? std::nullopt : std::optional{t.strides()}; + return from_blob(data, t.shape(), t.dtype(), t.device(), std::forward(deleter), stride, byte_offset); +} + +} // namespace host::ffi diff --git a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh index 01ce21a7a813..f5000d4a147d 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh @@ -53,6 +53,36 @@ SGL_DEVICE void PDLTriggerSecondary() { #endif } +template +SGL_DEVICE constexpr auto div_ceil(T a, U b) { + return (a + b - 1) / b; +} + +/** + * \brief Load data with the specified type and offset from a void pointer. + * \tparam T The type to load. + * \param ptr The base pointer. + * \param offset The offset in number of elements of type T. + */ +template +SGL_DEVICE T load_as(const void* ptr, int64_t offset = 0) { + return static_cast(ptr)[offset]; +} + +/** + * \brief Store data with the specified type and offset to a void pointer. + * \tparam T The type to store. + * \param ptr The base pointer. + * \param val The value to store. + * \param offset The offset in number of elements of type T. + * \note we use type_identity_t to force the caller to explicitly specify + * the template parameter `T`, which can avoid accidentally using the wrong type. + */ +template +SGL_DEVICE void store_as(void* ptr, std::type_identity_t val, int64_t offset = 0) { + static_cast(ptr)[offset] = val; +} + namespace pointer { // we only allow void * pointer arithmetic for safety @@ -112,16 +142,22 @@ struct LaunchKernel { auto enable_pdl(bool enabled = true) -> LaunchKernel& { if (enabled) { - m_attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - m_attrs[0].val.programmaticStreamSerializationAllowed = true; - m_config.numAttrs = 1; + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = true; m_config.attrs = m_attrs; - } else { - m_config.numAttrs = 0; } return *this; } + auto enable_cluster(dim3 cluster_dim) -> LaunchKernel& { + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {cluster_dim.x, cluster_dim.y, cluster_dim.z}; + m_config.attrs = m_attrs; + return *this; + } + template auto operator()(T&& kernel, Args&&... args) const -> void { RuntimeDeviceCheck(::cudaLaunchKernelEx(&m_config, kernel, std::forward(args)...), m_location); @@ -144,7 +180,7 @@ struct LaunchKernel { cudaLaunchConfig_t m_config; const DebugInfo m_location; - cudaLaunchAttribute m_attrs[1]; + cudaLaunchAttribute m_attrs[2]; }; } // namespace host diff --git a/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh b/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh index 5510b44746c9..f48d34181d4f 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh @@ -50,14 +50,10 @@ struct AlignedVector { using storage_t = AlignedStorage; public: - template - SGL_DEVICE void load(const U* ptr, std::size_t offset = 0) { - static_assert(std::is_same_v || std::is_same_v); + SGL_DEVICE void load(const void* ptr, std::size_t offset = 0) { m_storage = reinterpret_cast(ptr)[offset]; } - template - SGL_DEVICE void store(U* ptr, std::size_t offset = 0) const { - static_assert(std::is_same_v || std::is_same_v); + SGL_DEVICE void store(void* ptr, std::size_t offset = 0) const { reinterpret_cast(ptr)[offset] = m_storage; } SGL_DEVICE void fill(T value) { diff --git a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh index d69526e97f29..079ac0155872 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh @@ -1,23 +1,24 @@ #pragma once #include +#include // Some warp primitives namespace device::warp { static constexpr uint32_t kFullMask = 0xffffffffu; -template +template SGL_DEVICE T reduce_sum(T value, uint32_t active_mask = kFullMask) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) + for (auto mask = kThreads >> 1; mask > 0; mask >>= 1) value = value + __shfl_xor_sync(active_mask, value, mask, 32); return value; } -template +template SGL_DEVICE T reduce_max(T value, uint32_t active_mask = kFullMask) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) + for (auto mask = kThreads >> 1; mask > 0; mask >>= 1) value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32)); return value; } diff --git a/python/sglang/jit_kernel/moe_fused_gate.py b/python/sglang/jit_kernel/moe_fused_gate.py new file mode 100644 index 000000000000..d0daad0a3180 --- /dev/null +++ b/python/sglang/jit_kernel/moe_fused_gate.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Tuple + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +_SCORING_FUNC_MAP = { + "sigmoid": 0, + "sqrtsoftplus": 1, +} + + +@cache_once +def _jit_moe_fused_gate_module() -> Module: + return load_jit( + "moe_fused_gate", + cuda_files=["moe/moe_fused_gate.cuh"], + cuda_wrappers=[("moe_fused_gate", "MoEFusedGateKernel::run")], + ) + + +@cache_once +def can_use_moe_fused_gate() -> bool: + logger = logging.getLogger(__name__) + try: + _jit_moe_fused_gate_module() + return True + except Exception as e: + logger.warning(f"Failed to load JIT MoE fused gate kernel: {e}") + return False + + +def moe_fused_gate( + input: torch.Tensor, + bias: torch.Tensor, + topk: int, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + renormalize: bool = True, + routed_scaling_factor: float = 1.0, + apply_routed_scaling_factor_on_output: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + scoring_func_int = _SCORING_FUNC_MAP.get(scoring_func.lower()) + assert ( + scoring_func_int is not None + ), f"Unknown scoring_func '{scoring_func}', must be one of {list(_SCORING_FUNC_MAP.keys())}" + + assert input.dtype == torch.float32, "input must be float32" + assert bias.dtype == torch.float32, "bias must be float32" + assert input.ndim == 2, "input must be 2D" + assert bias.ndim == 1, "bias must be 1D" + assert input.size(1) == bias.size(0), "input and bias must have same num_experts" + assert topk > num_fused_shared_experts, "topk must be > num_fused_shared_experts" + + num_rows, _ = input.shape + device = input.device + + output = torch.empty(num_rows, topk, dtype=torch.float32, device=device) + indices = torch.empty(num_rows, topk, dtype=torch.int32, device=device) + + module = _jit_moe_fused_gate_module() + module.moe_fused_gate( + input, + bias, + output, + indices, + topk, + scoring_func_int, + num_fused_shared_experts, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, + ) + + return output, indices diff --git a/python/sglang/jit_kernel/tests/test_custom_all_reduce.py b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py new file mode 100644 index 000000000000..96a61bfbf785 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_custom_all_reduce.py @@ -0,0 +1,239 @@ +""" +Correctness test for the JIT custom all-reduce (v2) kernel. + +The test compares the JIT custom all-reduce output against NCCL all-reduce +for various tensor sizes and dtypes, in both eager and CUDA-graph modes. + +Usage: + python -m pytest test_jit_custom_all_reduce.py -v + +This file doubles as the torchrun worker script. The test class launches + torchrun --nproc_per_node=N +and asserts that all worker processes exit successfully. +""" + +from __future__ import annotations + +import itertools +import logging +import multiprocessing as mp +import os +from typing import Dict, Optional, Tuple + +import pytest +import torch +import torch.distributed as dist + +import sglang.srt.distributed.parallel_state as ps +from sglang.jit_kernel.all_reduce import ( + AllReduceAlgo, + _jit_custom_all_reduce_pull_module, + _jit_custom_all_reduce_push_module, +) +from sglang.jit_kernel.tests.utils import multiprocess_main, multiprocess_test +from sglang.srt.distributed.device_communicators.custom_all_reduce_v2 import ( + CustomAllReduceV2, +) +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci( + est_time=300, + suite="stage-b-kernel-unit-8-gpu-h200", +) +register_cuda_ci( + est_time=300, + suite="nightly-kernel-8-gpu-h200", + nightly=True, +) + +# --------------------------------------------------------------------------- +# Test parameters (shared between test class and worker) +# --------------------------------------------------------------------------- + +TEST_SIZES = [ + 16, + 32, + 512, + 1024, + 1024 + 16, # weird case + 4 * 1024, + 32 * 1024, + 256 * 1024, + 2 * 1024 * 1024, # 2M elements + 4 * 1024 * 1024, # 4M elements +] +TEST_DTYPES = [torch.float16, torch.bfloat16, torch.float32] +SHOTS = [ + AllReduceAlgo.ONE_SHOT_PULL, + AllReduceAlgo.ONE_SHOT_PUSH, + AllReduceAlgo.TWO_SHOT_PULL, +] +USE_GRAPH_OPTIONS = [True, False] +TEST_CONFIG = itertools.product(TEST_SIZES, TEST_DTYPES, SHOTS, USE_GRAPH_OPTIONS) +TEST_LAYERS = 4 +TEST_LOOP = 16 + +# --------------------------------------------------------------------------- +# Test class (runs via pytest, launches torchrun subprocesses) +# --------------------------------------------------------------------------- + + +def _compile_one(dtype: torch.dtype, world_size: int): + _jit_custom_all_reduce_push_module(dtype, world_size) + _jit_custom_all_reduce_pull_module(dtype, world_size) + + +def _precompile_kernels() -> None: + # NOTE: even when device count < 8, we should be able to compile all + process_map: Dict[Tuple[torch.dtype, int], mp.Process] = {} + COMPILE_SPACE = itertools.product(TEST_DTYPES, [2, 3, 4, 5, 6, 7, 8]) + mp.set_start_method("spawn") + for config in COMPILE_SPACE: + process_map[config] = mp.Process(target=_compile_one, args=config) + for process in process_map.values(): + process.start() + for (dtype, world_size), process in process_map.items(): + process.join() + if process.exitcode != 0: + raise RuntimeError(f"Custom All Reduce {world_size=} {dtype=} failed") + + +@pytest.mark.parametrize("nproc", [1, 2, 3, 4, 5, 6, 7, 8]) +def test_custom_allreduce(nproc: int) -> None: + if nproc == 1: # NOTE: special case to speed up tests + return _precompile_kernels() + + device_count = torch.cuda.device_count() + if device_count < nproc: + pytest.skip( + f"Requires at least {nproc} GPUs, but only {device_count} available" + ) + multiprocess_test(__file__, nproc) + + +# --------------------------------------------------------------------------- +# Worker logic (executed by each torchrun process) +# --------------------------------------------------------------------------- + + +def init_distributed(): + """Initialize distributed groups via torchrun env vars. + + Returns (rank, device, cpu_group, nccl_group, comm). + """ + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + rank = local_rank + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + dist.init_process_group(backend="gloo") + ps._WORLD = coord = ps.init_world_group( + ranks=list(range(world_size)), + local_rank=local_rank, + backend="nccl", + ) + + cpu_group = coord.cpu_group + nccl_group = coord.device_group + assert nccl_group is not None + + max_size = max(TEST_SIZES) * 4 + comm = CustomAllReduceV2(cpu_group, device, max_size, max_size) + if comm.disabled: + raise RuntimeError("JIT CustomAllReduceV2 is disabled on this system") + + return rank, device, cpu_group, nccl_group, comm + + +@torch.inference_mode() +def worker_test( + device: torch.device, + nccl_group: dist.ProcessGroup, + comm: CustomAllReduceV2, + size: int, + dtype: torch.dtype, + use_graph: bool, + algo: AllReduceAlgo, +) -> Optional[RuntimeError]: + comm.override_algo = algo + + def get_run_graph_fn(): + graph = torch.cuda.CUDAGraph() + graph_inp = torch.zeros((TEST_LAYERS, size), dtype=dtype, device=device) + out_jits = [] + with comm.capture(): + with torch.cuda.graph(graph): + for i in range(TEST_LAYERS): + out_jits.append(comm.custom_all_reduce(graph_inp[i])) + out_jit = torch.stack(out_jits) + torch.cuda.synchronize() + + def run_graph(x: torch.Tensor) -> torch.Tensor: + graph_inp.copy_(x) + graph.replay() + return out_jit.clone() + + return run_graph + + def get_run_eager_fn(): + def run_eager(x: torch.Tensor) -> torch.Tensor: + eager_inp = x.clone() + out_eagers = [] + for i in range(TEST_LAYERS): + out_eagers.append(comm.custom_all_reduce(eager_inp[i])) + torch.cuda.synchronize() + return torch.stack(out_eagers) + + return run_eager + + run_fn = get_run_graph_fn() if use_graph else get_run_eager_fn() + num_errors = 0 + for _ in range(TEST_LOOP): + # NOTE: 15 * 8 < 128, which is the precision limit for bf16 + inp = torch.randint(0, 16, (TEST_LAYERS, size), dtype=dtype, device=device) + assert comm.should_custom_ar(inp[0]) + out_ref = inp.clone() + dist.all_reduce(out_ref, group=nccl_group) + out_jit = run_fn(inp) + num_errors += not torch.all(out_jit == out_ref) + if num_errors > 0: + return RuntimeError( + f"Test failed for {size=}, {dtype=}, {algo=}, " + f"{use_graph=} with {num_errors} errors. " + ) + return None + + +def worker_main() -> None: + """Entry point for each torchrun worker process.""" + rank, device, cpu_group, nccl_group, comm = init_distributed() + + torch.cuda.set_stream(torch.cuda.Stream()) + + logging.disable(logging.INFO) # Suppress internal logging for cleaner test output + items = list(enumerate(TEST_CONFIG)) + for i, (size, dtype, algo, use_graph) in items: + error = worker_test(device, nccl_group, comm, size, dtype, use_graph, algo) + if error is not None: + print( + f"Worker {rank} failed for {size=}, {dtype=}, " + f"{algo=}, {use_graph=}, iteration={i}\n" + f"Error: {error}" + ) + # communicate the result to rank 0 for logging + result = torch.tensor([int(error is not None)]) + dist.all_reduce(result, group=cpu_group) + failed = bool(result.item()) + if failed: + raise RuntimeError( + f"Test failed on rank {rank} for config: " + f"{size=}, {dtype=}, {algo=}, {use_graph=}" + ) + + comm.close() + dist.destroy_process_group() + + +if __name__ == "__main__": + multiprocess_main(__file__, worker_main) diff --git a/python/sglang/jit_kernel/tests/utils.py b/python/sglang/jit_kernel/tests/utils.py new file mode 100644 index 000000000000..56f3511ab8c7 --- /dev/null +++ b/python/sglang/jit_kernel/tests/utils.py @@ -0,0 +1,41 @@ +import os +import subprocess +import sys +from typing import Callable + +import pytest + + +def multiprocess_test(file: str, nproc: int, timeout: int = 90) -> None: + """Launch this script as a torchrun worker and assert success.""" + cmd = [ + "torchrun", + f"--nproc_per_node={nproc}", + file, + ] + try: + result = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + timeout=timeout, + ) + except subprocess.TimeoutExpired as e: + raise RuntimeError( + f"torchrun (nproc={nproc}) timed out after {timeout} seconds\n" + f"{e.stdout}" + ) from e + + assert result.returncode == 0, ( + f"torchrun (nproc={nproc}) failed with rc={result.returncode}\n" + f"{result.stdout}" + ) + + +def multiprocess_main(file: str, main: Callable[[], None]) -> None: + """Helper to run a function in a multiprocess torchrun context.""" + if "LOCAL_RANK" in os.environ: + main() + else: + sys.exit(pytest.main([file, "-v", "-s"])) diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index e8358d35d68e..a073f0493231 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -42,7 +42,7 @@ def _package_install(): DEFAULT_CFLAGS = ["-std=c++20", "-O3"] DEFAULT_CUDA_CFLAGS = ["-std=c++20", "-O3", "--expt-relaxed-constexpr"] DEFAULT_LDFLAGS = [] -CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool, torch.dtype] +CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, str, bool, torch.dtype] class CPPArgList(list[str]): @@ -54,6 +54,8 @@ def __str__(self) -> str: torch.float: "fp32_t", torch.float16: "fp16_t", torch.bfloat16: "bf16_t", + torch.int32: "int32_t", + torch.int64: "int64_t", } @@ -61,7 +63,7 @@ def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList: def _convert(arg: CPP_TEMPLATE_TYPE) -> str: if isinstance(arg, bool): return "true" if arg else "false" - if isinstance(arg, (int, float)): + if isinstance(arg, (int, str, float)): return str(arg) if isinstance(arg, torch.dtype): return CPP_DTYPE_MAP[arg] diff --git a/python/sglang/multimodal_gen/runtime/models/registry.py b/python/sglang/multimodal_gen/runtime/models/registry.py index 5e6367a40c18..07d956420769 100644 --- a/python/sglang/multimodal_gen/runtime/models/registry.py +++ b/python/sglang/multimodal_gen/runtime/models/registry.py @@ -218,7 +218,9 @@ def _try_load_model_cls( try: return model.load_model_cls() except Exception: - logger.exception("Ignore import error when loading '%s'", model_arch) + logger.exception( + "In _try_load_model_cls: Ignore import error when loading '%s'", model_arch + ) return None diff --git a/python/sglang/srt/configs/config_backup_large.json b/python/sglang/srt/configs/config_backup_large.json new file mode 100644 index 000000000000..a2d9c21b9637 --- /dev/null +++ b/python/sglang/srt/configs/config_backup_large.json @@ -0,0 +1,68 @@ +{ + "architectures": [ + "DeepseekXYZForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 0, + "eos_token_id": 1, + "hc_eps": 1e-06, + "hc_mult": 4, + "hc_sinkhorn_iters": 20, + "head_dim": 512, + "hidden_act": "silu", + "hidden_size": 7168, + "index_head_dim": 128, + "index_n_heads": 64, + "index_topk": 1024, + "initializer_range": 0.02, + "max_position_embeddings": 1048576, + "model_type": "deepseek_ref", + "moe_intermediate_size": 3072, + "n_routed_experts": 384, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 128, + "num_experts_per_tok": 6, + "num_hidden_layers": 61, + "num_hash_layers": 3, + "num_key_value_heads": 1, + "num_nextn_predict_layers": 1, + "o_groups": 16, + "o_lora_rank": 1024, + "q_lora_rank": 1536, + "qk_rope_head_dim": 64, + "quantization_config": { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "scale_fmt": "ue8m0", + "weight_block_size": [ + 128, + 128 + ] + }, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 16, + "original_max_position_embeddings": 65536, + "type": "yarn" + }, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sqrtsoftplus", + "sliding_window": 128, + "swiglu_limit": 10.0, + "tie_word_embeddings": false, + "n_group": 8, + "topk_group": 8, + "topk_method": "noaux_tc", + "torch_dtype": "bfloat16", + "transformers_version": "4.57.1", + "use_cache": true, + "vocab_size": 129280, + "compress_rope_theta": 160000, + "compress_ratios": [128, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0] +} diff --git a/python/sglang/srt/configs/config_backup_small.json b/python/sglang/srt/configs/config_backup_small.json new file mode 100644 index 000000000000..fc5555a8e41d --- /dev/null +++ b/python/sglang/srt/configs/config_backup_small.json @@ -0,0 +1,68 @@ +{ + "architectures": [ + "DeepseekXYZForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 0, + "eos_token_id": 1, + "hc_eps": 1e-06, + "hc_mult": 4, + "hc_sinkhorn_iters": 20, + "head_dim": 512, + "hidden_act": "silu", + "hidden_size": 4096, + "index_head_dim": 128, + "index_n_heads": 64, + "index_topk": 512, + "initializer_range": 0.02, + "max_position_embeddings": 1048576, + "model_type": "deepseek_ref", + "moe_intermediate_size": 2048, + "n_routed_experts": 256, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 64, + "num_experts_per_tok": 6, + "num_hidden_layers": 43, + "num_hash_layers": 3, + "num_key_value_heads": 1, + "num_nextn_predict_layers": 1, + "o_groups": 8, + "o_lora_rank": 1024, + "q_lora_rank": 1024, + "qk_rope_head_dim": 64, + "quantization_config": { + "activation_scheme": "dynamic", + "fmt": "e4m3", + "quant_method": "fp8", + "scale_fmt": "ue8m0", + "weight_block_size": [ + 128, + 128 + ] + }, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 16, + "original_max_position_embeddings": 65536, + "type": "yarn" + }, + "rope_theta": 10000, + "routed_scaling_factor": 1.5, + "scoring_func": "sqrtsoftplus", + "sliding_window": 128, + "swiglu_limit": 10.0, + "tie_word_embeddings": false, + "n_group": 8, + "topk_group": 8, + "topk_method": "noaux_tc", + "torch_dtype": "bfloat16", + "transformers_version": "4.57.1", + "use_cache": true, + "vocab_size": 129280, + "compress_rope_theta": 160000, + "compress_ratios": [0, 0, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 128, 4, 0] +} diff --git a/python/sglang/srt/configs/deepseek_v4.py b/python/sglang/srt/configs/deepseek_v4.py new file mode 100644 index 000000000000..d345688efb17 --- /dev/null +++ b/python/sglang/srt/configs/deepseek_v4.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers import PretrainedConfig + +from sglang.srt.layers.quantization.base_config import QuantizationConfig + + +@dataclass +class DeepSeekV4Config(PretrainedConfig): + architectures: List[str] + attention_bias: bool = False + attention_dropout: float = 0.0 + bos_token_id: int = 0 + eos_token_id: int = 1 + ep_size: int = 1 + first_k_dense_replace: int = 0 + hidden_act: str = "silu" + hidden_size: int = 4096 + index_head_dim: int = 128 + index_n_heads: int = 64 + index_topk: int = 512 + initializer_range: float = 0.02 + intermediate_size: int = 2048 + kv_lora_rank: int = 512 + max_position_embeddings: int = 65536 + model_type: str = "deepseek_ref" + moe_intermediate_size: int = 2048 + moe_layer_freq: int = 1 + n_group: int = 8 + n_routed_experts: int = 256 + n_shared_experts: int = 1 + norm_topk_prob: bool = True + + num_attention_heads: int = 64 + num_experts_per_tok: int = 6 + num_hidden_layers: int = 43 + num_key_value_heads: int = 1 + + q_lora_rank: int = 1024 + qk_nope_head_dim: int = 448 + qk_rope_head_dim: int = 64 + + quantization_config: QuantizationConfig = field(default_factory=QuantizationConfig) + + rms_norm_eps: float = 1e-6 + + rope_scaling: Dict[str, float] = field(default_factory=dict) + rope_theta: int = 10000 + + routed_scaling_factor: float = 1.5 + scoring_func: str = "sqrtsoftplus" + + tie_word_embeddings: bool = False + + topk_group: int = 8 + topk_method: str = "noaux_tc" + + use_cache: bool = True + v_head_dim: int = 512 + vocab_size: int = 129280 + o_lora_rank: int = 1024 + o_groups: int = 8 + window_size: int = 128 + + compress_rope_theta: int = 40000 + compress_ratios: List[int] = field(default_factory=list) + + n_hash_layers: int = 3 + hc_mult: int = 4 + hc_sinkhorn_iters: int = 20 + hc_eps: float = 1e-6 diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index b3a49f8a410e..5a3997deeac9 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -66,8 +66,15 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool: ) +def is_deepseek_compressed(config: PretrainedConfig) -> bool: + return config.architectures is not None and ( + config.architectures[0] == "DeepseekV4ForCausalLM" + or config.architectures[0] == "DeepseekV4ForCausalLMNextN" + ) + + def get_nsa_index_head_dim(config: PretrainedConfig) -> int: - assert is_deepseek_nsa(config) + assert is_deepseek_nsa(config) or is_deepseek_compressed(config) return config.index_head_dim @@ -81,6 +88,64 @@ def get_nsa_index_n_heads(config: PretrainedConfig) -> int: return config.index_n_heads +import re as _re + +# Matches routed-expert weight keys in both HF-style layouts +# (``...mlp.experts..{gate,up,down}_proj.weight``) and DeepseekV4 2604-style +# layouts (``...ffn.experts..w{1,2,3}.weight``). ``shared_experts`` is +# excluded because the index segment requires a digit after ``.experts.``. +_ROUTED_EXPERT_KEY_RE = _re.compile( + r"\.experts\.\d+\.(?:w[123]|down_proj|up_proj|gate_proj)\.weight$" +) + + +def _probe_routed_expert_weight_dtype(model_path: str) -> Optional[str]: + """Return the safetensors dtype string (e.g. ``F8_E4M3``, ``U8``) of one + routed-expert weight tensor, or ``None`` if the checkpoint is remote or has + no matching key. Reads only the safetensors header of the relevant shard. + """ + import struct + + if not os.path.isdir(model_path): + return None + + index_file = os.path.join(model_path, "model.safetensors.index.json") + target_key = None + target_shard_path = None + + if os.path.exists(index_file): + with open(index_file) as f: + index = json.load(f) + weight_map = index.get("weight_map", {}) or {} + for k, shard in weight_map.items(): + if _ROUTED_EXPERT_KEY_RE.search(k): + target_key = k + target_shard_path = os.path.join(model_path, shard) + break + if target_key is None: + return None + else: + shards = sorted(Path(model_path).glob("*.safetensors")) + if not shards: + return None + target_shard_path = str(shards[0]) + + with open(target_shard_path, "rb") as f: + (header_len,) = struct.unpack(" 1 + and envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + else self.hidden_size + ) self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_attention_layers = self.num_hidden_layers if "LongcatFlashForCausalLM" in self.hf_config.architectures: @@ -723,10 +890,11 @@ def _get_modelopt_quant_type(self) -> str: return "fp8" # Default fallback def _get_sliding_window_size(self) -> Optional[int]: - sliding_window_size = getattr(self.hf_text_config, "sliding_window_size", None) - if sliding_window_size is None: - sliding_window_size = getattr(self.hf_text_config, "sliding_window", None) - return sliding_window_size + key_list = ["sliding_window_size", "sliding_window", "window_size"] + for key in key_list: + if hasattr(self.hf_text_config, key): + return getattr(self.hf_text_config, key) + return None def _validate_quantize_and_serve_config(self): """Validate quantize_and_serve configuration.""" @@ -1230,6 +1398,8 @@ def is_hybrid_swa_model(model_architectures: List[str]): hybrid_swa_archs = { "Llama4ForConditionalGeneration", + "DeepseekV4ForCausalLM", + "DeepseekV4ForCausalLMNextN", "GptOssForCausalLM", "MiMoV2FlashForCausalLM", "MiMoV2MTP", diff --git a/python/sglang/srt/debug_utils/deepseek_v4_debug_utils.py b/python/sglang/srt/debug_utils/deepseek_v4_debug_utils.py new file mode 100644 index 000000000000..1f734b945c09 --- /dev/null +++ b/python/sglang/srt/debug_utils/deepseek_v4_debug_utils.py @@ -0,0 +1,6 @@ +class _MoECodePathChecker: + def __init__(self): + self.observed = 0 + + +deepseek_v4_moe_code_path_checker = _MoECodePathChecker() diff --git a/python/sglang/srt/debug_utils/mmap_dumper.py b/python/sglang/srt/debug_utils/mmap_dumper.py new file mode 100644 index 000000000000..dfe0187adc7c --- /dev/null +++ b/python/sglang/srt/debug_utils/mmap_dumper.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +import json +import mmap +import os +from typing import Any, Optional + +import torch + + +_META_CAPACITY = 64 * 1024 + + +class MmapDumper: + def __init__(self, dump_dir: Optional[str] = None) -> None: + self._dump_dir = dump_dir + self._pid = os.getpid() + self._scalars: dict = {} + self._tensor_meta: dict = {} + self._tensor_mmaps: dict = {} + self._meta_mmap = None + if dump_dir: + self._activate(dump_dir) + + def set_dir(self, dump_dir: str) -> None: + if self._dump_dir == dump_dir: + return + self._activate(dump_dir) + + def is_active(self) -> bool: + return self._dump_dir is not None and self._meta_mmap is not None + + def dump(self, items: dict) -> None: + if not self.is_active(): + return + import time + + t0 = time.perf_counter() + for name, value in items.items(): + if isinstance(value, torch.Tensor): + self._dump_tensor(name, value) + else: + self._scalars[name] = _jsonify(value) + self._flush_meta() + elapsed_ms = (time.perf_counter() - t0) * 1000 + print( + f"[MmapDumper pid={self._pid}] dumped {len(items)} items " + f"in {elapsed_ms:.2f} ms", + flush=True, + ) + + def _activate(self, dump_dir: str) -> None: + os.makedirs(dump_dir, exist_ok=True) + self._dump_dir = dump_dir + path = os.path.join(dump_dir, f"pid{self._pid}_meta.json.mmap") + fd = os.open(path, os.O_RDWR | os.O_CREAT, 0o644) + os.ftruncate(fd, _META_CAPACITY) + self._meta_mmap = mmap.mmap( + fd, _META_CAPACITY, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE + ) + + def _dump_tensor(self, name: str, tensor: torch.Tensor) -> None: + import numpy as np + + cpu_tensor = tensor.detach().cpu().contiguous() + nbytes = cpu_tensor.numel() * cpu_tensor.element_size() + alloc_bytes = max(nbytes, 1) + + entry = self._tensor_mmaps.get(name) + bin_path = os.path.join(self._dump_dir, f"pid{self._pid}_{name}.bin") + if entry is None or entry["capacity"] < alloc_bytes: + if entry is not None: + entry["mmap"].close() + os.close(entry["fd"]) + capacity = max(alloc_bytes * 2, 4096) + fd = os.open(bin_path, os.O_RDWR | os.O_CREAT, 0o644) + os.ftruncate(fd, capacity) + mm = mmap.mmap( + fd, capacity, mmap.MAP_SHARED, mmap.PROT_READ | mmap.PROT_WRITE + ) + entry = {"fd": fd, "mmap": mm, "capacity": capacity} + self._tensor_mmaps[name] = entry + + if nbytes > 0: + src = cpu_tensor.numpy().reshape(-1).view(np.uint8) + dst = np.frombuffer(entry["mmap"], dtype=np.uint8, count=nbytes) + np.copyto(dst, src) + + self._tensor_meta[name] = { + "shape": list(cpu_tensor.shape), + "stride": list(cpu_tensor.stride()), + "dtype": str(cpu_tensor.dtype), + "nbytes": nbytes, + "bin_filename": os.path.basename(bin_path), + } + + def _flush_meta(self) -> None: + meta = {"pid": self._pid, "scalars": self._scalars, "tensors": self._tensor_meta} + payload = json.dumps(meta).encode("utf-8") + n = len(payload) + assert n + 4 <= _META_CAPACITY, f"mmap dumper meta too big: {n}" + self._meta_mmap[0:4] = n.to_bytes(4, "little") + self._meta_mmap[4 : 4 + n] = payload + + +def _jsonify(value: Any) -> Any: + if isinstance(value, (bool, int, float, str)) or value is None: + return value + if isinstance(value, (list, tuple)): + return [_jsonify(v) for v in value] + if isinstance(value, dict): + return {str(k): _jsonify(v) for k, v in value.items()} + return repr(value) + + +_TORCH_DTYPE_TO_TORCH = { + "torch.int8": torch.int8, + "torch.int16": torch.int16, + "torch.int32": torch.int32, + "torch.int64": torch.int64, + "torch.uint8": torch.uint8, + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "torch.float64": torch.float64, + "torch.bfloat16": torch.bfloat16, + "torch.bool": torch.bool, +} + + +def read_dump(dump_dir: str, pid: int) -> dict: + """ + Load a dump produced by `MmapDumper`. + + Returns: + { + "scalars": {name: jsonable_value, ...}, + "tensors": {name: torch.Tensor (cpu), ...}, + } + """ + meta_path = os.path.join(dump_dir, f"pid{pid}_meta.json.mmap") + with open(meta_path, "rb") as f: + n = int.from_bytes(f.read(4), "little") + meta = json.loads(f.read(n).decode("utf-8")) + + tensors = {} + for name, info in meta["tensors"].items(): + torch_dtype = _TORCH_DTYPE_TO_TORCH[info["dtype"]] + if info["nbytes"] == 0: + tensors[name] = torch.empty(info["shape"], dtype=torch_dtype) + continue + bin_path = os.path.join(dump_dir, info["bin_filename"]) + elem_size = torch.empty((), dtype=torch_dtype).element_size() + n_elem = info["nbytes"] // elem_size + with open(bin_path, "rb") as f: + buf = f.read(info["nbytes"]) + flat = torch.frombuffer(bytearray(buf), dtype=torch_dtype, count=n_elem) + tensors[name] = flat.reshape(info["shape"]) + return {"scalars": meta["scalars"], "tensors": tensors} + + +def list_dump_pids(dump_dir: str) -> list: + """Return all pids that have a dump in `dump_dir`.""" + pids = [] + for fn in os.listdir(dump_dir): + if fn.startswith("pid") and fn.endswith("_meta.json.mmap"): + pids.append(int(fn[len("pid") : -len("_meta.json.mmap")])) + return sorted(pids) + + +def _tester() -> None: + import shutil + import tempfile + + tmp_dir = tempfile.mkdtemp(prefix="mmap_dumper_test_") + print(f"[tester] dir = {tmp_dir}") + + # ----- Test 1: scalars ----- + d = MmapDumper(tmp_dir) + assert d.is_active() + d.dump({"a": 1, "b": True, "c": "hello", "d": None, "e": [1, 2, 3]}) + out = read_dump(tmp_dir, os.getpid()) + assert out["scalars"] == {"a": 1, "b": True, "c": "hello", "d": None, "e": [1, 2, 3]}, out + assert out["tensors"] == {} + print("[tester] T1 scalars OK") + + # ----- Test 2: tensors of different dtypes ----- + t_i32 = torch.arange(10, dtype=torch.int32) + t_i64 = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.int64) + t_f32 = torch.tensor([1.5, 2.5, 3.5], dtype=torch.float32) + d.dump({"t_i32": t_i32, "t_i64": t_i64, "t_f32": t_f32}) + out = read_dump(tmp_dir, os.getpid()) + assert torch.equal(out["tensors"]["t_i32"], t_i32) + assert torch.equal(out["tensors"]["t_i64"], t_i64) + assert torch.equal(out["tensors"]["t_f32"], t_f32) + print("[tester] T2 tensors OK") + + # ----- Test 3: re-dump with smaller tensor (capacity reuse) ----- + d.dump({"t_i32": torch.arange(3, dtype=torch.int32)}) + out = read_dump(tmp_dir, os.getpid()) + assert torch.equal(out["tensors"]["t_i32"], torch.arange(3, dtype=torch.int32)) + print("[tester] T3 shrink OK") + + # ----- Test 4: re-dump with larger tensor (mmap grow) ----- + big = torch.arange(10000, dtype=torch.int64) + d.dump({"t_i32": big}) # different dtype, much larger + out = read_dump(tmp_dir, os.getpid()) + assert torch.equal(out["tensors"]["t_i32"], big) + print("[tester] T4 grow OK") + + # ----- Test 5: scalars + tensors mixed, multiple flushes ----- + d.dump({"counter": 1}) + d.dump({"counter": 2, "x": torch.zeros(5, dtype=torch.int32)}) + out = read_dump(tmp_dir, os.getpid()) + assert out["scalars"]["counter"] == 2, out + assert torch.equal(out["tensors"]["x"], torch.zeros(5, dtype=torch.int32)) + print("[tester] T5 mixed flushes OK") + + # ----- Test 6: empty tensor ----- + d.dump({"empty": torch.zeros(0, dtype=torch.int32)}) + out = read_dump(tmp_dir, os.getpid()) + assert out["tensors"]["empty"].shape == (0,) + print("[tester] T6 empty OK") + + # ----- Test 7: inactive dumper is a no-op ----- + d2 = MmapDumper(None) + assert not d2.is_active() + d2.dump({"foo": 1}) # should not crash + print("[tester] T7 inactive no-op OK") + + # ----- Test 8: auto mkdir for missing nested dir ----- + nested = os.path.join(tmp_dir, "a", "b", "c") + d3 = MmapDumper(nested) + assert os.path.isdir(nested) + d3.dump({"v": 42}) + print("[tester] T8 auto mkdir OK") + + # ----- Test 9: list_dump_pids ----- + pids = list_dump_pids(tmp_dir) + assert pids == [os.getpid()], f"expected only this pid, got {pids}" + print("[tester] T9 list_dump_pids OK") + + shutil.rmtree(tmp_dir) + print("[tester] all OK") + + +def _bench_one_rank(rank: int, dump_dir: str, shape: tuple, n_iters: int, queue) -> None: + """Run inside a child process: alloc GPU tensor, dump N times, report timings.""" + import time + + import torch + + n_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % n_gpus) + try: + t = torch.zeros(shape, dtype=torch.int32, device="cuda") + except torch.cuda.OutOfMemoryError as e: + queue.put((rank, "OOM", str(e)[:200])) + return + torch.cuda.synchronize() + nbytes = t.numel() * t.element_size() + + d = MmapDumper(os.path.join(dump_dir, f"rank{rank}")) + times = [] + for _ in range(n_iters): + t0 = time.perf_counter() + d.dump({"req_to_token_partial": t}) + times.append(time.perf_counter() - t0) + queue.put((rank, "OK", times, nbytes)) + + +def _bench_speed(dump_dir_root: str = "/dev/shm/mmap_dumper_bench") -> None: + import multiprocessing as mp + import shutil + + import torch + + if not torch.cuda.is_available(): + print("[bench] no CUDA, skip") + return + + n_gpus = torch.cuda.device_count() + n_ranks = min(8, n_gpus) + print(f"[bench] using {n_ranks} ranks (GPUs available: {n_gpus})") + print(f"[bench] dump dir root = {dump_dir_root}") + + shutil.rmtree(dump_dir_root, ignore_errors=True) + os.makedirs(dump_dir_root, exist_ok=True) + + try: + _bench_speed_inner(dump_dir_root, n_ranks) + finally: + shutil.rmtree(dump_dir_root, ignore_errors=True) + print(f"[bench] cleaned {dump_dir_root}") + print("[bench] done") + + +def _bench_speed_inner(dump_dir_root: str, n_ranks: int) -> None: + import multiprocessing as mp + import shutil + + # Default: only run SMALL (the new fast path). + # Set SGLANG_DUMP_BENCH_BIG=1 to also run BIG (old slow path) for comparison. + big = os.environ.get("SGLANG_DUMP_BENCH_BIG", "0") not in ("", "0", "false", "False") + scenarios = [ + ("SMALL (4608, 10000) ~ 184 MB", (4608, 10000), 5), + ] + if big: + scenarios.insert(0, ("BIG (4608, 1048580) ~ 18 GB", (4608, 1048580), 3)) + + ctx = mp.get_context("spawn") + for label, shape, n_iters in scenarios: + nbytes = shape[0] * shape[1] * 4 + print(f"\n=== {label} ===") + if dump_dir_root.startswith("/host/data") and nbytes * n_ranks > 130 * 2**30: + print(f" skip: would write {nbytes * n_ranks / 2**30:.1f} GB to /host/data (~93% full)") + continue + + # Concurrent: 8 ranks at once + queue = ctx.Queue() + procs = [ + ctx.Process( + target=_bench_one_rank, + args=(r, os.path.join(dump_dir_root, "concurrent"), shape, n_iters, queue), + ) + for r in range(n_ranks) + ] + import time as _time + + wall_t0 = _time.perf_counter() + for p in procs: + p.start() + results = [queue.get() for _ in range(n_ranks)] + for p in procs: + p.join() + wall_total = _time.perf_counter() - wall_t0 + + results.sort(key=lambda x: x[0]) + all_oom = all(r[1] == "OOM" for r in results) + if all_oom: + print(" all ranks OOM, skip") + continue + print( + f" CONCURRENT 8-rank wall total = {wall_total*1000:.0f} ms " + f"(includes process spawn + GPU init + tensor alloc)" + ) + for rank, status, *rest in results: + if status == "OOM": + print(f" rank {rank}: OOM ({rest[0]})") + continue + times, nb = rest + t_min = min(times) * 1000 + t_med = sorted(times)[len(times) // 2] * 1000 + t_max = max(times) * 1000 + mb = nb / 2**20 + print( + f" rank {rank}: dump nbytes={mb:.0f} MB, " + f"per-call min={t_min:.1f} ms, median={t_med:.1f} ms, max={t_max:.1f} ms" + ) + + shutil.rmtree(os.path.join(dump_dir_root, "concurrent"), ignore_errors=True) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) > 1 and sys.argv[1] == "bench": + dump_dir = sys.argv[2] if len(sys.argv) > 2 else "/dev/shm/mmap_dumper_bench" + _bench_speed(dump_dir) + else: + _tester() diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 26df1bcb7d48..d2874c7f6371 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -52,6 +52,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.common import release_kv_cache +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import ( HybridLinearKVPool, HybridReqToTokenPool, @@ -280,6 +281,12 @@ def _init_kv_manager(self) -> BaseKVManager: self.metadata_buffers.get_buf_infos() ) + if isinstance(self.token_to_kv_pool, DeepSeekV4TokenToKVPool): + assert self.prefill_pp_size == 1, ( + "V4 PD disaggregation requires PP=1 " + "(get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized flat list)" + ) + if hasattr(self.token_to_kv_pool, "get_state_buf_infos"): state_data_ptrs, state_data_lens, state_item_lens = ( self.token_to_kv_pool.get_state_buf_infos() @@ -288,7 +295,7 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.state_data_lens = state_data_lens kv_args.state_item_lens = state_item_lens - if isinstance(self.token_to_kv_pool, SWAKVPool): + if isinstance(self.token_to_kv_pool, (SWAKVPool, DeepSeekV4TokenToKVPool)): kv_args.state_type = "swa" elif isinstance(self.token_to_kv_pool, HybridLinearKVPool): kv_args.state_type = "mamba" @@ -541,8 +548,9 @@ def pop_preallocated( .cpu() .numpy() ] - elif isinstance(self.token_to_kv_pool, SWAKVPool): - # SWA hybrid model: send decode-side SWA window indices + elif isinstance( + self.token_to_kv_pool, (SWAKVPool, DeepSeekV4TokenToKVPool) + ): seq_len = len(decode_req.req.origin_input_ids) window_size = self.scheduler.sliding_window_size diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index 7905aae65ed7..a18c3a12ee8e 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -41,6 +41,11 @@ class TransferInfo: dst_kv_indices: npt.NDArray[np.int32] dst_aux_index: int required_dst_info_num: int + # Decode-side state pool indices for SWA / NSA / Mamba state transfer. + # Empty when the model has no state buffer. + dst_state_indices: npt.NDArray[np.int32] = dataclasses.field( + default_factory=lambda: np.array([], dtype=np.int32) + ) def is_dummy(self): return self.dst_kv_indices.size == 0 @@ -55,6 +60,7 @@ def from_zmq(cls, msg: List[bytes]): dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32), dst_aux_index=int(msg[5].decode("ascii")), required_dst_info_num=int(msg[6].decode("ascii")), + dst_state_indices=np.frombuffer(msg[7], dtype=np.int32), ) @@ -73,6 +79,10 @@ class KVArgsRegisterInfo: decode_tp_size: int decode_tp_rank: int dst_kv_item_len: int + # Decode-side state buffer base pointers and per-tensor item lengths. + # Empty for models without a state buffer (state_type == "none"). + dst_state_data_ptrs: list[int] = dataclasses.field(default_factory=list) + dst_state_item_lens: list[int] = dataclasses.field(default_factory=list) @classmethod def from_zmq(cls, msg: List[bytes]): @@ -88,6 +98,12 @@ def from_zmq(cls, msg: List[bytes]): decode_tp_size=int(msg[8].decode("ascii")), decode_tp_rank=int(msg[9].decode("ascii")), dst_kv_item_len=int(msg[10].decode("ascii")), + dst_state_data_ptrs=list( + struct.unpack(f"{len(msg[11]) // 8}Q", msg[11]) + ), + dst_state_item_lens=list( + struct.unpack(f"{len(msg[12]) // 4}I", msg[12]) + ), ) @@ -105,6 +121,12 @@ class TransferStatus: num_pp_ranks_expected: Optional[int] = None # Whether aux data has been received. received_aux: bool = False + # Whether SWA / NSA / Mamba state pages are expected, and whether they + # have arrived. Set to True by the receiver only when the model has a + # state buffer; non-state models leave both False so is_done() ignores + # them. + expects_state: bool = False + received_state: bool = False # Mark as failed is_failure: bool = False @@ -113,6 +135,8 @@ def is_done(self): return True if self.num_pp_ranks_expected is None or not self.received_aux: return False + if self.expects_state and not self.received_state: + return False # All PP ranks must have reported their expected count if len(self.expected_kvs_per_pp) < self.num_pp_ranks_expected: return False @@ -319,6 +343,27 @@ def register_buffer_to_engine(self): if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") + # Register the SWA / NSA / Mamba state pool. It lives in VRAM on the + # prefill side and must be transferred to decode along with the main + # KV cache; without this, decode reads zero-initialised state for + # every state-bearing layer and produces incorrect attention output. + # Mirrors the equivalent logic in the mooncake backend. + self.state_descs = None + if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens: + state_addrs = [ + (ptr, length, self.kv_args.gpu_id, "") + for ptr, length in zip( + self.kv_args.state_data_ptrs, self.kv_args.state_data_lens + ) + ] + self.state_descs = self.agent.register_memory(state_addrs, "VRAM") + logger.debug( + f"Register state tensors, len(state_addrs)= {len(state_addrs)}, " + f"state_type={self.kv_args.state_type}" + ) + if not self.state_descs: + raise Exception("NIXL memory registration failed for state tensors") + def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo): agent_name = decode_kv_args.agent_name if agent_name in self.decode_kv_args_table: @@ -575,6 +620,69 @@ def send_aux( raise Exception("KVSender failed to post transfer") return xfer_handle + def send_state( + self, + peer_name: str, + prefill_state_indices: npt.NDArray[np.int32], + dst_state_ptrs: list[int], + dst_state_indices: npt.NDArray[np.int32], + dst_state_item_lens: list[int], + dst_gpu_id: int, + notif: str, + ): + """Per-page WRITE transfer of SWA / NSA / Mamba state pages. + + Mirrors :meth:`send_kvcache` (page-by-index VRAM->VRAM WRITE) but + operates on the state pool rather than the KV cache. Caller must only + invoke this on the last chunk of a request and only when both sides + have registered state pointers. + """ + src_state_ptrs = self.kv_args.state_data_ptrs + src_state_item_lens = self.kv_args.state_item_lens + assert len(src_state_ptrs) == len(dst_state_ptrs) + assert len(src_state_item_lens) == len(dst_state_item_lens) + # The page-by-index transfer below assumes prefill and decode have + # matching state-pool layouts (same item_len per tensor). Mismatched + # layouts arise only with mamba TP-slice across non-equal attn_tp_size, + # which is not yet supported in the nixl backend (see mooncake + # _send_mamba_state_with_tp_slice for the future TP-slice path). + for i in range(len(src_state_item_lens)): + assert src_state_item_lens[i] == dst_state_item_lens[i], ( + f"State item length mismatch at index {i}: " + f"{src_state_item_lens[i]} != {dst_state_item_lens[i]} " + "(non-equal item lens require mamba TP-slice transfer, " + "not yet supported in nixl backend)" + ) + + src_addrs = [] + dst_addrs = [] + for i in range(len(src_state_ptrs)): + item_len = src_state_item_lens[i] + for src_idx, dst_idx in zip(prefill_state_indices, dst_state_indices): + src_addr = src_state_ptrs[i] + int(src_idx) * item_len + dst_addr = dst_state_ptrs[i] + int(dst_idx) * item_len + src_addrs.append((src_addr, item_len, self.kv_args.gpu_id)) + dst_addrs.append((dst_addr, item_len, dst_gpu_id)) + + if not src_addrs: + return None + + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM") + xfer_handle = self.agent.initialize_xfer( + "WRITE", + src_descs, + dst_descs, + peer_name, + notif.encode("ascii"), # type: ignore + ) + if not xfer_handle: + raise Exception("KVSender failed to create state transfer") + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("KVSender failed to post state transfer") + return xfer_handle + def add_transfer_request( self, bootstrap_room: int, @@ -583,6 +691,7 @@ def add_transfer_request( is_last: bool, chunk_id: int, aux_index: Optional[int] = None, + state_indices: Optional[npt.NDArray[np.int32]] = None, ): assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) @@ -640,6 +749,30 @@ def add_transfer_request( str(req.room) + "_aux", ) handles.append(aux_xfer_handle) + + # If the model has a state buffer (SWA / NSA / Mamba), ship + # the per-request state pages now and emit a "_state" notif + # so the receiver waits for it. Skipped (and notif omitted) + # when either side has no state, which keeps non-state + # models on the same code path. + peer = self.decode_kv_args_table[req.agent_name] + if ( + self.kv_args.state_data_ptrs + and peer.dst_state_data_ptrs + and state_indices is not None + and req.dst_state_indices.size > 0 + ): + state_xfer_handle = self.send_state( + req.agent_name, + np.asarray(state_indices, dtype=np.int32), + peer.dst_state_data_ptrs, + req.dst_state_indices, + peer.dst_state_item_lens, + peer.gpu_id, + str(req.room) + "_state", + ) + if state_xfer_handle is not None: + handles.append(state_xfer_handle) if is_last: del self.transfer_infos[bootstrap_room] return handles @@ -674,6 +807,8 @@ def update_transfer_status(self): ) elif components[1] == "aux": self.transfer_statuses[room].received_aux = True + elif components[1] == "state": + self.transfer_statuses[room].received_state = True def check_transfer_done(self, room: int): if room not in self.transfer_statuses: @@ -748,6 +883,7 @@ def send( is_last, self.chunk_id, self.aux_index, + state_indices=state_indices, ) self.xfer_handles.extend(new_xfer_handles) self.chunk_id += 1 @@ -810,6 +946,13 @@ def init( logger.debug( f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}" ) + if state_indices is not None and not is_dummy: + state_indices_bytes = np.asarray( + state_indices, dtype=np.int32 + ).tobytes() + else: + state_indices_bytes = b"" + with lock: sock.send_multipart( [ @@ -821,9 +964,20 @@ def init( kv_indices.tobytes() if not is_dummy else b"", str(aux_index).encode("ascii"), str(self.required_dst_info_num).encode("ascii"), + state_indices_bytes, ] ) + # Tell the transfer-status tracker to wait for the state-page + # transfer when this receiver is going to receive one (i.e. the + # model has a state buffer and we sent indices to a non-dummy peer). + if ( + state_indices is not None + and len(state_indices) > 0 + and self.kv_mgr.kv_args.state_data_ptrs + ): + self.kv_mgr.transfer_statuses[self.bootstrap_room].expects_state = True + self.started_transfer = True self.init_time = time.time() @@ -875,6 +1029,14 @@ def _register_kv_args(self): packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) + packed_state_data_ptrs = b"".join( + struct.pack("Q", ptr) + for ptr in self.kv_mgr.kv_args.state_data_ptrs + ) + packed_state_item_lens = b"".join( + struct.pack("I", int(length)) + for length in self.kv_mgr.kv_args.state_item_lens + ) with lock: sock.send_multipart( @@ -891,6 +1053,8 @@ def _register_kv_args(self): str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"), str(self.kv_mgr.kv_args.engine_rank).encode("ascii"), str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"), + packed_state_data_ptrs, + packed_state_item_lens, ] ) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 35762a7446dd..35049e6758f8 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -49,6 +49,7 @@ ScheduleBatch, ) from sglang.srt.mem_cache.common import release_kv_cache +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end @@ -149,6 +150,15 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id + if isinstance(self.token_to_kv_pool, DeepSeekV4TokenToKVPool): + assert self.pp_size == 1, ( + "V4 PD disaggregation requires PP=1 " + "(get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized flat list)" + ) + assert ( + self.decode_tp_size == self.scheduler.tp_size + ), "V4 PD disaggregation requires same TP size on prefill and decode" + if hasattr(self.token_to_kv_pool, "get_state_buf_infos"): state_data_ptrs, state_data_lens, state_item_lens = ( self.token_to_kv_pool.get_state_buf_infos() @@ -157,7 +167,7 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.state_data_lens = state_data_lens kv_args.state_item_lens = state_item_lens - if isinstance(self.token_to_kv_pool, SWAKVPool): + if isinstance(self.token_to_kv_pool, (SWAKVPool, DeepSeekV4TokenToKVPool)): kv_args.state_type = "swa" elif isinstance(self.token_to_kv_pool, HybridLinearKVPool): kv_args.state_type = "mamba" @@ -697,8 +707,10 @@ def send_kv_chunk( .cpu() .numpy() ] - elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool): - # SWA hybrid model: send last window KV indices + elif isinstance( + self.token_to_kv_pool_allocator.get_kvcache(), + (SWAKVPool, DeepSeekV4TokenToKVPool), + ): seq_len = len(req.fill_ids) window_size = self.sliding_window_size window_start = max(0, seq_len - window_size) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index d660172de587..8f3e4416d35f 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -340,9 +340,10 @@ def kv_to_page_num(num_kv_indices: int, page_size: int): def is_mla_backend(target_kv_pool) -> bool: + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool - return isinstance(target_kv_pool, MLATokenToKVPool) + return isinstance(target_kv_pool, (MLATokenToKVPool, DeepSeekV4TokenToKVPool)) def prepare_abort(req: Req, error_message: str, status_code=None): diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index e71f93ebc3b8..f9dcacc32178 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -433,6 +433,12 @@ def dispatch_custom_allreduce(): On AMD with 1-stage AR enabled, use sglang's CustomAllreduce (has deterministic_all_reduce method). Otherwise use AiterCustomAllreduce if available. """ + if _is_cuda and envs.SGLANG_OPT_USE_CUSTOM_ALL_REDUCE_V2.get(): + from .custom_all_reduce_v2 import CustomAllReduceV2 + + logger.debug("[AR] Using CustomAllReduceV2 (JIT-compiled)") + return CustomAllReduceV2 + if _is_cuda: return CustomAllreduce diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py new file mode 100644 index 000000000000..9bdbae0116b0 --- /dev/null +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_v2.py @@ -0,0 +1,172 @@ +import logging +from contextlib import contextmanager +from dataclasses import dataclass, replace +from typing import Dict, List, Optional, TypeVar + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from sglang.jit_kernel.all_reduce import AllReduceAlgo, get_custom_all_reduce_cls +from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( + is_weak_contiguous, +) +from sglang.srt.utils import is_sm100_supported, log_info_on_rank0 + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +INF = 1 << 60 + + +@dataclass(frozen=True) +class ModeConfig: + one_shot_push_threshold: int # below this, use one-shot push + one_shot_pull_threshold: int # below this, use one-shot pull + + +class CustomAllReduceV2: + def __init__( + self, + group: ProcessGroup, + device: torch.device, + max_pull_size: Optional[int] = None, + max_push_size: Optional[int] = None, + ) -> None: + _init_config() + self.disabled = True + self.group = group + self.rank = dist.get_rank(group=self.group) + self.world_size = dist.get_world_size(group=self.group) + self.override_shot(None) + if max_pull_size is None: + max_pull_size = 16 * 1024 * 1024 # default to 16MB + if max_push_size is None: + max_push_size = self.config.one_shot_push_threshold + max_push_size = min(max_push_size, max_pull_size) + self.max_pull_size = max_pull_size + self.max_push_size = max_push_size + self.override_algo: Optional[AllReduceAlgo] = None + self.obj = get_custom_all_reduce_cls()( + rank=self.rank, + world_size=self.world_size, + pull_buffer_bytes=self.max_pull_size, + push_buffer_bytes=self.max_push_size, + graph_input_count=131072, + ) + self._post_init_obj() + self.disabled = False + log_info_on_rank0(logger, "Custom allreduce v2 initialized successfully") + + def override_shot(self, shot: int | None): + if shot is None: + self.config = THRESHOLD_2_SHOT_MAP[self.world_size] + else: + assert shot in (1, 2) + threshold = INF if shot == 1 else 0 + self.config = replace(self.config, one_shot_pull_threshold=threshold) + + @contextmanager + def capture(self): + try: + self.obj.set_cuda_graph_capture(True) + yield + finally: + self.obj.set_cuda_graph_capture(False) + if not self.disabled: + # cannot call when graph is capturing + assert ( + torch.cuda.is_current_stream_capturing() == False + ), "Cannot register graph inputs while capturing CUDA graph" + pairs = self.obj.share_graph_inputs() + handles = [handle for _, handle in pairs] + offsets = [offset for offset, _ in pairs] + handles_all = self._share_list(handles) + offsets_all = self._share_list(offsets) + result = [list(zip(o, h)) for o, h in zip(offsets_all, handles_all)] + self.obj.register_inputs(result) + log_info_on_rank0(logger, f"Registering {len(pairs)} cuda graph addresses") + + def should_custom_ar(self, inp: torch.Tensor) -> bool: + """Check if the input tensor is suitable for custom all-reduce.""" + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + return inp_size <= self.max_pull_size + + def custom_all_reduce(self, input: torch.Tensor) -> torch.Tensor: + return self._all_reduce(input) + + def close(self): + if not self.disabled and hasattr(self, "obj"): + self.obj.free(self.group) + + def _all_reduce(self, input: torch.Tensor) -> torch.Tensor: + """Perform the actual all-reduce via JIT kernel.""" + algo = self._determine_algo(input) + return torch.from_dlpack(self.obj.all_reduce(input, algo)) + + def _determine_algo(self, input: torch.Tensor) -> AllReduceAlgo: + if self.override_algo is not None: + return self.override_algo + input_bytes = input.numel() * input.element_size() + if input_bytes <= self.config.one_shot_push_threshold: + return AllReduceAlgo.ONE_SHOT_PUSH + if input_bytes <= self.config.one_shot_pull_threshold: + return AllReduceAlgo.ONE_SHOT_PULL + else: + return AllReduceAlgo.TWO_SHOT_PULL + + def _post_init_obj(self): + handles = [self.obj.share_storage()] + result = self._share_list(handles) + assert all(len(r) == 1 for r in result) + result = [h[0] for h in result] + self.obj.post_init(result) + + def _share_list(self, input: List[T]) -> List[List[T]]: + input_tensor = torch.tensor(input, dtype=torch.int64, device="cpu") + gather_list = [torch.empty_like(input_tensor) for _ in range(self.world_size)] + dist.all_gather(gather_list, input_tensor, group=self.group) + return [g.tolist() for g in gather_list] + + def __del__(self): + self.close() + + +def _init_config(): + global THRESHOLD_2_SHOT_MAP + KB, MB = 1024, 1024 * 1024 + + if is_sm100_supported(): + # NOTE: This result is based on benchmarks on B200 GPUs + THRESHOLD_2_SHOT_MAP = { + 2: ModeConfig(4 * MB, INF), + 3: ModeConfig(4 * MB, 4 * MB), + 4: ModeConfig(2 * MB, 2 * MB), + 5: ModeConfig(2 * MB, 2 * MB), + 6: ModeConfig(1 * MB, 1 * MB), + 7: ModeConfig(896 * KB, 896 * KB), + 8: ModeConfig(720 * KB, 720 * KB), + } + else: + # NOTE: This result is based on benchmarks on H200 GPUs + THRESHOLD_2_SHOT_MAP = { + 2: ModeConfig(2 * MB, INF), + 3: ModeConfig(512 * KB, 512 * KB), + 4: ModeConfig(384 * KB, 256 * KB), + 5: ModeConfig(256 * KB, 256 * KB), + 6: ModeConfig(192 * KB, 192 * KB), + 7: ModeConfig(192 * KB, 192 * KB), + 8: ModeConfig(160 * KB, 160 * KB), + } + # TODO: tune on more GPUs, e.g A100 + + +THRESHOLD_2_SHOT_MAP: Dict[int, ModeConfig] = {} diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 8d7ab8716803..56a27f56e4af 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -313,6 +313,7 @@ async def async_generate( lora_path: Optional[List[Optional[str]]] = None, custom_logit_processor: Optional[Union[List[str], str]] = None, return_hidden_states: bool = False, + return_routed_experts: bool = False, stream: bool = False, bootstrap_host: Optional[Union[List[str], str]] = None, bootstrap_port: Optional[Union[List[int], int]] = None, @@ -350,6 +351,7 @@ async def async_generate( token_ids_logprob=token_ids_logprob, lora_path=lora_path, return_hidden_states=return_hidden_states, + return_routed_experts=return_routed_experts, stream=stream, custom_logit_processor=custom_logit_processor, bootstrap_host=bootstrap_host, diff --git a/python/sglang/srt/entrypoints/openai/encoding_dsv4.py b/python/sglang/srt/entrypoints/openai/encoding_dsv4.py new file mode 100644 index 000000000000..9196f9bfe076 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/encoding_dsv4.py @@ -0,0 +1,850 @@ +# Adapted from the DeepSeek-V4 release reference implementation. +""" +DeepSeek-V4 Encoding + +A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages +with tool calling, thinking mode, and quick instruction task support. +""" + +import copy +import json +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +# ============================================================ +# Special Tokens +# ============================================================ + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" + +USER_SP_TOKEN = "<|User|>" +ASSISTANT_SP_TOKEN = "<|Assistant|>" +LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>" + +# Task special tokens for internal classification tasks +DS_TASK_SP_TOKENS = { + "action": "<|action|>", + "query": "<|query|>", + "authority": "<|authority|>", + "domain": "<|domain|>", + "title": "<|title|>", + "read_url": "<|read_url|>", +} +VALID_TASKS = set(DS_TASK_SP_TOKENS.keys()) + +# ============================================================ +# Templates +# ============================================================ + +system_msg_template: str = "{content}" +user_msg_template: str = "{content}" +latest_reminder_msg_template: str = "{content}" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token +assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}" +thinking_template: str = "{reasoning_content}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = ( + '<{dsml_token}invoke name="{name}">\n{arguments}\n' +) +tool_calls_template = ( + "<{dsml_token}{tc_block_name}>\n{tool_calls}\n" +) +tool_calls_block_name: str = "tool_calls" + +tool_output_template: str = "{content}" + +REASONING_EFFORT_MAX = ( + "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n" + "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n" + "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n" +) + +TOOLS_TEMPLATE = """## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following: + +<{dsml_token}tool_calls> +<{dsml_token}invoke name="$TOOL_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response. + +Otherwise, output directly after {thinking_end_token} with tool calls or final response. + +### Available Tool Schemas + +{tool_schemas} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +""" + +# ============================================================ +# Utility Functions +# ============================================================ + + +def to_json(value: Any) -> str: + """Serialize a value to JSON string.""" + try: + return json.dumps(value, ensure_ascii=False) + except: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + """Extract function definitions from OpenAI-format tool list.""" + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + """Convert OpenAI-format tool calls to internal format.""" + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + """Convert internal tool calls to OpenAI format.""" + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str: + """ + Encode tool call arguments into DSML parameter format. + + Args: + tool_call: Dict with "name" and "arguments" (JSON string) keys. + + Returns: + DSML-formatted parameter string. + """ + p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' + P_dsml_strs = [] + + try: + arguments = json.loads(tool_call["arguments"]) + except Exception as err: + arguments = {"arguments": tool_call["arguments"]} + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments( + tool_name: str, tool_args: Dict[str, Tuple[str, str]] +) -> Dict[str, str]: + """ + Decode DSML parameters back to a tool call dict. + + Args: + tool_name: Name of the tool. + tool_args: Dict mapping param_name -> (value, is_string_flag). + + Returns: + Dict with "name" and "arguments" (JSON string) keys. + """ + + def _decode_value(key: str, value: str, string: str): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = ( + "{" + + ", ".join( + [_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()] + ) + + "}" + ) + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str: + """ + Render tool schemas into the system prompt format. + + Args: + tools: List of tool schema dicts (each with name, description, parameters). + + Returns: + Formatted tools section string. + """ + tools_json = [to_json(t) for t in tools] + + return TOOLS_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages: List[Dict[str, Any]]) -> int: + """Find the index of the last user/developer message.""" + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def attach_task_to_last_user_message(messages: List[Dict[str, Any]], task: str) -> None: + """Set `task` on the most recent user/developer message; raise if none exists.""" + idx = find_last_user_index(messages) + if idx == -1: + raise ValueError( + "`task` requires at least one message with role='user' or 'developer'." + ) + messages[idx]["task"] = task + + +# ============================================================ +# Message Rendering +# ============================================================ + + +def render_message( + index: int, + messages: List[Dict[str, Any]], + thinking_mode: str, + drop_thinking: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + """ + Render a single message at the given index into its encoded string form. + + This is the core function that converts each message in the conversation + into the DeepSeek-V4 format. + + Args: + index: Index of the message to render. + messages: Full list of messages in the conversation. + thinking_mode: Either "chat" or "thinking". + drop_thinking: Whether to drop reasoning content from earlier turns. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + Encoded string for this message. + """ + assert 0 <= index < len(messages) + assert thinking_mode in [ + "chat", + "thinking", + ], f"Invalid thinking_mode `{thinking_mode}`" + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning_content = msg.get("reasoning_content") + wo_eos = msg.get("wo_eos", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + # Reasoning effort prefix (only at index 0 in thinking mode with max effort) + assert reasoning_effort in [ + "max", + None, + "high", + ], f"Invalid reasoning effort: {reasoning_effort}" + if index == 0 and thinking_mode == "thinking" and reasoning_effort == "max": + prompt += REASONING_EFFORT_MAX + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + if response_format: + prompt += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + elif role == "developer": + assert content, f"Invalid message for role `{role}`: {msg}" + + content_developer = USER_SP_TOKEN + content_developer += content + + if tools: + content_developer += "\n\n" + render_tools(tools) + if response_format: + content_developer += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + prompt += user_msg_template.format(content=content_developer) + + elif role == "user": + prompt += USER_SP_TOKEN + + # Handle content blocks (tool results mixed with text) + content_blocks = msg.get("content_blocks") + if content_blocks: + parts = [] + for block in content_blocks: + block_type = block.get("type") + if block_type == "text": + parts.append(block.get("text", "")) + elif block_type == "tool_result": + tool_content = block.get("content", "") + if isinstance(tool_content, list): + text_parts = [] + for b in tool_content: + if b.get("type") == "text": + text_parts.append(b.get("text", "")) + else: + text_parts.append(f"[Unsupported {b.get('type')}]") + tool_content = "\n\n".join(text_parts) + parts.append(tool_output_template.format(content=tool_content)) + else: + parts.append(f"[Unsupported {block_type}]") + prompt += "\n\n".join(parts) + else: + prompt += content or "" + + elif role == "latest_reminder": + prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format( + content=content + ) + + elif role == "tool": + raise NotImplementedError( + "deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()" + ) + + elif role == "assistant": + thinking_part = "" + tc_content = "" + + if tool_calls: + tc_list = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tc.get("name"), + arguments=encode_arguments_to_dsml(tc), + ) + for tc in tool_calls + ] + tc_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, + tool_calls="\n".join(tc_list), + tc_block_name=tool_calls_block_name, + ) + + summary_content = content or "" + rc = reasoning_content or "" + + # Check if previous message has a task - if so, this is a task output (no thinking) + prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None + + if thinking_mode == "thinking" and not prev_has_task: + if not drop_thinking or index > last_user_idx: + thinking_part = ( + thinking_template.format(reasoning_content=rc) + thinking_end_token + ) + else: + thinking_part = "" + + if wo_eos: + prompt += assistant_msg_wo_eos_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + # Append transition tokens based on what follows + if index + 1 < len(messages) and messages[index + 1].get("role") not in [ + "assistant", + "latest_reminder", + ]: + return prompt + + task = messages[index].get("task") + if task is not None: + # Task special token for internal classification tasks + assert ( + task in VALID_TASKS + ), f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}" + task_sp_token = DS_TASK_SP_TOKENS[task] + + if task != "action": + # Non-action tasks: append task sp token directly after the message + prompt += task_sp_token + else: + # Action task: append Assistant + thinking token + action sp token + prompt += ASSISTANT_SP_TOKEN + prompt += ( + thinking_end_token + if thinking_mode != "thinking" + else thinking_start_token + ) + prompt += task_sp_token + + elif messages[index].get("role") in ["user", "developer"]: + # Normal generation: append Assistant + thinking token + prompt += ASSISTANT_SP_TOKEN + if not drop_thinking and thinking_mode == "thinking": + prompt += thinking_start_token + elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx: + prompt += thinking_start_token + else: + prompt += thinking_end_token + + return prompt + + +# ============================================================ +# Preprocessing +# ============================================================ + + +def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merge tool messages into the preceding user message using content_blocks format. + + DeepSeek-V4 does not have a standalone "tool" role; instead, tool results + are encoded as blocks within user messages. + + This function converts a standard OpenAI-format conversation (with separate + "tool" role messages) into V4 format where tool results are merged into + user messages. + + Args: + messages: List of message dicts in OpenAI format. + + Returns: + Processed message list with tool messages merged into user messages. + """ + merged: List[Dict[str, Any]] = [] + + for msg in messages: + msg = copy.deepcopy(msg) + role = msg.get("role") + + if role == "tool": + # Convert tool message to a user message with tool_result block + tool_block = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": msg.get("content", ""), + } + # Merge into previous message if it's already a user (merged tool) + if ( + merged + and merged[-1].get("role") == "user" + and "content_blocks" in merged[-1] + ): + merged[-1]["content_blocks"].append(tool_block) + else: + merged.append( + { + "role": "user", + "content_blocks": [tool_block], + } + ) + elif role == "user": + text_block = {"type": "text", "text": msg.get("content", "")} + if ( + merged + and merged[-1].get("role") == "user" + and "content_blocks" in merged[-1] + and merged[-1].get("task") is None + ): + merged[-1]["content_blocks"].append(text_block) + else: + new_msg = { + "role": "user", + "content": msg.get("content", ""), + "content_blocks": [text_block], + } + # Preserve extra fields (task, wo_eos, mask, etc.) + for key in ("task", "wo_eos", "mask"): + if key in msg: + new_msg[key] = msg[key] + merged.append(new_msg) + else: + merged.append(msg) + + return merged + + +def sort_tool_results_by_call_order( + messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: + """ + Sort tool_result blocks within user messages by the order of tool_calls + in the preceding assistant message. + + Args: + messages: Preprocessed message list (after merge_tool_messages). + + Returns: + Message list with sorted tool result blocks. + """ + last_tool_call_order: Dict[str, int] = {} + + for msg in messages: + role = msg.get("role") + if role == "assistant" and msg.get("tool_calls"): + last_tool_call_order = {} + for idx, tc in enumerate(msg["tool_calls"]): + tc_id = tc.get("id") or tc.get("function", {}).get("id", "") + if tc_id: + last_tool_call_order[tc_id] = idx + + elif role == "user" and msg.get("content_blocks"): + tool_blocks = [ + b for b in msg["content_blocks"] if b.get("type") == "tool_result" + ] + if len(tool_blocks) > 1 and last_tool_call_order: + sorted_blocks = sorted( + tool_blocks, + key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0), + ) + sorted_idx = 0 + new_blocks = [] + for block in msg["content_blocks"]: + if block.get("type") == "tool_result": + new_blocks.append(sorted_blocks[sorted_idx]) + sorted_idx += 1 + else: + new_blocks.append(block) + msg["content_blocks"] = new_blocks + + return messages + + +# ============================================================ +# Main Encoding Function +# ============================================================ + + +def encode_messages( + messages: List[Dict[str, Any]], + thinking_mode: str, + context: Optional[List[Dict[str, Any]]] = None, + drop_thinking: bool = True, + add_default_bos_token: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + """ + Encode a list of messages into the DeepSeek-V4 prompt format. + + This is the main entry point for encoding conversations. It handles: + - BOS token insertion + - Thinking mode with optional reasoning content dropping + - Tool message merging into user messages + - Multi-turn conversation context + + Args: + messages: List of message dicts to encode. + thinking_mode: Either "chat" or "thinking". + context: Optional preceding context messages (already encoded prefix). + drop_thinking: If True, drop reasoning_content from earlier assistant turns + (only keep reasoning for messages after the last user message). + add_default_bos_token: Whether to prepend BOS token at conversation start. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + The encoded prompt string. + """ + context = context if context else [] + + # Preprocess: merge tool messages and sort tool results + messages = merge_tool_messages(messages) + messages = sort_tool_results_by_call_order(context + messages)[len(context) :] + if context: + context = merge_tool_messages(context) + context = sort_tool_results_by_call_order(context) + + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + # Resolve drop_thinking: if any message has tools defined, don't drop thinking + effective_drop_thinking = drop_thinking + if any(m.get("tools") for m in full_messages): + effective_drop_thinking = False + + if thinking_mode == "thinking" and effective_drop_thinking: + full_messages = _drop_thinking_messages(full_messages) + # After dropping, recalculate how many messages to render + # (context may have shrunk too) + num_to_render = len(full_messages) - len(_drop_thinking_messages(context)) + context_len = len(full_messages) - num_to_render + else: + num_to_render = len(messages) + context_len = len(context) + + for idx in range(num_to_render): + prompt += render_message( + idx + context_len, + full_messages, + thinking_mode=thinking_mode, + drop_thinking=effective_drop_thinking, + reasoning_effort=reasoning_effort, + ) + + return prompt + + +def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Drop reasoning_content and non-essential messages before the last user message. + + Behavior: + - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept. + - Messages at or after the last user index are always kept. + - Assistant messages before the last user get reasoning_content removed. + - Developer messages before the last user are dropped entirely. + """ + last_user_idx = find_last_user_index(messages) + result = [] + keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"} + + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in keep_roles or idx >= last_user_idx: + result.append(msg) + elif role == "assistant": + msg = copy.copy(msg) + msg.pop("reasoning_content", None) + result.append(msg) + # developer and other roles before last_user_idx are dropped + + return result + + +# ============================================================ +# Parsing (Decoding model output) +# ============================================================ + + +def _read_until_stop( + index: int, text: str, stop: List[str] +) -> Tuple[int, str, Optional[str]]: + """ + Read text from index until one of the stop strings is found. + + Returns: + Tuple of (new_index, content_before_stop, matched_stop_string_or_None). + """ + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls( + index: int, text: str +) -> Tuple[int, Optional[str], List[Dict[str, str]]]: + """ + Parse DSML tool calls from text starting at the given index. + + Args: + index: Starting position in text. + text: The full text to parse. + + Returns: + Tuple of (new_index, last_stop_token, list_of_tool_call_dicts). + Each tool call dict has "name" and "arguments" keys. + """ + tool_calls: List[Dict[str, Any]] = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}invoke", tool_calls_end_token] + ) + if _ != ">\n": + raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise ValueError("Missing special token in tool calls") + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL + ) + if len(p_tool_name) != 1: + raise ValueError(f"Tool name format error: '{tool_name_content}'") + tool_name = p_tool_name[0] + + tool_args: Dict[str, Tuple[str, str]] = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop( + index, text, [f"/{dsml_token}parameter"] + ) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + if len(param_kv) != 1: + raise ValueError(f"Parameter format error: '{param_content}'") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise ValueError(f"Duplicate parameter name: '{param_name}'") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n": + raise ValueError( + f"Parameter format error: expected '>\\n' but got '{content}'" + ) + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]: + """ + Parse a model completion text into a structured assistant message. + + This function takes the raw text output from the model (a single assistant turn) + and extracts: + - reasoning_content (thinking block) + - content (summary/response) + - tool_calls (if any) + + NOTE: This function is designed to parse only correctly formatted strings and + will raise ValueError for malformed output. + + Args: + text: The raw completion text (including EOS token). + thinking_mode: Either "chat" or "thinking". + + Returns: + Dict with keys: "role", "content", "reasoning_content", "tool_calls". + tool_calls are in OpenAI format. + """ + summary_content, reasoning_content, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}" + + is_thinking = thinking_mode == "thinking" + is_tool_calling = False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop( + index, text, [thinking_end_token, tool_calls_start_token] + ) + reasoning_content = content_delta + assert ( + stop_token == thinking_end_token + ), "Invalid thinking format: missing " + + index, content_delta, stop_token = _read_until_stop( + index, text, [eos_token, tool_calls_start_token] + ) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + assert stop_token == eos_token, "Invalid format: missing EOS token" + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + assert not tool_ends_text, "Unexpected content after tool calls" + + assert len(text) == index and stop_token in [ + eos_token, + None, + ], "Unexpected content at end" + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + assert ( + sp_token not in summary_content and sp_token not in reasoning_content + ), f"Unexpected special token '{sp_token}' in content" + + return { + "role": "assistant", + "content": summary_content, + "reasoning_content": reasoning_content, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 40ad9f3fb0b4..e462f0a4e965 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -17,7 +17,17 @@ import time import uuid from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union +from typing import ( + Any, + Dict, + List, + NamedTuple, + Optional, + Tuple, + TypeAlias, + Union, + get_args, +) from openai.types.responses import ( ResponseFunctionToolCall, @@ -425,8 +435,14 @@ class ToolCall(BaseModel): function: FunctionResponse +_GenericMessageRole = Literal[ + "system", "assistant", "tool", "function", "developer", "latest_reminder" +] +_GENERIC_MESSAGE_ROLES: Tuple[str, ...] = get_args(_GenericMessageRole) + + class ChatCompletionMessageGenericParam(BaseModel): - role: Literal["system", "assistant", "tool", "function", "developer"] + role: _GenericMessageRole content: Union[str, List[ChatCompletionMessageContentPart], None] = Field( default=None ) @@ -441,10 +457,9 @@ class ChatCompletionMessageGenericParam(BaseModel): def _normalize_role(cls, v): if isinstance(v, str): v_lower = v.lower() - if v_lower not in {"system", "assistant", "tool", "function", "developer"}: - raise ValueError( - "'role' must be one of 'system', 'developer', 'assistant', 'tool', or 'function' (case-insensitive)." - ) + if v_lower not in _GENERIC_MESSAGE_ROLES: + allowed = ", ".join(repr(r) for r in _GENERIC_MESSAGE_ROLES) + raise ValueError(f"'role' must be one of {allowed} (case-insensitive).") return v_lower raise ValueError("'role' must be a string") @@ -526,12 +541,23 @@ class ChatCompletionRequest(BaseModel): ) # noqa return_hidden_states: bool = False return_routed_experts: bool = False - reasoning_effort: Optional[Literal["low", "medium", "high"]] = Field( + reasoning_effort: Optional[Literal["low", "medium", "high", "max"]] = Field( default="medium", description="Constrains effort on reasoning for reasoning models. " - "'low' is the least effort, 'high' is the most effort. Reducing reasoning effort can " - "result in faster responses and fewer tokens used on reasoning in a response. " - "Currently only supported for OpenAI models in the harmony path, i.e GPT-OSS models.", + "'low' is the least effort, 'high' is the most effort. Reducing reasoning " + "effort can result in faster responses and fewer tokens used on reasoning " + "in a response. 'max' is an sglang extension to the OpenAI schema for " + "models that expose a maximum-effort tier above 'high'; models that don't " + "support it treat it the same as 'high'.", + ) + task: Optional[ + Literal["action", "query", "authority", "domain", "title", "read_url"] + ] = Field( + default=None, + description="DeepSeek-V4 quick instruction task. When set, the last " + "user/developer message is treated as a single-shot classification prompt " + "and the corresponding task special token (e.g. `<|domain|>`) is appended " + "before generation. Only honored by the dsv4 chat encoder; ignored otherwise.", ) # Extra parameters for SRT backend only and will be ignored by OpenAI models. diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 3b8773a551cb..8e57dac038ae 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -13,7 +13,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse from jsonschema import Draft202012Validator, SchemaError -from sglang.srt.entrypoints.openai.encoding_dsv32 import encode_messages +from sglang.srt.entrypoints.openai import encoding_dsv4, encoding_dsv32 from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -41,6 +41,7 @@ process_routed_experts_from_ret, to_openai_style_logprobs, ) +from sglang.srt.environ import envs from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.json_array_parser import JsonArrayParser @@ -112,7 +113,9 @@ def __init__( and self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss" ) - self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding() + # Which Python-based chat encoder (if any) bypasses apply_chat_template. + # Values: "dsv32", "dsv4", or None. + self.chat_encoding_spec = self._resolve_chat_encoding_spec() def _handle_last_assistant_message( self, @@ -170,14 +173,25 @@ def _append_assistant_prefix_to_prompt_ids( encoded = encoded[1:] return prompt_ids + encoded - def _use_dpsk_v32_encoding(self) -> bool: + def _resolve_chat_encoding_spec(self) -> Optional[str]: + if self.tool_call_parser == "deepseekv4": + return "dsv4" + if self.tool_call_parser == "deepseekv32": + return "dsv32" + + architectures = self.tokenizer_manager.model_config.hf_config.architectures + arch = architectures[0] if architectures else "" + + if "DeepseekV4" in arch: + return "dsv4" + has_chat_template = ( self.tokenizer_manager.tokenizer is not None and self.tokenizer_manager.tokenizer.chat_template is not None ) - architectures = self.tokenizer_manager.model_config.hf_config.architectures - is_dpsk_v32 = "DeepseekV3" in architectures[0] if architectures else False - return not has_chat_template and is_dpsk_v32 + if "DeepseekV3" in arch and not has_chat_template: + return "dsv32" + return None def _request_id_prefix(self) -> str: return "chatcmpl-" @@ -377,14 +391,22 @@ def _apply_jinja_template( template_content_format = self.template_manager.jinja_template_content_format - if self.use_dpsk_v32_encoding: - thinking_mode = ( - "thinking" - if (request.chat_template_kwargs or {}).get("thinking") - else "chat" + if self.chat_encoding_spec is not None: + # Per-request wins; env is fallback so existing + # `export SGLANG_ENABLE_THINKING=1` workflow keeps working here. + thinking_requested = (request.chat_template_kwargs or {}).get( + "thinking", envs.SGLANG_ENABLE_THINKING.get() ) - messages = request.messages - messages = [msg.model_dump() for msg in messages] + thinking_mode = "thinking" if thinking_requested else "chat" + messages = [msg.model_dump() for msg in request.messages] + + # dsv4/dsv32 are text-only and consume string content; flatten + # OpenAI parts-list content here so the encoder sees a plain string. + for i, msg in enumerate(messages): + if isinstance(msg.get("content"), list): + messages[i] = process_content_for_template_format( + msg, "string", [], [], [], [] + ) # Handle continue_final_message: separate final assistant message messages, assistant_prefix = self._handle_last_assistant_message( @@ -396,7 +418,32 @@ def _apply_jinja_template( messages.insert(0, {"role": "system", "content": ""}) if request.tools: messages[0]["tools"] = [tool.model_dump() for tool in request.tools] - real_input = encode_messages(messages, thinking_mode=thinking_mode) + + if self.chat_encoding_spec == "dsv4": + # V4 encoder only accepts "max" / "high" / None. + # OpenAI protocol defaults to "medium" which V4 rejects; drop it. + # Fallback: if request didn't set it, try env SGLANG_REASONING_EFFORT. + effort_source = request.reasoning_effort + if effort_source is None: + env_val = envs.SGLANG_REASONING_EFFORT.get() + if env_val: + effort_source = env_val + v4_reasoning_effort = ( + effort_source if effort_source in ("max", "high") else None + ) + if request.task is not None: + encoding_dsv4.attach_task_to_last_user_message( + messages, request.task + ) + real_input = encoding_dsv4.encode_messages( + messages, + thinking_mode=thinking_mode, + reasoning_effort=v4_reasoning_effort, + ) + else: + real_input = encoding_dsv32.encode_messages( + messages, thinking_mode=thinking_mode + ) prompt_ids = self.tokenizer_manager.tokenizer.encode(real_input) # Append assistant prefix if continue_final_message is enabled @@ -446,17 +493,16 @@ def _apply_jinja_template( ) try: + chat_template_kwargs = request.chat_template_kwargs or {} + if envs.SGLANG_ENABLE_THINKING.get(): + chat_template_kwargs["thinking"] = True prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, tools=tools, reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), + **(chat_template_kwargs), return_dict=False, ) except Exception as e: @@ -475,11 +521,7 @@ def _apply_jinja_template( add_generation_prompt=True, tools=tools, reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), + **(chat_template_kwargs), return_dict=False, ) except jinja2.TemplateError as template_error: @@ -1194,7 +1236,7 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: """Judge whether the request needs reasoning""" if not self.reasoning_parser: return False - if self.reasoning_parser in ["deepseek-v3"]: + if self.reasoning_parser in ["deepseek-v3", "deepseek-v4"]: return ( request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 928ec998ee99..26566bbe59b7 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -156,6 +156,10 @@ class Envs: # Model & File Download SGLANG_USE_MODELSCOPE = EnvBool(False) SGLANG_DISABLED_MODEL_ARCHS = EnvTuple(tuple()) + # "none" = use checkpoint's config.json, "small"/"large" = force the packaged + # config_backup_{small,large}.json, "auto" = pick small/large based on the + # checkpoint's num_hidden_layers. + SGLANG_APPLY_CONFIG_BACKUP = EnvStr("auto") # Logging Options SGLANG_LOG_GC = EnvBool(False) @@ -332,11 +336,13 @@ class Envs: # DeepGemm SGLANG_ENABLE_JIT_DEEPGEMM = EnvBool(True) SGLANG_JIT_DEEPGEMM_PRECOMPILE = EnvBool(True) + SGLANG_JIT_DEEPGEMM_FAST_WARMUP = EnvBool(False) SGLANG_JIT_DEEPGEMM_COMPILE_WORKERS = EnvInt(4) SGLANG_IN_DEEPGEMM_PRECOMPILE_STAGE = EnvBool(False) SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm")) SGLANG_DG_USE_NVRTC = EnvBool(False) SGLANG_USE_DEEPGEMM_BMM = EnvBool(False) + SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT = EnvBool(True) SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD = EnvInt(8192) # DeepEP @@ -344,6 +350,12 @@ class Envs: SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK = EnvInt(128) SGLANG_DEEPEP_LL_COMBINE_SEND_NUM_SMS = EnvInt(32) SGLANG_BLACKWELL_OVERLAP_SHARED_EXPERTS_OUTSIDE_SBO = EnvBool(False) + SGLANG_HACK_OVERRIDE_TOPK_IDS_RANDOM = EnvBool(False) + SGLANG_HACK_FORCE_TID2EID_ZERO = EnvBool(False) + # Workaround torch.profiler+kineto first-call dropping all GPU events on + # PyTorch 2.9.1 + CUDA 13.0 + GB300. Run a tiny dummy 1-kernel profile at + # first start() to warm CUPTI activity callbacks. See journal 0427_011. + SGLANG_HACK_WARMUP_KINETO = EnvBool(False) # NSA Backend SGLANG_NSA_FUSE_TOPK = EnvBool(True) @@ -454,6 +466,82 @@ class Envs: # TokenizerManager SGLANG_REQUEST_STATE_WAIT_TIMEOUT = EnvInt(4) + SGLANG_ENABLE_THINKING = EnvBool(False) + # Default reasoning_effort for dsv4 chat encoder when request doesn't set it. + # Accepts "", "max", "high" (empty string means unset). Other values filtered to None. + SGLANG_REASONING_EFFORT = EnvStr("") + + SGLANG_DSV4_MODE = EnvStr("2604") + SGLANG_DSV4_2604_SUBMODE = EnvStr("2604B") + SGLANG_DSV4_FP4_EXPERTS = EnvBool(True) # Set False when using FP4-to-FP8 converted checkpoint with 2604 config + SGLANG_OPT_HISPARSE_C4_SHRINK = EnvInt(1) + SGLANG_OPT_DEEPGEMM_HC_PRENORM = EnvBool(True) + SGLANG_OPT_USE_TILELANG_MHC_PRE = EnvBool(True) + SGLANG_OPT_USE_TILELANG_MHC_POST = EnvBool(True) + SGLANG_HACK_FLASHMLA_BACKEND = EnvStr("kernel") + SGLANG_HACK_SKIP_FP4_FP8_GEMM = EnvBool(False) + SGLANG_OPT_FP8_WO_A_GEMM = EnvBool(False) + + + SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK = EnvBool(True) + SGLANG_OPT_USE_TILELANG_SWA_PREPARE = EnvBool(True) + SGLANG_OPT_USE_MULTI_STREAM_OVERLAP = EnvBool(True) + + SGLANG_FIX_MTP_HC_HIDDEN = EnvBool(True) + SGLANG_FIX_ATTN_BACKEND_IDLE = EnvBool(True) + SGLANG_FIX_PD_IDLE = EnvBool(True) + SGLANG_FIX_SWA_CHUNKED_REQ_DOUBLE_FREE = EnvBool(True) + SGLANG_OPT_V4_DRAFT_EXTEND_CUDA_GRAPH = EnvBool(False) # usually not useful + SGLANG_OPT_USE_FUSED_STORE_CACHE = EnvBool(True) + SGLANG_OPT_USE_OVERLAP_STORE_CACHE = EnvBool(True) + SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvStr("cublas") + SGLANG_OPT_USE_FUSED_HASH_TOPK = EnvBool(True) + SGLANG_OPT_USE_JIT_EP_ACTIVATION = EnvBool(True) + SGLANG_OPT_ALLOW_SHARED_EXPERT_DUAL_STREAM = EnvBool(True) # verified in journal 2026-04-21-017 + SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True) + SGLANG_OPT_SWA_RADIX_CACHE_COMPACT = EnvBool(True) + SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT = EnvBool(False) + SGLANG_OPT_SWA_EVICT_DROP_PAGE_MARGIN = EnvBool(False) + SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW = EnvBool(False) + SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD = EnvBool(True) + SGLANG_OPT_MXFP4_STATIC_SCALE_ONES = EnvBool(True) + SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING = EnvBool(True) + SGLANG_OPT_USE_JIT_INDEXER_METADATA = EnvBool(False) + SGLANG_OPT_SWIGLU_CLAMP_FUSION = EnvBool(True) + SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE = EnvInt(-1) + SGLANG_DSV4_FIX_ATTN_PADDING = EnvBool(True) # verified in journal 2026-04-21-017 + SGLANG_DSV4_FIX_TP_ATTN_A2A_SCATTER = EnvBool(True) + SGLANG_SHARED_EXPERT_TP1 = EnvBool(False) + SGLANG_DEBUG_SANITY_CHECK_CONFIG = EnvBool(False) + SGLANG_DEBUG_HACK_CP_ASSERT_PURE_EXTEND = EnvBool(False) + SGLANG_DEBUG_HACK_CP_CHECK_RANK_CONSISTENCY = EnvBool(False) + SGLANG_OPT_USE_TOPK_V2 = EnvBool(False) + SGLANG_OPT_FIX_APE_2604 = EnvBool(True) + SGLANG_OPT_CP_REARRANGE_TRITON = EnvBool(True) + SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE = EnvBool(False) + SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK = EnvInt(1024) + SGLANG_OPT_MEGA_MOE_FUSED_PRE_DISPATCH = EnvBool(True) + SGLANG_OPT_FUSE_WQA_WKV = EnvBool(True) + SGLANG_OPT_USE_JIT_NORM = EnvBool(False) + SGLANG_OPT_FIX_HASH_MEGA_MOE = EnvBool(False) + SGLANG_OPT_FIX_NEXTN_MEGA_MOE = EnvBool(False) + SGLANG_OPT_USE_CUSTOM_ALL_REDUCE_V2 = EnvBool(False) + SGLANG_OPT_FIX_MEGA_MOE_MEMORY = EnvBool(False) + SGLANG_FIX_DSV4_BASE_MODEL_LOAD = EnvBool(False) + SGLANG_HANDLE_C128_PREFILL_KERNEL = EnvBool(False) + SGLANG_HACK_DEBUG_DUMP_CREATE_PAGED_COMPRESS_DATA = EnvStr("") + SGLANG_OPT_USE_ONLINE_COMPRESS = EnvBool(False) + + # Dangerous untested flagas + SGLANG_OPT_USE_FAST_MASK_EP = EnvBool(False) + SGLANG_OPT_USE_FLASHINFER_NORM = EnvBool(False) + + SGLANG_PREP_IN_CUDA_GRAPH = EnvBool(True) + + SGLANG_OPT_USE_TILELANG_INDEXER = EnvBool(False) + SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) + SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) + # Symmetric Memory SGLANG_SYMM_MEM_PREALLOC_GB_SIZE = EnvInt(-1) @@ -466,6 +554,14 @@ class Envs: EnvField._allow_set_name = False +from functools import lru_cache + + +@lru_cache(maxsize=1) +def is_large_dummy_model() -> bool: + return os.environ.get("SGLANG_HACK_ASSERT_CKPT_VERSION") == "large-dummy" + + def _print_deprecated_env(new_name: str, old_name: str): if old_name in os.environ: warnings.warn( @@ -492,8 +588,7 @@ def _convert_SGL_to_SGLANG(): "SGLANG_MOE_NVFP4_DISPATCH", "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH" ) _print_deprecated_env( - "SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK", - "SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", + "SGLANG_PREP_IN_CUDA_GRAPH", "SGLANG_ADVANCED_CUDA_GRAPH_CAPTURE" ) for key, value in os.environ.items(): diff --git a/python/sglang/srt/function_call/deepseekv32_detector.py b/python/sglang/srt/function_call/deepseekv32_detector.py index 7d8742f39ec8..08d25a9cae38 100644 --- a/python/sglang/srt/function_call/deepseekv32_detector.py +++ b/python/sglang/srt/function_call/deepseekv32_detector.py @@ -81,8 +81,13 @@ def __init__(self): self.function_calls_regex = ( r"<|DSML|function_calls>(.*?)" ) + # Long-form `<|DSML|invoke name="x">...` and the + # self-closing `<|DSML|invoke name="x"/>` shape V4 emits for zero-arg + # tools. The `end` group is empty when the closer hasn't streamed in. self.invoke_regex = ( - r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)(|$)' + r'<|DSML|invoke\s+name="(?P[^"]+)"\s*' + r"(?:(?P/>)" + r"|>(?P.*?)(?P(?:|$)))" ) self.prefix_parameter_end_call = [" bool: """Check if the text contains a deepseek v32 format tool call.""" return self.bot_token in text or "<|DSML|invoke" in text + @staticmethod + def _unpack_invoke_match(m: "re.Match[str]") -> tuple[str, str, bool]: + """Returns (name, body, is_complete) for an invoke_regex match. + + Self-closing invokes have empty body and are always complete. + Long-form bodies are always strings (possibly empty); they're + incomplete when matched against `$` because the closing tag + hasn't streamed in yet. + """ + name = m.group("name").strip() + if m.group("self_close"): + return name, "", True + return name, m.group("body"), bool(m.group("end")) + def _parse_parameters_from_xml( self, invoke_content: str, allow_partial: bool = False ) -> dict: @@ -191,14 +210,11 @@ def detect_and_parse(self, text: str, tools: list[Tool]) -> StreamingParseResult function_calls_content = function_calls_match.group(1) # Find all invoke blocks - invoke_matches = re.findall( + for invoke_match in re.finditer( self.invoke_regex, function_calls_content, re.DOTALL - ) - - for func_name, invoke_content, _ in invoke_matches: - # Parse parameters from XML format + ): + func_name, invoke_content, _ = self._unpack_invoke_match(invoke_match) func_args = self._parse_parameters_from_xml(invoke_content) - # construct match_result for parse_base_json match_result = {"name": func_name, "parameters": func_args} calls.extend(self.parse_base_json(match_result, tools)) @@ -253,10 +269,9 @@ def parse_streaming_increment( if not invoke_match: break - func_name = invoke_match.group(1).strip() - invoke_content = invoke_match.group(2) - # group(3) is either "" (complete) or "" (incomplete, matched with $) - is_tool_end = bool(invoke_match.group(3)) + func_name, invoke_content, is_tool_end = self._unpack_invoke_match( + invoke_match + ) # Initialize state if this is the first tool call if self.current_tool_id == -1: diff --git a/python/sglang/srt/function_call/deepseekv4_detector.py b/python/sglang/srt/function_call/deepseekv4_detector.py new file mode 100644 index 000000000000..2bb74ceebf3a --- /dev/null +++ b/python/sglang/srt/function_call/deepseekv4_detector.py @@ -0,0 +1,27 @@ +from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector + + +class DeepSeekV4Detector(DeepSeekV32Detector): + """ + Detector for DeepSeek V4 DSML tool-call format. + + Identical to V3.2 except the outer block wrapper is + ``<|DSML|tool_calls>...`` instead of + ``<|DSML|function_calls>...``. The inner + ``<|DSML|invoke>`` / ``<|DSML|parameter>`` shape is unchanged. + + Example (XML parameters): + ``` + <|DSML|tool_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="city" string="true">San Francisco + + + ``` + """ + + def __init__(self): + super().__init__() + self.bot_token = "<|DSML|tool_calls>" + self.eot_token = "" + self.function_calls_regex = r"<|DSML|tool_calls>(.*?)" diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 10d14cc432eb..6f2ac60f091c 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -12,6 +12,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.deepseekv4_detector import DeepSeekV4Detector from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector @@ -48,6 +49,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, "deepseekv32": DeepSeekV32Detector, + "deepseekv4": DeepSeekV4Detector, "glm": Glm4MoeDetector, "glm45": Glm4MoeDetector, "glm47": Glm47MoeDetector, diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 246e2554f4b6..9492e00162ff 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -84,6 +84,15 @@ def create_nsa_backend(runner): return NativeSparseAttnBackend(runner) +@register_attention_backend("compressed") +def create_compressed_backend(runner): + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DeepseekV4BackendRadix, + ) + + return DeepseekV4BackendRadix(runner) + + @register_attention_backend("triton") def create_triton_backend(runner): assert not runner.model_config.is_encoder_decoder, ( diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 8d14e32a916b..c46fdbde215d 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -57,6 +57,9 @@ def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" raise NotImplementedError() + def on_after_cuda_graph_warmup_pass(self): + pass + def get_verify_buffers_to_fill_after_draft(self): """ Return buffers of verify attention kernels that needs to be filled after draft. @@ -128,6 +131,7 @@ def forward_decode( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ): """Run a forward for decode.""" raise NotImplementedError() @@ -140,6 +144,7 @@ def forward_extend( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ): """Run a forward for extend.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/compressed/__init__.py b/python/sglang/srt/layers/attention/compressed/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/srt/layers/attention/compressed/compressor.py b/python/sglang/srt/layers/attention/compressed/compressor.py new file mode 100644 index 000000000000..81d4a8091d8a --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/compressor.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Literal, NamedTuple, Optional, Union + +import torch + +from sglang.jit_kernel.deepseek_v4 import ( + CompressorDecodePlan, + CompressorPrefillPlan, + compress_forward, + compress_fused_norm_rope_inplace, + triton_create_paged_compress_data, +) +from sglang.srt.environ import envs +from sglang.srt.layers.attention.nsa.quant_k_cache_v4 import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.layers.attention.nsa.triton_kernel import act_quant +from sglang.srt.layers.attention.nsa.utils import ( + assert_tensor_identical_across_cp_ranks, +) + +if TYPE_CHECKING: + from sglang.srt.layers.attention.compressed.metadata import DeepseekV4Metadata + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.models.deepseek_v4 import Compressor, DeepseekRefRMSNorm + + +class FusedCompressMetadata(NamedTuple): + write_loc: torch.Tensor + extra_data: Optional[torch.Tensor] + plan: Union[CompressorDecodePlan, CompressorPrefillPlan] + + def copy_(self, other: FusedCompressMetadata) -> None: + from .metadata import maybe_copy_inplace + + self.write_loc.copy_(other.write_loc) + maybe_copy_inplace(self.extra_data, src=other.extra_data) + self.plan.copy_(other.plan) + + +class CompressorBackend: + def __init__(self): + super().__init__() + self.forward_metadata: DeepseekV4Metadata + + def get_paged_compress_metadata(self, compress_ratio: int) -> FusedCompressMetadata: + attr_name = f"c{compress_ratio}_compress_metadata" + metadata = getattr(self.forward_metadata, attr_name) + assert isinstance(metadata, FusedCompressMetadata) + return metadata + + def forward_compress( + self, + *, + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + ape: torch.Tensor, + head_dim: int, + norm: DeepseekRefRMSNorm, + freqs_cis_cache: torch.Tensor, + rotate: bool, + forward_batch: ForwardBatch, + compress_ratio: int, + is_paged: bool = False, + ) -> torch.Tensor: + from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation + + assert compress_ratio == 4 or compress_ratio == 128 + if is_paged: + metadata = self.get_paged_compress_metadata(compress_ratio) + coff = 2 if is_overlap_compress(compress_ratio) else 1 + if compress_ratio == 128 and envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + kv_score_buffer = kv_score_buffer.view(-1, 1, head_dim * 3) + else: + last_dim = 2 * head_dim * coff + assert kv_score_buffer.shape[-1] == last_dim + kv_score_buffer = kv_score_buffer.view(-1, compress_ratio, last_dim) + else: + plan = make_compressor_plan(compress_ratio, forward_batch) + metadata = (forward_batch.req_pool_indices.to(torch.int32), None, plan) + indices, extra_data, plan = metadata + + kv_compressed = compress_forward( + kv_score_buffer=kv_score_buffer, + kv_score_input=kv_score_input, + ape=ape, + indices=indices, + plan=plan, + compress_ratio=compress_ratio, + head_dim=head_dim, + extra_data=extra_data, + ) + compress_fused_norm_rope_inplace( + kv_compressed, + norm.weight, + norm.eps, + freqs_cis_cache, + plan, + ) + return rotate_activation(kv_compressed) if rotate else kv_compressed + + def forward_core_compressor( + self, + x: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + compressor: Compressor, + ) -> None: + if forward_batch.forward_mode.is_idle(): + return + # PREP_IN_CG lazy upgrade: the concrete backend (DeepseekV4BackendRadix) + # owns this helper. MQALayer._forward_prepare calls us before + # attn_backend.forward(), so Raw -> Radix must happen here too + # (e.g. 1.6T layer 0 has compress_ratio=128 and needs cX_compress_metadata). + self._maybe_upgrade_forward_metadata() + token_to_kv_pool = forward_batch.token_to_kv_pool + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + new_compressed_kv = compressor(x, forward_batch) + if envs.SGLANG_DEBUG_HACK_CP_CHECK_RANK_CONSISTENCY.get(): + assert_tensor_identical_across_cp_ranks( + new_compressed_kv, + tag=f"compressor(ratio={compressor.ratio}) layer_id={layer_id}", + forward_batch=forward_batch, + ) + core_metadata = self.forward_metadata.core_metadata + out_loc = ( + core_metadata.c4_out_loc + if compressor.ratio == 4 + else core_metadata.c128_out_loc + ) + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + token_to_kv_pool.set_extra_key_buffer_fused( + layer_id=layer_id, + loc=out_loc, + cache_k=new_compressed_kv, + ) + else: + pack = quant_to_nope_fp8_rope_bf16_pack_triton(new_compressed_kv.bfloat16()) + token_to_kv_pool.set_extra_key_buffer(layer_id, out_loc, pack) + + def forward_indexer_compressor( + self, + x: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + compressor: Compressor, + ) -> None: + assert is_overlap_compress(compressor.ratio) + # PREP_IN_CG lazy upgrade (see forward_core_compressor for rationale). + self._maybe_upgrade_forward_metadata() + token_to_kv_pool = forward_batch.token_to_kv_pool + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + new_compressed_kv = compressor(x, forward_batch) + if envs.SGLANG_DEBUG_HACK_CP_CHECK_RANK_CONSISTENCY.get(): + assert_tensor_identical_across_cp_ranks( + new_compressed_kv, + tag=f"indexer_compressor(ratio={compressor.ratio}) layer_id={layer_id}", + forward_batch=forward_batch, + ) + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + token_to_kv_pool.set_index_k_fused( + layer_id=layer_id, + loc=self.forward_metadata.core_metadata.c4_out_loc, + cache_k=new_compressed_kv, + ) + else: + new_compressed_kv_fp8, new_compressed_kv_scale = act_quant( + new_compressed_kv + ) + token_to_kv_pool.set_index_k_scale_buffer( + layer_id=layer_id, + loc=self.forward_metadata.core_metadata.c4_out_loc, + index_k=new_compressed_kv_fp8, + index_k_scale=new_compressed_kv_scale, + ) + + +def is_overlap_compress(compress_ratio: int) -> bool: + return compress_ratio == 4 + + +def make_compressor_plan( + compress_ratio: Literal[4, 128], + forward_batch: ForwardBatch, +) -> Union[CompressorDecodePlan, CompressorPrefillPlan]: + if forward_batch.forward_mode.is_decode(): + seq_lens_32 = forward_batch.seq_lens.to(torch.int32) + return CompressorDecodePlan(compress_ratio, seq_lens_32) + if forward_batch.forward_mode.is_prefill(): + assert not forward_batch.forward_mode.is_target_verify() + extend_lens_list = forward_batch.extend_seq_lens_cpu + seq_lens_cpu = forward_batch.seq_lens_cpu + assert extend_lens_list is not None and seq_lens_cpu is not None + return CompressorPrefillPlan.generate( + compress_ratio=compress_ratio, + num_q_tokens=sum(extend_lens_list), + seq_lens=seq_lens_cpu, + extend_lens=torch.tensor(extend_lens_list), + device=forward_batch.seq_lens.device, + ) + elif forward_batch.forward_mode.is_target_verify(): + raise NotImplementedError("target verify mode to be implemented") + else: + raise NotImplementedError(f"unsupported mode {forward_batch.forward_mode=}") + + +def create_paged_compressor_data( + compress_ratio: Literal[4, 128], + *, + is_prefill: bool, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_lens: Optional[torch.Tensor] = None, + seq_lens_cpu: Optional[List[int]] = None, + extend_lens_cpu: Optional[List[int]] = None, + use_prefill_cuda_graph: bool = False, + num_q_tokens: Optional[int] = None, +) -> FusedCompressMetadata: + swa_page_size = token_to_kv_pool.swa_page_size + ring_size = token_to_kv_pool.get_ring_size(compress_ratio=compress_ratio) + # assert ring_size % compress_ratio == 0 + + def clip_down(positions: torch.Tensor) -> torch.Tensor: + return positions // compress_ratio * compress_ratio + + def get_raw_loc(positions: torch.Tensor) -> torch.Tensor: + positions = positions.masked_fill(positions < 0, 0) + loc = req_to_token[req_pool_indices, positions] + swa_loc = token_to_kv_pool.translate_loc_from_full_to_swa(loc) + swa_pages = swa_loc // swa_page_size + state_loc = swa_pages * ring_size + swa_loc % ring_size + return (state_loc // compress_ratio).to(torch.int32) + + is_overlap = is_overlap_compress(compress_ratio) + + if is_prefill: + assert extend_lens is not None + write_loc, extra_data = triton_create_paged_compress_data( + compress_ratio=compress_ratio, + is_overlap=is_overlap, + swa_page_size=swa_page_size, + ring_size=ring_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_seq_lens=extend_lens, + req_to_token=req_to_token, + full_to_swa_index_mapping=token_to_kv_pool.full_to_swa_index_mapping, + ) + + plan_kwargs: dict + if seq_lens_cpu is None: + assert num_q_tokens is not None + plan_kwargs = dict( + num_q_tokens=num_q_tokens, + seq_lens=seq_lens, + extend_lens=extend_lens, + ) + else: + assert extend_lens_cpu is not None + plan_kwargs = dict( + num_q_tokens=sum(extend_lens_cpu), + seq_lens=torch.tensor(seq_lens_cpu), + extend_lens=torch.tensor(extend_lens_cpu), + ) + plan = CompressorPrefillPlan.generate( + compress_ratio=compress_ratio, + device=seq_lens.device, + use_cuda_graph=use_prefill_cuda_graph, + **plan_kwargs, + ) + _maybe_dump_metadata_extras( + token_to_kv_pool=token_to_kv_pool, + compress_ratio=compress_ratio, + plan=plan, + ) + else: + write_positions = clip_down(seq_lens - 1) + write_loc = get_raw_loc(write_positions) + if is_overlap: + write_overlap_loc = get_raw_loc(write_positions - compress_ratio) + extra_data = write_overlap_loc.view(-1, 1) + else: + extra_data = None + plan = CompressorDecodePlan(compress_ratio, seq_lens.to(torch.int32)) + + return FusedCompressMetadata(write_loc=write_loc, extra_data=extra_data, plan=plan) + + +def _maybe_dump_metadata_extras( + *, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + compress_ratio: int, + plan: CompressorPrefillPlan, +) -> None: + from sglang.jit_kernel.deepseek_v4 import maybe_dump_compress_metadata_extras + + try: + ratio_idx = list(token_to_kv_pool.compression_ratios).index(compress_ratio) + pool = token_to_kv_pool.compress_state_pools[ratio_idx] + kv = pool.kv_score_buffer.kv_score + shape, dtype = kv.shape, kv.dtype + except (AttributeError, ValueError, IndexError): + return + maybe_dump_compress_metadata_extras( + compress_ratio=compress_ratio, + kv_score_buffer_shape=shape, + kv_score_buffer_dtype=dtype, + plan_compress_plan=plan.compress_plan, + plan_write_plan=plan.write_plan, + ) diff --git a/python/sglang/srt/layers/attention/compressed/indexer.py b/python/sglang/srt/layers/attention/compressed/indexer.py new file mode 100644 index 000000000000..27d7678eea9e --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/indexer.py @@ -0,0 +1,472 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from sglang.jit_kernel.deepseek_v4 import topk_transform_512, topk_transform_512_v2 +from sglang.srt.environ import envs +from sglang.srt.layers.attention.compressed.metadata import ( + PagedCoreMetadata, + PagedIndexerMetadata, +) +from sglang.srt.layers.attention.indexer_topk_capturer import ( + get_global_indexer_capturer, +) +from sglang.srt.layers.attention.nsa.triton_kernel import act_quant +from sglang.srt.utils import is_hip + +if TYPE_CHECKING: + from sglang.srt.layers.attention.compressed.compressor import CompressorBackend + from sglang.srt.layers.attention.compressed.metadata import DeepseekV4Metadata + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.models.deepseek_v4 import C4Indexer + + +if is_hip(): + FP8_DTYPE = torch.float8_e4m3fnuz + FP8_MAX = torch.finfo(FP8_DTYPE).max +else: + FP8_DTYPE = torch.float8_e4m3fn + FP8_MAX = torch.finfo(FP8_DTYPE).max + + +def fp8_paged_mqa_logits_torch( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + + assert head_dim == 128, "TODO" + assert block_size == 64, "TODO" + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + for i in range(batch_size): + q = q_fp8[i, 0] + q = q.to(torch.float32) + q_scale = weight[i] + seq_len = int(seq_lens[i].item()) + assert seq_len <= max_seq_len + num_pages = (seq_len + block_size - 1) // block_size + padded_seq_len = num_pages * block_size + pages = page_table[i, :num_pages] + kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + kvcache = kvcache_fp8[pages] + SCALE_OFFSET = block_size * head_dim + kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE) + kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32) + kvcache_value = kvcache_value.to(torch.float32) + kvcache_scale = kvcache_scale.contiguous() + kvcache_value = kvcache_value.view(padded_seq_len, head_dim) + kvcache_scale = kvcache_scale.view(padded_seq_len) + score = F.linear(kvcache_value, q) + score = F.relu(score) + score *= q_scale[None, :] + score = score.sum(dim=1) + score *= kvcache_scale + logits[i, :seq_len] = score[:seq_len] + + return logits + + +def topk_transform_512_pytorch_vectorized( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, +) -> None: + + TOPK = 512 + batch_size = scores.shape[0] + max_seq_len = scores.shape[1] + device = scores.device + + page_bits = (page_size - 1).bit_length() if page_size > 1 else 0 + page_mask = page_size - 1 + + positions = ( + torch.arange(max_seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + ) + valid_mask = positions < seq_lens.unsqueeze(1) + + masked_scores = scores.clone() + masked_scores[~valid_mask] = float("-inf") + + actual_k = min(TOPK, max_seq_len) + _, raw_indices = torch.topk( + masked_scores, k=actual_k, dim=1, largest=True, sorted=False + ) + raw_indices = raw_indices.to(torch.int32) + + if actual_k < TOPK: + padding = torch.zeros( + (batch_size, TOPK - actual_k), dtype=torch.int32, device=device + ) + raw_indices = torch.cat([raw_indices, padding], dim=1) + + batch_indices = ( + torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, TOPK) + ) + gathered_scores = scores[ + batch_indices.flatten(), raw_indices.clamp(min=0).flatten() + ].view(batch_size, TOPK) + + valid_topk = gathered_scores != float("-inf") + if actual_k < TOPK: + pad_mask = torch.arange(TOPK, device=device).unsqueeze(0) >= actual_k + valid_topk = valid_topk & ~pad_mask + + needs_sequential = seq_lens <= TOPK + if needs_sequential.any(): + sequential_indices = ( + torch.arange(TOPK, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(batch_size, -1) + ) + sequential_valid = sequential_indices < seq_lens.unsqueeze(1) + + raw_indices = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), + torch.where( + sequential_valid, + sequential_indices, + torch.tensor(-1, device=device, dtype=torch.int32), + ), + raw_indices, + ) + valid_topk = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), sequential_valid, valid_topk + ) + + page_idx = raw_indices >> page_bits + offset_in_page = raw_indices & page_mask + + page_idx_clamped = torch.clamp(page_idx, min=0) + physical_pages = torch.gather(page_tables, dim=1, index=page_idx_clamped.long()) + + page_indices = (physical_pages << page_bits) | offset_in_page + page_indices = page_indices.to(torch.int32) + + page_indices = torch.where( + valid_topk, page_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) + + out_page_indices.copy_(page_indices) + + if out_raw_indices is not None: + raw_indices = torch.where( + valid_topk, raw_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) + out_raw_indices.copy_(raw_indices) + + +@triton.jit +def _fused_scale_kernel( + weight_ptr, + q_scale_ptr, + out_ptr, + numel, + out_scale, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < numel + + w = tl.load(weight_ptr + offs, mask=mask) + qs = tl.load(q_scale_ptr + offs, mask=mask) + + acc = w.to(tl.float32) * out_scale * qs.to(tl.float32) + tl.store(out_ptr + offs, acc.to(out_ptr.dtype.element_ty), mask=mask) + + +def fused_scale( + weight: torch.Tensor, + out_scale: float, + q_scale: torch.Tensor, +) -> torch.Tensor: + assert weight.is_contiguous() and q_scale.is_contiguous() + B, H = weight.shape + numel = B * H + out_dtype = torch.promote_types(weight.dtype, q_scale.dtype) + out = torch.empty((B, H, 1), device=weight.device, dtype=out_dtype) + BLOCK = 1024 + grid = (triton.cdiv(numel, BLOCK),) + _fused_scale_kernel[grid]( + weight, + q_scale, + out, + numel, + out_scale, + BLOCK=BLOCK, + ) + return out + + +class C4IndexerBackend: + def __init__(self): + super().__init__() + self.forward_metadata: DeepseekV4Metadata + self.debug_use_external_c4_sparse_indices: bool = False + + def _forward_prepare_multi_stream( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + positions: torch.Tensor, + forward_batch: ForwardBatch, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + alt_streams: Optional[List[torch.cuda.Stream]] = None, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if TYPE_CHECKING: + assert isinstance(self, CompressorBackend) + + assert alt_streams is not None + assert len(alt_streams) >= 2 + current_stream = torch.cuda.current_stream() + stream_q = alt_streams[0] + stream_weights = alt_streams[1] + + stream_q.wait_stream(current_stream) + stream_weights.wait_stream(current_stream) + + self.forward_indexer_compressor( + x=x, + forward_batch=forward_batch, + layer_id=c4_indexer.layer_id, + compressor=c4_indexer.compressor, + ) + c4_indexer_kv_cache = token_to_kv_pool.get_index_k_with_scale_buffer( + layer_id=c4_indexer.layer_id, + ) + + with torch.cuda.stream(stream_q): + if q_lora_ready is not None: + stream_q.wait_event(q_lora_ready) + q = c4_indexer.compute_q(q_lora, positions=positions) + q_fp8, q_scale = act_quant(q) + q_scale_ready = stream_q.record_event() + + with torch.cuda.stream(stream_weights): + weights = c4_indexer.compute_weights(x, skip_scale=True) + stream_weights.wait_event(q_scale_ready) + weights = fused_scale(weights, c4_indexer.weight_scale, q_scale) + + current_stream.wait_stream(stream_q) + current_stream.wait_stream(stream_weights) + + return q_fp8, weights, c4_indexer_kv_cache + + def _forward_prepare_normal( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + positions: torch.Tensor, + forward_batch: ForwardBatch, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if TYPE_CHECKING: + assert isinstance(self, CompressorBackend) + + q = c4_indexer.compute_q(q_lora, positions=positions) + q_fp8, q_scale = act_quant(q) + weights = c4_indexer.compute_weights(x, skip_scale=True) + weights = fused_scale(weights, c4_indexer.weight_scale, q_scale) + self.forward_indexer_compressor( + x=x, + forward_batch=forward_batch, + layer_id=c4_indexer.layer_id, + compressor=c4_indexer.compressor, + ) + c4_indexer_kv_cache = token_to_kv_pool.get_index_k_with_scale_buffer( + layer_id=c4_indexer.layer_id, + ) + return q_fp8, weights, c4_indexer_kv_cache + + def forward_c4_indexer( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + forward_batch: ForwardBatch, + alt_streams: Optional[List[torch.cuda.Stream]] = None, + enable_multi_stream: bool = False, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> None: + if forward_batch.forward_mode.is_idle(): + return + # PREP_IN_CG lazy upgrade: this runs from MQALayer._forward_prepare, + # before attn_backend.forward() would trigger the upgrade. + self._maybe_upgrade_forward_metadata() + token_to_kv_pool = forward_batch.token_to_kv_pool + + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + assert isinstance(self, CompressorBackend) + + metadata = self.forward_metadata + indexer_metadata = metadata.indexer_metadata + core_metadata = metadata.core_metadata + + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DSV4AttnMetadataRadix, + ) + + assert isinstance(core_metadata, (PagedCoreMetadata, DSV4AttnMetadataRadix)) + assert isinstance(indexer_metadata, PagedIndexerMetadata) + + if enable_multi_stream: + q_fp8, weights, c4_indexer_kv_cache = self._forward_prepare_multi_stream( + x=x, + q_lora=q_lora, + c4_indexer=c4_indexer, + positions=core_metadata.positions, + forward_batch=forward_batch, + token_to_kv_pool=token_to_kv_pool, + alt_streams=alt_streams, + q_lora_ready=q_lora_ready, + ) + else: + assert q_lora_ready is None + q_fp8, weights, c4_indexer_kv_cache = self._forward_prepare_normal( + x=x, + q_lora=q_lora, + c4_indexer=c4_indexer, + positions=core_metadata.positions, + forward_batch=forward_batch, + token_to_kv_pool=token_to_kv_pool, + ) + + assert len(q_fp8.shape) == 3 + q_fp8 = q_fp8.unsqueeze(1) + assert len(c4_indexer_kv_cache.shape) == 2 + block_kv = 64 + num_heads_kv = 1 + head_dim_with_sf = 132 + + c4_indexer_kv_cache = c4_indexer_kv_cache.view( + c4_indexer_kv_cache.shape[0], block_kv, num_heads_kv, head_dim_with_sf + ) + assert len(weights.shape) == 3 + weights = weights.squeeze(2) + if envs.SGLANG_OPT_USE_TILELANG_INDEXER.get(): + from sglang.srt.layers.attention.nsa.tilelang_kernel import ( + tilelang_fp8_paged_mqa_logits as fn, + ) + elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + fn = fp8_paged_mqa_logits_torch + else: + if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1: + from sglang.srt.layers.deep_gemm_wrapper.paged_mqa_logits import ( + fp8_paged_mqa_logits_chunked as fn, + ) + else: + from deep_gemm import fp8_paged_mqa_logits as fn + + _c4sl = indexer_metadata.c4_seq_lens + if _c4sl.dim() == 1: + _c4sl = _c4sl.unsqueeze(-1) + logits = fn( + q_fp8, + c4_indexer_kv_cache, + weights, + _c4sl, + indexer_metadata.page_table, + indexer_metadata.deep_gemm_metadata, + indexer_metadata.max_c4_seq_len, + False, + ) + + assert indexer_metadata.page_table is core_metadata.page_table + if self.debug_use_external_c4_sparse_indices: + return + + indexer_capturer = get_global_indexer_capturer() + capture_enabled = indexer_capturer.is_enabled() + + hisparse_coordinator = forward_batch.hisparse_coordinator + hisparse_decode = ( + hisparse_coordinator is not None and forward_batch.forward_mode.is_decode() + ) + + raw_indices = None + if capture_enabled: + raw_indices = torch.empty_like(core_metadata.c4_sparse_page_indices) + elif hisparse_decode: + raw_indices = hisparse_coordinator.raw_indices_buffer[ + : core_metadata.c4_sparse_page_indices.size(0) + ] + + if envs.SGLANG_TOPK_TRANSFORM_512_TORCH.get(): + topk_transform_512_pytorch_vectorized( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + raw_indices, + ) + elif envs.SGLANG_OPT_USE_TOPK_V2.get() and raw_indices is None: + topk_transform_512_v2( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + indexer_metadata.topk_metadata, + ) + else: + topk_transform_512( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + raw_indices, + ) + if hisparse_coordinator is not None: + if hisparse_decode: + compress_layer_id = token_to_kv_pool.layer_mapping[ + c4_indexer.layer_id + ].compress_layer_id + core_metadata.c4_sparse_page_indices = ( + hisparse_coordinator.swap_in_selected_pages( + req_pool_indices=forward_batch.req_pool_indices, + compressed_seq_lens=indexer_metadata.c4_seq_lens, + top_k_result=raw_indices, + layer_id=compress_layer_id, + ) + ) + else: + core_metadata.c4_sparse_page_indices = token_to_kv_pool.c4_kv_pool.translate_loc_from_compressed_to_hisparse_device( + core_metadata.c4_sparse_page_indices + ) + + if capture_enabled: + compress_layer_id = token_to_kv_pool.layer_mapping[ + c4_indexer.layer_id + ].compress_layer_id + indexer_capturer.capture(compress_layer_id, raw_indices) diff --git a/python/sglang/srt/layers/attention/compressed/metadata.py b/python/sglang/srt/layers/attention/compressed/metadata.py new file mode 100644 index 000000000000..9b866259e462 --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/metadata.py @@ -0,0 +1,291 @@ +from __future__ import annotations + +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING, Any, List, Optional + +import torch + +from sglang.srt.environ import envs +from sglang.srt.utils import is_hip + +if TYPE_CHECKING: + from flash_mla.flash_mla_interface import FlashMLASchedMeta + + +""" +Some comments on the common terms used in DeepSeekV4Backend: + +topk_lengths: + NOTE: TL;DR: topk_lengths == seq_lens + The FlashMLA sparse decode kernel will attend to `k` tokens for each query. + `topk_lengths` indicates how many tokens each query will attend to. + This should be named as `seq_lens`, but we simply follow the naming convention. + +page_table: + The page table indicates which pages each request is assigned to. + Each value in the page table is the page index in the TokenToKVPool. + This page index is irrelevant to the actual `page_size`. + +page_indices: + The real indices used to index into the KV cache. + This can be computed from the `page_table` and `page_size`. + e.g. page_indices[i, j] = page_table[i, j // page_size] * page_size + (j % page_size) + For sparse C4 top-512 attention, the indices will be selected from the C4 page indices. + In implementation, we don't materialize the full C4 `page_indices`, + but calculate them from `page_table` on-the-fly in the attention kernel. + +positions: + The position of the last token for each request. + For compress token, the positions must be times of compress ratio. + For example, for C4, raw_position=11 will trigger a compression, + But the RoPE's position, during compression, must be 8 instead of 11. + +Some other notes: + c4_ / c128_: means "compressed by 4" / "compressed by 128". + c4_page_size: page_size // 4 + c4_seq_lens: seq_lens // 4, but bounded by at least 1, due to flash_mla requirement. + c4_sparse: means "compressed by 4" but only attend to top-512 tokens. + all related length will be clipped to 512. +""" + + +def copy_metadata( + *, + src, + dst, + check_eq_fields: List[str], + copy_fields: List[str], + assign_fields: Optional[List[str]] = None, +): + assign_fields = assign_fields or [] + + for field_name in check_eq_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + assert src_val == dst_val, f"{field_name=} {src_val=} {dst_val=}" + + for field_name in copy_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + assert dst_val is not None, f"{field_name=} {src_val=} {dst_val=}" + dst_val.copy_(src_val) + + for field_name in assign_fields: + setattr(dst, field_name, getattr(src, field_name)) + + provided_fields = check_eq_fields + copy_fields + assign_fields + assert len(provided_fields) == len( + set(provided_fields) + ), f"{provided_fields=} has dup" + all_fields = {f.name for f in fields(src)} + assert set(provided_fields) == all_fields, f"{provided_fields=} {all_fields=}" + + +def create_flashmla_metadata(): + if is_hip(): + return None + else: + import flash_mla + + return flash_mla.get_mla_metadata()[0] + + +@dataclass +class CoreMetadata: + positions: torch.Tensor + swa_slice: Optional[torch.Tensor] + swa_out_loc_sliced: torch.Tensor + c4_out_loc: torch.Tensor + c128_out_loc: torch.Tensor + + def init_swa_slice(self, swa_slice: torch.Tensor): + assert self.swa_slice is None, "can only update once" + self.swa_slice = swa_slice + self.swa_out_loc_sliced = self.swa_out_loc_sliced[swa_slice] + + def copy_(self, other): + raise NotImplementedError + + +@dataclass +class IndexerMetadata: + def copy_(self, other): + raise NotImplementedError + + +@dataclass +class PagedIndexerMetadata(IndexerMetadata): + page_size: int + page_table: torch.Tensor + c4_seq_lens: torch.Tensor + deep_gemm_metadata: Any = field(init=False, repr=False) + topk_metadata: torch.Tensor = field(init=False, repr=False) + + def __post_init__(self): + if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + self.deep_gemm_metadata = None + else: + import deep_gemm + + if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1: + from sglang.srt.layers.deep_gemm_wrapper.paged_mqa_logits import ( + get_paged_mqa_logits_metadata_chunked as get_paged_mqa_logits_metadata, + ) + elif envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get(): + from sglang.jit_kernel.deepseek_v4 import get_paged_mqa_logits_metadata + else: + from deep_gemm import get_paged_mqa_logits_metadata + + _c4 = self.c4_seq_lens.to(torch.int32) + if _c4.dim() == 1: + _c4 = _c4.unsqueeze(-1) + self.deep_gemm_metadata = get_paged_mqa_logits_metadata( + _c4, + self.c4_page_size, + deep_gemm.get_num_sms(), + ) + + if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1: + pass + else: + assert isinstance(self.deep_gemm_metadata, torch.Tensor) + + from sglang.jit_kernel.deepseek_v4 import plan_topk_v2 + + if envs.SGLANG_OPT_USE_TOPK_V2.get(): + self.topk_metadata = plan_topk_v2(self.c4_seq_lens) + else: + self.topk_metadata = torch.empty((0,)) + + assert self.page_size == 256 + + @property + def c4_page_size(self) -> int: + return self.page_size // 4 + + @property + def max_seq_len(self) -> int: + return self.page_table.shape[1] * self.page_size + + @property + def max_c4_seq_len(self) -> int: + return self.page_table.shape[1] * self.c4_page_size + + def copy_(self, other: "PagedIndexerMetadata"): + if is_hip(): + copy_fields = ["page_table", "c4_seq_lens"] + else: + copy_fields = ["page_table", "c4_seq_lens", "deep_gemm_metadata"] + copy_fields += ["topk_metadata"] + copy_metadata( + src=other, + dst=self, + check_eq_fields=["page_size"], + copy_fields=copy_fields, + ) + + +@dataclass +class PagedCoreMetadata(CoreMetadata): + page_table: torch.Tensor + swa_page_indices: torch.Tensor + swa_topk_lengths: torch.Tensor + c128_page_indices: torch.Tensor + c128_topk_lengths_clamp1: torch.Tensor + c4_topk_lengths_raw: torch.Tensor + c4_topk_lengths_clamp1: torch.Tensor + c4_sparse_topk: int + c4_sparse_topk_lengths: torch.Tensor = field(init=False) + c4_sparse_page_indices: torch.Tensor = field(init=False) + c1_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c4_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c128_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + + def get_flashmla_metadata(self, compress_ratio: int): + if compress_ratio == 0: + return self.c1_flashmla_metadata + elif compress_ratio == 4: + return self.c4_flashmla_metadata + elif compress_ratio == 128: + return self.c128_flashmla_metadata + else: + raise ValueError(f"invalid {compress_ratio=}") + + def __post_init__(self): + # c4_sparse_topk is set from model_config.index_topk per-model + # (small model: 512, large model: 1024). + assert self.c4_sparse_topk in (512, 1024), ( + f"unexpected c4_sparse_topk={self.c4_sparse_topk}; " + "supported: 512 (small) or 1024 (large)" + ) + self.c4_sparse_topk_lengths = torch.clamp( + self.c4_topk_lengths_clamp1, max=self.c4_sparse_topk + ) + self.c4_sparse_page_indices = torch.full( + (self.c4_topk_lengths_clamp1.size(0), self.c4_sparse_topk), + -1, + dtype=torch.int32, + device=self.c4_topk_lengths_clamp1.device, + ) + self.c1_flashmla_metadata = create_flashmla_metadata() + self.c4_flashmla_metadata = create_flashmla_metadata() + self.c128_flashmla_metadata = create_flashmla_metadata() + + def copy_(self, other: PagedCoreMetadata) -> None: + copy_metadata( + src=other, + dst=self, + check_eq_fields=["c4_sparse_topk", "swa_slice"], + copy_fields=[ + "positions", + "swa_out_loc_sliced", + "c4_out_loc", + "c128_out_loc", + "page_table", + "swa_page_indices", + "swa_topk_lengths", + "c128_page_indices", + "c128_topk_lengths_clamp1", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c4_sparse_topk_lengths", + "c4_sparse_page_indices", + ], + assign_fields=[ + "c1_flashmla_metadata", + "c4_flashmla_metadata", + "c128_flashmla_metadata", + ], + ) + + + + +@dataclass +class RaggedCoreMetadata(CoreMetadata): + swa_ragged_indices: torch.Tensor + swa_c4_ragged_indices: torch.Tensor + swa_c128_ragged_indices: torch.Tensor + + +@dataclass +class RaggedIndexerMetadata(IndexerMetadata): + c4_k_start: torch.Tensor + c4_k_finish: torch.Tensor + + +@dataclass +class DeepseekV4Metadata: + core_metadata: CoreMetadata + indexer_metadata: IndexerMetadata + debug_seq_lens_expanded: torch.Tensor + + def copy_(self, other: "DeepseekV4Metadata"): + self.core_metadata.copy_(other.core_metadata) + self.indexer_metadata.copy_(other.indexer_metadata) + + +def maybe_copy_inplace(dst, *, src) -> None: + assert type(src) == type(dst) + if dst is not None: + dst.copy_(src) diff --git a/python/sglang/srt/layers/attention/compressed/paged_prefill.py b/python/sglang/srt/layers/attention/compressed/paged_prefill.py new file mode 100644 index 000000000000..f60d8e7fa6bc --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/paged_prefill.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +import torch + +from sglang.jit_kernel.deepseek_v4 import tilelang_make_swa_prefill_indices +from sglang.srt.environ import envs +from sglang.srt.layers.attention.nsa import index_buf_accessor_v4 +from sglang.srt.layers.attention.nsa.quant_k_cache_v4 import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.utils import ceil_align + +if TYPE_CHECKING: + from sglang.srt.layers.attention.compressed.metadata import PagedCoreMetadata + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +_HOST_INT32_KWARGS: Dict = dict(dtype=torch.int32, device="cpu", pin_memory=True) + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +def expand_seq_lens( + *, + seq_lens: List[int], + extend_seq_lens: List[int], + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + num_tokens = sum(extend_seq_lens) + seq_lens_expanded = torch.empty(num_tokens, **_HOST_INT32_KWARGS) + expanded_idx_to_unexpanded_idx = torch.empty(num_tokens, **_HOST_INT32_KWARGS) + offset = 0 + for i, (kv_len, qo_len) in enumerate(zip(seq_lens, extend_seq_lens)): + out = seq_lens_expanded[offset : offset + qo_len] + offset += qo_len + torch.arange(kv_len - qo_len + 1, kv_len + 1, out=out) + expanded_idx_to_unexpanded_idx[offset - qo_len : offset].fill_(i) + return ( + seq_lens_expanded.to(device, non_blocking=True), + expanded_idx_to_unexpanded_idx.to(device, non_blocking=True), + ) + + + + +def make_swa_ring_buffer_indices( + forward_batch: ForwardBatch, + device: torch.device, + *, + max_seq_len: int, + swa_window_size: int, +) -> torch.Tensor: + SWA_WINDOW = swa_window_size + extend_num_tokens = forward_batch.extend_num_tokens + assert extend_num_tokens is not None + if envs.SGLANG_OPT_USE_TILELANG_SWA_PREPARE.get(): + seq_lens = forward_batch.seq_lens + extend_lens = forward_batch.extend_seq_lens + assert extend_lens is not None + seq_lens_k = seq_lens.to(torch.int32) + seq_lens_q = extend_lens.to(torch.int32) + swa_indices = torch.empty( + (extend_num_tokens, SWA_WINDOW), device=device, dtype=torch.int32 + ) + return tilelang_make_swa_prefill_indices( + seq_lens_k=seq_lens_k, + seq_lens_q=seq_lens_q, + swa_indices=swa_indices, + ) + seq_lens = forward_batch.seq_lens_cpu + extend_lens = forward_batch.extend_seq_lens_cpu + assert seq_lens is not None and extend_lens is not None + batch_size = len(seq_lens) + num_tokens = extend_num_tokens + swa_indices = torch.full((num_tokens, swa_window_size), -1, **_HOST_INT32_KWARGS) + cum_qo_len = 0 + abs_pos_buf = torch.arange(max_seq_len, dtype=torch.int32) + for seq_idx, (kv_len, qo_len) in enumerate(zip(seq_lens.tolist(), extend_lens)): + old_kv_start = seq_idx * SWA_WINDOW + new_kv_start = batch_size * SWA_WINDOW + cum_qo_len + prefix_len = kv_len - qo_len + for curr_seq_qo_idx in range(qo_len): + end_abs_pos = prefix_len + curr_seq_qo_idx + 1 + start_abs_pos = max(end_abs_pos - SWA_WINDOW, 0) + chosen_abs_positions = abs_pos_buf[start_abs_pos:end_abs_pos] + torch.where( + chosen_abs_positions < prefix_len, + old_kv_start + chosen_abs_positions % SWA_WINDOW, + new_kv_start + (chosen_abs_positions - prefix_len), + out=swa_indices[ + cum_qo_len + curr_seq_qo_idx, : end_abs_pos - start_abs_pos + ], + ) + cum_qo_len += qo_len + return swa_indices.to(device, non_blocking=True) + + +def prepare_swa_ring_buffer_cache( + swa_k: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + core_metadata: PagedCoreMetadata, + debug_dump_hook: Any, +) -> Tuple[torch.Tensor, index_buf_accessor_v4.NopeFp8RopeBf16Pack]: + + pool_swa_k_cache = token_to_kv_pool.get_swa_key_buffer(layer_id) + num_pool_pages = forward_batch.batch_size + num_newly_gen_tokens, _ = swa_k.shape + + swa_kv_pool = token_to_kv_pool.swa_kv_pool + swa_page_size = swa_kv_pool.page_size + assert swa_page_size == 128 + effective_swa_k_cache = swa_kv_pool.create_buffer( + num_pages=num_pool_pages + ceil_align(num_newly_gen_tokens, swa_page_size), + ) + + loc_swa = forward_batch.req_pool_indices + assert loc_swa.shape[0] == forward_batch.batch_size == num_pool_pages + effective_swa_k_cache[:num_pool_pages, :] = pool_swa_k_cache[loc_swa, :].view( + effective_swa_k_cache.dtype + ) + + swa_k_pack = quant_to_nope_fp8_rope_bf16_pack_triton(swa_k) + offset = num_pool_pages * swa_page_size + loc_newly_gen = torch.arange( + offset, + offset + num_newly_gen_tokens, + device=loc_swa.device, + ) + index_buf_accessor_v4.SetKAndS.execute( + pool=swa_kv_pool, + buf=effective_swa_k_cache, + loc=loc_newly_gen, + nope_fp8_rope_bf16_pack=swa_k_pack, + ) + + if h := debug_dump_hook: + h( + "forward__swa_info", + dict( + loc_swa=loc_swa, + loc_newly_gen=loc_newly_gen, + ), + ) + + return effective_swa_k_cache, swa_k_pack.slice_pack(core_metadata.swa_slice) diff --git a/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py b/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py new file mode 100644 index 000000000000..10ae2c3ba38e --- /dev/null +++ b/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py @@ -0,0 +1,5 @@ +def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): + assert backend == "kernel", f"unsupported backend {backend!r}" + import flash_mla + + return flash_mla.flash_mla_with_kvcache(**kwargs) diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend_radix.py b/python/sglang/srt/layers/attention/deepseek_v4_backend_radix.py new file mode 100644 index 000000000000..83748b17752d --- /dev/null +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend_radix.py @@ -0,0 +1,1341 @@ +from __future__ import annotations + +import dataclasses +import functools +import logging +import warnings +from dataclasses import dataclass, field +from typing import ( + TYPE_CHECKING, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, +) + +import torch +import torch.nn.functional as F + +from sglang.srt.environ import envs +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.compressed.compressor import ( + CompressorBackend, + FusedCompressMetadata, + create_paged_compressor_data, +) +from sglang.srt.layers.attention.compressed.indexer import C4IndexerBackend +from sglang.srt.layers.attention.compressed.metadata import ( + PagedIndexerMetadata, + maybe_copy_inplace, +) +from sglang.srt.layers.attention.debug_flash_mla_adapter import ( + flash_mla_with_kvcache_entrypoint, +) +from sglang.srt.layers.attention.nsa.quant_k_cache_v4 import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_round_robin_split +from sglang.srt.layers.attention.triton_ops.compressed_metadata import ( + init_compressed_metadata as _init_compressed_metadata_triton, +) +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.spec_info import SpecInput +from sglang.srt.utils import ceil_align + +if TYPE_CHECKING: + from flash_mla.flash_mla_interface import FlashMLASchedMeta + + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + +SWA_WINDOW = 128 +C4_TOPK = 512 +PAGE_INDEX_ALIGNED_SIZE = 64 + + +T = TypeVar("T", bound=Optional[torch.Tensor]) + + +def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T: + if x is None: + return None + curr_size = x.shape[-1] + target_size = ceil_align(curr_size, multiples_of) + return F.pad(x, pad=(0, target_size - curr_size), mode="constant", value=-1) + + +def _copy_metadata( + src, + dst, + check_eq_fields: List[str], + copy_fields: List[str], + assign_fields: Optional[List[str]] = None, +): + assign_fields = assign_fields or [] + + for field_name in check_eq_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + assert src_val == dst_val, f"{field_name=} {src_val=} {dst_val=}" + + for field_name in copy_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + if src_val is None and dst_val is None: + continue + assert dst_val is not None, f"{field_name=} {src_val=} {dst_val=}" + if hasattr(dst_val, "copy_"): + dst_val.copy_(src_val) + else: + warnings.warn( + f"{field_name=} {type(dst_val)=} does not have copy_, use setattr" + ) + setattr(dst, field_name, src_val) + + for field_name in assign_fields: + setattr(dst, field_name, getattr(src, field_name)) + + provided_fields = check_eq_fields + copy_fields + assign_fields + provided_fields_unique = set(provided_fields) + assert len(provided_fields) == len( + provided_fields_unique + ), f"{provided_fields=} has dup" + all_fields = {f.name for f in dataclasses.fields(src)} + provided_fields = set(provided_fields) + assert ( + provided_fields == all_fields + ), f"{provided_fields - all_fields=}, {all_fields - provided_fields=}" + + +def _create_flashmla_metadata(): + import flash_mla + + return flash_mla.get_mla_metadata()[0] + + +def _create_dummy_paged_compress_data(compress_ratio: int): + return None + + +@dataclass +class DSV4AttnMetadataRadix: + page_size: int + page_table: torch.Tensor + raw_out_loc: torch.Tensor + cuda_int32_kwargs: dict + + seq_lens_casual: torch.Tensor + positions_casual: torch.Tensor + + swa_page_indices: torch.Tensor + swa_topk_lengths: torch.Tensor + + c4_sparse_topk: int + c4_out_loc: Optional[torch.Tensor] = None + c4_positions: Optional[torch.Tensor] = None + c4_topk_lengths_raw: Optional[torch.Tensor] = None + c4_topk_lengths_clamp1: Optional[torch.Tensor] = None + c4_sparse_topk_lengths: torch.Tensor = field(init=False) + c4_sparse_page_indices: torch.Tensor = field(init=False) + + c128_out_loc: Optional[torch.Tensor] = None + c128_positions: Optional[torch.Tensor] = None + c128_page_indices: Optional[torch.Tensor] = None + c128_topk_lengths_clamp1: Optional[torch.Tensor] = None + + c1_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c4_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c128_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + + @property + def positions(self) -> torch.Tensor: + return self.positions_casual + + def get_flashmla_metadata(self, compress_ratio: Literal[0, 4, 128]): + if compress_ratio == 0: + return self.c1_flashmla_metadata + elif compress_ratio == 4: + return self.c4_flashmla_metadata + elif compress_ratio == 128: + return self.c128_flashmla_metadata + else: + raise ValueError(f"invalid {compress_ratio=}") + + def copy_(self, other: DSV4AttnMetadataRadix) -> None: + _copy_metadata( + src=other, + dst=self, + check_eq_fields=[ + "c4_sparse_topk", + "page_size", + "cuda_int32_kwargs", + ], + copy_fields=[ + "raw_out_loc", + "seq_lens_casual", + "positions_casual", + "c4_positions", + "c128_positions", + "c4_out_loc", + "c128_out_loc", + "page_table", + "swa_page_indices", + "swa_topk_lengths", + "c128_page_indices", + "c128_topk_lengths_clamp1", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c4_sparse_topk_lengths", + "c4_sparse_page_indices", + ], + assign_fields=[ + "c1_flashmla_metadata", + "c4_flashmla_metadata", + "c128_flashmla_metadata", + ], + ) + + def init_compressed_metadata(self): + assert self.page_table.dim() == 2 + assert ( + self.raw_out_loc.shape == self.seq_lens_casual.shape + ), f"{self.raw_out_loc.shape=}, {self.seq_lens_casual.shape=}" + + ( + self.c4_out_loc, + self.c4_positions, + self.c4_topk_lengths_raw, + self.c4_topk_lengths_clamp1, + self.c128_out_loc, + self.c128_positions, + self.c128_topk_lengths_clamp1, + self.c128_page_indices, + ) = _init_compressed_metadata_triton( + self.seq_lens_casual, + self.positions_casual, + self.raw_out_loc, + self.page_table, + self.page_size, + compute_page_indices=True, + ) + + self.c128_page_indices = _pad_last_dim(self.c128_page_indices) + self.swa_page_indices = _pad_last_dim(self.swa_page_indices) + + _CP_REINDEX_FIELDS = [ + "seq_lens_casual", + "positions_casual", + "swa_page_indices", + "swa_topk_lengths", + "page_table", + "c4_positions", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c128_positions", + "c128_page_indices", + "c128_topk_lengths_clamp1", + ] + _CP_GLOBAL_FIELDS = [ + "raw_out_loc", + "c4_out_loc", + "c128_out_loc", + ] + + def apply_cp_reindex(self) -> None: + cp_rank = get_attention_tp_rank() + cp_size = get_attention_tp_size() + idx = slice(cp_rank, None, cp_size) + pre_global_len = self.seq_lens_casual.shape[0] + assert pre_global_len % cp_size == 0, ( + f"apply_cp_reindex: global token count {pre_global_len} is not divisible by cp_size={cp_size}. " + "CP round-robin requires padding to ensure divisibility." + ) + expected_local_len = pre_global_len // cp_size + for field_name in self._CP_REINDEX_FIELDS: + val = getattr(self, field_name, None) + assert isinstance( + val, torch.Tensor + ), f"CP reindex: {field_name} is {type(val)}, expected Tensor" + setattr(self, field_name, val[idx].contiguous()) + + for field_name in self._CP_REINDEX_FIELDS: + val = getattr(self, field_name) + assert val.shape[0] == expected_local_len, ( + f"apply_cp_reindex post-condition: {field_name}.shape[0]={val.shape[0]} " + f"!= expected_local_len={expected_local_len} (cp_size={cp_size})" + ) + for field_name in self._CP_GLOBAL_FIELDS: + val = getattr(self, field_name, None) + if val is None: + continue + assert val.shape[0] == pre_global_len, ( + f"apply_cp_reindex post-condition: global field {field_name}.shape[0]={val.shape[0]} " + f"!= pre_global_len={pre_global_len} (must remain global for compressor write path)" + ) + + def init_flashmla_related(self): + # c4_sparse_topk is set from model_config.index_topk per-model + # (small model: 512, large model: 1024). + assert self.c4_sparse_topk in (512, 1024), ( + f"unexpected c4_sparse_topk={self.c4_sparse_topk}; " + "supported: 512 (small) or 1024 (large)" + ) + assert self.c4_topk_lengths_clamp1 is not None + self.c4_sparse_topk_lengths = torch.clamp( + self.c4_topk_lengths_clamp1, max=self.c4_sparse_topk + ) + self.c4_sparse_page_indices = torch.full( + (self.c4_topk_lengths_clamp1.size(0), self.c4_sparse_topk), + -1, + dtype=torch.int32, + device=self.c4_topk_lengths_clamp1.device, + ) + self.c4_sparse_page_indices = _pad_last_dim(self.c4_sparse_page_indices) + self.c1_flashmla_metadata = _create_flashmla_metadata() + self.c4_flashmla_metadata = _create_flashmla_metadata() + self.c128_flashmla_metadata = _create_flashmla_metadata() + + +@dataclass +class DSV4MetadataRadix: + core_attn_metadata: DSV4AttnMetadataRadix + indexer_metadata: Optional[PagedIndexerMetadata] + + c4_compress_metadata: Optional[FusedCompressMetadata] = None + c128_compress_metadata: Optional[FusedCompressMetadata] = None + + @property + def core_metadata(self) -> DSV4AttnMetadataRadix: + return self.core_attn_metadata + + def copy_(self, other: DSV4MetadataRadix): + self.core_attn_metadata.copy_(other.core_attn_metadata) + maybe_copy_inplace(self.indexer_metadata, src=other.indexer_metadata) + maybe_copy_inplace(self.c4_compress_metadata, src=other.c4_compress_metadata) + maybe_copy_inplace( + self.c128_compress_metadata, src=other.c128_compress_metadata + ) + + +@dataclass +class DSV4MetadataRawVerify: + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + out_cache_loc: torch.Tensor + + extend_seq_lens: Optional[torch.Tensor] = None + real_metadata: Optional[DSV4MetadataRadix] = None + + def copy_(self, other: DSV4MetadataRawVerify): + self.req_pool_indices.copy_(other.req_pool_indices) + self.seq_lens.copy_(other.seq_lens) + self.out_cache_loc.copy_(other.out_cache_loc) + + self.extend_seq_lens = other.extend_seq_lens + + +@dataclass +class DSV4MetadataRawDecode: + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + out_cache_loc: torch.Tensor + + real_metadata: Optional[DSV4MetadataRadix] = None + + def copy_(self, other: DSV4MetadataRawDecode): + self.req_pool_indices.copy_(other.req_pool_indices) + self.seq_lens.copy_(other.seq_lens) + self.out_cache_loc.copy_(other.out_cache_loc) + + +_DSV4_RAW_TYPES = (DSV4MetadataRawVerify, DSV4MetadataRawDecode) + + +@dataclass +class _DecodeCudaGraphSharedData: + pass + + +class DeepseekV4BackendRadix(AttentionBackend, C4IndexerBackend, CompressorBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + speculative_step_id=0, + topk=0, + speculative_num_steps=0, + ): + super().__init__() + self.device = torch.device(model_runner.device) + head_dim = model_runner.model_config.head_dim + assert head_dim == 512 + self.softmax_scale: float = head_dim**-0.5 + self.head_dim_v: int = model_runner.model_config.v_head_dim + self.cuda_int32_kwargs = {"device": self.device, "dtype": torch.int32} + self.debug_dump_hook: Optional[Callable] = None + self.swa_page_size = 128 + assert model_runner.page_size is not None + assert model_runner.req_to_token_pool is not None + self.page_size = model_runner.page_size + assert self.page_size == 256, "the system hardcodes page_size=256" + + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool: DeepSeekV4TokenToKVPool = model_runner.token_to_kv_pool + self.MAX_SEQ_LEN_FOR_CAPTURE = self.req_to_token.shape[1] + + assert isinstance(self.token_to_kv_pool, DeepSeekV4TokenToKVPool) + self.c4_topk = getattr( + model_runner.model_config.hf_text_config, "index_topk", C4_TOPK + ) + + self.topk = model_runner.server_args.speculative_eagle_topk or 0 + assert self.topk in [0, 1], "MTP Topk > 1 not supported for DeepSeek V4" + self.mtp_enabled = self.topk > 0 + self.speculative_num_steps = speculative_num_steps + self.speculative_num_draft_tokens: int = ( + model_runner.server_args.speculative_num_draft_tokens + ) + self.speculative_step_id = speculative_step_id + self.forward_metadata: Union[ + DSV4MetadataRadix, + DSV4MetadataRawVerify, + DSV4MetadataRawDecode, + ] = None + + def _move_to_device(self, x: List[int]) -> torch.Tensor: + pin_tensor = torch.tensor(x, dtype=torch.int32, pin_memory=True) + return pin_tensor.to(self.device, non_blocking=True) + + def init_forward_metadata_indexer(self, core_attn_metadata: DSV4AttnMetadataRadix): + return PagedIndexerMetadata( + page_size=self.page_size, + page_table=core_attn_metadata.page_table, + c4_seq_lens=core_attn_metadata.c4_topk_lengths_raw, + ) + + def init_forward_metadata_decode( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + ) -> Union[DSV4MetadataRadix, DSV4MetadataRawDecode]: + assert ( + req_pool_indices.shape[0] == seq_lens.shape[0] == out_cache_loc.shape[0] + ), f"{req_pool_indices.shape=} {seq_lens.shape=} {out_cache_loc.shape=}" + + if envs.SGLANG_PREP_IN_CUDA_GRAPH.get(): + return DSV4MetadataRawDecode( + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + ) + + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices, + seq_lens_casual=seq_lens, + max_seq_len=max_seq_len, + out_loc=out_cache_loc, + need_compress=True, + ) + + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + + create = functools.partial( + create_paged_compressor_data, + is_prefill=False, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) + + return DSV4MetadataRadix( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_prefill( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: List[int], + out_cache_loc: torch.Tensor, + num_tokens: int, + extend_seq_lens: torch.Tensor, + extend_seq_lens_cpu: List[int], + need_compress: bool = True, + use_prefill_cuda_graph: bool = False, + ) -> DSV4MetadataRadix: + seq_lens_casual, req_pool_indices_repeated = self.expand_prefill_casually( + num_tokens=num_tokens, + seq_lens=seq_lens_cpu, + extend_seq_lens=extend_seq_lens_cpu, + req_pool_indices=req_pool_indices, + padded_num_tokens=out_cache_loc.shape[0], + ) + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices_repeated, + seq_lens_casual=seq_lens_casual, + max_seq_len=max_seq_len, + out_loc=out_cache_loc, + need_compress=need_compress, + is_prefill=True, + ) + indexer_metadata = ( + self.init_forward_metadata_indexer(core_attn_metadata) + if need_compress + else None + ) + if not need_compress: + create = _create_dummy_paged_compress_data + else: + create = functools.partial( + create_paged_compressor_data, + is_prefill=True, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + extend_lens=extend_seq_lens, + extend_lens_cpu=extend_seq_lens_cpu, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + return DSV4MetadataRadix( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_target_verify( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> Union[DSV4MetadataRadix, DSV4MetadataRawVerify]: + if envs.SGLANG_PREP_IN_CUDA_GRAPH.get(): + assert out_cache_loc is not None + if not hasattr(self, "extend_seq_lens_buffer"): + self.extend_seq_lens_buffer = torch.tensor( + [self.speculative_num_draft_tokens] * 1025, device=self.device + ) + extend_seq_lens = self.extend_seq_lens_buffer[: len(seq_lens)] + + return DSV4MetadataRawVerify( + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + extend_seq_lens=extend_seq_lens, + ) + else: + seq_lens_cpu = seq_lens.tolist() + return self.init_forward_metadata_target_verify_old( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def init_forward_metadata_target_verify_old( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[List[int]] = None, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> DSV4MetadataRadix: + batch_size = len(seq_lens) + seq_lens = seq_lens + self.speculative_num_draft_tokens + seq_lens_cpu = [x + self.speculative_num_draft_tokens for x in seq_lens_cpu] + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size + extend_seq_lens = self._move_to_device(extend_seq_lens_cpu) + num_tokens = self.speculative_num_draft_tokens * batch_size + if out_cache_loc is None: + out_cache_loc = seq_lens.new_zeros(num_tokens) + return self.init_forward_metadata_prefill( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + num_tokens=num_tokens, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=True, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def make_forward_metadata_from_raw_verify( + self, raw_metadata: DSV4MetadataRawVerify + ) -> DSV4MetadataRadix: + req_pool_indices = raw_metadata.req_pool_indices + seq_lens = raw_metadata.seq_lens + out_cache_loc = raw_metadata.out_cache_loc + + bs, num_draft_tokens = len(seq_lens), self.speculative_num_draft_tokens + seq_lens = seq_lens + self.speculative_num_draft_tokens + extend_seq_lens = raw_metadata.extend_seq_lens + + seq_lens_casual, req_pool_indices_repeated = ( + self.expend_extend_with_same_length( + bs, num_draft_tokens, seq_lens, req_pool_indices + ) + ) + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices_repeated, + seq_lens_casual=seq_lens_casual, + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + out_loc=out_cache_loc, + need_compress=True, + ) + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + create = functools.partial( + create_paged_compressor_data, + is_prefill=True, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_lens=extend_seq_lens, + seq_lens_cpu=None, + extend_lens_cpu=None, + use_prefill_cuda_graph=True, + num_q_tokens=num_draft_tokens * bs, + ) + return DSV4MetadataRadix( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def make_forward_metadata_from_raw_decode( + self, raw_metadata: DSV4MetadataRawDecode + ) -> DSV4MetadataRadix: + req_pool_indices = raw_metadata.req_pool_indices + seq_lens = raw_metadata.seq_lens + out_cache_loc = raw_metadata.out_cache_loc + + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices, + seq_lens_casual=seq_lens, + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + out_loc=out_cache_loc, + need_compress=True, + ) + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + + create = functools.partial( + create_paged_compressor_data, + is_prefill=False, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) + + return DSV4MetadataRadix( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_draft_extend( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: List[int], + num_tokens_per_bs: int, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> DSV4MetadataRadix: + batch_size = len(seq_lens) + extend_seq_lens_cpu = [num_tokens_per_bs] * batch_size + extend_seq_lens = self._move_to_device(extend_seq_lens_cpu) + num_tokens = num_tokens_per_bs * batch_size + if out_cache_loc is None: + out_cache_loc = seq_lens.new_zeros(num_tokens) + return self.init_forward_metadata_prefill( + seq_lens=seq_lens, + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + num_tokens=num_tokens, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=False, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch) -> None: + if self.mtp_enabled and forward_batch.forward_mode.is_idle(): + return + + req_pool_indices = forward_batch.req_pool_indices + seq_lens = forward_batch.seq_lens.to(torch.int32) + seq_lens_cpu = forward_batch.seq_lens_cpu + assert forward_batch.req_to_token_pool.req_to_token is self.req_to_token + + assert self.swa_page_size % SWA_WINDOW == 0 and self.page_size % 128 == 0 + assert seq_lens_cpu is not None + max_seq_len = int(seq_lens_cpu.max().item()) + + if forward_batch.forward_mode.is_decode_or_idle(): + metadata = self.init_forward_metadata_decode( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=forward_batch.out_cache_loc, + ) + elif forward_batch.forward_mode.is_target_verify(): + metadata = self.init_forward_metadata_target_verify( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=forward_batch.out_cache_loc, + ) + elif forward_batch.forward_mode.is_prefill(include_draft_extend_v2=True): + extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + extend_seq_lens = forward_batch.extend_seq_lens + assert ( + seq_lens is not None + and seq_lens_cpu is not None + and extend_seq_lens is not None + and extend_seq_lens_cpu is not None + ) + is_draft = forward_batch.forward_mode.is_draft_extend(include_v2=True) + metadata = self.init_forward_metadata_prefill( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu.tolist(), + out_cache_loc=forward_batch.out_cache_loc, + num_tokens=sum(extend_seq_lens_cpu), + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=not is_draft, + ) + else: + raise NotImplementedError(f"unsupported mode {forward_batch.forward_mode=}") + + self.forward_metadata = metadata + if h := self.debug_dump_hook: + h("init_forward_metadata_output", metadata) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + self.decode_cuda_graph_shared_data = _DecodeCudaGraphSharedData() + self.decode_cuda_graph_metadata_of_bs: Dict[ + int, Union[DSV4MetadataRadix, DSV4MetadataRawDecode] + ] = {} + self.target_verify_cuda_graph_metadata_of_bs: Dict[ + int, Union[DSV4MetadataRadix, DSV4MetadataRawVerify] + ] = {} + self.draft_extend_cuda_graph_metadata_of_bs: Dict[int, DSV4MetadataRadix] = {} + self.draft_extend_num_tokens_per_bs = ( + max_num_tokens // max_bs if max_bs > 0 else 1 + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ) -> None: + assert req_pool_indices.size(0) == bs + assert seq_lens.size(0) == bs + + if forward_mode.is_decode_or_idle(): + + metadata = self.init_forward_metadata_decode( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=torch.zeros_like(seq_lens), + ) + + self.decode_cuda_graph_metadata_of_bs[bs] = metadata + self.forward_metadata = metadata + + self._current_capture_raw = ( + metadata if isinstance(metadata, DSV4MetadataRawDecode) else None + ) + elif forward_mode.is_target_verify(): + out_cache_loc = torch.zeros(num_tokens, **self.cuda_int32_kwargs) + metadata = self.init_forward_metadata_target_verify( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + use_prefill_cuda_graph=True, + ) + self.target_verify_cuda_graph_metadata_of_bs[bs] = metadata + self.forward_metadata = metadata + + self._current_capture_raw = ( + metadata if isinstance(metadata, DSV4MetadataRawVerify) else None + ) + elif forward_mode.is_draft_extend(include_v2=True): + num_tokens_per_bs = num_tokens // bs + metadata = self.init_forward_metadata_draft_extend( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.tolist(), + num_tokens_per_bs=num_tokens_per_bs, + use_prefill_cuda_graph=True, + ) + self.draft_extend_cuda_graph_metadata_of_bs[bs] = metadata + self.forward_metadata = metadata + else: + raise NotImplementedError(f"{forward_mode=} not supported yet") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + actual_forward_mode: Optional[ForwardMode] = None, + ) -> None: + if actual_forward_mode == ForwardMode.IDLE and envs.SGLANG_FIX_PD_IDLE.get(): + logger.debug( + f"[IDLE replay] bs={bs}, " + f"local_seq_lens_len={len(seq_lens)}, " + f"has_graph={bs in self.decode_cuda_graph_metadata_of_bs}" + ) + device = seq_lens.device + seq_lens = torch.ones(bs, dtype=seq_lens.dtype, device=device) + seq_lens_cpu = torch.ones(bs, dtype=torch.int64) + seq_lens_sum = bs + req_pool_indices = torch.zeros( + bs, dtype=req_pool_indices.dtype, device=device + ) + out_cache_loc = torch.zeros(bs, dtype=torch.int64, device=device) + + assert seq_lens_cpu is not None + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + req_pool_indices = req_pool_indices[:bs] + + if forward_mode.is_decode_or_idle(): + assert out_cache_loc is not None + actual_max_seq_len = seq_lens_cpu.max().item() + + chosen_max_seq_len = self.MAX_SEQ_LEN_FOR_CAPTURE + assert actual_max_seq_len <= chosen_max_seq_len + + assert len(out_cache_loc.shape) == 1, f"{out_cache_loc.shape=}" + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, bs - len(out_cache_loc)), + mode="constant", + value=0, + ) + + temp_metadata = self.init_forward_metadata_decode( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc_padded, + ) + + chosen_metadata = self.decode_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_target_verify(): + assert out_cache_loc is not None + actual_max_seq_len = seq_lens_cpu.max().item() + chosen_max_seq_len = self.MAX_SEQ_LEN_FOR_CAPTURE + assert actual_max_seq_len <= chosen_max_seq_len + num_tokens = self.speculative_num_draft_tokens * bs + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, num_tokens - len(out_cache_loc)), + mode="constant", + value=0, + ) + temp_metadata = self.init_forward_metadata_target_verify( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc_padded, + use_prefill_cuda_graph=True, + ) + chosen_metadata = self.target_verify_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_draft_extend(include_v2=True): + actual_max_seq_len = seq_lens_cpu.max().item() + chosen_max_seq_len = self.MAX_SEQ_LEN_FOR_CAPTURE + assert actual_max_seq_len <= chosen_max_seq_len + num_tokens_per_bs = self.draft_extend_num_tokens_per_bs + temp_metadata = self.init_forward_metadata_draft_extend( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu.tolist(), + num_tokens_per_bs=num_tokens_per_bs, + use_prefill_cuda_graph=True, + ) + chosen_metadata = self.draft_extend_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + else: + raise NotImplementedError + + def replay_cuda_graph_metadata_from( + self, + bs: int, + temp_metadata: Union[ + DSV4MetadataRadix, + DSV4MetadataRawVerify, + DSV4MetadataRawDecode, + ], + forward_mode: ForwardMode, + ) -> None: + if forward_mode.is_decode_or_idle(): + chosen_metadata = self.decode_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_target_verify(): + chosen_metadata = self.target_verify_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_draft_extend(include_v2=True): + chosen_metadata = self.draft_extend_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + else: + raise NotImplementedError + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def on_after_cuda_graph_warmup_pass(self): + metadata = self.forward_metadata + if isinstance(metadata, DSV4MetadataRadix) and isinstance( + metadata.core_attn_metadata, DSV4AttnMetadataRadix + ): + metadata.core_attn_metadata.c1_flashmla_metadata = ( + _create_flashmla_metadata() + ) + metadata.core_attn_metadata.c4_flashmla_metadata = ( + _create_flashmla_metadata() + ) + metadata.core_attn_metadata.c128_flashmla_metadata = ( + _create_flashmla_metadata() + ) + + current_raw = getattr(self, "_current_capture_raw", None) + if current_raw is not None: + self.forward_metadata = current_raw + + def store_cache( + self, layer_id: int, swa_k: torch.Tensor, forward_batch: ForwardBatch + ) -> None: + raw_loc = forward_batch.out_cache_loc + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + self.token_to_kv_pool.set_swa_key_buffer_radix_fused( + layer_id=layer_id, + raw_loc=raw_loc, + cache_k=swa_k, + ) + else: + swa_k_pack = quant_to_nope_fp8_rope_bf16_pack_triton(swa_k) + self.token_to_kv_pool.set_swa_key_buffer_radix( + layer_id=layer_id, + raw_loc=raw_loc, + cache_nope_fp8_rope_bf16_pack=swa_k_pack, + ) + + def _maybe_upgrade_forward_metadata(self) -> None: + # With SGLANG_PREP_IN_CUDA_GRAPH=1, init_forward_metadata_* + # returns a Raw metadata that only carries a few tensors. The + # full Radix metadata (including c4/c128 compress + core_attn + + # indexer metadata) must be materialized before any caller that + # touches those fields. For 1.6T the first two layers have + # compress_ratio=128, so forward_core_compressor / forward_c4_indexer + # can fire before attn_backend.forward(), and must trigger the + # upgrade themselves. + if isinstance(self.forward_metadata, DSV4MetadataRawVerify): + self.forward_metadata = self.make_forward_metadata_from_raw_verify( + raw_metadata=self.forward_metadata, + ) + elif isinstance(self.forward_metadata, DSV4MetadataRawDecode): + self.forward_metadata = self.make_forward_metadata_from_raw_decode( + raw_metadata=self.forward_metadata, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + compress_ratio: Literal[0, 4, 128], + save_kv_cache: bool = True, + attn_sink: Optional[torch.Tensor] = None, + **_, + ) -> torch.Tensor: + self._maybe_upgrade_forward_metadata() + + if self.mtp_enabled and forward_batch.forward_mode.is_idle(): + return q.new_empty(q.shape[0], q.shape[1], layer.v_head_dim) + + assert k is v, "DeepseekV4 shares k and v" + swa_k = k + + layer_id = layer.layer_id + metadata = self.forward_metadata + core_attn_metadata = metadata.core_attn_metadata + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + if isinstance(core_attn_metadata, DSV4AttnMetadataRadix): + if save_kv_cache: + self.store_cache(layer_id, swa_k, forward_batch) + swa_k_cache = token_to_kv_pool.get_swa_key_buffer_radix(layer_id) + + extra_k_cache, extra_indices, extra_topk_lengths = None, None, None + if compress_ratio == 4: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_attn_metadata.c4_sparse_page_indices + extra_topk_lengths = core_attn_metadata.c4_sparse_topk_lengths + elif compress_ratio == 128: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_attn_metadata.c128_page_indices + extra_topk_lengths = core_attn_metadata.c128_topk_lengths_clamp1 + + swa_window_size = token_to_kv_pool.swa_window_size + assert swa_k_cache.ndim == 2 + k_cache_total_dim = token_to_kv_pool.swa_kv_pool.kv_cache_total_dim + swa_k_cache = swa_k_cache[:, : swa_window_size * k_cache_total_dim].view( + swa_k_cache.shape[0], swa_window_size, 1, k_cache_total_dim + ) + + if extra_k_cache is not None: + page_sizes = { + 4: token_to_kv_pool.page_size // 4, + 128: token_to_kv_pool.page_size // 128, + } + extra_k_cache = extra_k_cache[ + :, : page_sizes[compress_ratio] * k_cache_total_dim + ].view( + extra_k_cache.shape[0], + page_sizes[compress_ratio], + 1, + k_cache_total_dim, + ) + swa_page_indices = core_attn_metadata.swa_page_indices + swa_topk_lengths = core_attn_metadata.swa_topk_lengths + + if self.mtp_enabled: + if swa_page_indices.shape[0] != q.shape[0]: + swa_page_indices = _pad_tensor_to_size( + swa_page_indices, q.shape[0], value=0 + ) + + if swa_topk_lengths.shape[0] != q.shape[0]: + swa_topk_lengths = _pad_tensor_to_size( + swa_topk_lengths, q.shape[0], value=1 + ) + + if q.ndim == 3: + q = q.unsqueeze(1) + if swa_page_indices.ndim == 2: + swa_page_indices = swa_page_indices.unsqueeze(1) + if extra_indices is not None and extra_indices.ndim == 2: + extra_indices = extra_indices.unsqueeze(1) + + assert attn_sink is not None + + flashmla_metadata = core_attn_metadata.get_flashmla_metadata(compress_ratio) + + assert ( + swa_page_indices.shape[-1] % 64 == 0 + ), f"{swa_page_indices.shape=}'s last dimension is not aligned to 64" + if extra_indices is not None: + assert ( + extra_indices.shape[-1] % 64 == 0 + ), f"{extra_indices.shape=}'s last dimension is not aligned to 64" + + input_dict = dict( + q=q, + k_cache=swa_k_cache, + head_dim_v=self.head_dim_v, + block_table=None, + cache_seqlens=None, + tile_scheduler_metadata=flashmla_metadata, + softmax_scale=self.softmax_scale, + is_fp8_kvcache=True, + indices=swa_page_indices, + topk_length=swa_topk_lengths, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_lengths, + ) + + backend = envs.SGLANG_HACK_FLASHMLA_BACKEND.get() + o = flash_mla_with_kvcache_entrypoint(**input_dict, backend=backend)[0] + + o = o.squeeze(1) + return o + + raise NotImplementedError("ragged attention") + + def expand_prefill_casually( + self, + num_tokens: int, + seq_lens: List[int], + extend_seq_lens: List[int], + req_pool_indices: torch.Tensor, + padded_num_tokens: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_lens_casual = torch.empty(num_tokens, **self.cuda_int32_kwargs) + idx_to_req_repeated = torch.empty(num_tokens, **self.cuda_int32_kwargs) + offset = 0 + for i, (kv_len, qo_len) in enumerate(zip(seq_lens, extend_seq_lens)): + out = seq_lens_casual[offset : offset + qo_len] + offset += qo_len + torch.arange(kv_len - qo_len + 1, kv_len + 1, out=out) + idx_to_req_repeated[offset - qo_len : offset].fill_(i) + + assert offset == num_tokens + req_pool_indices_repeated = req_pool_indices[idx_to_req_repeated] + + _need_pad = ( + is_nsa_prefill_cp_round_robin_split() + or envs.SGLANG_DSV4_FIX_ATTN_PADDING.get() + ) + if ( + _need_pad + and padded_num_tokens is not None + and padded_num_tokens > num_tokens + ): + pad_size = padded_num_tokens - num_tokens + seq_lens_casual = torch.nn.functional.pad( + seq_lens_casual, + (0, pad_size), + value=1, + ) + req_pool_indices_repeated = torch.nn.functional.pad( + req_pool_indices_repeated, + (0, pad_size), + value=req_pool_indices_repeated[-1].item(), + ) + + return seq_lens_casual, req_pool_indices_repeated + + def expend_extend_with_same_length( + self, + bs: int, + qo_len: int, + seq_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + ): + seq_lens_casual = seq_lens[:, None] + torch.arange( + -qo_len + 1, 1, **self.cuda_int32_kwargs + ) + seq_lens_casual = seq_lens_casual.flatten() + idx_to_req_repeated = torch.arange( + bs, **self.cuda_int32_kwargs + ).repeat_interleave(qo_len) + req_pool_indices_repeated = req_pool_indices[idx_to_req_repeated] + return seq_lens_casual, req_pool_indices_repeated + + def make_core_attn_metadata( + self, + req_to_token: torch.Tensor, + req_pool_indices_repeated: torch.Tensor, + seq_lens_casual: torch.Tensor, + max_seq_len: int, + out_loc: torch.Tensor, + need_compress: bool = True, + is_prefill: bool = False, + ) -> DSV4AttnMetadataRadix: + assert self.swa_page_size == SWA_WINDOW + + swa_page_indices = self.get_swa_page_indices( + seq_lens_casual=seq_lens_casual, + req_pool_indices_repeated=req_pool_indices_repeated, + ) + + swa_page_indices = _pad_last_dim( + swa_page_indices, multiples_of=PAGE_INDEX_ALIGNED_SIZE + ) + + raw_positions = seq_lens_casual - 1 + swa_topk_lengths = torch.clamp(seq_lens_casual, max=SWA_WINDOW) + + page_table = req_to_token[ + req_pool_indices_repeated, : max_seq_len : self.page_size + ] + page_table = (page_table // self.page_size).to(torch.int32) + + core_attn_metadata = DSV4AttnMetadataRadix( + page_size=self.page_size, + raw_out_loc=out_loc, + seq_lens_casual=seq_lens_casual, + cuda_int32_kwargs=self.cuda_int32_kwargs, + positions_casual=raw_positions, + page_table=page_table, + swa_page_indices=swa_page_indices, + swa_topk_lengths=swa_topk_lengths, + c4_sparse_topk=self.c4_topk, + ) + + if need_compress: + core_attn_metadata.init_compressed_metadata() + if is_prefill and is_nsa_prefill_cp_round_robin_split(): + core_attn_metadata.apply_cp_reindex() + core_attn_metadata.init_flashmla_related() + else: + core_attn_metadata.c4_sparse_topk_lengths = None + core_attn_metadata.c4_sparse_page_indices = None + core_attn_metadata.c1_flashmla_metadata = _create_flashmla_metadata() + core_attn_metadata.c4_flashmla_metadata = None + core_attn_metadata.c128_flashmla_metadata = None + return core_attn_metadata + + def get_swa_page_indices( + self, + seq_lens_casual: torch.Tensor, + req_pool_indices_repeated: torch.Tensor, + ) -> torch.Tensor: + pos_causal = seq_lens_casual - 1 + num_qo_tokens = seq_lens_casual.size(0) + offsets = pos_causal.unsqueeze(1) - torch.arange( + SWA_WINDOW, **self.cuda_int32_kwargs + ).unsqueeze(0) + invalid_offset_mask = offsets < 0 + offsets.masked_fill_(invalid_offset_mask, 0) + raw_indices = self.req_to_token[req_pool_indices_repeated[:, None], offsets] + assert raw_indices.shape == (num_qo_tokens, SWA_WINDOW) + raw_indices.masked_fill_(invalid_offset_mask, -1) + swa_indices = self.token_to_kv_pool.translate_loc_from_full_to_swa(raw_indices) + return swa_indices + + def extract_metadata(self, forward_batch: ForwardBatch): + return self.forward_metadata + + +class DeepseekV4MultiStepBackend(DeepseekV4BackendRadix): + def __init__( + self, model_runner: ModelRunner, topk: int, speculative_num_steps: int + ): + super().__init__(model_runner) + self.model_runner = model_runner + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.attn_backends: List[DeepseekV4BackendRadix] = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + DeepseekV4BackendRadix( + model_runner, + speculative_step_id=i, + topk=self.topk, + speculative_num_steps=self.speculative_num_steps, + ) + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + def on_after_cuda_graph_warmup_pass(self): + for backend in self.attn_backends: + backend.on_after_cuda_graph_warmup_pass() + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + if self.speculative_num_steps == 1: + return + + self.attn_backends[0].init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=forward_batch.seq_lens, + seq_lens_sum=forward_batch.seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, + out_cache_loc=forward_batch.out_cache_loc, + ) + temp_metadata = self.attn_backends[0].forward_metadata + + for i in range(1, self.speculative_num_steps - 1): + self.attn_backends[i].replay_cuda_graph_metadata_from( + bs=bs, + temp_metadata=temp_metadata, + forward_mode=ForwardMode.DECODE, + ) + + +def _pad_tensor_to_size(tensor: torch.Tensor, size: int, *, value: int = 0): + if value == 0: + return torch.cat( + [tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])], + dim=0, + ) + else: + return torch.cat( + [ + tensor, + tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value), + ], + dim=0, + ) diff --git a/python/sglang/srt/layers/attention/indexer_topk_capturer.py b/python/sglang/srt/layers/attention/indexer_topk_capturer.py new file mode 100644 index 000000000000..326ad29481a4 --- /dev/null +++ b/python/sglang/srt/layers/attention/indexer_topk_capturer.py @@ -0,0 +1,118 @@ +import logging +from typing import TYPE_CHECKING, Optional + +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.topk_capturer_base import ( + _GB, + _MB, + BaseTopkCapturer, + BaseTopkCapturerNoop, +) + +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig + +logger = logging.getLogger(__name__) + +INDEX_TOPK = 512 + + +def _count_indexer_layers(model_config: "ModelConfig") -> int: + compress_ratios = getattr(model_config.hf_text_config, "compress_ratios", None) + if compress_ratios is None: + return 0 + return sum(1 for r in compress_ratios if r == 4) + + +class IndexerTopkCapturer(BaseTopkCapturer): + def __init__( + self, + model_config: "ModelConfig", + num_tokens: int, + max_running_requests: int, + device: str, + ): + from sglang.srt.server_args import get_global_server_args + + self.num_indexer_layers = _count_indexer_layers(model_config) + self.index_topk = getattr(model_config.hf_text_config, "index_topk", INDEX_TOPK) + + if self.num_indexer_layers == 0: + logger.warning("No indexer layers found, IndexerTopkCapturer disabled") + self._enabled = False + return + + self._enabled = True + + server_args = get_global_server_args() + max_batch_size = max( + server_args.chunked_prefill_size * server_args.dp_size, + max_running_requests, + ) + + attn_tp_size = get_attention_tp_size() + assert attn_tp_size == 1, "IndexerTopkCapturer now only supports DP attention" + + super().__init__( + num_tokens=num_tokens, + max_batch_size=max_batch_size, + num_layers=self.num_indexer_layers, + topk_size=self.index_topk, + device=device, + ) + + self._log_allocation() + + def _log_allocation(self): + host_size_gb = self.host_cache.get_buffer_size_bytes() / _GB + device_size_mb = self.device_cache.get_buffer_size_bytes() / _MB + logger.info( + f"IndexerTopkCapturer allocated: " + f"num_indexer_layers={self.num_indexer_layers}, index_topk={self.index_topk}, " + f"host_cache={host_size_gb:.2f}GB, device_cache={device_size_mb:.2f}MB" + ) + + def _sync_to_host(self, forward_batch, can_run_graph, cuda_graph_batch): + num_tokens = forward_batch.out_cache_loc.shape[0] + out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() + self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ + :num_tokens, :, : self.topk_size + ].cpu() + + def is_enabled(self) -> bool: + return self._enabled + + +class IndexerTopkCapturerNoop(BaseTopkCapturerNoop): + pass + + +_global_indexer_capturer: Optional[IndexerTopkCapturer] = IndexerTopkCapturerNoop() + + +def get_global_indexer_capturer(): + return _global_indexer_capturer + + +def set_global_indexer_capturer(capturer): + global _global_indexer_capturer + _global_indexer_capturer = capturer + + +def create_indexer_capturer( + enable: bool, + model_config: "ModelConfig", + num_tokens: int, + max_running_requests: int, + device: str, +): + if enable: + capturer = IndexerTopkCapturer( + model_config=model_config, + num_tokens=num_tokens, + max_running_requests=max_running_requests, + device=device, + ) + if capturer.is_enabled(): + return capturer + return IndexerTopkCapturerNoop() diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py index 1cdf65b91c29..138a5ea79d38 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -316,6 +316,8 @@ def vanilla(cls, pool, buf, loc, index_k, index_k_scale): @classmethod def triton(cls, pool, buf, loc, index_k, index_k_scale): + loc = loc.to(torch.int64) + _set_k_and_s_triton( buf=buf, loc=loc, @@ -354,14 +356,14 @@ def _set_k_and_s_triton( f"index_k_scale must be 1D or 2D, got shape {index_k_scale.shape}" ) if _is_hip: - assert buf_numel_per_page == 1 * (128 + 4) + pass else: assert buf_numel_per_page == 64 * (128 + 4) assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__ assert index_head_dim == 128 assert scale_dim == 1 if _is_hip: - assert page_size == 1 + pass else: assert page_size == 64 diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor_v4.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor_v4.py new file mode 100644 index 000000000000..a8d39b0f985e --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor_v4.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +@dataclass +class NopeFp8RopeBf16Pack: + k_nope_fp8: torch.Tensor + k_rope_bf16: torch.Tensor + scale_k_nope_ue8m0: torch.Tensor + + def __post_init__(self): + assert self.k_nope_fp8.shape[-1] == 448 + assert self.k_rope_bf16.shape[-1] == 64 + assert self.scale_k_nope_ue8m0.shape[-1] == 7 + + def slice_pack(self, _slice: Any) -> NopeFp8RopeBf16Pack: + return NopeFp8RopeBf16Pack( + k_nope_fp8=self.k_nope_fp8[_slice], + k_rope_bf16=self.k_rope_bf16[_slice], + scale_k_nope_ue8m0=self.scale_k_nope_ue8m0[_slice], + ) + + + + +class SetKAndS: + @classmethod + def execute(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + cls.triton(pool, buf, loc, nope_fp8_rope_bf16_pack) + + @classmethod + def torch(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + _set_k_and_s_torch(buf, loc, nope_fp8_rope_bf16_pack, pool.page_size) + + @classmethod + def triton(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + _set_k_and_s_triton(buf, loc, nope_fp8_rope_bf16_pack, pool.page_size) + + +def _set_k_and_s_triton( + buf: torch.Tensor, + loc: torch.Tensor, + nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + page_size: int, +): + num_pages, buf_numel_per_page = buf.shape + (num_tokens_to_write,) = loc.shape + + k_nope, k_rope, scale_k_nope = ( + nope_fp8_rope_bf16_pack.k_nope_fp8, + nope_fp8_rope_bf16_pack.k_rope_bf16, + nope_fp8_rope_bf16_pack.scale_k_nope_ue8m0, + ) + + num_tokens_to_write_nope, nope_dim = k_nope.shape + num_tokens_to_write_rope, rope_dim = k_rope.shape + num_tokens_to_write_scale, scale_dim = scale_k_nope.shape + + assert ( + num_tokens_to_write + == num_tokens_to_write_nope + == num_tokens_to_write_rope + == num_tokens_to_write_scale + ) + + assert buf.dtype == torch.uint8 + assert loc.dtype in [torch.int64, torch.int32], f"{loc.dtype=}" + + assert k_nope.dtype == fp8_dtype + assert k_rope.dtype == torch.bfloat16 + assert scale_k_nope.dtype == torch.uint8, f"{scale_k_nope.dtype=}" + + assert buf.is_contiguous() + assert loc.is_contiguous() + assert k_nope.is_contiguous() + assert k_rope.is_contiguous() + assert scale_k_nope.is_contiguous() + + buf_fp8 = buf.view(fp8_dtype) + buf_bf16 = buf.view(torch.bfloat16) + buf_uint8 = buf.view(torch.uint8) + + nope_rope_bytes = nope_dim + rope_dim * 2 + s_offset_nbytes_in_page = page_size * (nope_dim + rope_dim * 2) + + _set_k_and_s_triton_kernel[(num_tokens_to_write,)]( + buf_fp8, + buf_bf16, + buf_uint8, + loc, + k_nope, + k_rope, + scale_k_nope, + k_nope.stride(0), + k_rope.stride(0), + scale_k_nope.stride(0), + PAGE_SIZE=page_size, + BUF_NUMEL_PER_PAGE=buf_numel_per_page, + NUM_NOPE_ELEMS_PER_TOKEN=nope_dim, + NUM_ROPE_ELEMS_PER_TOKEN=rope_dim, + NUM_SCALE_ELEMS_PER_TOKEN=scale_dim, + NUM_NOPE_ROPE_BYTES_PER_TOKEN=nope_rope_bytes, + PADDED_SCALE_ELEMS_PER_TOKEN=scale_dim + 1, + S_OFFSET_NBYTES_IN_PAGE=s_offset_nbytes_in_page, + BLOCK_NOPE=512, + BLOCK_ROPE=64, + BLOCK_SCALE=8, + ) + + +@triton.jit +def _set_k_and_s_triton_kernel( + buf_fp8_ptr, + buf_bf16_ptr, + buf_uint8_ptr, + loc_ptr, + k_nope_ptr, + k_rope_ptr, + scale_k_nope_ptr, + k_nope_ptr_stride_0, + k_rope_ptr_stride_0, + scale_k_nope_ptr_stride_0, + PAGE_SIZE: tl.constexpr, + BUF_NUMEL_PER_PAGE: tl.constexpr, + NUM_NOPE_ELEMS_PER_TOKEN: tl.constexpr, + NUM_ROPE_ELEMS_PER_TOKEN: tl.constexpr, + NUM_NOPE_ROPE_BYTES_PER_TOKEN: tl.constexpr, + NUM_SCALE_ELEMS_PER_TOKEN: tl.constexpr, + PADDED_SCALE_ELEMS_PER_TOKEN: tl.constexpr, + S_OFFSET_NBYTES_IN_PAGE: tl.constexpr, + BLOCK_NOPE: tl.constexpr, + BLOCK_ROPE: tl.constexpr, + BLOCK_SCALE: tl.constexpr, +): + token_id = tl.program_id(0) + loc = tl.load(loc_ptr + token_id) + + nope_range = tl.arange(0, BLOCK_NOPE) + nope_mask = nope_range < NUM_NOPE_ELEMS_PER_TOKEN + in_k_nope_offsets = token_id * k_nope_ptr_stride_0 + nope_range + k_nope = tl.load(k_nope_ptr + in_k_nope_offsets, mask=nope_mask, other=0.0) + + rope_range = tl.arange(0, BLOCK_ROPE) + in_k_rope_offsets = token_id * k_rope_ptr_stride_0 + rope_range + k_rope = tl.load(k_rope_ptr + in_k_rope_offsets) + + scale_range = tl.arange(0, BLOCK_SCALE) + scale_mask = scale_range < NUM_SCALE_ELEMS_PER_TOKEN + in_scale_k_offsets = token_id * scale_k_nope_ptr_stride_0 + scale_range + k_scale = tl.load(scale_k_nope_ptr + in_scale_k_offsets, mask=scale_mask, other=0) + + loc_page_index = loc // PAGE_SIZE + loc_token_offset_in_page = loc % PAGE_SIZE + + out_k_nope_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE + + loc_token_offset_in_page * NUM_NOPE_ROPE_BYTES_PER_TOKEN + + nope_range + ) + + out_k_rope_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE // 2 + + loc_token_offset_in_page * (NUM_NOPE_ROPE_BYTES_PER_TOKEN // 2) + + NUM_NOPE_ELEMS_PER_TOKEN // 2 + + rope_range + ) + + out_s_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE + + S_OFFSET_NBYTES_IN_PAGE + + loc_token_offset_in_page * PADDED_SCALE_ELEMS_PER_TOKEN + + scale_range + ) + + tl.store(buf_fp8_ptr + out_k_nope_offsets, k_nope, mask=nope_mask) + tl.store(buf_bf16_ptr + out_k_rope_offsets, k_rope) + tl.store(buf_uint8_ptr + out_s_offsets, k_scale, mask=scale_mask) + + +def _set_k_and_s_torch( + buf: torch.Tensor, + loc: torch.Tensor, + nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + page_size: int, +): + num_pages, buf_numel_per_page = buf.shape + (num_tokens_to_write,) = loc.shape + + k_nope, k_rope, scale_k_nope = ( + nope_fp8_rope_bf16_pack.k_nope_fp8, + nope_fp8_rope_bf16_pack.k_rope_bf16, + nope_fp8_rope_bf16_pack.scale_k_nope_ue8m0, + ) + + num_tokens_to_write_nope, nope_dim = k_nope.shape + num_tokens_to_write_rope, rope_dim = k_rope.shape + num_tokens_to_write_scale, scale_dim = scale_k_nope.shape + + assert ( + num_tokens_to_write + == num_tokens_to_write_nope + == num_tokens_to_write_rope + == num_tokens_to_write_scale + ), f"{num_tokens_to_write=} {num_tokens_to_write_nope=} {num_tokens_to_write_rope=} {num_tokens_to_write_scale=}" + + assert buf.dtype == torch.uint8 + assert loc.dtype in [ + torch.int64, + torch.int32, + ], f"{loc.dtype=}" + + assert k_nope.dtype == fp8_dtype + assert k_rope.dtype == torch.bfloat16 + assert scale_k_nope.dtype == torch.uint8 + + assert buf.is_contiguous() + assert loc.is_contiguous() + assert k_nope.is_contiguous() + assert k_rope.is_contiguous() + assert scale_k_nope.is_contiguous() + + buf_fp8 = buf.view(fp8_dtype).flatten() + buf_bf16 = buf.view(torch.bfloat16).flatten() + buf_scale = buf.view(torch.uint8).flatten() + + loc_page_index = loc // page_size + loc_token_offset_in_page = loc % page_size + + s_offset_nbytes_in_page = page_size * (nope_dim + rope_dim * 2) + + nope_offset = loc_page_index * buf_numel_per_page + loc_token_offset_in_page * ( + nope_dim + rope_dim * 2 + ) + + rope_offset = ( + loc_page_index * buf_numel_per_page // 2 + + (loc_token_offset_in_page * (nope_dim + rope_dim * 2) + nope_dim) // 2 + ) + + s_offset = ( + loc_page_index * buf_numel_per_page + + s_offset_nbytes_in_page + + loc_token_offset_in_page * (scale_dim + 1) + ) + + for i in range(num_tokens_to_write): + buf_fp8[nope_offset[i] : nope_offset[i] + nope_dim] = k_nope[i] + buf_bf16[rope_offset[i] : rope_offset[i] + rope_dim] = k_rope[i] + buf_scale[s_offset[i] : s_offset[i] + scale_dim] = scale_k_nope[i] diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index d17523b41955..51ba1bcab7fc 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -10,11 +10,19 @@ from sglang.srt.layers.layernorm import LayerNorm from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.utils import MultiPlatformOp -from sglang.srt.utils import add_prefix, ceil_align, is_cuda, is_hip, is_npu +from sglang.srt.utils import ( + add_prefix, + ceil_align, + get_device_sm, + is_cuda, + is_hip, + is_npu, +) global _use_multi_stream _is_cuda = is_cuda() _is_hip = is_hip() +_is_sm103 = _is_cuda and get_device_sm() == 103 _is_npu = is_npu() _is_fp8_fnuz = is_fp8_fnuz() if _is_cuda: @@ -46,6 +54,9 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn DUAL_STREAM_TOKEN_THRESHOLD = 1024 if _is_cuda else 0 @@ -112,12 +123,13 @@ def topk_transform( def rotate_activation(x: torch.Tensor) -> torch.Tensor: - assert x.dtype == torch.bfloat16 - # from sgl_kernel import hadamard_transform - if _is_hip: + if _is_hip or _is_sm103: from fast_hadamard_transform import hadamard_transform else: - from sgl_kernel import hadamard_transform + try: + from sgl_kernel import hadamard_transform + except ImportError: + from fast_hadamard_transform import hadamard_transform hidden_size = x.size(-1) assert ( @@ -371,7 +383,9 @@ def _get_topk_paged( if _is_cuda: if schedule_metadata is None: schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata( - seqlens_32, blocksize, self.sm_count + seqlens_32.unsqueeze(-1) if seqlens_32.dim() == 1 else seqlens_32, + blocksize, + self.sm_count, ) assert len(q_fp8.shape) == 3 @@ -420,7 +434,7 @@ def _get_topk_paged( q_fp8, kv_cache_fp8, weights, - seqlens_32, + seqlens_32.unsqueeze(-1) if seqlens_32.dim() == 1 else seqlens_32, block_tables, schedule_metadata, max_seq_len, @@ -738,7 +752,7 @@ def _get_topk_ragged_with_cp( actual_seq_q_list.append(actual_seq_q) batch_idx_list.append(batch_idx) - k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) + k_fp8 = torch.cat(k_fp8_list, dim=0).view(fp8_dtype) k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) kv_fp8 = (k_fp8, k_scale) ks = torch.cat(ks_list, dim=0) @@ -779,7 +793,7 @@ def _get_topk_ragged_with_cp( block_tables[0], ) - k_fp8 = k_fp8.view(torch.float8_e4m3fn) + k_fp8 = k_fp8.view(fp8_dtype) k_scale = k_scale.view(torch.float32).squeeze(-1) kv_fp8 = (k_fp8, k_scale) ks = torch.full((actual_seq_q,), offset, dtype=torch.int32, device="cuda") @@ -872,7 +886,7 @@ def forward_indexer( block_tables[i], ) - k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous() + k_fp8 = k_fp8.view(fp8_dtype).unsqueeze(0).contiguous() k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous() index_score = fp8_index( diff --git a/python/sglang/srt/layers/attention/nsa/quant_k_cache.py b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py index 5454071b897d..fa9c9ba5dda7 100644 --- a/python/sglang/srt/layers/attention/nsa/quant_k_cache.py +++ b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py @@ -2,6 +2,10 @@ import triton import triton.language as tl +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + def quantize_k_cache(cache_k): return _quantize_k_cache_fast_wrapped(cache_k) @@ -75,7 +79,7 @@ def _quantize_k_cache_ref( result = torch.empty( (num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), - dtype=torch.float8_e4m3fn, + dtype=fp8_dtype, device=input_k_cache.device, ) result_k_nope_part = result[..., :dv] @@ -100,7 +104,7 @@ def _quantize_k_cache_ref( ..., tile_idx * tile_size : (tile_idx + 1) * tile_size ].float() / cur_scale_factors_inv.float() - ).to(torch.float8_e4m3fn) + ).to(fp8_dtype) result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( cur_quantized_nope ) @@ -152,7 +156,7 @@ def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128): output = torch.empty( (num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope), - dtype=torch.float8_e4m3fn, + dtype=fp8_dtype, device=k_nope.device, ) output_nope_q = output[..., :dim_nope] @@ -180,8 +184,8 @@ def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128): GROUP_SIZE=group_size, DIM_NOPE=dim_nope, DIM_ROPE=dim_rope, - FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, - FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + FP8_MIN=torch.finfo(fp8_dtype).min, + FP8_MAX=torch.finfo(fp8_dtype).max, ) return output @@ -232,7 +236,7 @@ def _quantize_k_cache_fast_separate(k_nope, k_rope, group_size: int = 128): # Create typed views for the kernel to write into # Fixed byte layout for nope_part: [nope_fp8 (dim_nope bytes) | scales_fp32 (num_tiles*4 bytes)] # Fixed byte layout for rope_part: [rope_bf16 (dim_rope*2 bytes)] - nope_q_view = nope_part_u8[:, :dim_nope].view(torch.float8_e4m3fn) + nope_q_view = nope_part_u8[:, :dim_nope].view(fp8_dtype) nope_s_view = nope_part_u8[:, dim_nope:].view(torch.float32) rope_view = rope_part_u8.view(torch.bfloat16) @@ -256,8 +260,8 @@ def _quantize_k_cache_fast_separate(k_nope, k_rope, group_size: int = 128): GROUP_SIZE=group_size, DIM_NOPE=dim_nope, DIM_ROPE=dim_rope, - FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, - FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + FP8_MIN=torch.finfo(fp8_dtype).min, + FP8_MAX=torch.finfo(fp8_dtype).max, ) # Add middle dimension for compatibility with set_mla_kv_buffer_triton diff --git a/python/sglang/srt/layers/attention/nsa/quant_k_cache_v4.py b/python/sglang/srt/layers/attention/nsa/quant_k_cache_v4.py new file mode 100644 index 000000000000..6c1d4b96068f --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/quant_k_cache_v4.py @@ -0,0 +1,153 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.nsa.index_buf_accessor_v4 import NopeFp8RopeBf16Pack +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +@triton.jit +def _quant_k_cache_fused_kernel( + k_bf16_ptr, + k_nope_fp8_ptr, + k_rope_bf16_ptr, + scale_k_nope_uint8_ptr, + k_bf16_stride_0, + k_nope_fp8_stride_0, + k_rope_bf16_stride_0, + scale_stride_0, + DIM_NOPE: tl.constexpr, + DIM_ROPE: tl.constexpr, + TILE_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + EPS: tl.constexpr, +): + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + if tile_id == NUM_TILES: + rope_range = tl.arange(0, TILE_SIZE) + rope_mask = rope_range < DIM_ROPE + + in_rope_offsets = token_id * k_bf16_stride_0 + DIM_NOPE + rope_range + rope_data = tl.load(k_bf16_ptr + in_rope_offsets, mask=rope_mask, other=0.0) + + out_rope_offsets = token_id * k_rope_bf16_stride_0 + rope_range + tl.store(k_rope_bf16_ptr + out_rope_offsets, rope_data, mask=rope_mask) + else: + tile_range = tl.arange(0, TILE_SIZE) + + in_tile_offsets = token_id * k_bf16_stride_0 + tile_id * TILE_SIZE + tile_range + x_bf16 = tl.load(k_bf16_ptr + in_tile_offsets) + x_fp32 = x_bf16.to(tl.float32) + + abs_x = tl.abs(x_fp32) + max_abs = tl.max(abs_x) + max_abs_clamped = tl.maximum(max_abs, EPS) + scale = max_abs_clamped / FP8_MAX + + log2_scale = tl.log2(scale) + ceil_log2 = tl.math.ceil(log2_scale) + scale_pow2_fp32 = tl.exp2(ceil_log2) + scale_inv = 1.0 / scale_pow2_fp32 + x_scaled = x_fp32 * scale_inv + x_fp8 = tl.clamp(x_scaled, FP8_MIN, FP8_MAX).to(k_nope_fp8_ptr.dtype.element_ty) + + out_fp8_offsets = ( + token_id * k_nope_fp8_stride_0 + tile_id * TILE_SIZE + tile_range + ) + tl.store(k_nope_fp8_ptr + out_fp8_offsets, x_fp8) + + exponent = ceil_log2.to(tl.int32) + scale_uint8 = (exponent + 127).to(tl.uint8) + + out_scale_offset = token_id * scale_stride_0 + tile_id + tl.store(scale_k_nope_uint8_ptr + out_scale_offset, scale_uint8) + + +def quant_to_nope_fp8_rope_bf16_pack_triton( + k_bf16: torch.Tensor, +) -> NopeFp8RopeBf16Pack: + assert k_bf16.dtype == torch.bfloat16 + num_tokens, hidden_dim = k_bf16.shape + assert hidden_dim == 512 + dim_nope = 448 + dim_rope = 64 + tile_size = 64 + num_tiles = dim_nope // tile_size + + k_bf16 = k_bf16.contiguous() + + k_nope_fp8 = torch.empty( + (num_tokens, dim_nope), dtype=fp8_dtype, device=k_bf16.device + ) + k_rope_bf16 = torch.empty( + (num_tokens, dim_rope), dtype=torch.bfloat16, device=k_bf16.device + ) + scale_k_nope_ue8m0 = torch.empty( + (num_tokens, num_tiles), dtype=torch.uint8, device=k_bf16.device + ) + + fp8_dtype_info = torch.finfo(fp8_dtype) + + grid = (num_tokens, num_tiles + 1) + _quant_k_cache_fused_kernel[grid]( + k_bf16, + k_nope_fp8, + k_rope_bf16, + scale_k_nope_ue8m0, + k_bf16.stride(0), + k_nope_fp8.stride(0), + k_rope_bf16.stride(0), + scale_k_nope_ue8m0.stride(0), + DIM_NOPE=dim_nope, + DIM_ROPE=dim_rope, + TILE_SIZE=tile_size, + NUM_TILES=num_tiles, + FP8_MIN=fp8_dtype_info.min, + FP8_MAX=fp8_dtype_info.max, + EPS=1e-8, + ) + + return NopeFp8RopeBf16Pack( + k_nope_fp8=k_nope_fp8, + k_rope_bf16=k_rope_bf16, + scale_k_nope_ue8m0=scale_k_nope_ue8m0, + ) + + +def quant_to_nope_fp8_rope_bf16_pack(k_bf16: torch.Tensor) -> NopeFp8RopeBf16Pack: + assert k_bf16.dtype == torch.bfloat16 + _num_tokens, hidden_dim = k_bf16.shape + assert hidden_dim == 512 + dim_nope = 448 + dim_rope = 64 + + k_nope_bf16, k_rope_bf16 = k_bf16.split([dim_nope, dim_rope], dim=-1) + + tile_size = 64 + num_tiles = dim_nope // tile_size + + x = k_nope_bf16.contiguous().view(-1, num_tiles, tile_size) + scale = x.abs().amax(dim=-1).float() / 448.0 + scale_pow2_fp32 = _cast_scale_inv_to_ue8m0(scale, out_dtype=torch.float32) + scale_k_nope_ue8m0 = scale_pow2_fp32.to(torch.float8_e8m0fnu) + k_nope_fp8 = (x.float() / scale_pow2_fp32.unsqueeze(-1)).to(fp8_dtype) + k_nope_fp8 = k_nope_fp8.view(-1, tile_size * num_tiles) + scale_k_nope_ue8m0 = scale_k_nope_ue8m0.view(torch.uint8) + + return NopeFp8RopeBf16Pack( + k_nope_fp8=k_nope_fp8, + k_rope_bf16=k_rope_bf16.contiguous(), + scale_k_nope_ue8m0=scale_k_nope_ue8m0, + ) + + +def _cast_scale_inv_to_ue8m0( + scales_inv: torch.Tensor, out_dtype=torch.float32 +) -> torch.Tensor: + return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype) diff --git a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py index 1088bd3d171b..8af101aa334a 100644 --- a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py +++ b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py @@ -1,11 +1,11 @@ -from typing import Optional, Tuple +import functools +from typing import Any, Optional, Tuple import tilelang import tilelang.language as T import torch -from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz -from sglang.srt.utils import is_gfx95_supported, is_hip +from sglang.srt.utils import is_hip tilelang.set_log_level("WARNING") @@ -13,19 +13,18 @@ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, } -# TL_DISABLE_FAST_MATH has deprecated in v0.1.7.post1 tilelang -if hasattr(tilelang.PassConfigKey, "TL_DISABLE_FAST_MATH"): - pass_configs[tilelang.PassConfigKey.TL_DISABLE_FAST_MATH] = True -elif hasattr(tilelang.PassConfigKey, "TL_ENABLE_FAST_MATH"): - pass_configs[tilelang.PassConfigKey.TL_ENABLE_FAST_MATH] = False - -_is_hip = is_hip() -_is_gfx95_supported = is_gfx95_supported() -_is_fp8_fnuz = is_fp8_fnuz() BF16 = "bfloat16" -FP8 = "float8_e4m3fnuz" if _is_fp8_fnuz else "float8_e4m3" +if is_hip(): + FP8 = "float8_e5m2fnuz" + FP8_ = torch.float8_e5m2 +else: + FP8 = "float8_e4m3" + FP8_ = torch.float8_e4m3fn FP32 = "float32" +INT32 = "int32" + +_is_hip = is_hip() def fast_log2_ceil(x): @@ -49,8 +48,8 @@ def act_quant_kernel( N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False ): M = T.symbolic("M") - fp8_min = -224.0 if _is_fp8_fnuz else -448.0 - fp8_max = 224.0 if _is_fp8_fnuz else 448.0 + fp8_min = -448.0 + fp8_max = 448.0 fp8_max_inv = 1 / fp8_max num_stages = 0 if round_scale else 2 blk_m = 32 @@ -115,10 +114,7 @@ def act_quant( x.size(-1) % block_size == 0 ), f"Last dimension size must be divisible by block_size (block_size={block_size})" N = x.size(-1) - if _is_fp8_fnuz: - y = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) - else: - y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y = torch.empty_like(x, dtype=FP8_) s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) kernel = act_quant_kernel(N, round_scale=scale_fmt is not None) kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) @@ -787,23 +783,117 @@ def tilelang_sparse_fwd( topk = indices.shape[-1] assert topk == 2048 if _is_hip: - if _is_gfx95_supported: - kernel = sparse_attention_fwd_kernel_v1( - num_heads, d_v, tail_dim, topk, sm_scale=sm_scale, num_stages=1 - ) - else: # reduce LDS usage on gfx942 target - kernel = sparse_attention_fwd_kernel_v1( - num_heads, - d_v, - tail_dim, - topk, - sm_scale=sm_scale, - block_I=32, - num_stages=1, - threads=128, - ) + kernel = sparse_attention_fwd_kernel_v1( + num_heads, d_v, tail_dim, topk, sm_scale=sm_scale, num_stages=1 + ) else: kernel = sparse_attention_fwd_kernel_v2( num_heads, d_v, tail_dim, topk, sm_scale=sm_scale ) return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore + + +@functools.cache +def fp8_paged_mqa_logits_kernel( + head_dim: int = 128, + num_heads: int = 64, + block_size: int = 64, + clear_accum: bool = True, +) -> Any: + N = T.symbolic("batch_size") + L = T.symbolic("max_table_length") + S = T.symbolic("max_seq_len") + C = T.symbolic("num_blocks") + B = block_size + D = head_dim + H = num_heads + d_0, d_1 = T.dynamic("d_0, d_1") + + assert D % 4 == 0 + assert H % 4 == 0 + assert D == 128 + + @tilelang.jit + def fp8_paged_mqa_logits( + q: T.Tensor[(N, H, D), FP8], + kvcache: T.StridedTensor[(C, B, D), (d_0, D, 1), FP8], + kvcache_scale: T.StridedTensor[(C, B), (d_1, 1), FP32], + weight: T.Tensor[(N, H), FP32], + seq_lens: T.Tensor[(N,), INT32], + page_table: T.Tensor[(N, L), INT32], + o: T.Tensor[(N, S), FP32], + ) -> None: + _ = N, L, S, C, D, H, B, d_0, d_1 + with T.Kernel(N) as bx: + seq_len = seq_lens[bx] + q_smem = T.alloc_shared((H, D), FP8) + q_s_frag = T.alloc_fragment((H,), FP32) + T.copy(q[bx, 0, 0], q_smem) + T.copy(weight[bx, 0], q_s_frag) + + for i in T.Pipelined(T.ceildiv(seq_len, B), num_stages=2): + page = page_table[bx, i] + k_smem = T.alloc_shared((B, D), FP8) + k_s_frag = T.alloc_fragment((B,), FP32) + T.copy(kvcache[page, 0, 0], k_smem) + T.copy(kvcache_scale[page, 0], k_s_frag) + + logits = T.alloc_fragment((B, H), FP32) + if not clear_accum: + T.fill(logits, 0.0) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=clear_accum, + ) + + for h, j in T.Parallel(H, B): + logits[j, h] = T.max(logits[j, h], 0.0) * q_s_frag[h] + logits_sum = T.alloc_fragment((B,), FP32) + T.reduce_sum(logits, logits_sum, dim=1) + for j in T.Parallel(B): + logits_sum[j] *= k_s_frag[j] + T.copy(logits_sum, o[bx, i * B]) + + return fp8_paged_mqa_logits + + +def tilelang_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + assert head_dim == 128, "TODO" + assert block_size == 64, "TODO" + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + kernel = fp8_paged_mqa_logits_kernel( + head_dim=head_dim, + num_heads=num_heads, + block_size=block_size, + clear_accum=clean_logits, + ) + q_fp8 = q_fp8.view(batch_size, num_heads, head_dim) + kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + kvcache = kvcache_fp8[..., : block_size * head_dim].view(dtype=FP8_) + kvcache = kvcache.view(-1, block_size, head_dim) + kvcache_scale = kvcache_fp8[..., block_size * head_dim :].view(dtype=torch.float32) + kernel(q_fp8, kvcache, kvcache_scale, weight, seq_lens, page_table, logits) + return logits diff --git a/python/sglang/srt/layers/attention/nsa/triton_kernel.py b/python/sglang/srt/layers/attention/nsa/triton_kernel.py index 9d970b83a96a..58537b838ad1 100644 --- a/python/sglang/srt/layers/attention/nsa/triton_kernel.py +++ b/python/sglang/srt/layers/attention/nsa/triton_kernel.py @@ -5,6 +5,15 @@ import triton.language as tl +def _is_hip() -> bool: + return hasattr(torch.version, "hip") and torch.version.hip is not None + + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + # Triton implementation @triton.jit def _act_quant_kernel( @@ -109,7 +118,7 @@ def act_quant( M = x_flat.size(0) # Allocate output tensors - y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y = torch.empty_like(x, dtype=fp8_dtype) y_flat = y.view(-1, N) s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) s_flat = s.view(-1, N // block_size) @@ -120,6 +129,11 @@ def act_quant( grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) round_scale = scale_fmt is not None + if round_scale: + num_stages = 1 if _is_hip() else 0 + else: + num_stages = 2 + _act_quant_kernel[grid]( x_flat, y_flat, @@ -130,7 +144,7 @@ def act_quant( round_scale=round_scale, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_stages=0 if round_scale else 2, + num_stages=num_stages, ) return y, s diff --git a/python/sglang/srt/layers/attention/nsa/utils.py b/python/sglang/srt/layers/attention/nsa/utils.py index b9a3c7c6bcb1..ae77d9dfc0dc 100644 --- a/python/sglang/srt/layers/attention/nsa/utils.py +++ b/python/sglang/srt/layers/attention/nsa/utils.py @@ -8,6 +8,7 @@ import triton import triton.language as tl +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import ( attn_tp_all_gather_into_tensor, get_attention_tp_group, @@ -250,11 +251,71 @@ def nsa_use_prefill_cp(forward_batch, nsa_enable_prefill_cp=None): and nsa_enable_prefill_cp and forward_batch.forward_mode.is_context_parallel_extend() ): + if envs.SGLANG_DEBUG_HACK_CP_ASSERT_PURE_EXTEND.get(): + _assert_cp_pure_extend(forward_batch) return True else: return False +def _assert_cp_pure_extend(forward_batch: "ForwardBatch") -> None: + from sglang.srt.model_executor.forward_batch_info import ForwardMode + + mode = forward_batch.forward_mode + assert mode == ForwardMode.EXTEND, ( + f"SGLANG_DEBUG_HACK_CP_ASSERT_PURE_EXTEND: expected ForwardMode.EXTEND, got {mode}. " + "CP round-robin may be silently enabled on MIXED batches." + ) + + extend_lens = list(forward_batch.extend_seq_lens_cpu) + seq_lens = list(forward_batch.seq_lens_cpu.tolist()) + assert len(extend_lens) == len( + seq_lens + ), f"extend_seq_lens_cpu ({len(extend_lens)}) != seq_lens_cpu ({len(seq_lens)})" + mismatched = [ + (i, e, s) for i, (e, s) in enumerate(zip(extend_lens, seq_lens)) if e != s + ] + assert not mismatched, ( + f"SGLANG_DEBUG_HACK_CP_ASSERT_PURE_EXTEND: found chunked-prefill continuation " + f"(extend_seq_lens != seq_lens) at {mismatched[:5]}{'...' if len(mismatched) > 5 else ''}. " + "A request has prior KV cache; CP round-robin may have domain mismatch." + ) + + + +def assert_tensor_identical_across_cp_ranks( + t: torch.Tensor, tag: str, forward_batch +) -> None: + if not (is_nsa_enable_prefill_cp() and nsa_use_prefill_cp(forward_batch)): + return + cp_size = get_attention_tp_size() + if cp_size <= 1: + return + + t_contig = t.contiguous() + gathered = t_contig.new_empty(t_contig.shape[0] * cp_size, *t_contig.shape[1:]) + attn_tp_all_gather_into_tensor(gathered, t_contig) + chunks = gathered.view(cp_size, *t_contig.shape) + rank0 = chunks[0] + for r in range(1, cp_size): + if torch.equal(rank0, chunks[r]): + continue + rank0_f = rank0.float() + chunks_r_f = chunks[r].float() + both_nan = torch.isnan(rank0_f) & torch.isnan(chunks_r_f) + diff = (rank0_f - chunks_r_f).abs() + diff = torch.where(both_nan, torch.zeros_like(diff), diff) + if torch.equal(diff, torch.zeros_like(diff)): + continue + raise AssertionError( + f"[CP rank consistency] {tag}: rank {r} disagrees with rank 0. " + f"max_abs_diff={diff.max().item():.3e}, " + f"mean_abs_diff={diff.mean().item():.3e}, " + f"shape={tuple(t_contig.shape)}, dtype={t_contig.dtype}, " + f"my_rank={get_attention_tp_rank()}" + ) + + def cp_attn_tp_all_gather_reorganazied_into_tensor( input_: torch.Tensor, total_len, attn_tp_size, forward_batch, stream_op ): @@ -296,6 +357,86 @@ def cp_attn_tp_all_gather_reorganazied_into_tensor( return outputs +class CpRoundRobinRerange: + + @classmethod + def execute(cls, gathered: torch.Tensor, cp_size: int) -> torch.Tensor: + return cls.triton(gathered, cp_size) + + @classmethod + def vanilla(cls, gathered: torch.Tensor, cp_size: int) -> torch.Tensor: + out_shape = gathered.shape + return ( + gathered.view(cp_size, -1, *out_shape[1:]) + .transpose(0, 1) + .reshape(out_shape) + ) + + @classmethod + def triton(cls, gathered: torch.Tensor, cp_size: int) -> torch.Tensor: + assert ( + gathered.is_cuda + ), f"gathered must be on CUDA, got device={gathered.device}" + assert gathered.dtype in ( + torch.bfloat16, + torch.float16, + torch.float32, + ), f"unsupported dtype {gathered.dtype}" + assert ( + gathered.ndim >= 1 + ), f"gathered.ndim must be >=1, got shape={tuple(gathered.shape)}" + assert ( + gathered.is_contiguous() + ), f"gathered must be contiguous, got strides={gathered.stride()} shape={tuple(gathered.shape)}" + assert ( + isinstance(cp_size, int) and cp_size >= 1 + ), f"cp_size must be positive int, got {cp_size!r}" + total_rows = gathered.shape[0] + assert ( + total_rows % cp_size == 0 + ), f"total_rows={total_rows} not divisible by cp_size={cp_size}" + per_rank_len = total_rows // cp_size + + out = torch.empty_like(gathered) + if total_rows == 0 or gathered.numel() == 0: + return out + view_in = gathered.reshape(total_rows, -1) + view_out = out.view(total_rows, -1) + hidden = view_in.shape[1] + + BLOCK_H = 1024 + grid = (total_rows, triton.cdiv(hidden, BLOCK_H)) + _cp_round_robin_rerange_kernel[grid]( + view_in, + view_out, + per_rank_len, + hidden, + cp_size=cp_size, + BLOCK_H=BLOCK_H, + ) + return out + + +@triton.jit +def _cp_round_robin_rerange_kernel( + in_ptr, + out_ptr, + per_rank_len, + hidden, + cp_size: tl.constexpr, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + col_block = tl.program_id(1) + rank = row % cp_size + local = row // cp_size + src_row = rank * per_rank_len + local + offs = col_block * BLOCK_H + tl.arange(0, BLOCK_H) + mask = offs < hidden + x = tl.load(in_ptr + src_row * hidden + offs, mask=mask) + tl.store(out_ptr + row * hidden + offs, x, mask=mask) + + def cp_all_gather_rerange_output(input_tensor, cp_size, forward_batch, stream): """ # for in-seq-split @@ -331,6 +472,8 @@ def cp_all_gather_rerange_output(input_tensor, cp_size, forward_batch, stream): output_tensor, input_tensor, ) + if envs.SGLANG_OPT_CP_REARRANGE_TRITON.get(): + return CpRoundRobinRerange.execute(output_tensor, cp_size) out_shape = output_tensor.shape output_tensor = ( output_tensor.view(cp_size, -1, *out_shape[1:]) diff --git a/python/sglang/srt/layers/attention/triton_ops/compressed_metadata.py b/python/sglang/srt/layers/attention/triton_ops/compressed_metadata.py new file mode 100644 index 000000000000..f6933e4b55d0 --- /dev/null +++ b/python/sglang/srt/layers/attention/triton_ops/compressed_metadata.py @@ -0,0 +1,259 @@ + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + + + +@triton.jit +def _init_compressed_attn_metadata_kernel( + seq_lens_ptr, + positions_ptr, + raw_out_loc_ptr, + page_table_ptr, + c4_out_loc_ptr, + c4_positions_ptr, + c4_seq_lens_raw_ptr, + c4_seq_lens_clamp1_ptr, + c128_out_loc_ptr, + c128_positions_ptr, + c128_seq_lens_clamp1_ptr, + c128_page_indices_ptr, + bs, + max_pages, + page_size: tl.constexpr, + c128_max_seq_len: tl.constexpr, + c128_page_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + COMPUTE_PAGE_INDICES: tl.constexpr, +): + batch_id = tl.program_id(0) + if batch_id >= bs: + return + + seq_len = tl.load(seq_lens_ptr + batch_id) + position = tl.load(positions_ptr + batch_id) + raw_out_loc = tl.load(raw_out_loc_ptr + batch_id) + + c4_should_compress = (seq_len % 4) == 0 + c4_out_loc = tl.where(c4_should_compress, raw_out_loc // 4, 0) + c4_positions = position & (~3) + c4_seq_lens_raw = seq_len // 4 + c4_seq_lens_clamp1 = tl.maximum(c4_seq_lens_raw, 1) + + tl.store(c4_out_loc_ptr + batch_id, c4_out_loc) + tl.store(c4_positions_ptr + batch_id, c4_positions) + tl.store(c4_seq_lens_raw_ptr + batch_id, c4_seq_lens_raw) + tl.store(c4_seq_lens_clamp1_ptr + batch_id, c4_seq_lens_clamp1) + + c128_should_compress = (seq_len % 128) == 0 + c128_out_loc = tl.where(c128_should_compress, raw_out_loc // 128, 0) + c128_positions = position & (~127) + c128_seq_lens_raw = seq_len // 128 + c128_seq_lens_clamp1 = tl.maximum(c128_seq_lens_raw, 1) + + tl.store(c128_out_loc_ptr + batch_id, c128_out_loc) + tl.store(c128_positions_ptr + batch_id, c128_positions) + tl.store(c128_seq_lens_clamp1_ptr + batch_id, c128_seq_lens_clamp1) + + if COMPUTE_PAGE_INDICES: + page_indices_base = batch_id * c128_max_seq_len + for block_start in range(0, c128_max_seq_len, BLOCK_SIZE): + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < c128_max_seq_len + + page_idx = offsets // c128_page_size + offset_in_page = offsets % c128_page_size + + page_mask = mask & (page_idx < max_pages) + page_table_vals = tl.load( + page_table_ptr + batch_id * max_pages + page_idx, + mask=page_mask, + other=0, + ) + + c_page_indices_vals = page_table_vals * c128_page_size + offset_in_page + + valid_mask = offsets < c128_seq_lens_raw + c_page_indices_vals = tl.where(valid_mask, c_page_indices_vals, -1) + + tl.store( + c128_page_indices_ptr + page_indices_base + offsets, + c_page_indices_vals, + mask=mask, + ) + + +def _init_compressed_attn_metadata_triton( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: Optional[torch.Tensor] = None, + page_size: int = 0, + compute_page_indices: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], +]: + bs = seq_lens.shape[0] + device = seq_lens.device + + c4_out_loc = torch.empty(bs, dtype=torch.int32, device=device) + c4_positions = torch.empty(bs, dtype=torch.int32, device=device) + c4_seq_lens_raw = torch.empty(bs, dtype=torch.int32, device=device) + c4_seq_lens_clamp1 = torch.empty(bs, dtype=torch.int32, device=device) + + c128_out_loc = torch.empty(bs, dtype=torch.int32, device=device) + c128_positions = torch.empty(bs, dtype=torch.int32, device=device) + c128_seq_lens_clamp1 = torch.empty(bs, dtype=torch.int32, device=device) + + if compute_page_indices: + assert ( + page_table is not None + ), "page_table required when compute_page_indices=True" + assert page_size > 0, "page_size required when compute_page_indices=True" + max_pages = page_table.shape[1] + c128_page_size = page_size // 128 + c128_max_seq_len = c128_page_size * max_pages + c128_page_indices = torch.empty( + bs, c128_max_seq_len, dtype=torch.int32, device=device + ) + BLOCK_SIZE = triton.next_power_of_2(max(c128_page_size, 64)) + else: + max_pages = 0 + c128_page_size = 1 + c128_max_seq_len = 0 + c128_page_indices = None + BLOCK_SIZE = 64 + if page_table is None: + page_table = torch.empty(0, dtype=torch.int32, device=device) + + grid = (bs,) + _init_compressed_attn_metadata_kernel[grid]( + seq_lens, + positions, + raw_out_loc, + page_table, + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + ( + c128_page_indices + if c128_page_indices is not None + else torch.empty(0, dtype=torch.int32, device=device) + ), + bs, + max_pages, + page_size if page_size > 0 else 128, + c128_max_seq_len, + c128_page_size, + BLOCK_SIZE, + compute_page_indices, + ) + + return ( + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + c128_page_indices, + ) + + + + +def init_compressed_metadata( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: Optional[torch.Tensor] = None, + page_size: int = 0, + compute_page_indices: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], +]: + return _init_compressed_attn_metadata_triton( + seq_lens, + positions, + raw_out_loc, + page_table, + page_size, + compute_page_indices, + ) + + + + +def init_c4_metadata( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ( + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + _, + _, + _, + _, + ) = init_compressed_metadata( + seq_lens, + positions, + raw_out_loc, + page_table=None, + page_size=0, + compute_page_indices=False, + ) + return c4_out_loc, c4_positions, c4_seq_lens_raw, c4_seq_lens_clamp1 + + +def init_c128_metadata( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: torch.Tensor, + page_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ( + _, + _, + _, + _, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + c128_page_indices, + ) = init_compressed_metadata( + seq_lens, + positions, + raw_out_loc, + page_table=page_table, + page_size=page_size, + compute_page_indices=True, + ) + return c128_out_loc, c128_positions, c128_seq_lens_clamp1, c128_page_indices diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index de8f7983f360..998989caac70 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -12,6 +12,10 @@ import triton import triton.language as tl +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, @@ -197,7 +201,7 @@ def unpad_draft_extend_output_kernel( def _quantize_fp8_qkv(q, k, v, layer): - q = q.to(torch.float8_e4m3fn) + q = q.to(fp8_dtype) k_scale = getattr(layer, "k_scale_float", None) if k_scale is None: @@ -209,7 +213,7 @@ def _quantize_fp8_qkv(q, k, v, layer): ) k = k_2d.reshape(k.shape) else: - k = k.to(torch.float8_e4m3fn) + k = k.to(fp8_dtype) v_scale = getattr(layer, "v_scale_float", None) if v_scale is None: @@ -221,7 +225,7 @@ def _quantize_fp8_qkv(q, k, v, layer): ) v = v_2d.reshape(v.shape) else: - v = v.to(torch.float8_e4m3fn) + v = v.to(fp8_dtype) return q, k, v, k_scale, v_scale @@ -702,7 +706,7 @@ def quantize_and_rope_for_fp8( - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn """ - attn_dtype = torch.float8_e4m3fn + attn_dtype = fp8_dtype q_len, num_heads = q_rope.shape[0], q_rope.shape[1] # Allocate output tensors with FP8 dtype @@ -840,7 +844,7 @@ def forward_decode( ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" merge_query = q_rope is not None - if self.data_type == torch.float8_e4m3fn: + if self.data_type == fp8_dtype: # For FP8 path, we quantize the query and rope parts and merge them into a single tensor # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend assert all( @@ -965,7 +969,7 @@ def forward_extend( # TODO refactor to avoid code duplication merge_query = q_rope is not None if ( - self.data_type == torch.float8_e4m3fn + self.data_type == fp8_dtype ) and forward_batch.forward_mode.is_target_verify(): # For FP8 path, we quantize the query and rope parts and merge them into a single tensor # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend @@ -1130,7 +1134,7 @@ def forward_extend( v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) q_scale = k_scale = v_scale = 1.0 - if self.data_type == torch.float8_e4m3fn: + if self.data_type == fp8_dtype: q, k, v, k_scale, v_scale = _quantize_fp8_qkv(q, k, v, layer) common_trtllm_args = { diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py index 5e25e56a239c..8298e21f10be 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/compile_utils.py @@ -22,7 +22,7 @@ import deep_gemm -_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1)) +_BUILTIN_M_LIST: List[int] = [] _ENABLE_JIT_DEEPGEMM_PRECOMPILE = envs.SGLANG_JIT_DEEPGEMM_PRECOMPILE.get() _DO_COMPILE_ALL = True _IS_FIRST_RANK_ON_NODE = envs.SGLANG_IS_FIRST_RANK_ON_NODE.get() @@ -44,14 +44,45 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs): global _DO_COMPILE_ALL global _IS_FIRST_RANK_ON_NODE - # Generate m_max - m_max = 1024 * 16 - if server_args.chunked_prefill_size < 1: - m_max = 1024 * 64 - elif server_args.chunked_prefill_size > 8192: - m_max = server_args.chunked_prefill_size * 2 - m_max = min(1024 * 128, m_max) - _BUILTIN_M_LIST = list(range(1, m_max + 1)) + _BUILTIN_M_LIST = [] + + if envs.SGLANG_JIT_DEEPGEMM_FAST_WARMUP.get(): + # In fast warmup mode, only compile a small set of typical Ms + + # First cover all the small bs to ensure decode performance + _BUILTIN_M_LIST += list(range(1, 1025)) + + # Then cover larger batch sizes with gradually increasing steps + # For example, when chunked prefill size is 16384 + # The sampled Ms would be: + # 1024, 1026, ... 2046 (step 2) + # 2048, 2052, ... 4092 (step 4) + # 4096, 5004, ... 8184 (step 8) + # 8192, 9008, ... 16384 (step 16) + # Totally 1024 + 1024/2 + 2048/4 + 4096/8 + 8192/16 = 3072 kernels + next_m, sample_step = 1024, 2 + max_prefill_bs = ( + min(server_args.chunked_prefill_size, 32 * 1024) + if server_args.chunked_prefill_size >= 1 + else 16 * 1024 + ) + while next_m < max_prefill_bs: + _BUILTIN_M_LIST += list( + range(next_m, min(2 * next_m, max_prefill_bs), sample_step) + ) + next_m = next_m * 2 + sample_step = sample_step * 2 + _BUILTIN_M_LIST.append(max_prefill_bs) + _BUILTIN_M_LIST = sorted(set(_BUILTIN_M_LIST)) + else: + # When fast warmup isn't enabled, generate m_max and compile all the covered Ms. + m_max = 1024 * 16 + if server_args.chunked_prefill_size < 1: + m_max = 1024 * 64 + elif server_args.chunked_prefill_size > 8192: + m_max = server_args.chunked_prefill_size * 2 + m_max = min(1024 * 128, m_max) + _BUILTIN_M_LIST += list(range(1, m_max + 1)) _IS_FIRST_RANK_ON_NODE = server_args.base_gpu_id == gpu_id @@ -163,12 +194,17 @@ def _compile_deep_gemm_one_type_all( kernel_type, max_m=max_m, n=n, k=k, num_groups=num_groups ) - old_compile_mode = deep_gemm.get_compile_mode() - deep_gemm.set_compile_mode(1) + has_compile_mode_api = hasattr(deep_gemm, "get_compile_mode") and hasattr( + deep_gemm, "set_compile_mode" + ) + if has_compile_mode_api: + old_compile_mode = deep_gemm.get_compile_mode() + deep_gemm.set_compile_mode(1) # TODO can use multi thread for m in tqdm(m_list, desc=f"DeepGEMM warmup"): executor.execute(m=m) - deep_gemm.set_compile_mode(old_compile_mode) + if has_compile_mode_api: + deep_gemm.set_compile_mode(old_compile_mode) # clean up input buffers torch.cuda.current_stream().synchronize() @@ -263,7 +299,7 @@ def execute(self, m): (self.lhs_q[:m], self.lhs_s[:m]), (self.rhs_q, self.rhs_s), self.out[:m], - m_indices=self.m_indices[:m], + self.m_indices[:m], ) diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py index 88d0a959b156..5d5544be45b0 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py @@ -4,6 +4,7 @@ import torch +from sglang.srt.environ import envs from sglang.srt.layers.deep_gemm_wrapper import compile_utils from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401 DEEPGEMM_BLACKWELL, @@ -39,6 +40,13 @@ def grouped_gemm_nt_f8f8bf16_masked( _sanity_check_input(lhs) _sanity_check_input(rhs) + if envs.SGLANG_HACK_SKIP_FP4_FP8_GEMM.get(): + out.zero_() + return + + lhs = _ensure_cuda(lhs) + rhs = _ensure_cuda(rhs) + with compile_utils.deep_gemm_execution_hook( expected_m, n, k, num_groups, kernel_type ): @@ -46,12 +54,20 @@ def grouped_gemm_nt_f8f8bf16_masked( overlap_args.num_sms if overlap_args is not None else None ): + fp4_kwargs = ( + dict(recipe_a=(1, 128), recipe_b=(1, 32)) + if envs.SGLANG_DSV4_MODE.get() == "2604" + and envs.SGLANG_DSV4_FP4_EXPERTS.get() + else {} + ) + return deep_gemm.fp8_m_grouped_gemm_nt_masked( lhs, rhs, out, masked_m, expected_m, + **fp4_kwargs, **( dict( enable_overlap=True, @@ -64,6 +80,15 @@ def grouped_gemm_nt_f8f8bf16_masked( ) +def _ensure_cuda( + pair: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[torch.Tensor, torch.Tensor]: + return ( + pair[0].cuda() if not pair[0].is_cuda else pair[0], + pair[1].cuda() if not pair[1].is_cuda else pair[1], + ) + + def grouped_gemm_nt_f8f8bf16_contig( lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], @@ -74,11 +99,25 @@ def grouped_gemm_nt_f8f8bf16_contig( num_groups, n, _ = rhs[0].shape kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG + if m == 0: + return + _sanity_check_input(lhs) _sanity_check_input(rhs) + if envs.SGLANG_HACK_SKIP_FP4_FP8_GEMM.get(): + out.zero_() + return + fp4_kwargs = ( + dict(recipe_a=(1, 128), recipe_b=(1, 32)) + if envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get() + else {} + ) + with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): - deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + lhs, rhs, out, m_indices, **fp4_kwargs + ) def gemm_nt_f8f8bf16( diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/paged_mqa_logits.py b/python/sglang/srt/layers/deep_gemm_wrapper/paged_mqa_logits.py new file mode 100644 index 000000000000..d392f73078a6 --- /dev/null +++ b/python/sglang/srt/layers/deep_gemm_wrapper/paged_mqa_logits.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass +from typing import List, Union + +import deep_gemm +import torch + +from sglang.srt.environ import envs + + +@dataclass +class _PagedMqaLogitsMetadataChunk: + start: int + end: int + schedule_meta: torch.Tensor + + +@dataclass +class _PagedMqaLogitsMetadata: + chunks: List[_PagedMqaLogitsMetadataChunk] + + def copy_(self, other: "_PagedMqaLogitsMetadata"): + raise Exception("Not expect to be copied") + + +def get_paged_mqa_logits_metadata_chunked( + context_lens: torch.Tensor, + block_kv: int, + num_sms: int, +) -> Union[_PagedMqaLogitsMetadata, torch.Tensor]: + chunk_size = envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() + batch_size = context_lens.shape[0] + + if batch_size <= chunk_size: + return deep_gemm.get_paged_mqa_logits_metadata( + context_lens.unsqueeze(-1) if context_lens.dim() == 1 else context_lens, + block_kv, + num_sms, + ) + + chunks: List[_PagedMqaLogitsMetadataChunk] = [] + for start in range(0, batch_size, chunk_size): + end = min(start + chunk_size, batch_size) + schedule_meta = deep_gemm.get_paged_mqa_logits_metadata( + ( + (context_lens[start:end]).unsqueeze(-1) + if context_lens.dim() == 1 + else context_lens[start:end] + ), + block_kv, + num_sms, + ) + chunks.append( + _PagedMqaLogitsMetadataChunk( + start=start, end=end, schedule_meta=schedule_meta + ) + ) + + return _PagedMqaLogitsMetadata(chunks=chunks) + + +def fp8_paged_mqa_logits_chunked( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_table: torch.Tensor, + schedule_meta: Union[_PagedMqaLogitsMetadata, torch.Tensor], + max_context_len: int, + clean_logits: bool, +) -> torch.Tensor: + if not isinstance(schedule_meta, _PagedMqaLogitsMetadata): + return deep_gemm.fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits, + ) + + all_logits = [] + for chunk_meta in schedule_meta.chunks: + chunk_logits = deep_gemm.fp8_paged_mqa_logits( + q[chunk_meta.start : chunk_meta.end], + kv_cache, + weights[chunk_meta.start : chunk_meta.end], + context_lens[chunk_meta.start : chunk_meta.end], + block_table[chunk_meta.start : chunk_meta.end], + chunk_meta.schedule_meta, + max_context_len, + clean_logits, + ) + all_logits.append(chunk_logits) + + return torch.cat(all_logits, dim=0) diff --git a/python/sglang/srt/layers/deepseek_v4_rope.py b/python/sglang/srt/layers/deepseek_v4_rope.py new file mode 100644 index 000000000000..57561b62f76a --- /dev/null +++ b/python/sglang/srt/layers/deepseek_v4_rope.py @@ -0,0 +1,196 @@ +import math +from functools import lru_cache +from typing import Optional + +import tilelang +import torch +import triton +import triton.language as tl + +from sglang.srt.utils.common import maybe_torch_compile + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + + +@lru_cache(2) +def precompute_freqs_cis( + dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow +) -> torch.Tensor: + + def find_correction_dim(num_rotations, dim, base, max_seq_len): + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if original_seq_len > 0: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + t = torch.arange(seqlen) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +@maybe_torch_compile +def apply_rotary_emb( + x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False +) -> torch.Tensor: + y = x + x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2))) + if inverse: + freqs_cis = freqs_cis.conj() + if x.ndim == 3: + freqs_cis = freqs_cis.unsqueeze(1) + x = torch.view_as_real(x * freqs_cis).flatten(-2) + y.copy_(x) + return y + + +@triton.jit +def apply_rotary_emb_triton_kernel( + x_ptr, + freqs_ptr, + positions_ptr, + rope_dim, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_freq_pos, + stride_freq_dim, + USE_POS: tl.constexpr, + IS_INVERSE: tl.constexpr, + IS_3D: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_dim = tl.program_id(2) + + if USE_POS: + position = tl.load(positions_ptr + pid_batch) + else: + position = pid_batch + + if IS_3D: + base_offset = pid_batch * stride_x_batch + pid_head * stride_x_head + else: + base_offset = pid_batch * stride_x_batch + + offs_pair = pid_dim * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs_pair < (rope_dim // 2) + + offs_x_real = base_offset + offs_pair * 2 * stride_x_dim + offs_x_imag = base_offset + (offs_pair * 2 + 1) * stride_x_dim + + x_real = tl.load(x_ptr + offs_x_real, mask=mask, other=0.0).to(tl.float32) + x_imag = tl.load(x_ptr + offs_x_imag, mask=mask, other=0.0).to(tl.float32) + + offs_freq_real = position * stride_freq_pos + offs_pair * 2 * stride_freq_dim + offs_freq_imag = position * stride_freq_pos + (offs_pair * 2 + 1) * stride_freq_dim + + freq_real = tl.load(freqs_ptr + offs_freq_real, mask=mask, other=0.0) + freq_imag = tl.load(freqs_ptr + offs_freq_imag, mask=mask, other=0.0) + + if IS_INVERSE: + out_real = x_real * freq_real + x_imag * freq_imag + out_imag = x_imag * freq_real - x_real * freq_imag + else: + out_real = x_real * freq_real - x_imag * freq_imag + out_imag = x_real * freq_imag + x_imag * freq_real + + tl.store(x_ptr + offs_x_real, out_real, mask=mask) + tl.store(x_ptr + offs_x_imag, out_imag, mask=mask) + + +def apply_rotary_emb_triton( + x: torch.Tensor, + freqs_cis: torch.Tensor, + positions: Optional[torch.Tensor] = None, + inverse: bool = False, +) -> torch.Tensor: + is_3d = x.ndim == 3 + + if is_3d: + batch_size, n_heads, rope_dim = x.shape + else: + batch_size, rope_dim = x.shape + n_heads = 1 + + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + + BLOCK_SIZE = 128 + + num_blocks_dim = triton.cdiv(rope_dim // 2, BLOCK_SIZE) + grid = (batch_size, n_heads if is_3d else 1, num_blocks_dim) + + if positions is not None: + assert positions.shape == ( + batch_size, + ), f"positions shape {positions.shape} != ({batch_size},)" + + apply_rotary_emb_triton_kernel[grid]( + x, + freqs_real, + positions, + rope_dim, + x.stride(0), + x.stride(1) if is_3d else 0, + x.stride(-1), + freqs_real.stride(0), + freqs_real.stride(1), + USE_POS=True, + IS_INVERSE=inverse, + IS_3D=is_3d, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + assert ( + freqs_real.shape[0] == batch_size + ), f"freqs_cis batch size {freqs_real.shape[0]} != x batch size {batch_size}" + + apply_rotary_emb_triton_kernel[grid]( + x, + freqs_real, + None, + rope_dim, + x.stride(0), + x.stride(1) if is_3d else 0, + x.stride(-1), + freqs_real.stride(0), + freqs_real.stride(1), + USE_POS=False, + IS_INVERSE=inverse, + IS_3D=is_3d, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return x diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 7fde05894b59..6dd72c046200 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -24,6 +24,7 @@ is_batch_invariant_mode_enabled, rms_norm_batch_invariant, ) +from sglang.srt.environ import envs from sglang.srt.layers.utils import MultiPlatformOp from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import ( @@ -51,6 +52,7 @@ if _is_flashinfer_available: try: from flashinfer.norm import layernorm + from flashinfer.norm import rmsnorm as fi_rmsnorm _flashinfer_layernorm_available = True except (ImportError, AttributeError): @@ -108,6 +110,9 @@ def forward_cuda( ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if x.numel() == 0: return x + if envs.SGLANG_OPT_USE_FLASHINFER_NORM.get(): + return fi_rmsnorm(x, self.weight, self.variance_epsilon) + if self.variance_size_override is not None: return self.forward_native(x, residual, post_residual_addition) if is_batch_invariant_mode_enabled(): @@ -154,11 +159,15 @@ def forward_aiter( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if _is_hip: + if x.shape[0] == 0: + if residual is not None: + return x, residual + return x + if residual is not None: residual_out = torch.empty_like(x) output = torch.empty_like(x) - if post_residual_addition is not None: - residual = residual + post_residual_addition fused_add_rms_norm( output, x, diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 39919eedd72c..1e48d7f0b799 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -252,7 +252,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.dtype == loaded_weight.dtype ), "init para dtype and loaded weight dtype should be the same" - assert param.size() == loaded_weight.size() + assert ( + param.size() == loaded_weight.size() + ), f"{param.shape=} {param.dtype=} {loaded_weight.shape=} {loaded_weight.dtype=}" param.data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -408,7 +410,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"param_data.shape={param_data.shape} != loaded_weight.shape={loaded_weight.shape}" param_data.copy_(loaded_weight) def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): diff --git a/python/sglang/srt/layers/linear_bf16_fp32/__init__.py b/python/sglang/srt/layers/linear_bf16_fp32/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/srt/layers/linear_bf16_fp32/configs/device_name=NVIDIA_B200.json b/python/sglang/srt/layers/linear_bf16_fp32/configs/device_name=NVIDIA_B200.json new file mode 100644 index 000000000000..ebca61c39cc4 --- /dev/null +++ b/python/sglang/srt/layers/linear_bf16_fp32/configs/device_name=NVIDIA_B200.json @@ -0,0 +1,302 @@ +{ + "entries": { + "1,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.8, + "deep_gemm_us": 8.632 + }, + "1,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 9.533, + "deep_gemm_us": 8.952 + }, + "1,256,4096": { + "chosen": "cublas", + "cublas_us": 5.042, + "deep_gemm_us": 9.132 + }, + "1,512,4096": { + "chosen": "cublas", + "cublas_us": 6.443, + "deep_gemm_us": 9.099 + }, + "1024,1024,4096": { + "chosen": "cublas", + "cublas_us": 10.001, + "deep_gemm_us": 11.539 + }, + "1024,2048,4096": { + "chosen": "cublas", + "cublas_us": 15.009, + "deep_gemm_us": 16.094 + }, + "1024,256,4096": { + "chosen": "cublas", + "cublas_us": 7.329, + "deep_gemm_us": 10.018 + }, + "1024,512,4096": { + "chosen": "cublas", + "cublas_us": 7.643, + "deep_gemm_us": 10.483 + }, + "128,1024,4096": { + "chosen": "cublas", + "cublas_us": 7.204, + "deep_gemm_us": 9.314 + }, + "128,2048,4096": { + "chosen": "cublas", + "cublas_us": 7.279, + "deep_gemm_us": 10.081 + }, + "128,256,4096": { + "chosen": "cublas", + "cublas_us": 6.84, + "deep_gemm_us": 9.281 + }, + "128,512,4096": { + "chosen": "cublas", + "cublas_us": 7.095, + "deep_gemm_us": 9.305 + }, + "16,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.962, + "deep_gemm_us": 9.176 + }, + "16,2048,4096": { + "chosen": "cublas", + "cublas_us": 9.168, + "deep_gemm_us": 9.365 + }, + "16,256,4096": { + "chosen": "cublas", + "cublas_us": 5.391, + "deep_gemm_us": 9.18 + }, + "16,512,4096": { + "chosen": "cublas", + "cublas_us": 6.715, + "deep_gemm_us": 9.202 + }, + "2,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.708, + "deep_gemm_us": 8.83 + }, + "2,2048,4096": { + "chosen": "cublas", + "cublas_us": 8.971, + "deep_gemm_us": 9.026 + }, + "2,256,4096": { + "chosen": "cublas", + "cublas_us": 5.294, + "deep_gemm_us": 9.122 + }, + "2,512,4096": { + "chosen": "cublas", + "cublas_us": 6.608, + "deep_gemm_us": 9.095 + }, + "2048,1024,4096": { + "chosen": "cublas", + "cublas_us": 15.047, + "deep_gemm_us": 16.022 + }, + "2048,2048,4096": { + "chosen": "cublas", + "cublas_us": 26.603, + "deep_gemm_us": 27.409 + }, + "2048,256,4096": { + "chosen": "cublas", + "cublas_us": 8.028, + "deep_gemm_us": 10.657 + }, + "2048,512,4096": { + "chosen": "cublas", + "cublas_us": 10.036, + "deep_gemm_us": 11.74 + }, + "256,1024,4096": { + "chosen": "cublas", + "cublas_us": 7.284, + "deep_gemm_us": 9.714 + }, + "256,2048,4096": { + "chosen": "cublas", + "cublas_us": 7.581, + "deep_gemm_us": 9.977 + }, + "256,256,4096": { + "chosen": "cublas", + "cublas_us": 7.098, + "deep_gemm_us": 9.408 + }, + "256,512,4096": { + "chosen": "cublas", + "cublas_us": 7.186, + "deep_gemm_us": 9.518 + }, + "32,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.878, + "deep_gemm_us": 9.191 + }, + "32,2048,4096": { + "chosen": "cublas", + "cublas_us": 8.322, + "deep_gemm_us": 9.384 + }, + "32,256,4096": { + "chosen": "cublas", + "cublas_us": 5.396, + "deep_gemm_us": 9.19 + }, + "32,512,4096": { + "chosen": "cublas", + "cublas_us": 6.923, + "deep_gemm_us": 9.27 + }, + "4,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.729, + "deep_gemm_us": 8.835 + }, + "4,2048,4096": { + "chosen": "cublas", + "cublas_us": 8.97, + "deep_gemm_us": 9.026 + }, + "4,256,4096": { + "chosen": "cublas", + "cublas_us": 5.315, + "deep_gemm_us": 9.119 + }, + "4,512,4096": { + "chosen": "cublas", + "cublas_us": 6.604, + "deep_gemm_us": 9.117 + }, + "4096,1024,4096": { + "chosen": "cublas", + "cublas_us": 27.144, + "deep_gemm_us": 27.626 + }, + "4096,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 54.767, + "deep_gemm_us": 53.59 + }, + "4096,256,4096": { + "chosen": "cublas", + "cublas_us": 10.074, + "deep_gemm_us": 11.698 + }, + "4096,512,4096": { + "chosen": "cublas", + "cublas_us": 14.95, + "deep_gemm_us": 15.43 + }, + "512,1024,4096": { + "chosen": "cublas", + "cublas_us": 7.625, + "deep_gemm_us": 9.939 + }, + "512,2048,4096": { + "chosen": "cublas", + "cublas_us": 9.947, + "deep_gemm_us": 10.97 + }, + "512,256,4096": { + "chosen": "cublas", + "cublas_us": 7.168, + "deep_gemm_us": 9.518 + }, + "512,512,4096": { + "chosen": "cublas", + "cublas_us": 7.305, + "deep_gemm_us": 9.743 + }, + "64,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.943, + "deep_gemm_us": 9.197 + }, + "64,2048,4096": { + "chosen": "cublas", + "cublas_us": 7.178, + "deep_gemm_us": 9.507 + }, + "64,256,4096": { + "chosen": "cublas", + "cublas_us": 6.282, + "deep_gemm_us": 9.272 + }, + "64,512,4096": { + "chosen": "cublas", + "cublas_us": 6.82, + "deep_gemm_us": 9.272 + }, + "8,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.759, + "deep_gemm_us": 8.872 + }, + "8,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 9.196, + "deep_gemm_us": 9.073 + }, + "8,256,4096": { + "chosen": "cublas", + "cublas_us": 5.278, + "deep_gemm_us": 9.134 + }, + "8,512,4096": { + "chosen": "cublas", + "cublas_us": 6.631, + "deep_gemm_us": 9.145 + } + }, + "metadata": { + "device_name": "NVIDIA B200", + "m_buckets": [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096 + ], + "nk_pairs": [ + [ + 256, + 4096 + ], + [ + 512, + 4096 + ], + [ + 1024, + 4096 + ], + [ + 2048, + 4096 + ] + ], + "rep_ms": 50, + "tuned_at": "2026-04-21T10:56:08" + } +} diff --git a/python/sglang/srt/layers/linear_bf16_fp32/configs/device_name=NVIDIA_H200.json b/python/sglang/srt/layers/linear_bf16_fp32/configs/device_name=NVIDIA_H200.json new file mode 100644 index 000000000000..0a21ee11896f --- /dev/null +++ b/python/sglang/srt/layers/linear_bf16_fp32/configs/device_name=NVIDIA_H200.json @@ -0,0 +1,302 @@ +{ + "entries": { + "1,1024,4096": { + "chosen": "cublas", + "cublas_us": 5.939, + "deep_gemm_us": 8.423 + }, + "1,2048,4096": { + "chosen": "cublas", + "cublas_us": 7.433, + "deep_gemm_us": 9.208 + }, + "1,256,4096": { + "chosen": "cublas", + "cublas_us": 5.822, + "deep_gemm_us": 6.48 + }, + "1,512,4096": { + "chosen": "cublas", + "cublas_us": 5.879, + "deep_gemm_us": 6.616 + }, + "1024,1024,4096": { + "chosen": "cublas", + "cublas_us": 15.438, + "deep_gemm_us": 18.147 + }, + "1024,2048,4096": { + "chosen": "cublas", + "cublas_us": 26.645, + "deep_gemm_us": 27.86 + }, + "1024,256,4096": { + "chosen": "cublas", + "cublas_us": 8.727, + "deep_gemm_us": 11.936 + }, + "1024,512,4096": { + "chosen": "cublas", + "cublas_us": 10.003, + "deep_gemm_us": 10.036 + }, + "128,1024,4096": { + "chosen": "deep_gemm", + "cublas_us": 7.161, + "deep_gemm_us": 6.895 + }, + "128,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 8.615, + "deep_gemm_us": 8.306 + }, + "128,256,4096": { + "chosen": "cublas", + "cublas_us": 6.152, + "deep_gemm_us": 6.649 + }, + "128,512,4096": { + "chosen": "cublas", + "cublas_us": 6.375, + "deep_gemm_us": 6.774 + }, + "16,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.106, + "deep_gemm_us": 6.528 + }, + "16,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 6.835, + "deep_gemm_us": 6.624 + }, + "16,256,4096": { + "chosen": "cublas", + "cublas_us": 5.972, + "deep_gemm_us": 6.41 + }, + "16,512,4096": { + "chosen": "cublas", + "cublas_us": 5.973, + "deep_gemm_us": 6.538 + }, + "2,1024,4096": { + "chosen": "cublas", + "cublas_us": 5.991, + "deep_gemm_us": 7.11 + }, + "2,2048,4096": { + "chosen": "cublas", + "cublas_us": 7.431, + "deep_gemm_us": 9.01 + }, + "2,256,4096": { + "chosen": "cublas", + "cublas_us": 5.835, + "deep_gemm_us": 6.476 + }, + "2,512,4096": { + "chosen": "cublas", + "cublas_us": 5.82, + "deep_gemm_us": 6.637 + }, + "2048,1024,4096": { + "chosen": "cublas", + "cublas_us": 26.918, + "deep_gemm_us": 27.936 + }, + "2048,2048,4096": { + "chosen": "cublas", + "cublas_us": 53.439, + "deep_gemm_us": 56.603 + }, + "2048,256,4096": { + "chosen": "cublas", + "cublas_us": 10.131, + "deep_gemm_us": 12.903 + }, + "2048,512,4096": { + "chosen": "cublas", + "cublas_us": 16.607, + "deep_gemm_us": 18.883 + }, + "256,1024,4096": { + "chosen": "deep_gemm", + "cublas_us": 8.917, + "deep_gemm_us": 7.648 + }, + "256,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 10.955, + "deep_gemm_us": 9.91 + }, + "256,256,4096": { + "chosen": "cublas", + "cublas_us": 6.375, + "deep_gemm_us": 6.76 + }, + "256,512,4096": { + "chosen": "deep_gemm", + "cublas_us": 7.076, + "deep_gemm_us": 6.774 + }, + "32,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.275, + "deep_gemm_us": 6.52 + }, + "32,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 7.017, + "deep_gemm_us": 6.638 + }, + "32,256,4096": { + "chosen": "cublas", + "cublas_us": 5.96, + "deep_gemm_us": 6.409 + }, + "32,512,4096": { + "chosen": "cublas", + "cublas_us": 6.148, + "deep_gemm_us": 6.531 + }, + "4,1024,4096": { + "chosen": "cublas", + "cublas_us": 5.942, + "deep_gemm_us": 7.891 + }, + "4,2048,4096": { + "chosen": "cublas", + "cublas_us": 7.115, + "deep_gemm_us": 8.442 + }, + "4,256,4096": { + "chosen": "cublas", + "cublas_us": 5.752, + "deep_gemm_us": 6.466 + }, + "4,512,4096": { + "chosen": "cublas", + "cublas_us": 5.895, + "deep_gemm_us": 6.559 + }, + "4096,1024,4096": { + "chosen": "cublas", + "cublas_us": 53.854, + "deep_gemm_us": 56.633 + }, + "4096,2048,4096": { + "chosen": "cublas", + "cublas_us": 107.806, + "deep_gemm_us": 120.604 + }, + "4096,256,4096": { + "chosen": "cublas", + "cublas_us": 18.814, + "deep_gemm_us": 19.207 + }, + "4096,512,4096": { + "chosen": "cublas", + "cublas_us": 28.492, + "deep_gemm_us": 29.546 + }, + "512,1024,4096": { + "chosen": "deep_gemm", + "cublas_us": 11.019, + "deep_gemm_us": 10.094 + }, + "512,2048,4096": { + "chosen": "deep_gemm", + "cublas_us": 18.014, + "deep_gemm_us": 17.726 + }, + "512,256,4096": { + "chosen": "cublas", + "cublas_us": 7.15, + "deep_gemm_us": 7.758 + }, + "512,512,4096": { + "chosen": "cublas", + "cublas_us": 8.479, + "deep_gemm_us": 8.946 + }, + "64,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.427, + "deep_gemm_us": 7.138 + }, + "64,2048,4096": { + "chosen": "cublas", + "cublas_us": 7.722, + "deep_gemm_us": 7.977 + }, + "64,256,4096": { + "chosen": "cublas", + "cublas_us": 6.086, + "deep_gemm_us": 6.535 + }, + "64,512,4096": { + "chosen": "cublas", + "cublas_us": 6.228, + "deep_gemm_us": 6.72 + }, + "8,1024,4096": { + "chosen": "cublas", + "cublas_us": 6.043, + "deep_gemm_us": 6.513 + }, + "8,2048,4096": { + "chosen": "cublas", + "cublas_us": 6.636, + "deep_gemm_us": 6.727 + }, + "8,256,4096": { + "chosen": "cublas", + "cublas_us": 5.848, + "deep_gemm_us": 6.432 + }, + "8,512,4096": { + "chosen": "cublas", + "cublas_us": 5.953, + "deep_gemm_us": 6.538 + } + }, + "metadata": { + "device_name": "NVIDIA H200", + "m_buckets": [ + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096 + ], + "nk_pairs": [ + [ + 256, + 4096 + ], + [ + 512, + 4096 + ], + [ + 1024, + 4096 + ], + [ + 2048, + 4096 + ] + ], + "rep_ms": 50, + "tuned_at": "2026-04-21T11:27:10" + } +} diff --git a/python/sglang/srt/layers/linear_bf16_fp32/selector.py b/python/sglang/srt/layers/linear_bf16_fp32/selector.py new file mode 100644 index 000000000000..610825daaa98 --- /dev/null +++ b/python/sglang/srt/layers/linear_bf16_fp32/selector.py @@ -0,0 +1,56 @@ + +from __future__ import annotations + +import json +import logging +from functools import lru_cache +from pathlib import Path +from typing import Literal + +from sglang.srt.utils import get_device_name, next_power_of_2 + +logger = logging.getLogger(__name__) + +Backend = Literal["cublas", "deep_gemm"] + +_FALLBACK: Backend = "cublas" +_CONFIG_DIR = Path(__file__).parent / "configs" + + +def pick_backend(*, m: int, n: int, k: int) -> Backend: + m_bucket = next_power_of_2(m) + device_name = _cached_device_name() + entries = _load_config(device_name) + + key = f"{m_bucket},{n},{k}" + entry = entries.get(key) + if entry is None: + logger.debug( + "linear_bf16_fp32 config miss key=%s (real M=%d) device=%s; falling back to %s", + key, + m, + device_name, + _FALLBACK, + ) + return _FALLBACK + return entry["chosen"] + + +@lru_cache(maxsize=1) +def _cached_device_name() -> str: + return get_device_name(0).replace(" ", "_") + + +@lru_cache(maxsize=None) +def _load_config(device_name: str) -> dict: + path = _CONFIG_DIR / f"device_name={device_name}.json" + if not path.exists(): + logger.warning( + "linear_bf16_fp32 tuned config not found at %s; selector will always fall back to %s", + path, + _FALLBACK, + ) + return {} + with path.open() as f: + payload = json.load(f) + return payload.get("entries", {}) diff --git a/python/sglang/srt/layers/linear_bf16_fp32/tuner.py b/python/sglang/srt/layers/linear_bf16_fp32/tuner.py new file mode 100644 index 000000000000..5344c3a69915 --- /dev/null +++ b/python/sglang/srt/layers/linear_bf16_fp32/tuner.py @@ -0,0 +1,149 @@ + +from __future__ import annotations + +import datetime as _dt +import functools +import json +import logging +from pathlib import Path +from typing import Annotated, Callable, Dict, List, Optional, Tuple + +import torch +import triton.testing +import typer + +from sglang.jit_kernel.deepseek_v4 import _dispatch_bf16_fp32_backend + +logger = logging.getLogger(__name__) + +DEFAULT_M_BUCKETS: Tuple[int, ...] = ( + 1, + 2, + 4, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, +) + +_BACKEND_NAMES: Tuple[str, ...] = ("cublas", "deep_gemm") +_BACKENDS: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = { + name: functools.partial(_dispatch_bf16_fp32_backend, algo=name) + for name in _BACKEND_NAMES +} + + +def main( + nk_pairs: Annotated[ + str, + typer.Option(help="Space-separated N,K pairs, e.g. '4096,7168 256,7168'"), + ], + output: Annotated[ + Path, + typer.Option( + help="Output JSON path, typically configs/device_name=.json" + ), + ], + m_buckets: Annotated[ + str, + typer.Option(help="Comma-separated M values to sweep"), + ] = ",".join(str(m) for m in DEFAULT_M_BUCKETS), + rep_ms: Annotated[ + int, + typer.Option( + help="triton do_bench_cudagraph `rep` in ms (timing budget per backend/shape)" + ), + ] = 50, +) -> None: + assert torch.cuda.is_available(), "CUDA device required" + + device_name = torch.cuda.get_device_name(0) + m_list = [int(s) for s in m_buckets.split(",") if s.strip()] + nk_list: List[Tuple[int, int]] = [] + for pair in nk_pairs.split(): + n_str, k_str = pair.split(",") + nk_list.append((int(n_str), int(k_str))) + + logger.info( + "tuning on device=%s m_buckets=%s nk_pairs=%s rep_ms=%d", + device_name, + m_list, + nk_list, + rep_ms, + ) + + entries: Dict[str, Dict] = {} + for n, k in nk_list: + for m in m_list: + key = f"{m},{n},{k}" + logger.info("tune %s", key) + entry = _tune_one(m=m, n=n, k=k, rep_ms=rep_ms) + logger.info(" -> %s", entry) + entries[key] = entry + + payload = { + "metadata": { + "device_name": device_name, + "tuned_at": _dt.datetime.now().isoformat(timespec="seconds"), + "rep_ms": rep_ms, + "m_buckets": m_list, + "nk_pairs": [[n, k] for n, k in nk_list], + }, + "entries": entries, + } + + output.parent.mkdir(parents=True, exist_ok=True) + with output.open("w") as f: + json.dump(payload, f, indent=2, sort_keys=True) + logger.info("wrote %s with %d entries", output, len(entries)) + + +def _tune_one(*, m: int, n: int, k: int, rep_ms: int) -> Dict: + x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") + y = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + + timings: Dict[str, float] = {} + for name, fn in _BACKENDS.items(): + try: + timings[name] = _bench(fn=fn, x=x, y=y, rep_ms=rep_ms) + except Exception: + logger.warning( + "backend=%s shape=(M=%d,N=%d,K=%d) raised; marking as +inf", + name, + m, + n, + k, + exc_info=True, + ) + timings[name] = float("inf") + + chosen: Optional[str] = min(timings, key=timings.__getitem__) + if timings[chosen] == float("inf"): + chosen = None + + entry: Dict = {"chosen": chosen} + for name, t in timings.items(): + entry[f"{name}_us"] = None if t == float("inf") else round(t, 3) + return entry + + +def _bench(*, fn: Callable, x: torch.Tensor, y: torch.Tensor, rep_ms: int) -> float: + for _ in range(3): + fn(x, y) + torch.cuda.synchronize() + + ms = triton.testing.do_bench_cudagraph(lambda: fn(x, y), rep=rep_ms) + return ms * 1000.0 + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s" + ) + typer.run(main) diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py new file mode 100644 index 000000000000..1c27636efb5c --- /dev/null +++ b/python/sglang/srt/layers/mhc.py @@ -0,0 +1,643 @@ +import functools +import math +from typing import Tuple + +import tilelang +import tilelang.language as T +import torch + +from sglang.jit_kernel.utils import is_arch_support_pdl +from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_round_robin_split +from sglang.srt.layers.utils.common import strict_contiguous + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + + +@tilelang.jit(pass_configs=pass_configs) +def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float): + n = T.symbolic("n") + mix_hc = (2 + hc) * hc + threads = 64 + + ENABLE_PDL = is_arch_support_pdl() + + @T.prim_func + def hc_split_sinkhorn_kernel_( + mixes: T.Tensor[(n, mix_hc), FP32], + hc_scale: T.Tensor[(3,), T.float32], + hc_base: T.Tensor[(mix_hc,), T.float32], + pre: T.Tensor[(n, hc), FP32], + post: T.Tensor[(n, hc), FP32], + comb: T.Tensor[(n, hc, hc), FP32], + ): + with T.Kernel(n, threads=threads) as i: + if ENABLE_PDL: + T.pdl_sync() + + mixes_shared = T.alloc_shared(mix_hc, FP32) + comb_frag = T.alloc_fragment((hc, hc), FP32) + T.copy(mixes[i, :], mixes_shared) + + for j in T.Parallel(hc): + pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps + for j in T.Parallel(hc): + post[i, j] = 2 * T.sigmoid( + mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc] + ) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = ( + mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + + hc_base[j * hc + k + hc * 2] + ) + + row_sum = T.alloc_fragment(hc, FP32) + col_sum = T.alloc_fragment(hc, FP32) + + row_max = T.alloc_fragment(hc, FP32) + T.reduce_max(comb_frag, row_max, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j]) + T.reduce_sum(comb_frag, row_sum, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps + + T.reduce_sum(comb_frag, col_sum, dim=0) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) + + for _ in T.serial(sinkhorn_iters - 1): + T.reduce_sum(comb_frag, row_sum, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps) + T.reduce_sum(comb_frag, col_sum, dim=0) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) + + T.copy(comb_frag, comb[i, :, :]) + if ENABLE_PDL: + T.pdl_trigger() + + return hc_split_sinkhorn_kernel_ + + +def hc_split_sinkhorn( + mixes: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + hc_mult: int = 4, + sinkhorn_iters: int = 20, + eps: float = 1e-6, +): + b, s, _ = mixes.size() + pre = mixes.new_empty(b, s, hc_mult) + post = mixes.new_empty(b, s, hc_mult) + comb = mixes.new_empty(b, s, hc_mult, hc_mult) + kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps) + kernel( + mixes.view(-1, (2 + hc_mult) * hc_mult), + hc_scale, + hc_base, + pre.view(-1, hc_mult), + post.view(-1, hc_mult), + comb.view(-1, hc_mult, hc_mult), + ) + return pre, post, comb + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 16, + hc_mult: int = 4, +): + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + hidden_block = math.gcd(512, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] + hc_scale: T.Tensor[[3], T.float32] + hc_base: T.Tensor[[hc_mult3], T.float32] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(num_tokens, threads=96) as i: + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + + if ENABLE_PDL: + T.pdl_sync() + + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = ( + T.sigmoid( + mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult] + ) + * hc_post_mult_value + ) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = ( + mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + + hc_base[j * hc_mult + k + hc_mult * 2] + ) + + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + xs = T.alloc_shared((hc_mult, hidden_block), T.float32) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + + if ENABLE_PDL: + T.pdl_trigger() + + +@tilelang.jit +def mhc_pre_gemm_sqrsum_tilelang( + x, + fn, + out, + sqrsum, + hc_mult3: int, + hc_hidden_size: int, + token_block: int = 32, + hidden_block: int = 256, +) -> tilelang.JITKernel: + assert hc_mult3 <= 32 + num_tokens = T.dynamic("num_tokens") + assert hc_hidden_size % hidden_block == 0 + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) + fn: T.Tensor((hc_mult3, hc_hidden_size), T.float32) + out: T.Tensor((num_tokens, hc_mult3), T.float32) + sqrsum: T.Tensor((num_tokens), T.float32) + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(T.ceildiv(num_tokens, token_block)) as px: + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sqrsum_part = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sqrsum_part) + if ENABLE_PDL: + T.pdl_sync() + for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2): + x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout( + {x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)} + ) + + T.copy(x[px * token_block, pz * hidden_block], x_smem_16) + T.copy(fn[0, pz * hidden_block], fn_smem) + + x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem_16, x_frag_16) + x_frag = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_frag_16, x_frag) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j] + + T.gemm( + x_frag, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + sqrsum_l = T.alloc_fragment(token_block, T.float32) + T.reduce_sum(sqrsum_part, sqrsum_l) + for i in T.Parallel(token_block): + sqrsum[px * token_block + i] = sqrsum_l[i] + for i, j in T.Parallel(token_block, 32): + if j < hc_mult3: + out[px * token_block + i, j] = out_frag[i, j] + if ENABLE_PDL: + T.pdl_trigger() + + +@functools.cache +def mhc_pre_gemm_sqrsum_splitk_kernel( + hc_mult3: int, + hc_hidden_size: int, + split_k: int, + token_block: int = 32, + hidden_block: int = 256, + threads: int = 128, +) -> Tuple[tilelang.JITKernel, tilelang.JITKernel]: + assert hc_mult3 <= 32 + assert hc_hidden_size % hidden_block == 0 + assert hc_hidden_size % split_k == 0 + split_size = hc_hidden_size // split_k + assert split_size % hidden_block == 0 + + num_tokens = T.dynamic("num_tokens") + + ENABLE_PDL = is_arch_support_pdl() + + @tilelang.jit + def mhc_pre_gemm_sqrsum_splitk_stage_0( + x: T.Tensor[(num_tokens, hc_hidden_size), T.bfloat16], + fn: T.Tensor[(hc_mult3, hc_hidden_size), T.float32], + out_partial: T.Tensor[(split_k, num_tokens, 32), T.float32], + sqrsum_partial: T.Tensor[(split_k, num_tokens), T.float32], + ): + with T.Kernel(T.ceildiv(num_tokens, token_block), split_k, threads=threads) as ( + px, + bz, + ): + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sq_part4 = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sq_part4) + + k_base = bz * split_size + + if ENABLE_PDL: + T.pdl_sync() + + for pz in T.Pipelined(split_size // hidden_block, num_stages=2): + x_smem = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout( + {x_smem: tilelang.layout.make_swizzled_layout(x_smem)} + ) + + T.copy(x[px * token_block, k_base + pz * hidden_block], x_smem) + T.copy(fn[0, k_base + pz * hidden_block], fn_smem) + + x_f16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem, x_f16) + x_f = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_f16, x_f) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + v = x_f[i, jj * 4 + j] + sq_part4[i, j] += v * v + + T.gemm( + x_f, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + + sq_l = T.alloc_fragment((token_block,), T.float32) + T.reduce_sum(sq_part4, sq_l) + + for i in T.Parallel(token_block): + t = px * token_block + i + if t < num_tokens: + sqrsum_partial[bz, t] = sq_l[i] + + for i, j in T.Parallel(token_block, 32): + t = px * token_block + i + if t < num_tokens: + out_partial[bz, t, j] = out_frag[i, j] + + if ENABLE_PDL: + T.pdl_trigger() + + @tilelang.jit + def mhc_pre_gemm_sqrsum_splitk_stage_1( + out_partial: T.Tensor[(split_k, num_tokens, 32), T.float32], + sqrsum_partial: T.Tensor[(split_k, num_tokens), T.float32], + out: T.Tensor[(num_tokens, hc_mult3), T.float32], + sqrsum: T.Tensor[(num_tokens,), T.float32], + ): + warps_per_cta = threads // 32 + num_reduce = T.ceildiv(split_k, 32) + with T.Kernel(T.ceildiv(num_tokens, warps_per_cta), threads=threads) as (px,): + tx = T.get_thread_binding() + warp = tx // 32 + lane = tx % 32 + t = px * warps_per_cta + warp + s = T.alloc_local((1,), T.float32) + acc = T.alloc_local((1,), T.float32) + s[0] = 0 + acc[0] = 0 + if ENABLE_PDL: + T.pdl_sync() + + if t < num_tokens: + for r in T.serial(num_reduce): + bz = r * 32 + lane + s[0] += T.if_then_else(bz < split_k, sqrsum_partial[bz, t], 0.0) + sqrsum[t] = T.warp_reduce_sum(s[0]) + if lane < hc_mult3: + for bz in T.serial(split_k): + acc[0] += out_partial[bz, t, lane] + out[t, lane] = acc[0] + + if ENABLE_PDL: + T.pdl_trigger() + + return ( + mhc_pre_gemm_sqrsum_splitk_stage_0, + mhc_pre_gemm_sqrsum_splitk_stage_1, + ) + + +def mhc_pre( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + n_splits_pre: int = 32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + assert residual.dtype == torch.bfloat16 + assert fn.dtype == torch.float32 + assert hc_scale.dtype == torch.float32 + assert hc_base.dtype == torch.float32 + + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + hc_hidden_size = hc_mult * hidden_size + assert fn.shape[0] == hc_mult3 + assert fn.shape[1] == hc_hidden_size + assert hc_scale.shape == (3,) + assert hc_base.shape == (hc_mult3,) + + outer_shape = residual.shape[:-2] + + residual_flat = residual.view(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + fn_flat = fn + + post_mix = torch.empty( + num_tokens, hc_mult, dtype=torch.float32, device=residual.device + ) + comb_mix = torch.empty( + num_tokens, hc_mult2, dtype=torch.float32, device=residual.device + ) + layer_input = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device + ) + + gemm_out_mul = torch.empty( + n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device + ) + gemm_out_sqrsum = torch.empty( + n_splits, num_tokens, dtype=torch.float32, device=residual.device + ) + + if num_tokens <= 2048: + assert n_splits == 1 + if hc_hidden_size == 16384: + hidden_block = 256 + elif hc_hidden_size == 28672: + hidden_block = 128 + else: + raise NotImplementedError( + f"mhc_pre splitk kernel only supports hc_hidden_size in {{16384, 28672}}, " + f"got {hc_hidden_size}" + ) + kernel_0, kernel_1 = mhc_pre_gemm_sqrsum_splitk_kernel( + hc_mult3, + hc_hidden_size, + split_k=n_splits_pre, + token_block=32, + hidden_block=hidden_block, + ) + partial_out = gemm_out_mul.new_empty(n_splits_pre, num_tokens, 32) + partial_sqrsum = gemm_out_sqrsum.new_empty(n_splits_pre, num_tokens) + kernel_0( + residual_flat.view(num_tokens, hc_hidden_size), + fn_flat, + partial_out, + partial_sqrsum, + ) + kernel_1( + partial_out, + partial_sqrsum, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + ) + del partial_out, partial_sqrsum + else: + assert ( + n_splits == 1 + ), "The simple TileLang version gemm_sqrsum doesn't support split-k" + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + + return post_mix, comb_mix, layer_input + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_post_tilelang( + a, b, c, d, x, hc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024 +) -> tilelang.JITKernel: + n = T.dynamic("num_tokens") + h = hidden + + h_blk = math.gcd(hidden, h_blk) + a: T.Tensor((n, hc, hc), T.float32) + b: T.Tensor((n, hc, h), T.bfloat16) + c: T.Tensor((n, hc), T.float32) + d: T.Tensor((n, h), T.bfloat16) + x: T.Tensor((n, hc, h), T.bfloat16) + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(n, threads=n_thr) as i_n: + if ENABLE_PDL: + T.pdl_sync() + + x_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + b_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + d_shared = T.alloc_shared(h_blk, T.bfloat16) + + x_local = T.alloc_fragment((hc, h_blk), T.float32) + b_local = T.alloc_fragment((hc, h_blk), T.float32) + d_local = T.alloc_fragment(h_blk, T.float32) + + a_local = T.alloc_fragment((hc, hc), T.float32) + c_local = T.alloc_fragment(hc, T.float32) + T.copy(a[i_n, 0, 0], a_local) + T.copy(c[i_n, 0], c_local) + + for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2): + T.copy(b[i_n, 0, i0_h * h_blk], b_shared) + T.copy(d[i_n, i0_h * h_blk], d_shared) + + T.copy(b_shared, b_local) + T.copy(d_shared, d_local) + for i_hco, i1_h in T.Parallel(hc, h_blk): + x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h] + for i_hci in T.serial(hc): + x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] + T.copy(x_local, x_shared) + + T.copy(x_shared, x[i_n, 0, i0_h * h_blk]) + + if ENABLE_PDL: + T.pdl_trigger() + + +def mhc_post( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + if is_nsa_prefill_cp_round_robin_split(): + x = strict_contiguous(x) + residual = strict_contiguous(residual) + post_layer_mix = strict_contiguous(post_layer_mix) + comb_res_mix = strict_contiguous(comb_res_mix) + out = torch.empty_like(residual) + mhc_post_tilelang( + comb_res_mix, + residual, + post_layer_mix.squeeze(-1), + x, + out, + residual.shape[-2], + residual.shape[-1], + ) + return out diff --git a/python/sglang/srt/layers/moe/deepseek_v4_topk.py b/python/sglang/srt/layers/moe/deepseek_v4_topk.py new file mode 100644 index 000000000000..8c17ff240af2 --- /dev/null +++ b/python/sglang/srt/layers/moe/deepseek_v4_topk.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +import json +import logging +from typing import Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.environ import envs +from sglang.srt.eplb.expert_location_dispatch import ( + ExpertLocationDispatchInfo, + topk_ids_logical_to_physical, +) +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + get_compiler_backend, + is_cpu, + is_cuda, + is_hip, + is_npu, +) + +logger = logging.getLogger(__name__) +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_cpu = is_cpu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_npu = is_npu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + + +from sglang.srt.layers.moe.topk import ( + StandardTopKOutput, + _mask_topk_ids_padded_region, + _maybe_override_topk_ids_random, +) + + +class HashTopK(nn.Module): + def __init__( + self, + topk, + num_experts, + num_fused_shared_experts, + vocab_size, + scoring_func="sqrtsoftplus", + routed_scaling_factor=1.5, + apply_routed_scaling_factor_on_output=False, + ): + super().__init__() + self.num_experts = num_experts + self.topk = topk + self.routed_scaling_factor = routed_scaling_factor + self.num_fused_shared_experts = num_fused_shared_experts + self.score_func = scoring_func + self.tid2eid = nn.Parameter( + torch.empty(vocab_size, topk - num_fused_shared_experts, dtype=torch.int32), + requires_grad=False, + ) + + assert not apply_routed_scaling_factor_on_output, "not implemented" + + def empty_topk_output(self, device: torch.device): + topk = self.topk - self.num_fused_shared_experts + topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) + topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) + router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) + + def _forward_torch( + self, router_logits: torch.Tensor, input_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.score_func == "softmax": + scores = router_logits.softmax(dim=-1) + elif self.score_func == "sigmoid": + scores = router_logits.sigmoid() + else: + scores = torch.nn.functional.softplus(router_logits).sqrt() + + num_token = scores.shape[0] + + topk_ids = torch.zeros( + (num_token, self.topk), dtype=torch.int32, device=scores.device + ) + topk_weights = torch.zeros( + (num_token, self.topk), dtype=scores.dtype, device=scores.device + ) + + if self.num_fused_shared_experts == 1: + topk_ids[:, :-1] = self.tid2eid[input_ids] + topk_weights[:, :-1] = scores.gather(1, topk_ids[:, :-1]) + + if self.score_func != "softmax": + topk_weights[:, :-1] /= topk_weights[:, :-1].sum(dim=-1, keepdim=True) + + topk_ids[:, -1] = torch.randint( + low=self.num_experts, + high=self.num_experts + self.num_fused_shared_experts, + size=(num_token,), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / self.routed_scaling_factor + ) + else: + topk_ids[:, :] = self.tid2eid[input_ids] + topk_weights[:, :] = scores.gather(1, topk_ids[:, :]) + if self.score_func != "softmax": + topk_weights[:, :] /= topk_weights[:, :].sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + input_ids: torch.Tensor, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ): + assert ( + input_ids.shape[0] == hidden_states.shape[0] == router_logits.shape[0] + ), f"{input_ids.shape=} {hidden_states.shape=} {router_logits.shape=}" + + if envs.SGLANG_HACK_FORCE_TID2EID_ZERO.get(): + self.tid2eid.zero_() + + if envs.SGLANG_OPT_USE_FUSED_HASH_TOPK.get(): + from sglang.jit_kernel.deepseek_v4 import hash_topk + + topk_weights, topk_ids = hash_topk( + router_logits=router_logits, + input_ids=input_ids, + tid2eid=self.tid2eid, + num_fused_shared_experts=self.num_fused_shared_experts, + routed_scaling_factor=self.routed_scaling_factor, + scoring_func=self.score_func, + ) + else: + topk_weights, topk_ids = self._forward_torch(router_logits, input_ids) + + if is_hip(): + topk_weights = topk_weights.to(torch.float32) + + topk_ids = _maybe_override_topk_ids_random(topk_ids, self.num_experts) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + topk_output = StandardTopKOutput( + topk_weights=topk_weights, topk_ids=topk_ids, router_logits=router_logits + ) + return topk_output + + +@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu) +def biased_topk_impl( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + if scoring_func == "sigmoid": + scores = gating_output.sigmoid() + elif scoring_func == "sqrtsoftplus": + scores = torch.nn.functional.softplus(gating_output).sqrt() + + num_token = scores.shape[0] + num_experts = scores.shape[1] + + scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) + _, topk_ids = torch.topk( + scores_for_choice, + k=topk, + dim=-1, + sorted=(True if num_fused_shared_experts > 0 else False), + ) + topk_weights = scores.gather(1, topk_ids) + + if num_fused_shared_experts: + topk_ids[:, -1] = torch.randint( + low=num_experts, + high=num_experts + num_fused_shared_experts, + size=(topk_ids.size(0),), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + if routed_scaling_factor is not None: + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor + ) + + if renormalize: + topk_weights_sum = ( + topk_weights.sum(dim=-1, keepdim=True) + if num_fused_shared_experts == 0 + else topk_weights[:, :-1].sum(dim=-1, keepdim=True) + ) + topk_weights = topk_weights / topk_weights_sum + if apply_routed_scaling_factor_on_output: + topk_weights *= routed_scaling_factor + + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids + + +def biased_topk_jit_kernel_impl( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + from sglang.jit_kernel.moe_fused_gate import moe_fused_gate + + topk_weights, topk_ids = moe_fused_gate( + gating_output, + correction_bias, + topk=topk, + scoring_func=scoring_func, + num_fused_shared_experts=num_fused_shared_experts, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..eacde3f6b8fb --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..b60f7dc039df --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py index 1f1c3e709d49..586041f1b9cd 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py @@ -1,19 +1,25 @@ from typing import Optional import torch +import torch.nn.functional as F from sglang.srt.utils import is_cuda from sglang.srt.utils.custom_op import register_custom_op _is_cuda = is_cuda() - if _is_cuda: - from sgl_kernel import moe_sum_reduce, silu_and_mul - - -def get_scalar_type(num_bits: int, has_zp: bool): + from sgl_kernel import silu_and_mul from sgl_kernel.scalar_type import scalar_types + +def get_scalar_type(num_bits: int, has_zp: bool, scales: Optional[torch.Tensor] = None): + if ( + not has_zp + and num_bits == 4 + and scales is not None + and scales.dtype == torch.float8_e8m0fnu + ): + return scalar_types.float4_e2m1f if has_zp: assert num_bits == 4 return scalar_types.uint4 @@ -21,6 +27,22 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 +def swiglu_limit_func( + output: torch.Tensor, + input: torch.Tensor, # first half is gate, second half is up + swiglu_limit: float = 0.0, +) -> None: + d = input.shape[1] // 2 + gate = input[:, :d] + up = input[:, d:] + + if swiglu_limit > 0: + gate = torch.clamp(gate, max=swiglu_limit) + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + + output.copy_(F.silu(gate) * up) + + @register_custom_op(out_shape="hidden_states") def fused_marlin_moe( hidden_states: torch.Tensor, @@ -44,6 +66,7 @@ def fused_marlin_moe( is_k_full: bool = True, inplace: bool = False, routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -83,12 +106,29 @@ def fused_marlin_moe( assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert ( - hidden_states.dtype == w1_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" - assert ( - hidden_states.dtype == w2_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" + is_mxfp4_marlin = ( + num_bits == 4 + and w1_zeros is None + and w2_zeros is None + and w1_scale.dtype == torch.float8_e8m0fnu + and w2_scale.dtype == torch.float8_e8m0fnu + ) + if is_mxfp4_marlin: + assert w1_scale.dtype == torch.float8_e8m0fnu, ( + "MXFP4 Marlin expects w1_scale to be torch.float8_e8m0fnu, " + f"got {w1_scale.dtype}" + ) + assert w2_scale.dtype == torch.float8_e8m0fnu, ( + "MXFP4 Marlin expects w2_scale to be torch.float8_e8m0fnu, " + f"got {w2_scale.dtype}" + ) + else: + assert ( + hidden_states.dtype == w1_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" + assert ( + hidden_states.dtype == w2_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" assert num_bits in [4, 8] M, K = hidden_states.shape @@ -119,8 +159,8 @@ def fused_marlin_moe( max_workspace_size, dtype=torch.int, device=device, requires_grad=False ) - scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) - scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None, w1_scale) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None, w2_scale) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -140,7 +180,7 @@ def fused_marlin_moe( use_atomic_add = ( hidden_states.dtype == torch.half or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 - ) + ) and (not is_mxfp4_marlin) intermediate_cache1 = torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( hidden_states, @@ -171,7 +211,14 @@ def fused_marlin_moe( is_zp_float=False, ) - silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) + if clamp_limit is not None: + swiglu_limit_func( + intermediate_cache2, + intermediate_cache1.view(-1, 2 * N), + clamp_limit, + ) + else: + silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) if expert_map is not None: intermediate_cache3.zero_() @@ -206,13 +253,4 @@ def fused_marlin_moe( ).view(-1, topk, K) output = hidden_states if inplace else torch.empty_like(hidden_states) - - if routed_scaling_factor is None: - routed_scaling_factor = 1.0 - - moe_sum_reduce( - intermediate_cache3, - output, - routed_scaling_factor, - ) - return output + return torch.sum(intermediate_cache3, dim=1, out=output) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index a1885fade143..b168a98d9ac8 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -13,6 +13,8 @@ import torch.nn.functional as F import triton.language as tl +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker +from sglang.srt.environ import envs from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.utils import ( cpu_has_amx_support, @@ -87,6 +89,7 @@ def inplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ) -> None: fused_experts_impl( hidden_states, @@ -117,6 +120,7 @@ def inplace_fused_experts( gemm1_alpha, gemm1_limit, filter_expert, + swiglu_limit=swiglu_limit, ) @@ -149,6 +153,7 @@ def outplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -179,6 +184,7 @@ def outplace_fused_experts( gemm1_alpha=gemm1_alpha, gemm1_limit=gemm1_limit, filter_expert=filter_expert, + swiglu_limit=swiglu_limit, ) @@ -237,6 +243,7 @@ def fused_experts( moe_runner_config.gemm1_alpha, moe_runner_config.gemm1_clamp_limit, filter_expert, + moe_runner_config.swiglu_limit, ) return hidden_states else: @@ -268,6 +275,7 @@ def fused_experts( gemm1_alpha=moe_runner_config.gemm1_alpha, gemm1_limit=moe_runner_config.gemm1_clamp_limit, filter_expert=filter_expert, + swiglu_limit=moe_runner_config.swiglu_limit, ) @@ -319,6 +327,7 @@ def fused_experts_impl( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ): padded_size = padding_size if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: @@ -478,23 +487,67 @@ def fused_experts_impl( gemm1_alpha, gemm1_limit, ) - elif _is_cuda or _is_hip: - if not filter_expert: - silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) - else: - act_and_mul_triton( - intermediate_cache1.view(-1, N), - intermediate_cache2, - config, - topk_ids, - expert_ids, - down_moe_use_tma, - activation, - ) else: - vllm_ops.silu_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, N) + is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" + assert is_2604b == (swiglu_limit is not None), ( + f"swiglu_limit must be non-None iff submode=2604B " + f"(got submode={envs.SGLANG_DSV4_2604_SUBMODE.get()!r}, swiglu_limit={swiglu_limit!r})" ) + + swiglu_limit_for_triton: Optional[float] = None + swiglu_limit_for_silu_and_mul_clamp: Optional[float] = None + if is_2604b: + assert swiglu_limit == 10 + assert intermediate_cache1.shape == (total_tokens, N) + assert ( + _is_cuda or _is_hip + ), "DSV4 2604 submode 2604B only supports CUDA/HIP downstream" + + if envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get(): + if filter_expert: + swiglu_limit_for_triton = swiglu_limit + else: + assert ( + _is_cuda + ), "fused silu_and_mul_clamp kernel is CUDA-only; HIP must disable SWIGLU_CLAMP_FUSION" + swiglu_limit_for_silu_and_mul_clamp = swiglu_limit + else: + half = N // 2 + intermediate_cache1[:, :half].clamp_(max=swiglu_limit) + intermediate_cache1[:, half:].clamp_( + min=-swiglu_limit, max=swiglu_limit + ) + deepseek_v4_moe_code_path_checker.observed += 1 + + if _is_cuda or _is_hip: + if not filter_expert: + if swiglu_limit_for_silu_and_mul_clamp is not None: + from sglang.jit_kernel.deepseek_v4 import silu_and_mul_clamp + + silu_and_mul_clamp( + intermediate_cache1.view(-1, N), + intermediate_cache2, + swiglu_limit_for_silu_and_mul_clamp, + ) + else: + silu_and_mul( + intermediate_cache1.view(-1, N), intermediate_cache2 + ) + else: + act_and_mul_triton( + intermediate_cache1.view(-1, N), + intermediate_cache2, + config, + topk_ids, + expert_ids, + down_moe_use_tma, + activation, + swiglu_limit=swiglu_limit_for_triton, + ) + else: + vllm_ops.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) + ) elif activation == "gelu" and is_gated: assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py index 230b64057ab4..29e23868e8d2 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -8,6 +8,7 @@ import triton import triton.language as tl +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, scaled_fp8_quant, @@ -871,6 +872,8 @@ def act_and_mul_kernel( expert_step: tl.constexpr, BLOCK_SIZE: tl.constexpr, ACTIVATION_TYPE: tl.constexpr, + SWIGLU_LIMIT: tl.constexpr = 0.0, + HAS_SWIGLU_LIMIT: tl.constexpr = False, ): """ Unified activation and multiply kernel that handles both sorted and unsorted routing, @@ -899,6 +902,10 @@ def act_and_mul_kernel( gate_output = tl.load(gate_output_ptr + offset, mask=mask) up_output = tl.load(up_output_ptr + offset, mask=mask) + if HAS_SWIGLU_LIMIT: + gate_output = tl.minimum(gate_output, SWIGLU_LIMIT) + up_output = tl.maximum(tl.minimum(up_output, SWIGLU_LIMIT), -SWIGLU_LIMIT) + gate_output_activated = _apply_activation(gate_output, ACTIVATION_TYPE) gate_output_activated = gate_output_activated.to(InDtype) @@ -915,6 +922,7 @@ def act_and_mul_triton( expert_ids: Optional[torch.Tensor] = None, down_moe_use_tma: bool = False, activation: str = "silu", + swiglu_limit: Optional[float] = None, ) -> None: """ Args: @@ -925,11 +933,16 @@ def act_and_mul_triton( expert_ids: Expert IDs for sorted routing (used when down_moe_use_tma=True) down_moe_use_tma: Whether to use sorted routing layout activation: Activation type ("silu" or "gelu") + swiglu_limit: if not None, clamp gate to [-inf, L] and up to [-L, L] before activation + (compiles a separate kernel variant via tl.constexpr). """ grid = (down_input.shape[0],) hidden_size = gateup_output.shape[1] expert_ids_row = topk_ids.view(-1) if not down_moe_use_tma else expert_ids expert_step = 1 if not down_moe_use_tma else config["BLOCK_SIZE_M"] + has_swiglu_limit = swiglu_limit is not None + if has_swiglu_limit: + deepseek_v4_moe_code_path_checker.observed += 1 act_and_mul_kernel[grid]( gateup_output, down_input, @@ -938,6 +951,8 @@ def act_and_mul_triton( expert_step, BLOCK_SIZE=512, ACTIVATION_TYPE=activation, + SWIGLU_LIMIT=float(swiglu_limit) if has_swiglu_limit else 0.0, + HAS_SWIGLU_LIMIT=has_swiglu_limit, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 019843ae0365..8154ca9ff65f 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -181,6 +181,7 @@ def __init__( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None, + swiglu_limit: Optional[float] = None, use_weight_loader_fused: bool = False, with_bias=False, routing_method_type: Optional[RoutingMethodType] = None, @@ -255,6 +256,7 @@ def __init__( routed_scaling_factor=routed_scaling_factor, gemm1_alpha=gemm1_alpha, gemm1_clamp_limit=gemm1_clamp_limit, + swiglu_limit=swiglu_limit, is_gated=is_gated, routing_method_type=routing_method_type, ) @@ -1130,9 +1132,10 @@ def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): self.moe_runner_config.activation == "silu" ), "Only silu is supported for flashinfer trtllm moe" assert self.quant_method is not None - assert ( - topk_output.topk_config.renormalize - ), "Renormalize is required for flashinfer trtllm moe" + if hasattr(topk_output, "topk_config"): + assert ( + topk_output.topk_config.renormalize + ), "Renormalize is required for flashinfer trtllm moe" assert ( self.num_fused_shared_experts == 0 ), "Fused shared experts are not supported for flashinfer trtllm moe" @@ -1140,14 +1143,18 @@ def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): self.moe_runner_config.is_gated ), "Only gated MoEs are supported for flashinfer trtllm moe" - assert TopKOutputChecker.format_is_bypassed(topk_output) - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - correction_bias = topk_config.correction_bias - routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + assert TopKOutputChecker.format_is_bypassed( + topk_output + ) or TopKOutputChecker.format_is_standard(topk_output) if isinstance(self.quant_method, UnquantizedFusedMoEMethod): + assert TopKOutputChecker.format_is_bypassed( + topk_output + ), "BF16 MoE requires bypassed topk output" + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + correction_bias = topk_config.correction_bias + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor # lazy import try: from flashinfer.fused_moe import trtllm_bf16_moe diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index 12dd2ba6a237..088fbfbef33d 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -48,6 +48,7 @@ class MoeRunnerConfig: routed_scaling_factor: Optional[float] = None gemm1_alpha: Optional[float] = None gemm1_clamp_limit: Optional[float] = None + swiglu_limit: Optional[float] = None @dataclass diff --git a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py index 7fa8193fb328..8899f6234182 100644 --- a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py +++ b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py @@ -1,10 +1,16 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple +import einops import torch +from sglang.jit_kernel.deepseek_v4 import silu_and_mul_masked_post_quant +from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, +) +from sglang.srt.environ import envs, is_large_dummy_model from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.moe_runner.base import ( MoeQuantInfo, @@ -35,9 +41,11 @@ _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +# Imported only for the SGLANG_OPT_FIX_MEGA_MOE_MEMORY=False fallback path. if not (_is_npu or _is_hip): - from sgl_kernel import silu_and_mul - + from sgl_kernel import silu_and_mul as _legacy_silu_and_mul +else: + _legacy_silu_and_mul = None _MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT") _DEEPGEMM_ON_H20 = get_bool_env_var("SGLANG_DEEPGEMM_ON_H20") @@ -106,6 +114,13 @@ def __init__(self, config: MoeRunnerConfig): super().__init__(config) assert self.config.activation == "silu" assert self.config.is_gated + self.swiglu_limit = self.config.swiglu_limit + self.use_swizzle = False + if envs.SGLANG_OPT_FIX_MEGA_MOE_MEMORY.get(): + assert envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get() + assert envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get() + assert envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get() + self.use_swizzle = True def run( self, @@ -129,9 +144,10 @@ def _run_contiguous_gemm( quant_info: DeepGemmMoeQuantInfo, running_state: dict, ) -> torch.Tensor: + from sglang.jit_kernel.deepseek_v4 import silu_and_mul_contig_post_quant from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_fp8, + create_per_token_group_quant_fp8_output_scale, ) hidden_states = runner_input.hidden_states @@ -169,25 +185,66 @@ def _run_contiguous_gemm( dispose_tensor(hidden_states) dispose_tensor(hidden_states_scale) - down_input = torch.empty( - ( - all_tokens, - N // 2, - ), - device=gateup_output.device, - dtype=torch.bfloat16, - ) - silu_and_mul(gateup_output.view(-1, N), down_input) - del gateup_output + if envs.SGLANG_OPT_FIX_MEGA_MOE_MEMORY.get(): + is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" + swiglu_limit_arg: Optional[float] = None + if is_2604b: + swiglu_limit_arg = self.swiglu_limit - down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( - down_input, - scale_block_size, - column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - ) - del down_input + down_input_fp8 = torch.empty( + (all_tokens, N // 2), + device=gateup_output.device, + dtype=torch.float8_e4m3fn, + ) + down_input_scale = create_per_token_group_quant_fp8_output_scale( + x_shape=(all_tokens, N // 2), + device=gateup_output.device, + group_size=scale_block_size, + column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + silu_and_mul_contig_post_quant( + input=gateup_output, + output=down_input_fp8, + output_scale=down_input_scale, + quant_group_size=scale_block_size, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + transposed=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + swiglu_limit=swiglu_limit_arg, + swizzle=self.use_swizzle, + ) + del gateup_output + else: + # Hacky byte-equal fallback that reproduces the optimize-branch + # code path exactly: bf16 silu_and_mul then a separate per-token + # group fp8 quant. Kept behind the mega-moe-memory flag. + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) + + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + gateup_output = _apply_swiglu_limit( + gateup_output, swiglu_limit=self.swiglu_limit + ) + deepseek_v4_moe_code_path_checker.observed += 1 + + down_input = torch.empty( + (all_tokens, N // 2), + device=gateup_output.device, + dtype=torch.bfloat16, + ) + _legacy_silu_and_mul(gateup_output.view(-1, N), down_input) + del gateup_output + + down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( + down_input, + scale_block_size, + column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + del down_input down_output = torch.empty( (all_tokens, K), @@ -213,12 +270,6 @@ def _run_masked_gemm( running_state: dict, ) -> torch.Tensor: from sglang.srt.layers import deep_gemm_wrapper - from sglang.srt.layers.moe.ep_moe.kernels import ( - silu_and_mul_masked_post_quant_fwd, - ) - from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_8bit, - ) hidden_states = runner_input.hidden_states hidden_states_scale = runner_input.hidden_states_scale @@ -262,47 +313,42 @@ def _run_masked_gemm( dispose_tensor(hidden_states) dispose_tensor(hidden_states_scale) + is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" + assert is_2604b == ( + self.swiglu_limit is not None + ), f"swiglu_limit must be non-None iff submode=2604B (got submode={envs.SGLANG_DSV4_2604_SUBMODE.get()!r}, swiglu_limit={self.swiglu_limit!r})" + swiglu_limit_arg: Optional[float] = None + if is_2604b: + assert ( + not _MASKED_GEMM_FAST_ACT + ), "DSV4 2604 submode 2604B does not support SGLANG_MASKED_GEMM_FAST_ACT" + assert ( + envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get() + ), "DSV4 2604 submode 2604B requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + + if envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get(): + swiglu_limit_arg = self.swiglu_limit + else: + gateup_output = einops.rearrange( + gateup_output, "grp tok hidden -> (grp tok) hidden" + ) + gateup_output = _apply_swiglu_limit( + gateup_output, swiglu_limit=self.swiglu_limit + ) + gateup_output = einops.rearrange( + gateup_output, "(grp tok) hidden -> grp tok hidden", grp=num_groups + ) + deepseek_v4_moe_code_path_checker.observed += 1 + # Act - scale_block_size = 128 - if _MASKED_GEMM_FAST_ACT: - down_input, down_input_scale = sglang_per_token_group_quant_8bit( - x=gateup_output, - dst_dtype=torch.float8_e4m3fn, - group_size=scale_block_size, - masked_m=masked_m, - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - fuse_silu_and_mul=True, - enable_v2=True, - ) - else: - down_input = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2, - ), - device=hidden_states_device, - dtype=torch.float8_e4m3fn, - ) - down_input_scale = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2 // scale_block_size, - ), - device=hidden_states_device, - dtype=torch.float32, - ) - silu_and_mul_masked_post_quant_fwd( - gateup_output, - down_input, - down_input_scale, - scale_block_size, - masked_m, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - ) + down_input, down_input_scale = _varlen_deep_gemm_silu_mul_quant( + gateup_output, + masked_m, + group_size=128, + topk=self.config.top_k, + swiglu_limit=swiglu_limit_arg, + swizzle=self.use_swizzle, + ) del gateup_output # GroupGemm-1 @@ -604,3 +650,115 @@ def post_permute_deep_gemm_to_deepep_normal( topk_ids=running_state["topk_ids"], topk_weights=running_state["topk_weights"], ) + + +def _varlen_deep_gemm_silu_mul_quant( + gateup_output: torch.Tensor, + masked_m: Optional[torch.Tensor], + group_size: int, + topk: int, + swiglu_limit: Optional[float] = None, + swizzle: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_8bit, + ) + + if _MASKED_GEMM_FAST_ACT: + assert not swizzle, ( + "SGLANG_OPT_FIX_MEGA_MOE_MEMORY is incompatible with " + "SGLANG_MASKED_GEMM_FAST_ACT (swizzled layout only supported by JIT act)" + ) + assert ( + swiglu_limit is None + ), "swiglu_limit (DSV4 2604 submode 2604B) is not supported together with SGLANG_MASKED_GEMM_FAST_ACT" + return sglang_per_token_group_quant_8bit( + x=gateup_output, + dst_dtype=torch.float8_e4m3fn, + group_size=group_size, + masked_m=masked_m, + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + fuse_silu_and_mul=True, + enable_v2=True, + ) + + assert masked_m is not None + hidden_states_device = gateup_output.device + E, N, D_2 = gateup_output.shape + D = D_2 // 2 + del D_2 + G = D // group_size + down_input = torch.empty( + (E, N, D), + device=hidden_states_device, + dtype=torch.float8_e4m3fn, + ) + + if envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get(): + assert N % 4 == 0 and G % 4 == 0 + packed_ue8m0 = deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + down_input_scale = torch.empty( + (E, G // 4, N) if packed_ue8m0 else (E, N, G), + device=hidden_states_device, + dtype=torch.int32 if packed_ue8m0 else torch.float32, + ) + silu_and_mul_masked_post_quant( + gateup_output, + down_input, + down_input_scale, + group_size, + masked_m, + scale_ue8m0=packed_ue8m0, + topk=topk, + transposed=packed_ue8m0, + swiglu_limit=swiglu_limit, + swizzle=swizzle, + ) + if packed_ue8m0: + down_input_scale = down_input_scale.transpose(-1, -2) + else: + assert ( + swiglu_limit is None + ), "swiglu_limit (DSV4 2604 submode 2604B) requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + assert ( + not swizzle + ), "SGLANG_OPT_FIX_MEGA_MOE_MEMORY requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + down_input_scale = torch.empty( + (E, N, G), + device=hidden_states_device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + gateup_output, + down_input, + down_input_scale, + group_size, + masked_m, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + return down_input, down_input_scale + + +def _apply_swiglu_limit( + gateup_output: torch.Tensor, swiglu_limit: float +) -> torch.Tensor: + assert swiglu_limit == 10 + + num_tokens, hidden_size_x2 = gateup_output.shape + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model(): + assert hidden_size_x2 == 2048 * 2 + assert gateup_output.dtype == torch.bfloat16 + + gate, up = torch.chunk(gateup_output, chunks=2, dim=-1) + assert gate.shape == (num_tokens, hidden_size_x2 // 2) + assert up.shape == (num_tokens, hidden_size_x2 // 2) + + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + gate = torch.clamp(gate, max=swiglu_limit) + + out = torch.cat([gate, up], dim=-1) + assert out.shape == (num_tokens, hidden_size_x2) + return out diff --git a/python/sglang/srt/layers/moe/moe_runner/marlin.py b/python/sglang/srt/layers/moe/moe_runner/marlin.py index 45104dd27805..a25791c423a9 100644 --- a/python/sglang/srt/layers/moe/moe_runner/marlin.py +++ b/python/sglang/srt/layers/moe/moe_runner/marlin.py @@ -5,6 +5,10 @@ import torch +from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, +) +from sglang.srt.environ import envs from sglang.srt.layers.moe.moe_runner.base import ( MoeQuantInfo, MoeRunnerConfig, @@ -97,8 +101,31 @@ def fused_experts_none_to_marlin( hidden_states.device, max_blocks_per_sm=4 ) + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" and ( + runner_config.swiglu_limit is not None + ): + deepseek_v4_moe_code_path_checker.observed += 1 + + marlin_hidden_states = hidden_states + # Avoid aliasing the MoE input buffer until Marlin output semantics are + # fully validated across shared-expert and overlap paths. + marlin_inplace = False + if ( + quant_info.weight_bits == 4 + and quant_info.w13_qzeros is None + and quant_info.w2_qzeros is None + and quant_info.w13_scales.dtype == torch.float8_e8m0fnu + and quant_info.w2_scales.dtype == torch.float8_e8m0fnu + and hidden_states.dtype == torch.float16 + ): + # MXFP4(E8M0) Marlin kernels are only numerically valid on the bf16 + # activation path. The fp16 + E8M0 path is intentionally not generated + # in sgl-kernel, so upcast activations here and cast the result back. + marlin_hidden_states = hidden_states.to(torch.bfloat16) + marlin_inplace = False + output = fused_marlin_moe( - hidden_states=hidden_states, + hidden_states=marlin_hidden_states, w1=quant_info.w13_qweight, w2=quant_info.w2_qweight, w1_scale=quant_info.w13_scales, @@ -116,8 +143,9 @@ def fused_experts_none_to_marlin( workspace=MARLIN_MOE_WORKSPACE, num_bits=quant_info.weight_bits, is_k_full=quant_info.is_k_full, - inplace=runner_config.inplace, + inplace=marlin_inplace, routed_scaling_factor=runner_config.routed_scaling_factor, + clamp_limit=runner_config.swiglu_limit, ).to(hidden_states.dtype) return StandardCombineInput( diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index 00bd68755587..a258bfe8d06d 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -8,10 +8,13 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather_into_tensor, get_attention_dp_rank, + get_attention_tp_size, get_dp_local_info, is_dp_attention_enabled, ) +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args @@ -181,6 +184,17 @@ def __init__( device=device, ) + if get_moe_a2a_backend().is_deepep(): + attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 + self.gather_buffer = torch.empty( + ( + self.device_cache.buffer.shape[0] * attn_tp_size, + self.device_cache.buffer.shape[2], + ), + dtype=torch.int32, + device=device, + ) + def _sync_fwd_experts_buffer_DtoH( self, forward_batch: ForwardBatch, @@ -206,6 +220,12 @@ def _sync_fwd_experts_buffer_DtoH( ].cpu() def capture(self, layer_id: int, topk_ids: torch.Tensor): + if get_moe_a2a_backend().is_deepep(): + local_topk_ids = topk_ids + topk_ids = self.gather_buffer[ + : local_topk_ids.size(0) * get_attention_tp_size() + ] + attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) def get_routed_experts( @@ -281,7 +301,6 @@ def set_global_experts_capturer(capturer: RoutedExpertsCapturer): def extract_routed_experts_from_meta_info(data): # To solve the performance issue, we return the experts_ids in base64 # We left this function for user to change it back to normal int32 - # See detokenizer_manager::_extract_routed_experts routed_experts_base64 = data["meta_info"].get("routed_experts", None) routed_experts = np.frombuffer( pybase64.b64decode(routed_experts_base64.encode("utf-8")), dtype=np.int32 diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 8539639d5e9a..d5fe05012630 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -490,7 +490,6 @@ def combine_a( topk_ids: torch.Tensor, topk_weights: torch.Tensor, ): - if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu: output = hidden_states else: diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index 0a127009885a..a47534ce7f4a 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -12,6 +12,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import ( get_dp_global_num_tokens, get_local_dp_buffer, @@ -86,6 +87,13 @@ def __init__(self, moe_runner_config: MoeRunnerConfig): self.enable_flashinfer_cutlass_moe = ( get_moe_runner_backend().is_flashinfer_cutlass() ) + self.enable_flashinfer_mxfp4_moe = ( + get_moe_runner_backend().is_flashinfer_mxfp4() + ) + self.skip_local_expert_mapping = ( + self.enable_flashinfer_mxfp4_moe + and envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get() + ) self.num_experts = moe_runner_config.num_experts self.num_local_shared_experts = moe_runner_config.num_fused_shared_experts self.num_local_routed_experts = ( @@ -142,6 +150,7 @@ def dispatch( if ( self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe + and not self.skip_local_expert_mapping and TopKOutputChecker.format_is_standard(topk_output) ): if self.local_expert_mapping is None: @@ -167,7 +176,11 @@ def dispatch( ) ) - if self.local_expert_mapping is not None and not _use_aiter: + if ( + self.local_expert_mapping is not None + and not _use_aiter + and not self.skip_local_expert_mapping + ): if TopKOutputChecker.format_is_standard(topk_output): topk_output = topk_output._replace( topk_ids=self.local_expert_mapping[topk_output.topk_ids] diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 419786c2f06e..b12177e45b1c 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -35,10 +35,12 @@ except ImportError: pass +from sglang.jit_kernel.deepseek_v4 import mask_topk_ids from sglang.srt.distributed import get_tp_group from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.environ import envs from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( @@ -278,9 +280,12 @@ def forward_cuda( output_format = self.topk_config.output_format elif get_moe_runner_backend().is_triton_kernels(): output_format = TopKOutputFormat.TRITON_KERNEL - elif ( - get_moe_runner_backend().is_flashinfer_trtllm() - or get_moe_runner_backend().is_flashinfer_mxfp4() + elif get_moe_runner_backend().is_flashinfer_trtllm() or ( + get_moe_runner_backend().is_flashinfer_mxfp4() + and not ( + envs.SGLANG_DSV4_MODE.get() == "2604" + and envs.SGLANG_DSV4_FP4_EXPERTS.get() + ) ): output_format = TopKOutputFormat.BYPASSED else: @@ -707,11 +712,24 @@ def is_power_of_two(n): def _mask_topk_ids_padded_region( topk_ids: torch.Tensor, num_token_non_padded: Optional[torch.Tensor] = None, -): +) -> None: if num_token_non_padded is None: return - indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device) - topk_ids[indices >= num_token_non_padded, :] = -1 + if envs.SGLANG_OPT_USE_FAST_MASK_EP.get(): + mask_topk_ids(topk_ids, num_token_non_padded) + else: + indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device) + topk_ids[indices >= num_token_non_padded, :] = -1 + + +def _maybe_override_topk_ids_random( + topk_ids: torch.Tensor, num_experts: int +) -> torch.Tensor: + if not envs.SGLANG_HACK_OVERRIDE_TOPK_IDS_RANDOM.get() or topk_ids.numel() == 0: + return topk_ids + n_tokens, k = topk_ids.shape + scores = torch.rand(n_tokens, num_experts, device=topk_ids.device) + return scores.topk(k, dim=-1).indices.to(topk_ids.dtype) @torch.compile(dynamic=True, backend=get_compiler_backend()) @@ -842,6 +860,7 @@ def biased_grouped_topk_gpu( # Use optimized path for Kimi K2 (384 experts with num_expert_group=1) num_experts = gating_output.shape[1] if _is_cuda and num_experts == 384 and num_expert_group == 1: + assert False, "dpsk should not use kimi" return kimi_k2_moe_fused_gate( gating_output.to(dtype=torch.float32), correction_bias, @@ -996,17 +1015,38 @@ def select_experts( ) elif custom_routing_function is None: assert not apply_routed_scaling_factor_on_output, "Not implemented" - # Qwen3MOE uses fused_topk - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=num_routed_topk if _use_aiter else top_k, - renormalize=renormalize, - correction_bias=correction_bias, - num_token_non_padded=num_token_non_padded, - expert_location_dispatch_info=expert_location_dispatch_info, - scoring_func=scoring_func, - ) + if scoring_func == "sqrtsoftplus": + if envs.SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK.get(): + from sglang.srt.layers.moe.deepseek_v4_topk import ( + biased_topk_jit_kernel_impl as biased_topk_impl, + ) + else: + from sglang.srt.layers.moe.deepseek_v4_topk import biased_topk_impl + + topk_weights, topk_ids = biased_topk_impl( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=num_routed_topk if _use_aiter else top_k, + renormalize=renormalize, + scoring_func=scoring_func, + num_fused_shared_experts=num_fused_shared_experts, + routed_scaling_factor=routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + else: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=num_routed_topk if _use_aiter else top_k, + renormalize=renormalize, + correction_bias=correction_bias, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + scoring_func=scoring_func, + ) else: assert ( num_token_non_padded is None @@ -1041,6 +1081,9 @@ def select_experts( N, # base id for shared experts ) + topk_ids = _maybe_override_topk_ids_random( + topk_ids, num_experts=router_logits.shape[-1] + ) get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) get_global_experts_capturer().capture( layer_id=layer_id, diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 565f5b9fd202..9e4ca8ebba4a 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -33,6 +33,7 @@ def _dtype_rank(dtype: torch.dtype) -> Optional[int]: torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, ): return 0 if dtype in (torch.float16, torch.bfloat16): @@ -69,6 +70,8 @@ def copy_with_check(target: torch.Tensor, loaded_weight: torch.Tensor): raise ValueError( f"Downcasting not allowed: {target.dtype=}, {loaded_weight.dtype=}" ) + if loaded_rank == torch.float8_e8m0fnu: + assert target_rank in {torch.float8_e8m0fnu, torch.float32} target.copy_(loaded_weight) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 573f69a3c4e9..e11f208455b1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -14,7 +14,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) -from sglang.srt.environ import envs +from sglang.srt.environ import envs, is_large_dummy_model from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.dp_attention import is_allocation_symmetric from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig @@ -182,7 +182,24 @@ def get_quant_method( return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self) + from sglang.srt.environ import envs + + fp8_method = Fp8MoEMethod(self) + + if ( + envs.SGLANG_DSV4_MODE.get() == "2604" + and envs.SGLANG_DSV4_FP4_EXPERTS.get() + and ( + get_moe_runner_backend().is_flashinfer_mxfp4() + or get_moe_runner_backend().is_marlin() + ) + ): + from sglang.srt.layers.quantization.mxfp4_deepseek import ( + DeepSeekMxfp4MoEMethod, + ) + + return DeepSeekMxfp4MoEMethod(fp8_method, prefix=prefix) + return fp8_method elif isinstance(layer, RadixAttention): return Fp8KVCacheMethod(self) return None @@ -601,6 +618,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + self.is_fp4_expert = ( + envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get() + ) if get_moe_runner_backend().is_cutlass(): assert ( cutlass_fp8_supported() @@ -660,7 +680,26 @@ def create_weights( ) # WEIGHTS - if _is_hip and _use_hip_int4: + if self.is_fp4_expert: + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + elif _is_hip and _use_hip_int4: # INT4 MoE weight - INT32 packed w13_weight = torch.nn.Parameter( torch.empty( @@ -707,7 +746,32 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - if self.block_quant: + if self.is_fp4_expert: + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model(): + assert hidden_size == 4096 + assert intermediate_size_per_partition == 2048 + fp4_block_k = 32 + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + elif self.block_quant: w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, @@ -858,6 +922,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) else: # For fp8 moe run with deepgemm, the expert weights and scales need be requantized to ue8m0 + from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.model_loader.utils import ( should_deepgemm_weight_requant_ue8m0, @@ -866,8 +931,50 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: # Check if MoE will actually use DeepGEMM runner will_use_deepgemm = self.is_deepgemm_moe_runner_backend_enabled() + if self.is_fp4_expert: + # FP4 experts support three MoE backends: + # - marlin (Hopper w4a16): only needs int8 view + # - flashinfer_mxfp4: only needs int8 view + # - deepgemm/auto (Blackwell): int8 view + mega_moe or scale conversion + layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) + layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) + + if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): + from sglang.srt.models.deepseek_v4 import ( + build_mega_moe_experts_weights, + ) + + build_mega_moe_experts_weights(layer) + return + + if ( + envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and will_use_deepgemm + ): + from deep_gemm import transform_sf_into_required_layout + + for scale_param, weight_param in [ + (layer.w13_weight_scale_inv, layer.w13_weight), + (layer.w2_weight_scale_inv, layer.w2_weight), + ]: + num_experts, n, _ = scale_param.data.shape + k = weight_param.shape[2] * 2 + scale_param.data = transform_sf_into_required_layout( + scale_param.data, + mn=n, + k=k, + recipe=(1, 32), + num_groups=num_experts, + disable_ue8m0_cast=False, + ) + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True + if ( - should_deepgemm_weight_requant_ue8m0( + not self.is_fp4_expert + and should_deepgemm_weight_requant_ue8m0( weight_block_size=getattr( self.quant_config, "weight_block_size", None ), diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 7701f9757f52..e590a4e7abfa 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -514,6 +514,51 @@ def sglang_per_token_group_quant_fp8( return x_q, x_s +def sglang_per_token_group_quant_fp8_ue8m0( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert ( + x.shape[-1] % group_size == 0 + ), f"hidden ({x.shape[-1]}) must be divisible by group_size ({group_size})" + assert x.is_contiguous(), "x must be contiguous" + assert enable_sgl_per_token_group_quant_8bit, ( + "sgl_per_token_group_quant_8bit is required (v2 kernel supports " + "group_size in {16, 32, 64, 128})" + ) + + *x_batch, x_q_mn, x_q_k = x.shape + x_q = torch.empty(x.shape, device=x.device, dtype=fp8_dtype) + + x_s_mn = x_q_mn + x_s_k = x_q_k // group_size + aligned_mn = ceil_align(x_s_mn, 4) + aligned_k = ceil_align(x_s_k, 4) + x_s = torch.empty( + (*x_batch, aligned_k // 4, aligned_mn), + device=x.device, + dtype=torch.int, + ).transpose(-1, -2)[..., :x_s_mn, :] + + if x.shape[0] > 0: + sgl_per_token_group_quant_8bit( + x, + x_q, + x_s, + group_size, + eps, + fp8_min, + fp8_max, + True, # scale_ue8m0 + False, # fuse_silu_and_mul + None, # masked_m + enable_v2=True, + ) + + return x_q, x_s + + # TODO maybe unify int8 and fp8 code later def sglang_per_token_group_quant_8bit( x: torch.Tensor, @@ -970,8 +1015,25 @@ def get_w8a8_block_fp8_configs( logger, f"Using configuration from {config_file_path} for W8A8 Block FP8 kernel.", ) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + raw = {int(key): val for key, val in json.load(f).items()} + + sanitized = {} + clamped_ms = [] + for m_key, cfg in raw.items(): + if cfg["BLOCK_SIZE_K"] < block_k: + clamped_ms.append((m_key, cfg["BLOCK_SIZE_K"])) + cfg = {**cfg, "BLOCK_SIZE_K": block_k} + sanitized[m_key] = cfg + if clamped_ms: + logger.warning( + "Clamped BLOCK_SIZE_K up to %d in tuned config %s for entries %s " + "(scale stepping requires BLOCK_SIZE_K >= block_k).", + block_k, + json_file_name, + clamped_ms, + ) + + return sanitized # If no optimized configuration is available, we will use the default # configuration diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 61219f6b04a7..ef532d7b1c39 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -384,6 +384,10 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( # TODO: https://github.com/sgl-project/sglang/pull/6890#issuecomment-2943395737 shape_supported = weight.shape[0] % 64 == 0 and weight.shape[1] % 128 == 0 + assert not get_bool_env_var( + "SGLANG_HACK_DEEPGEMM_W8A8_FORCE_TRITON" + ), "removed flag" + if not (shape_supported and dtype_supported): # fall back to triton # If weight_scale is in UE8M0 packed format (int32), convert back to float32 @@ -393,6 +397,9 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback( weight_scale = _unpack_ue8m0_scale_for_triton( weight_scale, weight.shape, block_size ) + + assert not get_bool_env_var("SGLANG_HACK_CUSTOM_W8A8_GEMM"), "removed flag" + return triton_w8a8_block_fp8_linear( input, weight, block_size, weight_scale, input_scale, bias ) @@ -680,7 +687,6 @@ def requant_weight_ue8m0( weight_dequant=weight_dequant, weight_block_size=weight_block_size, ) - out_s = transform_scale_ue8m0(out_s, mn=out_w.shape[-2]) return out_w, out_s diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py new file mode 100644 index 000000000000..d0f75bdcbd66 --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, +) +from sglang.srt.utils import is_cuda +from sglang.srt.utils.common import get_bool_env_var + +_is_cuda = is_cuda() +_INVERT_MXFP4_MARLIN_SCALES = get_bool_env_var("SGLANG_MXFP4_MARLIN_INVERT_SCALE") +_SKIP_MXFP4_MARLIN_SCALE_TRANSPOSE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_SKIP_SCALE_TRANSPOSE" +) +_SKIP_MXFP4_MARLIN_SCALE_SWIZZLE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_SKIP_SCALE_SWIZZLE" +) +_SKIP_MXFP4_MARLIN_W13_SCALE_TRANSPOSE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_W13_SKIP_SCALE_TRANSPOSE" +) +_SKIP_MXFP4_MARLIN_W13_SCALE_PERMUTE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_W13_SKIP_SCALE_PERMUTE" +) +_SKIP_MXFP4_MARLIN_W13_SCALE_SWIZZLE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_W13_SKIP_SCALE_SWIZZLE" +) +_SKIP_MXFP4_MARLIN_WEIGHT_REPACK_TRANSPOSE = get_bool_env_var( + "SGLANG_MXFP4_MARLIN_SKIP_WEIGHT_REPACK_TRANSPOSE" +) + +if _is_cuda: + from sgl_kernel import gptq_marlin_repack + + +def mxfp4_marlin_process_scales( + marlin_scales: torch.Tensor, + input_dtype: torch.dtype | None = None, + apply_swizzle: bool = True, +) -> torch.Tensor: + if ( + apply_swizzle + and not _SKIP_MXFP4_MARLIN_SCALE_SWIZZLE + and (input_dtype is None or input_dtype.itemsize == 2) + ): + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + if input_dtype == torch.float8_e4m3fn: + marlin_scales = marlin_scales.view(torch.uint8) + assert marlin_scales.max() <= 249 + # exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6 + marlin_scales = marlin_scales + 6 + marlin_scales = marlin_scales.view(torch.float8_e8m0fnu) + return marlin_scales + + +def _normalize_scale_tensor( + scales: torch.Tensor, target_dtype: torch.dtype +) -> torch.Tensor: + # The kernel consumes E8M0 exponents. Regardless of the placeholder dtype + # the loader used, we want the *numerical* value 2**e in ``target_dtype``. + # float32/bfloat16/float16 containers hold the numerical 2**e directly + # (they were filled via a dtype-promoting copy from uint8/e8m0). + # uint8/int8 containers hold the raw E8M0 byte and must be reinterpreted. + if scales.dtype == torch.float8_e8m0fnu: + return scales.to(target_dtype) + if scales.dtype == torch.uint8: + return scales.view(torch.float8_e8m0fnu).to(target_dtype) + if scales.dtype == torch.int8: + return scales.view(torch.uint8).view(torch.float8_e8m0fnu).to(target_dtype) + if scales.dtype in (torch.float32, torch.bfloat16, torch.float16): + return scales.to(target_dtype) + raise TypeError(f"Unsupported MXFP4 scale dtype for Marlin: {scales.dtype}") + + +def prepare_moe_mxfp4_layer_for_marlin(layer: torch.nn.Module) -> None: + group_size = 32 + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + + num_experts = w13.shape[0] + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] * 2 + param_dtype = getattr( + layer, + "orig_dtype", + w13_bias.dtype if w13_bias is not None else torch.bfloat16, + ) + + device = w13.device + layer.workspace = marlin_make_workspace(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + def _repack_weight(weight: torch.Tensor, is_w13: bool) -> torch.Tensor: + if is_w13: + size_n, size_k = intermediate_size * 2, hidden_size + else: + size_n, size_k = hidden_size, intermediate_size + assert weight.shape == (num_experts, size_n, size_k // 2) + + tensor_list = [] + for i in range(num_experts): + qweight = weight[i].view(torch.int32) + expected_packed_k = size_k // (32 // 4) + if ( + not _SKIP_MXFP4_MARLIN_WEIGHT_REPACK_TRANSPOSE + or qweight.size(0) != expected_packed_k + ): + qweight = qweight.T + qweight = qweight.contiguous() + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + tensor_list.append(marlin_qweight) + return torch.stack(tensor_list) + + def _permute_scales(scales: torch.Tensor, is_w13: bool) -> torch.Tensor: + scales = _normalize_scale_tensor(scales, param_dtype) + if _INVERT_MXFP4_MARLIN_SCALES: + scales = torch.reciprocal(scales) + + if is_w13: + size_n, size_k = intermediate_size * 2, hidden_size + else: + size_n, size_k = hidden_size, intermediate_size + + tensor_list = [] + for i in range(num_experts): + scale = scales[i] + skip_transpose = _SKIP_MXFP4_MARLIN_SCALE_TRANSPOSE or ( + is_w13 and _SKIP_MXFP4_MARLIN_W13_SCALE_TRANSPOSE + ) + if not skip_transpose: + scale = scale.T + scale = scale.contiguous() + skip_permute = is_w13 and _SKIP_MXFP4_MARLIN_W13_SCALE_PERMUTE + if skip_permute: + marlin_scales = scale + else: + marlin_scales = marlin_permute_scales( + s=scale, + size_k=size_k, + size_n=size_n, + group_size=group_size, + ) + apply_swizzle = not (is_w13 and _SKIP_MXFP4_MARLIN_W13_SCALE_SWIZZLE) + tensor_list.append( + mxfp4_marlin_process_scales( + marlin_scales, + input_dtype=param_dtype, + apply_swizzle=apply_swizzle, + ) + ) + return torch.stack(tensor_list) + + def _permute_bias(bias: torch.Tensor | None) -> torch.Tensor | None: + if bias is None: + return None + tensor_list = [] + for i in range(num_experts): + tensor_list.append(marlin_permute_bias(bias[i].to(param_dtype))) + return torch.stack(tensor_list) + + w13_marlin = _repack_weight(w13, True) + w2_marlin = _repack_weight(w2, False) + w13_scale_marlin = _permute_scales(w13_scale, True) + w2_scale_marlin = _permute_scales(w2_scale, False) + + layer.w13_weight = torch.nn.Parameter(w13_marlin, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_marlin, requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter(w13_scale_marlin, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter(w2_scale_marlin, requires_grad=False) + + if w13_bias is not None: + layer.w13_bias = torch.nn.Parameter(_permute_bias(w13_bias), requires_grad=False) + if w2_bias is not None: + layer.w2_bias = torch.nn.Parameter(_permute_bias(w2_bias), requires_grad=False) diff --git a/python/sglang/srt/layers/quantization/mxfp4_deepseek.py b/python/sglang/srt/layers/quantization/mxfp4_deepseek.py new file mode 100644 index 000000000000..8705f395eb87 --- /dev/null +++ b/python/sglang/srt/layers/quantization/mxfp4_deepseek.py @@ -0,0 +1,539 @@ + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +import triton +import triton.language as tl +from torch.nn import Module +from torch.nn.parameter import Parameter + +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo +from sglang.srt.layers.moe.utils import MoeRunnerBackend, get_moe_runner_backend +from sglang.srt.layers.moe.utils import RoutingMethodType +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ( + is_flashinfer_available, + is_sm90_supported, + log_info_on_rank0, + set_weight_attrs, +) +from sglang.srt.utils.common import next_power_of_2 + +if is_flashinfer_available(): + from flashinfer import mxfp8_quantize, shuffle_matrix_a, shuffle_matrix_sf_a + from flashinfer.fp4_quantization import block_scale_interleave + from flashinfer.fused_moe import trtllm_fp4_block_scale_routed_moe + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput + + +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker +from sglang.srt.environ import envs +from sglang.srt.utils.common import get_bool_env_var + +_USE_OFFICIAL_SHUFFLE = get_bool_env_var( + "SGLANG_MXFP4_USE_OFFICIAL_SHUFFLE", default="true" +) + + +class PackTopkIds: + + @classmethod + def execute( + cls, topk_ids: torch.Tensor, topk_weights: torch.Tensor + ) -> torch.Tensor: + return cls.triton(topk_ids, topk_weights) + + @classmethod + def vanilla( + cls, topk_ids: torch.Tensor, topk_weights: torch.Tensor + ) -> torch.Tensor: + weight_bits = ( + topk_weights.to(torch.bfloat16).view(torch.int16).to(torch.int32) & 0xFFFF + ) + return (topk_ids.to(torch.int32) << 16) | weight_bits + + @classmethod + def triton(cls, topk_ids: torch.Tensor, topk_weights: torch.Tensor) -> torch.Tensor: + assert ( + topk_ids.shape == topk_weights.shape + ), f"shape mismatch: {topk_ids.shape=} vs {topk_weights.shape=}" + assert topk_ids.ndim >= 1, f"expected >=1D, got {topk_ids.shape=}" + + assert ( + topk_ids.dtype == torch.int32 + ), f"topk_ids must be int32, got {topk_ids.dtype}" + assert ( + topk_weights.dtype == torch.float32 + ), f"topk_weights must be float32, got {topk_weights.dtype}" + + assert topk_ids.is_contiguous(), "topk_ids must be contiguous" + assert topk_weights.is_contiguous(), "topk_weights must be contiguous" + + out = torch.empty_like(topk_ids, dtype=torch.int32) + numel = out.numel() + if numel == 0: + return out + + BLOCK_SIZE = 1024 + grid = (triton.cdiv(numel, BLOCK_SIZE),) + _pack_topk_ids_triton_kernel[grid]( + topk_ids, + topk_weights, + out, + numel, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out + + +@triton.jit +def _pack_topk_ids_triton_kernel( + topk_ids_ptr, + topk_weights_ptr, + out_ptr, + numel, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < numel + + ids = tl.load(topk_ids_ptr + offsets, mask=mask, other=0) + w = tl.load(topk_weights_ptr + offsets, mask=mask, other=0.0) + + w_bf16 = w.to(tl.bfloat16) + w_i16 = w_bf16.to(tl.int16, bitcast=True) + w_i32 = w_i16.to(tl.int32) & 0xFFFF + + ids_i32 = ids.to(tl.int32) + packed = (ids_i32 << 16) | w_i32 + + tl.store(out_ptr + offsets, packed, mask=mask) + + +class DeepSeekMxfp4MoEMethod: + + def __init__(self, fp8_method, prefix: str): + self._fp8 = fp8_method + self.prefix = prefix + self.moe_runner_backend = get_moe_runner_backend() + self.flashinfer_mxfp4_moe_precision = ( + get_global_server_args().flashinfer_mxfp4_moe_precision + ) + + def create_moe_runner(self, layer, moe_runner_config): + self.moe_runner_config = moe_runner_config + if self.moe_runner_backend.is_marlin(): + from sglang.srt.layers.moe.moe_runner import MoeRunner + + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + + swiglu_limit = moe_runner_config.swiglu_limit + is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" + assert is_2604b == (swiglu_limit is not None), ( + f"swiglu_limit must be non-None iff submode=2604B " + f"(got submode={envs.SGLANG_DSV4_2604_SUBMODE.get()!r}, " + f"swiglu_limit={swiglu_limit!r})" + ) + self._gemm1_clamp_limit_tensor = ( + torch.full( + (layer.num_local_experts,), + swiglu_limit, + dtype=torch.float32, + device=layer.w13_weight.device, + ) + if swiglu_limit is not None + else None + ) + + def create_weights( + self, + layer, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + fp4_block_k = 32 + + w13_weight = Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + w2_weight = Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w13_weight_scale.format_ue8m0 = False + w2_weight_scale.format_ue8m0 = False + scale_attrs = dict(extra_weight_attrs) + scale_attrs["quant_method"] = FusedMoeWeightScaleSupported.BLOCK.value + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + set_weight_attrs(w13_weight_scale, scale_attrs) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + set_weight_attrs(w2_weight_scale, scale_attrs) + + def process_weights_after_loading(self, layer: Module) -> None: + from sglang.srt.layers.quantization.utils import reorder_w1w3_to_w3w1 + + self._fp8.process_weights_after_loading(layer) + + if getattr(layer, "_mega_moe_weights_built", False): + return + + if self.moe_runner_backend.is_marlin(): + from sglang.srt.layers.quantization.marlin_utils import ( + check_moe_marlin_supports_layer, + ) + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + + if not is_sm90_supported(): + raise RuntimeError( + "DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above." + ) + if not check_moe_marlin_supports_layer(layer, 32): + raise RuntimeError( + "Current DeepSeekV4 MoE layer does not satisfy Marlin constraints." + ) + + # NOTE: the Marlin MoE runner consumes w13 in the checkpoint's + # native ``[w1; w3]`` order -- see ``silu_and_mul`` in + # fused_marlin_moe.py which expects ``gate = intermediate[:, :N]`` + # (first half) and ``up = intermediate[:, N:]`` (second half). + # Unlike the flashinfer trtllm_fp4 kernel (which wants [w3, w1]), + # we must *not* call ``reorder_w1w3_to_w3w1`` here. + + log_info_on_rank0( + logger, + f"Preparing DeepSeekV4 MXFP4 experts for Marlin backend (layer: {self.prefix})...", + ) + prepare_moe_mxfp4_layer_for_marlin(layer) + layer._dsv4_mxfp4_backend = "marlin" + return + + w13_w, w13_s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale_inv.data + ) + layer.w13_weight = Parameter(w13_w, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_s, requires_grad=False) + + log_info_on_rank0( + logger, + f"Shuffling FP4 expert weights for TRT-LLM MxFP4 kernel " + f"(layer: {self.prefix})...", + ) + + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + num_experts = w13.shape[0] + + if w13_scale.dtype == torch.float32: + w13_scale = w13_scale.to(torch.float8_e8m0fnu) + w2_scale = w2_scale.to(torch.float8_e8m0fnu) + + epilogue_tile_m = 128 + g1_w, g1_s, g2_w, g2_s = [], [], [], [] + if _USE_OFFICIAL_SHUFFLE: + cache: dict = {} + for i in range(num_experts): + w13_u8 = w13[i].view(torch.uint8) + w13_s_u8 = w13_scale[i].view(torch.uint8) + w2_u8 = w2[i].view(torch.uint8) + w2_s_u8 = w2_scale[i].view(torch.uint8) + + perm = _maybe_get_cached_w3_w1_permute_indices( + cache, + w13_u8, + epilogue_tile_m, + ) + g1_w.append(w13_u8[perm.to(w13_u8.device)].contiguous()) + perm_sf = _maybe_get_cached_w3_w1_permute_indices( + cache, + w13_s_u8, + epilogue_tile_m, + num_elts_per_sf=16, + ) + g1_s.append( + block_scale_interleave( + w13_s_u8[perm_sf.to(w13_s_u8.device)].contiguous() + ) + ) + + perm = get_w2_permute_indices_with_cache( + cache, + w2_u8, + epilogue_tile_m, + ) + g2_w.append(w2_u8[perm.to(w2_u8.device)].contiguous()) + perm_sf = get_w2_permute_indices_with_cache( + cache, + w2_s_u8, + epilogue_tile_m, + num_elts_per_sf=16, + ) + g2_s.append( + block_scale_interleave( + w2_s_u8[perm_sf.to(w2_s_u8.device)].contiguous() + ) + ) + else: + for i in range(num_experts): + g1_w.append(shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)) + g1_s.append( + shuffle_matrix_sf_a(w13_scale[i].view(torch.uint8), epilogue_tile_m) + ) + g2_w.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)) + g2_s.append( + shuffle_matrix_sf_a(w2_scale[i].view(torch.uint8), epilogue_tile_m) + ) + + layer.w13_weight = Parameter(torch.stack(g1_w), requires_grad=False) + layer.w13_weight_scale_inv = Parameter( + torch.stack(g1_s) + .view(torch.float8_e4m3fn) + .reshape(num_experts, w13.shape[1], -1), + requires_grad=False, + ) + layer.w2_weight = Parameter(torch.stack(g2_w), requires_grad=False) + layer.w2_weight_scale_inv = Parameter( + torch.stack(g2_s) + .view(torch.float8_e4m3fn) + .reshape(num_experts, w2.shape[1], -1), + requires_grad=False, + ) + + if envs.SGLANG_OPT_MXFP4_STATIC_SCALE_ONES.get(): + self._register_static_scale_ones(layer) + torch.cuda.empty_cache() + + def _register_static_scale_ones(self, layer: Module) -> None: + device = layer.w13_weight.device + for name in ( + "output1_scale_scalar", + "output1_scale_gate_scalar", + "output2_scale_scalar", + ): + layer.register_buffer( + name, + torch.ones(layer.num_local_experts, device=device, dtype=torch.float32), + persistent=False, + ) + + def apply( + self, + layer: Module, + dispatch_output: DispatchOutput, + ) -> CombineInput: + if self.moe_runner_backend.is_marlin(): + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + from sglang.srt.layers.moe.topk import TopKOutputChecker + + topk_output = dispatch_output.topk_output + if not TopKOutputChecker.format_is_standard(topk_output): + raise ValueError(f"Unsupported topk output format: {topk_output.format}") + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale_inv, + w2_scales=layer.w2_weight_scale_inv, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + is_k_full=True, + ) + runner_output = self.runner.run(dispatch_output, quant_info=quant_info) + return StandardCombineInput(hidden_states=runner_output.hidden_states) + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + from sglang.srt.layers.moe.topk import TopKOutputChecker + + hidden_states = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale_inv + w2_scale = layer.w2_weight_scale_inv + + intermediate_size = w2.shape[2] * 2 if w2.dtype == torch.uint8 else w2.shape[2] + hidden_size = w13.shape[2] * 2 if w13.dtype == torch.uint8 else w13.shape[2] + + num_local_experts = layer.num_local_experts + if w13_scale.dim() == 2: + w13_scale = w13_scale.reshape(num_local_experts, 2 * intermediate_size, -1) + if w2_scale.dim() == 2: + w2_scale = w2_scale.reshape(num_local_experts, hidden_size, -1) + + if TopKOutputChecker.format_is_standard(topk_output): + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + elif TopKOutputChecker.format_is_bypassed(topk_output): + raise NotImplementedError( + "the old code in this branch is WRONG. e.g. it does not consider HashTopK, and may miss args" + ) + else: + raise ValueError(f"Unsupported topk output format: {topk_output.format}") + + if not envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get(): + local_expert_offset = layer.moe_ep_rank * layer.num_local_experts + topk_ids = torch.where( + topk_ids >= 0, + topk_ids + local_expert_offset, + topk_ids, + ) + packed_topk = PackTopkIds.execute(topk_ids, topk_weights) + + precision = self.flashinfer_mxfp4_moe_precision + if precision == "bf16": + assert hidden_states.dtype == torch.bfloat16 + x_quant = hidden_states + x_scale = None + origin_dim = x_quant.shape[-1] + if hidden_size != origin_dim: + x_quant = torch.nn.functional.pad( + x_quant, + (0, hidden_size - origin_dim), + mode="constant", + value=0.0, + ) + elif precision == "default": + x_quant, x_scale = mxfp8_quantize( + hidden_states, False, alignment=hidden_size + ) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *hidden_states.shape[:-1], -1 + ) + else: + raise NotImplementedError(f"Unsupported mxfp4 moe precision: {precision}") + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = x_quant.shape[0] + out_hidden_size = ( + x_quant.shape[-1] * 2 + if x_quant.dtype == torch.uint8 + else x_quant.shape[-1] + ) + symm_output = torch.empty( + num_tokens, out_hidden_size, dtype=torch.bfloat16, device=x_quant.device + ) + + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" and ( + self._gemm1_clamp_limit_tensor is not None + ): + deepseek_v4_moe_code_path_checker.observed += 1 + + output = trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_topk, + routing_bias=None, + hidden_states=x_quant, + hidden_states_scale=x_scale, + gemm1_weights=w13, + gemm1_weights_scale=w13_scale, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=self._gemm1_clamp_limit_tensor, + gemm2_weights=w2, + gemm2_weights_scale=w2_scale, + gemm2_bias=None, + output1_scale_scalar=( + layer.output1_scale_scalar + if envs.SGLANG_OPT_MXFP4_STATIC_SCALE_ONES.get() + else torch.ones( + num_local_experts, device=x_quant.device, dtype=torch.float32 + ) + ), + output1_scale_gate_scalar=( + layer.output1_scale_gate_scalar + if envs.SGLANG_OPT_MXFP4_STATIC_SCALE_ONES.get() + else torch.ones( + num_local_experts, device=x_quant.device, dtype=torch.float32 + ) + ), + output2_scale_scalar=( + layer.output2_scale_scalar + if envs.SGLANG_OPT_MXFP4_STATIC_SCALE_ONES.get() + else torch.ones( + num_local_experts, device=x_quant.device, dtype=torch.float32 + ) + ), + num_experts=layer.num_experts, + top_k=packed_topk.shape[1], + n_group=1, + topk_group=1, + intermediate_size=intermediate_size, + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=num_local_experts, + routed_scaling_factor=1.0, + routing_method_type=int(RoutingMethodType.TopK), + do_finalize=True, + tune_max_num_tokens=next_power_of_2(x_quant.shape[0]), + output=symm_output, + )[0] + + if not envs.SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD.get(): + rsf = layer.moe_runner_config.routed_scaling_factor + if rsf is not None and rsf != 1.0: + output.mul_(rsf) + + return StandardCombineInput(hidden_states=output) + diff --git a/python/sglang/srt/layers/topk_capturer_base.py b/python/sglang/srt/layers/topk_capturer_base.py new file mode 100644 index 000000000000..da354451d662 --- /dev/null +++ b/python/sglang/srt/layers/topk_capturer_base.py @@ -0,0 +1,147 @@ +import logging +from typing import Optional + +import torch + +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +logger = logging.getLogger(__name__) + +_GB = 1024 * 1024 * 1024 +_MB = 1024 * 1024 + + +def get_tensor_size_bytes(t: torch.Tensor): + import numpy as np + + return int(np.prod(t.shape)) * t.dtype.itemsize + + +class BaseDeviceCache: + def __init__( + self, max_batch_size: int, num_layers: int, topk_size: int, device: str + ): + self.buffer = torch.zeros( + (max_batch_size, num_layers, topk_size), + dtype=torch.int32, + device=device, + ) + self.num_layers = num_layers + self.topk_size = topk_size + + def capture(self, layer_id: int, topk_indices: torch.Tensor): + batch = topk_indices.shape[0] + topk_dim = min(topk_indices.shape[1], self.topk_size) + self.buffer[:batch, layer_id, :topk_dim] = topk_indices[:, :topk_dim] + + def get_buffer_size_bytes(self): + return get_tensor_size_bytes(self.buffer) + + +class BaseHostCache: + def __init__(self, num_tokens: int, num_layers: int, topk_size: int): + self.buffer = torch.zeros( + (num_tokens, num_layers, topk_size), + dtype=torch.int32, + device="cpu", + pin_memory=True, + ) + self.num_tokens = num_tokens + self.num_layers = num_layers + self.topk_size = topk_size + + def get_buffer_size_bytes(self): + return get_tensor_size_bytes(self.buffer) + + +class BaseTopkCapturer: + def __init__( + self, + num_tokens: int, + max_batch_size: int, + num_layers: int, + topk_size: int, + device: str, + ): + self.num_layers = num_layers + self.topk_size = topk_size + + self.host_cache = BaseHostCache(num_tokens, num_layers, topk_size) + self.device_cache = BaseDeviceCache( + max_batch_size, num_layers, topk_size, device + ) + + def capture(self, layer_id: int, topk_indices: torch.Tensor): + self.device_cache.capture(layer_id, topk_indices) + + def _sync_to_host( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ): + from sglang.srt.layers.dp_attention import ( + get_attention_dp_rank, + get_dp_local_info, + is_dp_attention_enabled, + ) + + if is_dp_attention_enabled(): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + if can_run_graph: + local_start_pos = get_attention_dp_rank() * cuda_graph_batch + local_end_pos = local_start_pos + local_num_tokens + else: + local_end_pos = local_start_pos + local_num_tokens + else: + local_start_pos = 0 + local_end_pos = forward_batch.out_cache_loc.shape[0] + + out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() + self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ + local_start_pos:local_end_pos, :, : self.topk_size + ].cpu() + + def get_topk( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ) -> torch.Tensor: + cache_pool_idx = ( + req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() + ) + return self.host_cache.buffer[cache_pool_idx] + + def on_forward_end( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ): + self._sync_to_host(forward_batch, can_run_graph, cuda_graph_batch) + + def is_enabled(self) -> bool: + return True + + +class BaseTopkCapturerNoop: + def capture(self, layer_id: int, topk_indices: torch.Tensor): + pass + + def get_topk( + self, req_pool_idx: int, seqlen: int, req_to_token_pool: ReqToTokenPool + ): + return None + + def on_forward_end( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ): + pass + + def is_enabled(self) -> bool: + return False diff --git a/python/sglang/srt/layers/utils/common.py b/python/sglang/srt/layers/utils/common.py index e88f3a938ad1..73d3b2c835c9 100644 --- a/python/sglang/srt/layers/utils/common.py +++ b/python/sglang/srt/layers/utils/common.py @@ -37,6 +37,21 @@ def pad_or_narrow_weight( ) +def is_strict_contiguous(x: torch.Tensor) -> bool: + expected_stride = 1 + for size, stride in zip(reversed(x.shape), reversed(x.stride())): + if stride != expected_stride: + return False + expected_stride *= size + return True + + +def strict_contiguous(x: torch.Tensor) -> torch.Tensor: + if is_strict_contiguous(x): + return x + return x.clone(memory_format=torch.contiguous_format) + + class PPMissingLayer(torch.nn.Identity): # Adapted from # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1 diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index a65b0dd28b2a..17a305e45c01 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -342,6 +342,18 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): return output_strs + def _extract_topk_base64(self, data_list) -> List[List[int]]: + if data_list is None: + return None + return [ + ( + pybase64.b64encode(item.numpy().tobytes()).decode("utf-8") + if item is not None + else [] + ) + for item in data_list + ] + def _extract_routed_experts( self, recv_obj: BatchTokenIDOutput ) -> list[str | None] | None: @@ -364,7 +376,8 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): if len(recv_obj.rids) > 0 else [] ) - routed_experts = self._extract_routed_experts(recv_obj) + routed_experts = self._extract_topk_base64(recv_obj.routed_experts) + indexer_topk = self._extract_topk_base64(recv_obj.indexer_topk) return BatchStrOutput( rids=recv_obj.rids, @@ -391,6 +404,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, + indexer_topk=indexer_topk, routed_experts=routed_experts, customized_info=recv_obj.customized_info, placeholder_tokens_idx=None, diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py new file mode 100644 index 000000000000..92c3693ffd28 --- /dev/null +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -0,0 +1,449 @@ + +import logging +from typing import List, NamedTuple + +import torch + +from sglang.srt.managers.schedule_batch import Req +from sglang.srt.mem_cache.hisparse_memory_pool import ( + DeepSeekV4SingleKVPoolHost, + HiSparseTokenToKVPoolAllocator, +) +from sglang.srt.utils import get_device_module + +device_module = get_device_module() + +from sglang.jit_kernel.hisparse import load_cache_to_device_buffer_mla +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + +logger = logging.getLogger(__name__) + + +class HiSparseAct(NamedTuple): + start_event: device_module.Event + finish_event: device_module.Event + req: Req + + +class HiSparseCoordinator: + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: HiSparseTokenToKVPoolAllocator, + top_k: int, + device_buffer_size: int, + device: str, + tp_group, + ): + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.top_k = top_k + self.device_buffer_size = device_buffer_size + self.device = device + self.compress_ratio = self.token_to_kv_pool_allocator.compress_ratio + + self.mem_pool_device = self.token_to_kv_pool_allocator.hisparse_kvcache + host_size = self.token_to_kv_pool_allocator.size_full // self.compress_ratio + self.mem_pool_host = DeepSeekV4SingleKVPoolHost( + self.mem_pool_device, host_size, 1 + ) + self.item_size_bytes = ( + self.mem_pool_host.kv_cache_total_dim * self.mem_pool_host.dtype.itemsize + ) + + max_num_reqs = req_to_token_pool.size + max_context_len = req_to_token_pool.max_context_len + + self.padded_buffer_size = ( + self.device_buffer_size + self.mem_pool_device.page_size + ) + + self.req_to_device_buffer = torch.zeros( + (max_num_reqs, self.padded_buffer_size), dtype=torch.int64, device=device + ) + self.req_device_buffer_size = torch.zeros( + max_num_reqs, dtype=torch.int64, device="cpu" + ) + self.req_to_host_pool = torch.full( + (max_num_reqs, max_context_len // self.compress_ratio), + -1, + dtype=torch.int64, + device=device, + ) + + self.write_staging_stream = device_module.Stream() + self.decode_backup_stream = device_module.Stream() + self.ack_staging_queue: List[HiSparseAct] = [] + self.decode_producer_stream = None + self._backup_done_event = device_module.Event() + self._has_pending_backup = False + + self.tp_group = tp_group + self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) + + layer_num = self.mem_pool_device.layer_num + self.req_device_buffer_tokens = torch.full( + (layer_num, max_num_reqs, self.padded_buffer_size), + -1, + dtype=torch.int32, + device=device, + ) + self.req_device_buffer_token_locs = torch.full( + (layer_num, max_num_reqs, self.padded_buffer_size), + -1, + dtype=torch.int32, + device=device, + ) + self._lru_init = torch.arange( + self.device_buffer_size, dtype=torch.int16, device=device + ) + self.lru_slots = ( + self._lru_init.view(1, 1, -1) + .repeat(layer_num, max_num_reqs, 1) + .contiguous() + ) + self._device_buffer_arange_i32 = torch.arange( + self.device_buffer_size, dtype=torch.int32, device=device + ) + self._top_k_arange = torch.arange(self.top_k, device=device).unsqueeze(0) + + self.top_k_device_locs_buffer = torch.full( + (max_num_reqs, self.top_k), -1, dtype=torch.int32, device=device + ) + self.raw_indices_buffer = torch.full( + (max_num_reqs, self.top_k), -1, dtype=torch.int32, device=device + ) + self.num_real_reqs = torch.zeros(1, dtype=torch.int32, device=device) + + self._skip_first_backup = [False] * max_num_reqs + + def set_decode_producer_stream(self, stream) -> None: + self.decode_producer_stream = stream + + def admit_request_into_staging(self, req: Req) -> None: + req.hisparse_staging = True + + full_kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.fill_ids) + ].to(dtype=torch.int64, copy=True) + device_indices = ( + self.mem_pool_device.translate_loc_from_full_to_hisparse_device( + full_kv_indices + ) + ) + + prefill_len = len(device_indices) + host_indices = self.mem_pool_host.alloc(prefill_len) + if host_indices is None: + logger.error( + "HiSparse: host mem pool alloc failed for %d tokens (req %s)", + prefill_len, + req.rid, + ) + raise RuntimeError( + f"HiSparse host mem pool alloc failed for {prefill_len} tokens" + ) + host_indices = host_indices.to(device=self.device) + self.req_to_host_pool[req.req_pool_idx, :prefill_len] = host_indices + + start_event = device_module.Event() + finish_event = device_module.Event() + start_event.record() + with device_module.stream(self.write_staging_stream): + start_event.wait(self.write_staging_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, + host_indices, + device_indices, + ) + finish_event.record() + if host_indices.is_cuda: + host_indices.record_stream(self.write_staging_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.write_staging_stream) + + self.ack_staging_queue.append(HiSparseAct(start_event, finish_event, req)) + + def alloc_device_buffer(self, req: Req) -> None: + prefill_len = len(req.fill_ids) + compressed_logical_indices = ( + self.mem_pool_device.translate_loc_from_full_to_compressed( + self.req_to_token_pool.req_to_token[req.req_pool_idx, :prefill_len] + ) + ) + compressed_len = len(compressed_logical_indices) + alloc_size = self.padded_buffer_size + + buffer_indices = self.token_to_kv_pool_allocator.alloc_device_buffer( + compressed_logical_indices, alloc_size + ) + if buffer_indices is None: + logger.error( + "HiSparse: alloc_device_buffer failed for req %s " + "(compressed_len=%d, alloc_size=%d)", + req.rid, + compressed_len, + alloc_size, + ) + raise RuntimeError("HiSparse alloc_device_buffer returned None") + + buffer_indices = buffer_indices.to(torch.int32) + self.req_to_device_buffer[req.req_pool_idx, :alloc_size] = buffer_indices + self.req_device_buffer_size[req.req_pool_idx] = alloc_size + + self.req_device_buffer_tokens[ + :, req.req_pool_idx, : self.device_buffer_size + ] = self._device_buffer_arange_i32 + self.req_device_buffer_token_locs[:, req.req_pool_idx, :alloc_size] = ( + buffer_indices[:alloc_size] + ) + + def has_ongoing_staging(self) -> bool: + return len(self.ack_staging_queue) > 0 + + def collect_ready_reqs(self) -> List[Req]: + ready_reqs: List[Req] = [] + if len(self.ack_staging_queue) == 0: + return ready_reqs + + finish_count = 0 + for _, finish_event, _ in self.ack_staging_queue: + if not finish_event.query(): + break + finish_count += 1 + queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu") + if self.tp_world_size > 1: + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + finish_count = int(queue_size.item()) + while finish_count > 0: + _, _, req = self.ack_staging_queue.pop(0) + self.alloc_device_buffer(req) + self._skip_first_backup[req.req_pool_idx] = True + req.hisparse_staging = False + finish_count -= 1 + ready_reqs.append(req) + return ready_reqs + + def map_last_loc_to_buffer( + self, + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens_cpu: torch.Tensor, + ) -> None: + req_pool_indices_cpu = req_pool_indices.cpu() + + self._eager_backup_previous_token( + seq_lens, req_pool_indices, seq_lens_cpu, req_pool_indices_cpu + ) + + active_reqs = seq_lens % self.compress_ratio == 0 + if not torch.any(active_reqs): + return + + active_seq_lens = seq_lens[active_reqs] + active_out_cache_loc = out_cache_loc[active_reqs] + active_req_pool_indices = req_pool_indices[active_reqs] + + compressed_seq_lens = active_seq_lens // self.compress_ratio + reserved_positions = (compressed_seq_lens - 1).clamp( + max=self.device_buffer_size + ) + reserved_buffer_loc = self.req_to_device_buffer[ + active_req_pool_indices, reserved_positions + ] + + self.req_device_buffer_token_locs[ + :, active_req_pool_indices, self.device_buffer_size + ] = reserved_buffer_loc.to(torch.int32) + + compressed_locs = self.token_to_kv_pool_allocator.get_last_loc_compressed( + active_out_cache_loc + ) + self.mem_pool_device.full_to_hisparse_device_index_mapping[compressed_locs] = ( + reserved_buffer_loc + ) + + def _eager_backup_previous_token( + self, + seq_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens_cpu: torch.Tensor, + req_pool_indices_cpu: torch.Tensor, + ) -> None: + backup_indices = [] + for i in range(len(seq_lens_cpu)): + req_idx = int(req_pool_indices_cpu[i]) + if self._skip_first_backup[req_idx]: + self._skip_first_backup[req_idx] = False + continue + if (int(seq_lens_cpu[i]) - 1) % self.compress_ratio == 0: + backup_indices.append(i) + + if not backup_indices: + return + + backup_indices_gpu = torch.tensor( + backup_indices, dtype=torch.int64, device=self.device + ) + backup_req_indices = req_pool_indices[backup_indices_gpu] + + prev_seq_lens = seq_lens[backup_indices_gpu] - 1 + compressed_prev_seq_lens = prev_seq_lens // self.compress_ratio + actual_compressed_pos = compressed_prev_seq_lens - 1 + + buffer_slot = actual_compressed_pos.clamp(max=self.device_buffer_size) + + device_locs = self.req_to_device_buffer[backup_req_indices, buffer_slot] + + host_locs = self.mem_pool_host.alloc(len(device_locs)) + if host_locs is None: + logger.error( + "HiSparse: host mem pool alloc failed for %d decode backup tokens", + len(device_locs), + ) + raise RuntimeError( + f"HiSparse host mem pool alloc failed for {len(device_locs)} decode backup tokens" + ) + host_locs = host_locs.to(device=self.device) + self.req_to_host_pool[backup_req_indices, actual_compressed_pos] = host_locs + + self.wait_for_pending_backup() + schedule_stream = device_module.current_stream() + with device_module.stream(self.decode_backup_stream): + self.decode_backup_stream.wait_stream(schedule_stream) + if self.decode_producer_stream is not None: + self.decode_backup_stream.wait_stream(self.decode_producer_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, + host_locs, + device_locs, + ) + self._backup_done_event.record() + if host_locs.is_cuda: + host_locs.record_stream(self.decode_backup_stream) + if backup_req_indices.is_cuda: + backup_req_indices.record_stream(self.decode_backup_stream) + if actual_compressed_pos.is_cuda: + actual_compressed_pos.record_stream(self.decode_backup_stream) + if device_locs.is_cuda: + device_locs.record_stream(self.decode_backup_stream) + self._has_pending_backup = True + + def wait_for_pending_backup(self) -> None: + if not self._has_pending_backup: + return + self._backup_done_event.wait(device_module.current_stream()) + self._has_pending_backup = False + + def get_front_topk_tokens( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + ) -> torch.Tensor: + compressed_seq_lens = seq_lens // self.compress_ratio + top_k_indices = self.req_to_device_buffer[req_pool_indices, : self.top_k].to( + torch.int32 + ) + mask = self._top_k_arange >= compressed_seq_lens.unsqueeze(1) + top_k_indices[mask] = -1 + return top_k_indices + + def abort_staging_request(self, req: Req) -> None: + self.ack_staging_queue = [ + act for act in self.ack_staging_queue if act.req is not req + ] + self.write_staging_stream.synchronize() + + prefill_len = len(req.fill_ids) + allocated_locs = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, :prefill_len + ] + self.token_to_kv_pool_allocator.free_hisparse(allocated_locs) + + compressed_len = prefill_len // self.compress_ratio + host_indices = self.req_to_host_pool[req.req_pool_idx, :compressed_len] + host_indices = host_indices[host_indices >= 0] + if host_indices.numel() > 0: + self.mem_pool_host.free(host_indices) + self.req_to_host_pool[req.req_pool_idx, :] = -1 + self._skip_first_backup[req.req_pool_idx] = False + req.hisparse_staging = False + + def retract_req(self, req: Req) -> None: + if req.hisparse_staging: + self.abort_staging_request(req) + else: + self.request_finished(req) + + def request_finished(self, req: Req): + if self.decode_producer_stream is not None: + device_module.current_stream().wait_stream(self.decode_producer_stream) + self.wait_for_pending_backup() + + compressed_len = req.seqlen // self.compress_ratio + + current_cap = int(self.req_device_buffer_size[req.req_pool_idx]) + if current_cap > 0: + side_buf_hi = self.req_to_device_buffer[req.req_pool_idx, :current_cap] + all_hi = torch.unique(side_buf_hi[side_buf_hi > 0]) + if all_hi.numel() > 0: + self.token_to_kv_pool_allocator.free_hisparse_indices(all_hi) + + allocated_locs = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : req.seqlen + ] + compressed_locs = self.mem_pool_device.translate_loc_from_full_to_compressed( + allocated_locs + ) + self.mem_pool_device.full_to_hisparse_device_index_mapping[compressed_locs] = 0 + + host_indices = self.req_to_host_pool[req.req_pool_idx, :compressed_len] + host_indices = host_indices[host_indices >= 0] + if host_indices.numel() > 0: + self.mem_pool_host.free(host_indices) + + self.req_device_buffer_tokens[:, req.req_pool_idx, :] = -1 + self.req_device_buffer_token_locs[:, req.req_pool_idx, :] = -1 + self.req_to_device_buffer[req.req_pool_idx, :] = 0 + self.req_device_buffer_size[req.req_pool_idx] = 0 + self.req_to_host_pool[req.req_pool_idx, :] = -1 + self.lru_slots[:, req.req_pool_idx, :].copy_(self._lru_init) + self._skip_first_backup[req.req_pool_idx] = False + + def swap_in_selected_pages( + self, + req_pool_indices: torch.Tensor, + compressed_seq_lens: torch.Tensor, + top_k_result: torch.Tensor, + layer_id: int, + ) -> torch.Tensor: + num_reqs = req_pool_indices.size(0) + + top_k_indices = self.top_k_device_locs_buffer[:num_reqs] + top_k_indices.fill_(-1) + + block_size = 1024 + load_cache_to_device_buffer_mla( + top_k_tokens=top_k_result, + device_buffer_tokens=self.req_device_buffer_tokens[layer_id], + host_cache_locs=self.req_to_host_pool, + device_buffer_locs=self.req_device_buffer_token_locs[layer_id], + host_cache=self.mem_pool_host.kv_buffer[layer_id], + device_buffer=self.mem_pool_device.kv_buffer[layer_id], + top_k_device_locs=top_k_indices, + req_pool_indices=req_pool_indices, + seq_lens=compressed_seq_lens, + lru_slots=self.lru_slots[layer_id], + item_size_bytes=self.item_size_bytes, + num_top_k=self.top_k, + hot_buffer_size=self.device_buffer_size, + page_size=1, + block_size=block_size, + num_real_reqs=self.num_real_reqs, + ) + return top_k_indices diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index fad02e0a0112..924ac9db316f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -202,6 +202,7 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): return_hidden_states: Union[List[bool], bool] = False # Whether to return captured routed experts return_routed_experts: bool = False + return_indexer_topk: bool = False # The start location in the prompt for returning routed experts. routed_experts_start_len: int = 0 @@ -639,6 +640,7 @@ def __getitem__(self, i): else self.return_hidden_states ), return_routed_experts=self.return_routed_experts, + return_indexer_topk=self.return_indexer_topk, modalities=self.modalities[i] if self.modalities else None, session_params=self.session_params, lora_path=self.lora_path[i] if self.lora_path is not None else None, @@ -714,6 +716,8 @@ class TokenizedGenerateReqInput(BaseReq): # The start location in the prompt for returning routed experts. routed_experts_start_len: int = 0 + return_indexer_topk: bool = False + # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None @@ -989,6 +993,8 @@ class BatchTokenIDOutput( # routed_experts[i] is a tensor of shape (token, layer, top_k) for request i routed_experts: List[Optional[torch.Tensor]] + indexer_topk: List[torch.Tensor] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. @@ -1077,6 +1083,8 @@ class BatchStrOutput( # routed_experts[i] is a tensor of shape (token, layer, top_k) for request i routed_experts: List[Optional[torch.Tensor]] + indexer_topk: List[List[int]] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 4f2e3fb19197..b0e66948407d 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -281,6 +281,9 @@ def _handle_output_by_index(output, i): routed_experts=_extract_field_by_index( output, "routed_experts", i, check_length=False ), + indexer_topk=_extract_field_by_index( + output, "indexer_topk", i, check_length=False + ), customized_info=_extract_field_by_index( output, "customized_info", i, check_length=False ), diff --git a/python/sglang/srt/managers/prefill_delayer.py b/python/sglang/srt/managers/prefill_delayer.py index 8df34fe8ee6f..0104b7fd0a13 100644 --- a/python/sglang/srt/managers/prefill_delayer.py +++ b/python/sglang/srt/managers/prefill_delayer.py @@ -32,6 +32,9 @@ class _NegotiateOutput(NamedTuple): output_reason: str num_prefillable: int num_token_watermark_force_allow: int + # Debug info: prev_state captured at decision time (wait_success/wait_timeout + # zero out next_state, so we keep prev_state separately to inspect timing). + debug_prev_state: Optional[_State] = None class PrefillDelayer: @@ -128,6 +131,7 @@ def _negotiate_should_allow_prefill_pure( next_state=None, output_allow=True, output_reason="wait_success" if exist_previous_wait else "no_wait", + debug_prev_state=prev_state, **debug_info, ) elif prefillable_status == "none": @@ -162,6 +166,7 @@ def _negotiate_should_allow_prefill_pure( next_state=None, output_allow=True, output_reason="wait_timeout", + debug_prev_state=prev_state, **debug_info, ) else: @@ -218,11 +223,25 @@ def _record_single_pass_result( output: _NegotiateOutput, metrics_collector: Optional["SchedulerMetricsCollector"], ) -> None: + # Compute waited time/passes (independent of metrics_collector path so DEBUG_LOG can print them). + # next_state captures in-progress delay; debug_prev_state captures completed wait_success/wait_timeout. + if (_dbg_s := output.next_state) is not None: + _dbg_wait_seconds = time.perf_counter() - _dbg_s.start_time + _dbg_forward_passes = _dbg_s.delayed_count + elif (_dbg_s := output.debug_prev_state) is not None: + _dbg_wait_seconds = time.perf_counter() - _dbg_s.start_time + _dbg_forward_passes = _dbg_s.delayed_count + else: + _dbg_wait_seconds = _dbg_forward_passes = 0 + if _DEBUG_LOG: if output.output_allow and (output.output_reason == "wait_timeout"): logger.info( f"PrefillDelayer timeout thus not forbid prefill " f"(num_prefillable={output.num_prefillable}, " + f"input_estimation={output.input_estimation}, " + f"forward_passes={_dbg_forward_passes}, " + f"wait_seconds={_dbg_wait_seconds:.4f}, " f"actual_execution={actual_execution})" ) elif output.output_allow and (output.output_reason == "token_watermark"): @@ -232,6 +251,33 @@ def _record_single_pass_result( f"num_token_watermark_force_allow={output.num_token_watermark_force_allow}, " f"actual_execution={actual_execution})" ) + elif output.output_allow and (output.output_reason == "wait_success"): + logger.info( + f"PrefillDelayer wait_success: prefill allowed after delay " + f"(num_prefillable={output.num_prefillable}, " + f"input_estimation={output.input_estimation}, " + f"forward_passes={_dbg_forward_passes}, " + f"wait_seconds={_dbg_wait_seconds:.4f}, " + f"actual_execution={actual_execution})" + ) + elif output.output_allow and (output.output_reason == "no_wait"): + logger.info( + f"PrefillDelayer no_wait: prefill allowed immediately " + f"(num_prefillable={output.num_prefillable}, " + f"input_estimation={output.input_estimation}, " + f"actual_execution={actual_execution})" + ) + elif (not output.output_allow) and (output.output_reason == "delay"): + logger.info( + f"PrefillDelayer delay: prefill blocked this pass " + f"(num_prefillable={output.num_prefillable}, " + f"input_estimation={output.input_estimation}, " + f"forward_passes={_dbg_forward_passes}, " + f"wait_seconds={_dbg_wait_seconds:.4f})" + ) + elif output.output_reason == "": + # prefillable_status=='none' branch — silenced (one per scheduler pass × 8 ranks → log explosion) + pass else: assert output.output_reason in { "", @@ -241,14 +287,9 @@ def _record_single_pass_result( } if metrics_collector is not None: - if (s := output.next_state) is not None: - wait_seconds = time.perf_counter() - s.start_time - forward_passes = s.delayed_count - else: - wait_seconds = forward_passes = 0 metrics_collector.observe_prefill_delayer_outcome( - forward_passes=forward_passes, - wait_seconds=wait_seconds, + forward_passes=_dbg_forward_passes, + wait_seconds=_dbg_wait_seconds, input_estimation=output.input_estimation, output_allow=output.output_allow, output_reason=output.output_reason, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9681a1d70dcf..674f27f74e98 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -90,6 +90,7 @@ from typing import Any, Dict from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm @@ -511,6 +512,7 @@ def __init__( require_reasoning: bool = False, return_hidden_states: bool = False, return_routed_experts: bool = False, + return_indexer_topk: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, @@ -643,6 +645,8 @@ def __init__( self.host_hit_length = 0 # The node to lock until for swa radix tree lock ref self.swa_uuid_for_lock: Optional[int] = None + # Whether the prefill-time SWA tree lock has been released early + self.swa_prefix_lock_released: bool = False # The prefix length that is inserted into the tree cache self.cache_protected_len: int = 0 @@ -716,6 +720,11 @@ def __init__( self.routed_experts: Optional[torch.Tensor] = ( None # cpu tensor: shape (seqlen, topk) ) + + self.return_indexer_topk = return_indexer_topk + self.indexer_topk: Optional[torch.Tensor] = ( + None # cpu tensor: shape (seqlen, num_indexer_layers, index_topk) + ) # Customized info self.customized_info: Optional[Dict[str, List[Any]]] = None @@ -779,6 +788,8 @@ def __init__( self.dllm_block_offset = 0 self.dllm_config = dllm_config + self.hisparse_staging = False + @property def seqlen(self) -> int: """Get the current sequence length of the request.""" @@ -1075,8 +1086,10 @@ def reset_for_retract(self): self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.routed_experts = None + self.indexer_topk = None self.last_node = None self.swa_uuid_for_lock = None + self.swa_prefix_lock_released = False self.extend_input_len = 0 self.customized_info = None self.is_retracted = True @@ -1096,6 +1109,9 @@ def reset_for_retract(self): self.kv_committed_len = 0 self.kv_committed_freed = False self.kv_overallocated_freed = False + self.swa_evicted_seqlen = 0 + self.extend_batch_idx = 0 + self.decode_batch_idx = 0 def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): token_indices = req_to_token_pool.req_to_token[ @@ -1332,6 +1348,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return captured experts return_routed_experts: bool = False + return_indexer_topk: bool = False + # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False @@ -1345,6 +1363,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Metrics dp_cooperation_info: Optional[DPCooperationInfo] = None + hisparse_coordinator: Optional[HiSparseCoordinator] = None + @classmethod def init_new( cls, @@ -1365,7 +1385,7 @@ def init_new( if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): is_hybrid_swa = True - return cls( + batch = cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool_allocator=token_to_kv_pool_allocator, @@ -1380,11 +1400,13 @@ def init_new( spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), return_routed_experts=any(req.return_routed_experts for req in reqs), + return_indexer_topk=any(req.return_indexer_topk for req in reqs), is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, dllm_staging_reqs=dllm_staging_reqs, dllm_config=dllm_config, ) + return batch def batch_size(self): return len(self.reqs) @@ -1809,6 +1831,9 @@ def new_tokens_required_next_decode( new_pages = sum(1 for r in requests if r.kv_committed_len % page_size == 0) return new_pages * page_size + if self.is_spec_v2: + return self._new_tokens_required_next_decode_spec_v2(requests, page_size) + server_args = get_global_server_args() len_per_topk = server_args.speculative_num_steps or 1 spec_topk = server_args.speculative_eagle_topk or 1 @@ -1824,9 +1849,20 @@ def new_tokens_required_next_decode( spec_tokens = ceil_align(spec_tokens, page_size) num_tokens = max(len_per_topk * spec_topk, spec_tokens) * len(requests) + return num_tokens + + def _new_tokens_required_next_decode_spec_v2(self, requests, page_size): + """Tight estimate matching eagle_info_v2.prepare_for_decode allocation.""" + from sglang.srt.managers.utils import get_alloc_len_per_decode - # v2 eagle has over-allocation - return num_tokens * (1 + self.is_spec_v2) + alloc_len = get_alloc_len_per_decode() + total = 0 + for r in requests: + x = max(0, r.kv_committed_len + 2 * alloc_len - r.kv_allocated_len) + cur = r.kv_allocated_len + nxt = cur + x + total += ceil_align(nxt, page_size) - ceil_align(cur, page_size) + return total def check_decode_mem(self, selected_indices: Optional[List[int]] = None): num_tokens = self.new_tokens_required_next_decode(selected_indices) @@ -1904,6 +1940,9 @@ def retract_decode( def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): req = self.reqs[idx] + if self.hisparse_coordinator is not None and not req.finished(): + self.hisparse_coordinator.retract_req(req) + if server_args.disaggregation_mode == "decode": req.offload_kv_cache( self.req_to_token_pool, self.token_to_kv_pool_allocator @@ -2008,6 +2047,14 @@ def prepare_for_decode(self): self.orig_seq_lens.add_(1) self.seq_lens_sum += bs + if self.hisparse_coordinator is not None: + self.hisparse_coordinator.map_last_loc_to_buffer( + self.seq_lens, + self.out_cache_loc, + self.req_pool_indices, + self.seq_lens_cpu, + ) + if get_global_server_args().enable_mamba_extra_buffer(): self.mamba_track_indices = torch.tensor( [ @@ -2264,6 +2311,10 @@ def copy(self): def maybe_evict_swa(self): if self.tree_cache.supports_swa(): sliding_window_size = self.tree_cache.sliding_window_size + release_leaf_lock = ( + envs.SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW.get() + and hasattr(self.tree_cache, "dec_swa_lock_only") + ) for idx, req in enumerate(self.reqs): if self.forward_mode.is_decode(): # We set evict_swa condition here with two reasons: @@ -2271,6 +2322,22 @@ def maybe_evict_swa(self): # 2. Evict swa every window_size tokens to reduce the overhead. if req.decode_batch_idx % sliding_window_size == 1: self._evict_swa(req, req.seqlen - 1) + + # Once the decode position has moved past the sliding window, + # the SWA portion of the prefill-time tree lock is no longer + # needed by this request. Convert it from protected to + # evictable so SWA LRU can reclaim it under pressure. + if ( + release_leaf_lock + and not req.swa_prefix_lock_released + and req.swa_uuid_for_lock is not None + and req.last_node is not None + and req.decode_batch_idx >= sliding_window_size + ): + self.tree_cache.dec_swa_lock_only( + req.last_node, req.swa_uuid_for_lock + ) + req.swa_prefix_lock_released = True elif self.forward_mode.is_extend() and self.tree_cache.is_chunk_cache(): pre_len = self.prefix_lens[idx] if self.enable_overlap: @@ -2298,8 +2365,13 @@ def _evict_swa(self, req: Req, pre_len: int): ), "cache_protected_len must be page aligned" req.swa_evicted_seqlen = max(req.swa_evicted_seqlen, req.cache_protected_len) + if envs.SGLANG_OPT_SWA_EVICT_DROP_PAGE_MARGIN.get(): + evict_threshold = pre_len - sliding_window_size + else: + evict_threshold = pre_len - sliding_window_size - self.tree_cache.page_size new_swa_evicted_seqlen = max( - req.swa_evicted_seqlen, pre_len - sliding_window_size + req.swa_evicted_seqlen, + evict_threshold, ) if self.tree_cache.page_size > 1: diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 65ffc7198a0d..62223abd608e 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -40,6 +40,7 @@ InsertParams, MatchPrefixParams, ) +from sglang.srt.mem_cache.hisparse_memory_pool import HiSparseTokenToKVPoolAllocator from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator from sglang.srt.server_args import ServerArgs @@ -418,11 +419,16 @@ def __init__( ] ) + # Hisparse wraps an SWATokenToKVPoolAllocator internally and exposes + # the full SWA allocator interface. self.is_hybrid_swa = isinstance( - self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator + self.token_to_kv_pool_allocator, + (SWATokenToKVPoolAllocator, HiSparseTokenToKVPoolAllocator), ) self.is_hybrid_ssm_cache = self.tree_cache.supports_mamba() + self.rem_swa_token_offset = 0 + self.priority_scheduling_preemption_threshold = ( priority_scheduling_preemption_threshold ) @@ -449,11 +455,9 @@ def _get_running_request_total_token_offset(self, req: Req) -> int: @property def rem_total_tokens(self): if self.is_hybrid_swa: - available_and_evictable = min( + available_and_evictable = ( self.token_to_kv_pool_allocator.full_available_size() - + self.tree_cache.full_evictable_size(), - self.token_to_kv_pool_allocator.swa_available_size() - + self.tree_cache.swa_evictable_size(), + + self.tree_cache.full_evictable_size() ) elif self.is_hybrid_ssm_cache: available_and_evictable = ( @@ -467,14 +471,20 @@ def rem_total_tokens(self): ) return available_and_evictable - self.rem_total_token_offset + @property + def rem_swa_tokens(self): + return ( + self.token_to_kv_pool_allocator.swa_available_size() + + self.tree_cache.swa_evictable_size() + - self.rem_swa_token_offset + ) + @property def cur_rem_tokens(self): if self.is_hybrid_swa: - available_and_evictable = min( + available_and_evictable = ( self.token_to_kv_pool_allocator.full_available_size() - + self.tree_cache.full_evictable_size(), - self.token_to_kv_pool_allocator.swa_available_size() - + self.tree_cache.swa_evictable_size(), + + self.tree_cache.full_evictable_size() ) elif self.is_hybrid_ssm_cache: available_and_evictable = ( @@ -489,11 +499,21 @@ def cur_rem_tokens(self): return available_and_evictable - self.cur_rem_token_offset + def _swa_budget_for_req(self, extend_input_len: int) -> int: + if self.rem_chunk_tokens is not None: + alloc = min(extend_input_len, self.rem_chunk_tokens) + else: + alloc = extend_input_len + return max(alloc, self.tree_cache.sliding_window_size) + self.page_size + def ceil_paged_tokens(self, tokens: int) -> int: return -(-tokens // self.page_size) * self.page_size def budget_state(self): - if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: + no_token = self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0 + if not no_token and self.is_hybrid_swa: + no_token = self.rem_swa_tokens <= 0 + if no_token: return AddReqResult.NO_TOKEN if self.rem_input_tokens <= 0: @@ -518,6 +538,9 @@ def _update_prefill_budget( self.cur_rem_token_offset += extend_input_len self.rem_input_tokens -= extend_input_len + if self.is_hybrid_swa: + self.rem_swa_token_offset += self._swa_budget_for_req(extend_input_len) + if self.dllm_config is not None: self.rem_dllm_tokens -= extend_input_len elif self.rem_chunk_tokens is not None: @@ -567,9 +590,14 @@ def add_chunked_req(self, req: Req): _rem_tokens = self._get_dllm_remain_tokens() else: _rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens)) - # The chunked_req must be added to the list; otherwise, it will cause a memory leak. - # Therefore, in certain cases where _rem_tokens <= 0, it should be replaced with rem_chunk_tokens. + if self.is_hybrid_swa: + _rem_tokens = min( + _rem_tokens, int(self.rem_swa_tokens) - self.page_size + ) if _rem_tokens <= 0: + if self.is_hybrid_swa: + # skip to avoid alloc_extend OOM + return req _rem_tokens = self.rem_chunk_tokens truncated = req.extend_input_len > _rem_tokens @@ -604,11 +632,12 @@ def _lock_node(self, last_node: TreeNode): self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req): - # Early exit if no enough tokens for the input tokens - if self.ceil_paged_tokens(req.extend_input_len) > min( - self.cur_rem_tokens, self.rem_total_tokens - ): + paged_input = self.ceil_paged_tokens(req.extend_input_len) + if paged_input > min(self.cur_rem_tokens, self.rem_total_tokens): return AddReqResult.NO_TOKEN + if self.is_hybrid_swa: + if self._swa_budget_for_req(req.extend_input_len) > self.rem_swa_tokens: + return AddReqResult.NO_TOKEN def add_req_state(r, insert_sort=False): new_token_ratio = ( @@ -659,6 +688,13 @@ def add_req_state(r, insert_sort=False): return AddReqResult.NO_TOKEN tokens_freed += tokens_occupied + if (self.prefill_delayer_single_pass is not None) and ( + not self.prefill_delayer_single_pass.negotiate_should_allow_prefill( + local_prefillable=True + ) + ): + return AddReqResult.OTHER + if self.dllm_config is not None: if self.rem_dllm_tokens <= 0: return AddReqResult.OTHER @@ -705,10 +741,11 @@ def add_one_req( if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): return self.add_one_req_ignore_eos(req) - total_tokens = req.extend_input_len + min( + max_new = min( max(req.sampling_params.max_new_tokens - len(req.output_ids), 0), CLIP_MAX_NEW_TOKENS, ) + total_tokens = req.extend_input_len + max_new + self.page_size # adjusting the input_tokens based on host_hit_length and page_size real_input_tokens = req.extend_input_len - req.host_hit_length @@ -718,6 +755,11 @@ def add_one_req( if total_tokens >= self.rem_total_tokens: return AddReqResult.NO_TOKEN + if self.is_hybrid_swa: + swa_needed = self._swa_budget_for_req(req.extend_input_len) + if swa_needed >= self.rem_swa_tokens: + return AddReqResult.NO_TOKEN + if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0: return AddReqResult.OTHER @@ -726,6 +768,11 @@ def add_one_req( if total_tokens >= self.rem_total_tokens: return AddReqResult.NO_TOKEN + if self.is_hybrid_swa: + swa_needed = self._swa_budget_for_req(req.extend_input_len) + if swa_needed >= self.rem_swa_tokens: + return AddReqResult.NO_TOKEN + if req.host_hit_length > 0: new_indices, req.last_node = self.tree_cache.init_load_back( req.last_host_node, req.host_hit_length @@ -788,6 +835,13 @@ def add_one_req( trunc_len // truncation_align_size ) + now_input_len = trunc_len + len(req.prefix_indices) + now_input_len = now_input_len // self.page_size * self.page_size + trunc_len = now_input_len - len(req.prefix_indices) + + if trunc_len <= 0: + return AddReqResult.OTHER + # Chunked prefill req.set_extend_input_len(trunc_len) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 18fd50130581..09642cd9ac83 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.lora.lora_overlap_loader import LoRAOverlapLoader +from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator from sglang.srt.managers.io_struct import ( AbortReq, ActiveRanksOutput, @@ -170,6 +171,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin from sglang.srt.parser.reasoning_parser import ReasoningParser +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( @@ -308,6 +310,8 @@ def __init__( self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hicache_storage = server_args.hicache_storage_backend is not None self.max_recv_per_poll = envs.SGLANG_SCHEDULER_MAX_RECV_PER_POLL.get() + self.enable_hisparse = server_args.enable_hisparse + self.hisparse_coordinator: Optional[HiSparseCoordinator] = None # Distributed rank info self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( @@ -690,6 +694,9 @@ def init_cache_with_memory_pool(self): else: self.tree_cache = RadixCache(params) + if self.enable_hisparse: + self.hisparse_coordinator = self.tp_worker.model_runner.hisparse_coordinator + if ( server_args.disaggregation_mode == "decode" and server_args.disaggregation_decode_enable_offload_kvcache @@ -729,6 +736,15 @@ def init_chunked_prefill(self): if self.chunked_prefill_size <= 0: # -1 means disable self.chunked_prefill_size = None self.chunked_req = None + # Tracks whether the current self.chunked_req was actually scheduled + # into last iteration's batch (i.e., in can_run_list -> got a fresh + # req_pool_idx from prepare_for_extend). Used to gate the + # stash_chunked_request call at the top of get_next_batch_to_run: + # if add_chunked_req early-returned under hybrid-SWA pressure, + # the req_pool_idx was already freed and fill_ids was reset by + # init_next_round_input, so running stash would double-free and + # corrupt prefix_indices. + self._chunked_req_scheduled_last_iter = False self.is_mixed_chunk = ( self.chunked_prefill_size is not None and self.server_args.enable_mixed_chunk @@ -851,7 +867,7 @@ def init_disaggregation(self): self.disagg_metadata_buffers = MetadataBuffers( buffer_size, hidden_size=( - model_config.hidden_size + model_config.spec_hidden_size if self.spec_algorithm.is_eagle() else 16 # minimal padding size for RDMA ), @@ -905,7 +921,7 @@ def init_disaggregation(self): self.disagg_metadata_buffers = MetadataBuffers( buffer_size, hidden_size=( - model_config.hidden_size + model_config.spec_hidden_size if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone() else 16 # minimal padding size for RDMA @@ -1447,6 +1463,7 @@ def handle_generate_request( require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, return_routed_experts=recv_req.return_routed_experts, + return_indexer_topk=recv_req.return_indexer_topk, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, @@ -1790,6 +1807,40 @@ def stash_chunked_request(self, req: Req): else: self.req_to_token_pool.free(req.req_pool_idx) + def _build_hisparse_decode_batch(self, reqs): + device = self.device + + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + tree_cache=self.tree_cache, + model_config=self.model_config, + enable_overlap=self.enable_overlap, + spec_algorithm=self.spec_algorithm, + ) + + batch.req_pool_indices = torch.tensor( + [r.req_pool_idx for r in reqs], dtype=torch.int64, device=device + ) + seq_lens = [len(r.origin_input_ids) + len(r.output_ids) - 1 for r in reqs] + batch.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=device) + batch.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) + batch.orig_seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device) + batch.seq_lens_sum = sum(seq_lens) + batch.output_ids = torch.tensor( + [r.output_ids[-1] for r in reqs], dtype=torch.int64, device=device + ) + + if batch.return_logprob: + batch.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + batch.token_ids_logprobs = [list(r.origin_input_ids) for r in reqs] + + batch.sampling_info = SamplingBatchInfo.from_schedule_batch( + batch, self.model_config.vocab_size + ) + return batch + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: self._abort_on_queued_timeout() if self.dllm_config is not None: @@ -1812,9 +1863,29 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Move the chunked request out of the batch so that we can merge # only finished requests to running_batch. chunked_req_to_exclude.add(self.chunked_req) - self.stash_chunked_request(self.chunked_req) - if self.last_batch and self.last_batch.forward_mode.is_extend(): + if ( + not envs.SGLANG_FIX_SWA_CHUNKED_REQ_DOUBLE_FREE.get() + or self._chunked_req_scheduled_last_iter + ): + self.stash_chunked_request(self.chunked_req) + + if self.enable_hisparse: + ready_reqs = self.hisparse_coordinator.collect_ready_reqs() + if len(ready_reqs) > 0: + new_batch = self._build_hisparse_decode_batch(ready_reqs) + if self.running_batch.is_empty(): + self.running_batch = new_batch + else: + self.running_batch.merge_batch(new_batch) + self.running_batch.hisparse_coordinator = self.hisparse_coordinator + self.running_batch.batch_is_full = False + + if ( + not self.enable_hisparse + and self.last_batch + and self.last_batch.forward_mode.is_extend() + ): if self.last_batch.chunked_req is not None: # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. # We need to discard it. @@ -1876,14 +1947,14 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_num_allocatable_reqs(self, running_bs): res = get_global_server_args().pp_max_micro_batch_size - running_bs - if self.pp_size > 1: - res = min(res, self.req_to_token_pool.available_size()) + res = min(res, self.req_to_token_pool.available_size()) return res def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: prefill_delayer_single_pass = None if self.prefill_delayer: - _, token_usage, _, _ = self._get_token_info() + # _, token_usage, _, _ = self._get_token_info() + token_usage = 0.5 # HACK since it is unused prefill_delayer_single_pass = PrefillDelayerSinglePassExecutor( self.prefill_delayer, token_usage=token_usage ) @@ -1906,7 +1977,7 @@ def _get_new_batch_prefill_raw( for req in ready_grammar_requests: self._add_request_to_queue(req) - if self.try_preemption: + if self.try_preemption or self.is_hybrid_swa: # Reset batch_is_full to try preemption with a prefill adder. self.running_batch.batch_is_full = False @@ -1979,6 +2050,11 @@ def _get_new_batch_prefill_raw( if self.chunked_req is not None: self.chunked_req.init_next_round_input() self.chunked_req = adder.add_chunked_req(self.chunked_req) + self._chunked_req_scheduled_last_iter = ( + self.chunked_req in adder.can_run_list + ) + else: + self._chunked_req_scheduled_last_iter = False if self.enable_lora: running_loras = {req.lora_id for req in self.running_batch.reqs} @@ -2074,6 +2150,9 @@ def _get_new_batch_prefill_raw( # Update chunked prefill assert self.chunked_req is None self.chunked_req = adder.new_chunked_req + # new_chunked_req is added to can_run_list by add_one_req, + # so it will be scheduled this iter -> stash is needed next iter. + self._chunked_req_scheduled_last_iter = True if self.chunked_req is not None: self.chunked_req.is_chunked += 1 diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index c4728b714b57..1b8e8b308401 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -8,6 +8,9 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs +from sglang.srt.layers.attention.indexer_topk_capturer import ( + get_global_indexer_capturer, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer from sglang.srt.managers.io_struct import ( @@ -71,6 +74,13 @@ def maybe_collect_routed_experts(self: Scheduler, req: Req): req_to_token_pool=self.req_to_token_pool, ) + def maybe_collect_indexer_topk(self: Scheduler, req: Req): + req.indexer_topk = get_global_indexer_capturer().get_topk( + req_pool_idx=req.req_pool_idx, + seqlen=req.seqlen, + req_to_token_pool=self.req_to_token_pool, + ) + def maybe_collect_customized_info( self: Scheduler, i: int, req: Req, logits_output: LogitsProcessorOutput ): @@ -137,11 +147,14 @@ def process_batch_result_prefill( if req.finished(): self.maybe_collect_routed_experts(req) + self.maybe_collect_indexer_topk(req) release_kv_cache(req, self.tree_cache) req.time_stats.completion_time = time.perf_counter() elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) + if self.enable_hisparse: + self.hisparse_coordinator.admit_request_into_staging(req) self.maybe_collect_customized_info(i, req, logits_output) @@ -413,12 +426,15 @@ def process_batch_result_decode( if req.finished(): self.maybe_collect_routed_experts(req) + self.maybe_collect_indexer_topk(req) if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes if not self.decode_offload_manager.offload_kv_cache(req): release_kv_cache(req, self.tree_cache) else: + if self.enable_hisparse: + self.hisparse_coordinator.request_finished(req) release_kv_cache(req, self.tree_cache) req.time_stats.completion_time = time.perf_counter() @@ -854,6 +870,7 @@ def stream_output_generation( retraction_counts = [] output_hidden_states = None load = self.get_load() + indexer_topk = None routed_experts = None customized_info = {} @@ -1054,7 +1071,10 @@ def stream_output_generation( if routed_experts is None: routed_experts = [] routed_experts.append(req.routed_experts) - + if req.return_indexer_topk: + if indexer_topk is None: + indexer_topk = [] + indexer_topk.append(req.indexer_topk) if req.customized_info is not None: for k, v in req.customized_info.items(): if k not in customized_info: @@ -1109,6 +1129,7 @@ def stream_output_generation( output_token_entropy_val=None, output_hidden_states=output_hidden_states, routed_experts=routed_experts, + indexer_topk=indexer_topk, customized_info=customized_info, placeholder_tokens_idx=None, placeholder_tokens_val=None, diff --git a/python/sglang/srt/managers/scheduler_profiler_mixin.py b/python/sglang/srt/managers/scheduler_profiler_mixin.py index 7d08f12b35dd..c01d6a469091 100644 --- a/python/sglang/srt/managers/scheduler_profiler_mixin.py +++ b/python/sglang/srt/managers/scheduler_profiler_mixin.py @@ -34,6 +34,31 @@ logger = logging.getLogger(__name__) +_kineto_warmed = False + + +def _warmup_kineto_once() -> None: + """Workaround torch.profiler+kineto first-call dropping all GPU activities + on PyTorch 2.9.1 + CUDA 13.0 + GB300. Run a tiny dummy 1-kernel profile + once per process before the first real profile. + """ + global _kineto_warmed + if _kineto_warmed or not torch.cuda.is_available(): + return + logger.info("[KINETO_WARMUP] running dummy 1-kernel profile to warm CUPTI") + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + ): + t = torch.zeros(64, device="cuda") + t.add_(1.0) + torch.cuda.synchronize() + _kineto_warmed = True + logger.info("[KINETO_WARMUP] done") + + class SchedulerProfilerMixin: def init_profiler(self: Scheduler): if envs.SGLANG_PROFILE_V2.get(): @@ -188,6 +213,13 @@ def start_profile( self.rpd_profiler.rangePush("", "rpd profile range", "") self.profile_in_progress = True elif torchprof_activities: + if ( + envs.SGLANG_HACK_WARMUP_KINETO.get() + and not _is_npu + and torch.profiler.ProfilerActivity.CUDA in torchprof_activities + ): + _warmup_kineto_once() + self.torch_profiler = torch.profiler.profile( activities=torchprof_activities, with_stack=with_stack if with_stack is not None else True, diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 484a949f5b23..9b5a995c829f 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -68,6 +68,12 @@ def _get_swa_token_info(self: Scheduler): swa_num_used = self.swa_tokens_per_layer - ( swa_available_size + swa_evictable_size ) + if self.enable_hisparse: + full_num_used = max(0, full_num_used) + swa_num_used = max(0, swa_num_used) + else: + if swa_num_used < 0: + raise ValueError(f"swa_num_used < 0") full_token_usage = full_num_used / self.full_tokens_per_layer swa_token_usage = swa_num_used / self.swa_tokens_per_layer return ( @@ -176,7 +182,91 @@ def _get_batch_uncached_size(self: Scheduler, batch: ScheduleBatch) -> int: return ret + def _get_batch_swa_uncached_sizes( + self: Scheduler, batch: ScheduleBatch + ) -> tuple[int, int]: + full_uncached = 0 + swa_uncached = 0 + for req in batch.reqs: + assert req.kv_committed_freed == req.kv_overallocated_freed + if req.kv_committed_freed: + continue + + allocated_len = req.kv_allocated_len + if self.page_size > 1: + allocated_len = ceil_align(allocated_len, self.page_size) + assert req.cache_protected_len % self.page_size == 0 + + full_uncached += allocated_len - req.cache_protected_len + swa_uncached += allocated_len - max( + req.cache_protected_len, req.swa_evicted_seqlen + ) + + return full_uncached, swa_uncached + + def self_check_swa_during_busy(self: Scheduler): + current_batch: ScheduleBatch = self.last_batch + + if current_batch is None: + return + + spec_topk = self.server_args.speculative_eagle_topk or 1 + if spec_topk > 1: + warnings.warn( + "Runtime memory check (busy) is not supported when speculation topk > 1." + ) + return + + ( + _, + _, + _, + _, + full_available, + full_evictable, + swa_available, + swa_evictable, + ) = self._get_swa_token_info() + + full_protected = self.tree_cache.full_protected_size() + swa_protected = self.tree_cache.swa_protected_size() + + full_uncached, swa_uncached = self._get_batch_swa_uncached_sizes(current_batch) + + if ( + self.running_batch is not None + and self.running_batch is not current_batch + and not self.running_batch.is_empty() + ): + f_unc, s_unc = self._get_batch_swa_uncached_sizes(self.running_batch) + full_uncached += f_unc + swa_uncached += s_unc + + full_total = full_available + full_evictable + full_protected + full_uncached + swa_total = swa_available + swa_evictable + swa_protected + swa_uncached + + if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get() > 1: + log_msg = ( + f"[SWA Mem Check (BUSY)] " + f"full: ({full_available=} + {full_evictable=} + {full_protected=} + {full_uncached=}) = {full_total=} " + f"swa: ({swa_available=} + {swa_evictable=} + {swa_protected=} + {swa_uncached=}) = {swa_total=}" + ) + logger.info(log_msg) + + assert full_total == self.full_tokens_per_layer, ( + f"Full Pool Mem Leak Detected! {full_total=} vs {self.full_tokens_per_layer=}, " + f"{full_available=}, {full_evictable=}, {full_protected=}, {full_uncached=}" + ) + assert swa_total == self.swa_tokens_per_layer, ( + f"SWA Pool Mem Leak Detected! {swa_total=} vs {self.swa_tokens_per_layer=}, " + f"{swa_available=}, {swa_evictable=}, {swa_protected=}, {swa_uncached=}" + ) + def self_check_during_busy(self: Scheduler): + if self.is_hybrid_swa: + self.self_check_swa_during_busy() + return + current_batch: ScheduleBatch = self.last_batch if current_batch is None: @@ -218,7 +308,11 @@ def _check_req_pool(self: Scheduler): else: req_total_size = self.req_to_token_pool.size - if len(self.req_to_token_pool.free_slots) != req_total_size: + if self.disaggregation_mode == DisaggregationMode.DECODE: + expected_free = req_total_size + else: + expected_free = req_total_size - 1 + if len(self.req_to_token_pool.free_slots) != expected_free: msg = ( "req_to_token_pool memory leak detected!" f"available_size={len(self.req_to_token_pool.free_slots)}, " @@ -232,6 +326,9 @@ def _check_req_pool(self: Scheduler): ) def check_memory(self: Scheduler): + if self.enable_hisparse: + return + if self.is_hybrid_swa: memory_leak, token_msg = self._check_hybrid_memory() elif self.is_hybrid_ssm and self.tree_cache.supports_mamba(): @@ -329,6 +426,13 @@ def self_check_during_idle(self: Scheduler): if queue_size: return + if ( + self.enable_hisparse + and self.hisparse_coordinator is not None + and self.hisparse_coordinator.has_ongoing_staging() + ): + return + self.check_memory() self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ae6211887b44..778057df4e60 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -937,6 +937,7 @@ def _create_tokenized_object( require_reasoning=obj.require_reasoning, return_hidden_states=obj.return_hidden_states, return_routed_experts=obj.return_routed_experts, + return_indexer_topk=obj.return_indexer_topk, data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, @@ -1536,6 +1537,8 @@ def _handle_batch_output( meta_info["hidden_states"] = recv_obj.output_hidden_states[i] if getattr(recv_obj, "routed_experts", None): meta_info["routed_experts"] = recv_obj.routed_experts[i] + if getattr(recv_obj, "indexer_topk", None): + meta_info["indexer_topk"] = recv_obj.indexer_topk[i] if getattr(recv_obj, "customized_info", None): for k, v in recv_obj.customized_info.items(): meta_info[k] = v[i] diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 37416ba8b5af..c7f1dfcabf5e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -386,6 +386,9 @@ def set_hicache_consumer(self, consumer_index: int): if self.hicache_layer_transfer_counter is not None: self.hicache_layer_transfer_counter.set_consumer(consumer_index) + def register_hisparse_coordinator(self, coordinator): + self.model_runner.hisparse_coordinator = coordinator + def get_worker_info(self): return ( self.max_total_num_tokens, diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index eaf29628bf8e..714abb1ed0d4 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -55,6 +55,10 @@ def __init__( self.is_not_in_free_group = True self.free_group = [] + @property + def size_full(self): + return self.size + def debug_print(self) -> str: return "" diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index a4377989b4ba..6f465f552b0f 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -16,7 +16,6 @@ MatchPrefixParams, MatchResult, ) -from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -99,7 +98,6 @@ class SWAChunkCache(ChunkCache): """ChunkCache with support for sliding window attention.""" def __init__(self, params: CacheInitParams): - assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) super().__init__(params) self.sliding_window_size = params.sliding_window_size diff --git a/python/sglang/srt/mem_cache/compress_state.py b/python/sglang/srt/mem_cache/compress_state.py new file mode 100644 index 000000000000..cbcb79463b47 --- /dev/null +++ b/python/sglang/srt/mem_cache/compress_state.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import dataclasses +from contextlib import nullcontext + +import torch + +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE +from sglang.srt.mem_cache.utils import maybe_init_custom_mem_pool +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter + + +@dataclasses.dataclass +class KVAndScore: + kv_score: torch.Tensor + + @property + def kv(self) -> torch.Tensor: + return self.kv_score[..., : self._item_size] + + @property + def score(self) -> torch.Tensor: + return self.kv_score[..., self._item_size :] + + @property + def shape(self): + return self.kv_score.shape + + def __post_init__(self): + self._item_size = self.kv_score.shape[-1] // 2 + + @staticmethod + def from_kv_score(*, kv: torch.Tensor, score: torch.Tensor) -> KVAndScore: + assert kv.shape == score.shape + return KVAndScore(torch.cat([kv, score], dim=-1)) + + def new_empty(self, new_shape) -> KVAndScore: + assert new_shape[-1] == self._item_size + new_shape = list(new_shape) + new_shape[-1] = 2 * self._item_size + return KVAndScore(self.kv_score.new_empty(new_shape, requires_grad=False)) + + def __getitem__(self, index) -> KVAndScore: + return KVAndScore(self.kv_score[index]) + + def __setitem__(self, index, value: KVAndScore): + self.kv_score[index] = value.kv_score + + def clear(self): + self.kv.zero_() + self.score.fill_(float("-inf")) + + def view(self, *args): + args = list(args) + if isinstance(args[-1], int) and args[-1] != -1: + args[-1] = 2 * self._item_size + return KVAndScore(self.kv_score.view(*args)) + + def clone(self) -> KVAndScore: + return KVAndScore(self.kv_score.clone()) + + @staticmethod + def cat(tensors: list[KVAndScore], dim: int) -> KVAndScore: + assert dim != -1, "Concatenation along last dim is not supported." + assert len(tensors) > 0, "At least one tensor is required for concatenation." + item_size = tensors[0]._item_size + for v in tensors: + assert ( + v._item_size == item_size + ), "All tensors must have the same item size." + + return KVAndScore(torch.cat([v.kv_score for v in tensors], dim=dim)) + + +class DeepSeekV4CompressState: + def __init__( + self, + max_num_reqs: int, + ratio: int, + overlap: bool, + head_dim: int, + device: str, + dtype: torch.dtype, + enable_memory_saver: bool = True, + ): + self.max_num_reqs = max_num_reqs + self.ratio = ratio + self.overlap = overlap + self.head_dim = head_dim + self.device = device + self.dtype = dtype + coff = 1 + self.overlap + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + state_shape = (max_num_reqs, ratio * coff, 2 * head_dim * coff) + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + self.kv_score_state = torch.empty(state_shape, dtype=dtype, device=device) + + def get_state(self) -> KVAndScore: + return KVAndScore(self.kv_score_state) + + +class CompressStatePool: + def __init__( + self, + size: int, + swa_page_size: int, + ring_size: int, + overlap: bool, + head_dim: int, + dtype: torch.dtype, + device: str, + enable_memory_saver: bool, + ratio: int, + online: bool = False, + ): + self.swa_page_size = swa_page_size + self.ring_size = ring_size + self.enable_memory_saver = enable_memory_saver + self.online = online + + if online: + assert ring_size == 1, "online compress requires ring_size=1" + self._size = size + self.ring_size + 1 + last_dim = 3 * head_dim + else: + self._size = size + self.ring_size + 1 + self._size = (self._size + ratio - 1) // ratio * ratio + last_dim = 2 * (1 + overlap) * head_dim + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + self.enable_custom_mem_pool, self.custom_mem_pool, _ = ( + maybe_init_custom_mem_pool(device=device) + ) + + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.kv_score_buffer = KVAndScore( + torch.empty( + (self._size, last_dim), + dtype=dtype, + device=device, + ) + ) + if not online: + self.kv_score_buffer[-1].clear() + + def translate_from_swa_loc_to_state_loc( + self, swa_loc: torch.Tensor + ) -> torch.Tensor: + swa_pages = swa_loc // self.swa_page_size + state_loc = swa_pages * self.ring_size + (swa_loc % self.ring_size) + state_loc = torch.where(swa_loc < 0, -1, state_loc) + return state_loc + + def get_state_by_state_loc(self, state_loc: torch.Tensor) -> KVAndScore: + return self.kv_score_buffer[state_loc] + + def set_state_by_state_loc(self, state_loc: torch.Tensor, value: KVAndScore): + self.kv_score_buffer[state_loc] = value + self.kv_score_buffer[-1].clear() diff --git a/python/sglang/srt/mem_cache/deepseekv4_memory_pool.py b/python/sglang/srt/mem_cache/deepseekv4_memory_pool.py new file mode 100644 index 000000000000..80fecf96ce22 --- /dev/null +++ b/python/sglang/srt/mem_cache/deepseekv4_memory_pool.py @@ -0,0 +1,821 @@ +from __future__ import annotations + +import logging +from contextlib import nullcontext +from typing import List, Literal, NamedTuple, Optional, Tuple, Union + +import torch + +from sglang.jit_kernel.deepseek_v4 import fused_store_cache +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE +from sglang.srt.environ import envs +from sglang.srt.layers.attention.nsa import index_buf_accessor, index_buf_accessor_v4 +from sglang.srt.layers.attention.nsa.index_buf_accessor_v4 import NopeFp8RopeBf16Pack +from sglang.srt.mem_cache.compress_state import CompressStatePool +from sglang.srt.mem_cache.memory_pool import KVCache +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ceil_div + +logger = logging.getLogger(__name__) + +ONLINE_C128 = envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get() + + +def get_compress_state_ring_size( + compress_ratio: int, is_speculative: bool = False +) -> int: + assert compress_ratio in [4, 128], f"Unsupported {compress_ratio = }" + # Online c128 keeps a single (max, sum, kv) state per index instead of a + # 128-slot ring buffer of raw tokens, so ring_size collapses to 1. Online + # is incompatible with speculative decode for now. + if compress_ratio == 128 and ONLINE_C128: + assert not is_speculative, "online c128 does not support MTP" + return 1 + if is_speculative: + return 16 if compress_ratio == 4 else 256 + else: + return 8 if compress_ratio == 4 else 128 + + +class DeepSeekV4SingleKVPool(KVCache): + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + is_swa_pool: Optional[bool] = False, + ): + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + + self.scale_pad = 1 + self.quantize_block_size = 64 + self.rope_storage_dtype = torch.bfloat16 + self.k_with_scale_buffer_dtype = torch.int8 + self.is_swa_pool = is_swa_pool + self._create_buffers() + + @property + def page_size(self): + if self.is_swa_pool: + assert self._page_size == 256, "SWA KV pool page size not correct!" + + return self._page_size + + @page_size.setter + def page_size(self, value: int): + self._page_size = value + + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.kv_buffer = [ + self.create_buffer( + num_pages=(self.size + self.page_size + 1) // self.page_size, + ) + for _ in range(self.layer_num) + ] + + def get_bytes_per_token(self) -> int: + dim_per_token = ( + self.qk_nope_head_dim + + self.qk_rope_head_dim * self.rope_storage_dtype.itemsize + + self.qk_nope_head_dim // self.quantize_block_size + + self.scale_pad + ) + return dim_per_token + + def create_buffer(self, *, num_pages: int): + bytes_per_token = self.get_bytes_per_token() + self.kv_cache_total_dim = bytes_per_token + bytes_per_page_non_padded = self.page_size * bytes_per_token + self.bytes_per_page_padded = ceil_div(bytes_per_page_non_padded, 576) * 576 + + assert bytes_per_token == 448 + 64 * 2 + 8 + assert self.store_dtype == torch.uint8 + + return torch.zeros( + num_pages, + self.bytes_per_page_padded, + dtype=self.store_dtype, + device=self.device, + ) + + def set_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ): + index_buf_accessor_v4.SetKAndS.execute( + pool=self, + buf=self.kv_buffer[layer_id], + loc=loc, + nope_fp8_rope_bf16_pack=cache_nope_fp8_rope_bf16_pack, + ) + + def set_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + return fused_store_cache( + input=cache_k, + cache=self.kv_buffer[layer_id], + indices=loc, + page_size=self.page_size, + type="flashmla", + ) + + def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) + + return self.kv_buffer[layer_id] + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError("Use get_key_buffer instead.") + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError("Use get_key_buffer instead.") + + +class HiSparseC4DevicePool(DeepSeekV4SingleKVPool): + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: int | None = None, + end_layer: int | None = None, + ): + super().__init__( + size, + page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.kv_buffer], + dtype=torch.uint64, + device=self.device, + ) + self.compress_ratio = 4 + + def register_mapping(self, full_to_hisparse_device_index_mapping: torch.Tensor): + self.full_to_hisparse_device_index_mapping = ( + full_to_hisparse_device_index_mapping + ) + + def translate_loc_from_full_to_compressed(self, full_indices: torch.Tensor): + mask = (full_indices + 1) % self.compress_ratio == 0 + compressed_indices = full_indices[mask] // self.compress_ratio + return compressed_indices + + def translate_loc_from_compressed_to_hisparse_device( + self, compressed_indices: torch.Tensor + ): + return self.full_to_hisparse_device_index_mapping[compressed_indices].to( + torch.int32 + ) + + def _translate_loc_from_compressed_to_hisparse_device( + self, compressed_indices: torch.Tensor + ): + return self.full_to_hisparse_device_index_mapping[compressed_indices] + + def translate_loc_from_full_to_hisparse_device(self, full_indices: torch.Tensor): + return self._translate_loc_from_compressed_to_hisparse_device( + self.translate_loc_from_full_to_compressed(full_indices) + ) + + def set_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack, + ): + loc = self.translate_loc_from_compressed_to_hisparse_device(loc) + super().set_key_buffer(layer_id, loc, cache_nope_fp8_rope_bf16_pack) + + def set_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + loc = self.translate_loc_from_compressed_to_hisparse_device(loc) + return super().set_key_buffer_fused(layer_id, loc, cache_k) + + def get_cpu_copy(self, indices): + raise NotImplementedError("HiSparseC4DevicePool does not support get_cpu_copy") + + def load_cpu_copy(self, kv_cache_cpu, indices): + raise NotImplementedError("HiSparseC4DevicePool does not support load_cpu_copy") + + +class DeepSeekV4IndexerPool(KVCache): + quant_block_size = 128 + index_k_with_scale_buffer_dtype = torch.uint8 + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + index_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + self.index_head_dim = index_head_dim + + self._create_buffer() + + def _create_buffer(self): + num_scales_per_token = self.index_head_dim // self.quant_block_size + page_bytes = self.page_size * self.index_head_dim + page_bytes += self.page_size * num_scales_per_token * 4 + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.index_k_with_scale_buffer = [ + torch.zeros( + (self.size + self.page_size + 1) // self.page_size, + page_bytes, + dtype=self.index_k_with_scale_buffer_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: + return self.index_k_with_scale_buffer[layer_id] + + def get_index_k_scale_buffer( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + buf = self.index_k_with_scale_buffer[layer_id] + return index_buf_accessor.GetKAndS.execute( + self, buf, seq_len=seq_len, page_indices=page_indices + ) + + def set_index_k_scale_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + ) -> None: + buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] + index_buf_accessor.SetKAndS.execute( + pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale + ) + + def set_index_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + return fused_store_cache( + input=cache_k, + cache=self.index_k_with_scale_buffer[layer_id - self.start_layer], + indices=loc, + page_size=self.page_size, + type="indexer", + ) + + +class DeepSeekV4LayerItem(NamedTuple): + compress_ratio: Literal[0, 4, 128] + compress_layer_id: int + compress_kv_pool: Optional[DeepSeekV4SingleKVPool] = None + + +class DeepSeekV4TokenToKVPool(KVCache): + + def __init__( + self, + max_num_reqs: int, + swa_size: int, + c4_size: int, + c128_size: int, + c4_state_pool_size: int, + c128_state_pool_size: int, + page_size: int, + swa_page_size: int, + dtype: torch.dtype, + state_dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + indexer_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + compression_ratios: List[int], + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + enable_hisparse: bool = False, + ): + super().__init__( + swa_size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + c4_logical_size = c128_size * 32 + + logger.info( + "Initialize DeepSeekV4TokenToKVPool with " + f"{max_num_reqs=} {swa_size=} {c4_size=} " + f"{c4_logical_size=} {c128_size=} " + f"{c4_state_pool_size=} {c128_state_pool_size=}" + ) + + self.max_num_reqs = max_num_reqs + self.c4_size = c4_size + self.c4_logical_size = c4_logical_size + self.c128_size = c128_size + self.c4_state_pool_size = c4_state_pool_size + self.c128_state_pool_size = c128_state_pool_size + self.state_dtype = state_dtype + self.compression_ratios = compression_ratios + + assert page_size % swa_page_size == 0 + + self.swa_size = swa_size + self.swa_window_size = swa_page_size + self.swa_page_size = swa_page_size + self.scale_pad = 1 + + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.indexer_head_dim = indexer_head_dim + + c4_layer_num = sum(1 for r in compression_ratios if r == 4) + c128_layer_num = sum(1 for r in compression_ratios if r == 128) + c4_page_size = page_size // 4 + c128_page_size = page_size // 128 + self.swa_kv_pool = DeepSeekV4SingleKVPool( + swa_size, + swa_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + layer_num, + device, + enable_memory_saver, + is_swa_pool=True, + ) + + c4_kv_pool_type = DeepSeekV4SingleKVPool + if enable_hisparse: + c4_kv_pool_type = HiSparseC4DevicePool + self.c4_kv_pool = c4_kv_pool_type( + c4_size, + c4_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + c4_layer_num, + device, + enable_memory_saver, + ) + + self.c128_kv_pool = DeepSeekV4SingleKVPool( + c128_size, + c128_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + c128_layer_num, + device, + enable_memory_saver, + ) + + self.c4_indexer_kv_pool = DeepSeekV4IndexerPool( + self.c4_logical_size, + c4_page_size, + dtype, + indexer_head_dim, + c4_layer_num, + device, + enable_memory_saver, + ) + + self._init_compressed_layer_mapping() + + self._init_paged_compress_states(enable_memory_saver) + + self._should_cache_swa = envs.SGLANG_OPT_CACHE_SWA_TRANSLATION.get() + + self._dbg_dump_pool_sizes() + + def _dbg_dump_pool_sizes(self): + import os + + if os.environ.get("SGLANG_HISPARSE_DBG_POOL_SIZES") != "1": + return + try: + rank = torch.distributed.get_rank() + except Exception: + rank = 0 + if rank != 0: + return + + def sum_bufs(name, bufs): + if bufs is None: + return 0 + total = 0 + count = 0 + for b in bufs: + if b is None: + continue + t = getattr(b, "kv_score", None) + if t is None: + t = b + try: + total += t.element_size() * t.numel() + except Exception: + t = getattr(t, "kv_score_buffer", None) + if t is not None and hasattr(t, "kv_score"): + total += t.kv_score.element_size() * t.kv_score.numel() + count += 1 + logger.warning( + "HSDBG[pool] %-28s #bufs=%3d total=%10.2f MiB", + name, + count, + total / 2**20, + ) + return total + + total_all = 0 + total_all += sum_bufs("swa_kv_pool", self.swa_kv_pool.kv_buffer) + total_all += sum_bufs("c4_kv_pool", self.c4_kv_pool.kv_buffer) + total_all += sum_bufs("c128_kv_pool", self.c128_kv_pool.kv_buffer) + total_all += sum_bufs( + "c4_indexer_kv_pool", self.c4_indexer_kv_pool.index_k_with_scale_buffer + ) + if hasattr(self, "compress_state_pools"): + c4_state_bufs = [] + c128_state_bufs = [] + for ratio, pool in zip(self.compression_ratios, self.compress_state_pools): + if pool is None: + continue + if ratio == 4: + c4_state_bufs.append(pool.kv_score_buffer.kv_score) + elif ratio == 128: + c128_state_bufs.append(pool.kv_score_buffer.kv_score) + total_all += sum_bufs("c4_state_pool", c4_state_bufs) + total_all += sum_bufs("c128_state_pool", c128_state_bufs) + idx_bufs = [] + for pool in self.indexer_compress_state_pools: + if pool is None: + continue + idx_bufs.append(pool.kv_score_buffer.kv_score) + total_all += sum_bufs("c4_indexer_state_pool", idx_bufs) + logger.warning( + "HSDBG[pool] %-28s total=%10.2f MiB = %.2f GiB", + "GRAND_TOTAL", + total_all / 2**20, + total_all / 2**30, + ) + + def register_mapping(self, full_to_swa_index_mapping: torch.Tensor): + self.full_to_swa_index_mapping = full_to_swa_index_mapping + + def get_ring_size(self, compress_ratio: int) -> int: + server_args = get_global_server_args() + is_speculative = server_args.speculative_algorithm is not None + return get_compress_state_ring_size(compress_ratio, is_speculative) + + def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): + assert self.full_to_swa_index_mapping is not None + + return self.full_to_swa_index_mapping[kv_indices].to(torch.int32) + + def get_contiguous_buf_infos(self) -> Tuple[List[int], List[int], List[int]]: + data_ptrs: List[int] = [] + data_lens: List[int] = [] + item_lens: List[int] = [] + + for bufs in [ + self.c4_kv_pool.kv_buffer, + self.c4_indexer_kv_pool.index_k_with_scale_buffer, + self.c128_kv_pool.kv_buffer, + ]: + for buf in bufs: + assert buf.ndim == 2, f"expected 2D buffer, got {buf.ndim}D" + data_ptrs.append(buf.data_ptr()) + data_lens.append(buf.nbytes) + item_lens.append(buf[0].nbytes) + + return data_ptrs, data_lens, item_lens + + def get_state_buf_infos(self) -> Tuple[List[int], List[int], List[int]]: + data_ptrs: List[int] = [] + data_lens: List[int] = [] + item_lens: List[int] = [] + + for buf in self.swa_kv_pool.kv_buffer: + assert buf.ndim == 2, f"expected 2D buffer, got {buf.ndim}D" + data_ptrs.append(buf.data_ptr()) + data_lens.append(buf.nbytes) + item_lens.append(buf[0].nbytes) + + for pools in [ + self.compress_state_pools, + self.indexer_compress_state_pools, + ]: + for pool in pools: + if pool is None: + continue + t = pool.kv_score_buffer.kv_score + assert t.ndim == 2, f"expected 2D buffer, got {t.ndim}D" + data_ptrs.append(t.data_ptr()) + data_lens.append(t.nbytes) + item_lens.append(t[0].nbytes * pool.ring_size) + + return data_ptrs, data_lens, item_lens + + def _init_paged_compress_states(self, enable_memory_saver: bool): + c4_state_pool_size = self.c4_state_pool_size + c128_state_pool_size = self.c128_state_pool_size + self.compress_state_pools: List[CompressStatePool] = [] + self.indexer_compress_state_pools: List[CompressStatePool] = [] + + for ratio in self.compression_ratios: + overlap = ratio == 4 + compress_state_pool = indexer_compress_state_pool = None + size = c4_state_pool_size if ratio == 4 else c128_state_pool_size + ring_size = self.get_ring_size(ratio) if ratio != 0 else 0 + if ratio != 0: + compress_state_pool = CompressStatePool( + size=size, + swa_page_size=self.swa_page_size, + ring_size=ring_size, + overlap=overlap, + head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, + dtype=self.state_dtype, + device=self.device, + enable_memory_saver=enable_memory_saver, + ratio=ratio, + online=(ratio == 128 and ONLINE_C128), + ) + + if ratio == 4: + indexer_compress_state_pool = CompressStatePool( + size=size, + swa_page_size=self.swa_page_size, + ring_size=ring_size, + overlap=overlap, + head_dim=self.indexer_head_dim, + device=self.device, + dtype=self.state_dtype, + enable_memory_saver=enable_memory_saver, + ratio=ratio, + ) + + self.compress_state_pools.append(compress_state_pool) + self.indexer_compress_state_pools.append(indexer_compress_state_pool) + + def _init_compressed_layer_mapping(self): + c1_cnt, c4_cnt, c128_cnt = 0, 0, 0 + self.layer_mapping: List[DeepSeekV4LayerItem] = [] + + for ratio in self.compression_ratios: + if ratio == 0: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=0, + compress_layer_id=c1_cnt, + ) + ) + c1_cnt += 1 + elif ratio == 4: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=4, + compress_layer_id=c4_cnt, + compress_kv_pool=self.c4_kv_pool, + ) + ) + c4_cnt += 1 + elif ratio == 128: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=128, + compress_layer_id=c128_cnt, + compress_kv_pool=self.c128_kv_pool, + ) + ) + c128_cnt += 1 + else: + raise ValueError(f"Unsupported compression ratio: {ratio}") + + def get_attention_compress_states(self, layer_id: int) -> CompressStatePool: + compress_state_pool = self.compress_state_pools[layer_id] + assert ( + compress_state_pool is not None + ), "Only c4/c128 layers have attention states." + return compress_state_pool + + def get_indexer_compress_states(self, layer_id: int) -> CompressStatePool: + indexer_compress_state_pool = self.indexer_compress_state_pools[layer_id] + assert ( + indexer_compress_state_pool is not None + ), "Only c4 layers have indexer states." + return indexer_compress_state_pool + + def get_swa_key_buffer(self, layer_id: int) -> torch.Tensor: + return self.swa_kv_pool.get_key_buffer(layer_id) + + + def set_swa_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + self.swa_kv_pool.set_key_buffer(layer_id, loc, cache_nope_fp8_rope_bf16_pack) + + def get_extra_key_buffer(self, layer_id: int) -> torch.Tensor | None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + return compress_kv_pool.get_key_buffer(compress_layer_id) + + def set_extra_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + compress_kv_pool.set_key_buffer( + compress_layer_id, loc, cache_nope_fp8_rope_bf16_pack + ) + + def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.get_index_k_with_scale_buffer(compress_layer_id) + + def get_index_k_scale_buffer( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.get_index_k_scale_buffer( + compress_layer_id, seq_len, page_indices + ) + + def set_index_k_scale_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + ) -> None: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + self.c4_indexer_kv_pool.set_index_k_scale_buffer( + compress_layer_id, loc, index_k, index_k_scale + ) + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def set_swa_key_buffer_radix( + self, + layer_id: int, + raw_loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + swa_loc = self.translate_loc_from_full_to_swa(raw_loc) + self.swa_kv_pool.set_key_buffer( + layer_id, swa_loc, cache_nope_fp8_rope_bf16_pack + ) + + def get_swa_key_buffer_radix(self, layer_id: int) -> torch.Tensor: + return self.swa_kv_pool.get_key_buffer(layer_id) + + def set_swa_key_buffer_radix_fused( + self, + layer_id: int, + raw_loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + if self._should_cache_swa: + if layer_id == 0: + self.cached_loc = self.translate_loc_from_full_to_swa(raw_loc) + swa_loc = self.cached_loc + else: + swa_loc = self.translate_loc_from_full_to_swa(raw_loc) + return self.swa_kv_pool.set_key_buffer_fused(layer_id, swa_loc, cache_k) + + def set_extra_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + return compress_kv_pool.set_key_buffer_fused(compress_layer_id, loc, cache_k) + + def set_index_k_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.set_index_fused(compress_layer_id, loc, cache_k) + diff --git a/python/sglang/srt/mem_cache/hisparse_memory_pool.py b/python/sglang/srt/mem_cache/hisparse_memory_pool.py new file mode 100644 index 000000000000..56ba2a3d25a9 --- /dev/null +++ b/python/sglang/srt/mem_cache/hisparse_memory_pool.py @@ -0,0 +1,381 @@ +import logging +import weakref +from typing import Optional + +import psutil +import torch + +from sglang.srt.mem_cache.allocator import ( + BaseTokenToKVPoolAllocator, + PagedTokenToKVPoolAllocator, +) +from sglang.srt.mem_cache.deepseekv4_memory_pool import ( + DeepSeekV4TokenToKVPool, + HiSparseC4DevicePool, +) +from sglang.srt.utils.common import get_num_new_pages + +logger = logging.getLogger(__name__) + + +class DeepSeekV4SingleKVPoolHost: + + def __init__( + self, + device_pool: HiSparseC4DevicePool, + host_size: int, + page_size: int, + pin_memory: bool = True, + device: str = "cpu", + ): + + assert host_size > 0, "Host size must be specified and greater than 0" + assert page_size == 1, "Host page size must be 1 for DeepSeekV4SingleKVPoolHost" + + self.device_pool = device_pool + self.size = host_size + self.page_size = page_size + self.num_pages = (self.size + self.page_size - 1) // self.page_size + self.pin_memory = pin_memory + self.device = device + + self.dtype = device_pool.store_dtype + self.layer_num = device_pool.layer_num + self.kv_cache_total_dim = device_pool.kv_cache_total_dim + + self.kv_buffer = self.init_kv_buffer() + self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)] + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.data_refs], + dtype=torch.uint64, + device=self.device_pool.device, + ) + self.clear() + + def clear(self): + self.free_slots = torch.arange( + 1, self.num_pages + 1, dtype=torch.int64, device="cpu" + ) + + def init_kv_buffer(self): + dims = (self.layer_num, self.size + self.page_size, self.kv_cache_total_dim) + requested_bytes = ( + self.layer_num + * (self.size + self.page_size) + * self.kv_cache_total_dim + * self.dtype.itemsize + ) + host_mem = psutil.virtual_memory() + # preserve at least 10GB for other usage + ten_gb = 10 * (1024**3) + available_bytes = host_mem.available - ten_gb + if requested_bytes > available_bytes: + raise ValueError( + f"Not enough host memory available. Requesting " + f"{requested_bytes / 1e9:.2f} GB but only have " + f"{available_bytes / 1e9:.2f} GB free. Please reduce the " + f"size of the hierarchical cache." + ) + else: + logger.info( + f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." + ) + + host_pool = torch.empty(dims, dtype=self.dtype, device=self.device) + assert self.pin_memory, "DeepSeekV4SingleKVPoolHost requires pin_memory=True" + if self.pin_memory: + torch.cuda.cudart().cudaHostRegister( + host_pool.data_ptr(), host_pool.numel() * host_pool.element_size(), 0 + ) + return host_pool + + def backup_from_device_all_layer(self, device_pool, host_indices, device_indices): + from sglang.jit_kernel.deepseek_v4 import hisparse_offload_to_host + + if host_indices.device != device_indices.device: + host_indices = host_indices.to(device=device_indices.device) + host_indices_i64 = ( + host_indices.to(torch.int64) + if host_indices.dtype != torch.int64 + else host_indices + ) + device_indices_i64 = ( + device_indices.to(torch.int64) + if device_indices.dtype != torch.int64 + else device_indices + ) + hisparse_offload_to_host( + gpu_ptrs=device_pool.data_ptrs, + cpu_ptrs=self.data_ptrs, + gpu_indices=device_indices_i64, + cpu_indices=host_indices_i64, + ) + + def available_size(self): + return len(self.free_slots) + + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + if need_size > self.available_size(): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + return select_index + + def free(self, indices: torch.Tensor) -> int: + self.free_slots = torch.cat([self.free_slots, indices.cpu()]) + return len(indices) + + +class HiSparseTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): + + def __init__( + self, + logical_attn_allocator: BaseTokenToKVPoolAllocator, + ): + assert isinstance(logical_attn_allocator._kvcache, DeepSeekV4TokenToKVPool) + assert isinstance( + logical_attn_allocator._kvcache.c4_kv_pool, HiSparseC4DevicePool + ) + self.compress_ratio = 4 + + self.hisparse_kvcache = logical_attn_allocator._kvcache.c4_kv_pool + self._size_full = logical_attn_allocator.size_full + self._size_hisparse = self.hisparse_kvcache.size + + self.dtype = self.hisparse_kvcache.dtype + self.device = self.hisparse_kvcache.device + self.page_size = self.hisparse_kvcache.page_size + + self.logical_attn_allocator = logical_attn_allocator + self._kvcache = logical_attn_allocator._kvcache + self.hisparse_attn_allocator = PagedTokenToKVPoolAllocator( + self._size_hisparse, + self.page_size, + self.dtype, + self.device, + self.hisparse_kvcache, + logical_attn_allocator.need_sort, + ) + + self.full_to_hisparse_device_index_mapping = torch.cat( + [ + torch.zeros( + self._kvcache.c4_logical_size + self.page_size, + dtype=torch.int64, + device=self.device, + ), + torch.tensor([-1], dtype=torch.int64, device=self.device), + ] + ) + + self.need_sort = logical_attn_allocator.need_sort + self.free_pages = None + self.release_pages = None + self.is_not_in_free_group = True + self.free_group = [] + self.clear() + + self.hisparse_kvcache.register_mapping( + weakref.proxy(self.full_to_hisparse_device_index_mapping) + ) + + @property + def size_full(self) -> int: + return self._size_full + + def full_available_size(self): + return min( + self.logical_attn_allocator.full_available_size(), + self.hisparse_attn_allocator.available_size() * self.compress_ratio, + ) + + def swa_available_size(self): + return self.logical_attn_allocator.swa_available_size() + + def free_swa(self, free_indices: torch.Tensor): + self.logical_attn_allocator.free_swa(free_indices) + + def available_size(self) -> int: + return min( + self.logical_attn_allocator.available_size(), + self.hisparse_attn_allocator.available_size() * self.compress_ratio, + ) + + def alloc(self, need_size: int): + raise NotImplementedError( + "Page size = 1 is not supported in HiSparse allocator" + ) + + def alloc_device_buffer(self, allocated_indices, need_size: int): + assert need_size % self.page_size == 0 + hisparse_indices = self.full_to_hisparse_device_index_mapping[allocated_indices] + self.full_to_hisparse_device_index_mapping[allocated_indices] = 0 + + device_buffer_size = need_size - self.page_size + P = len(hisparse_indices) + if P > device_buffer_size + 1: + newest_src = hisparse_indices[P - 1].clone() + old_at_dbs = hisparse_indices[device_buffer_size].clone() + hisparse_indices[device_buffer_size] = newest_src + hisparse_indices[P - 1] = old_at_dbs + + if len(hisparse_indices) >= need_size: + buffer_indices = hisparse_indices[:need_size] + surplus = hisparse_indices[need_size:] + if surplus.numel() > 0: + buffer_pages = torch.unique(buffer_indices // self.page_size) + surplus_pages = torch.unique(surplus // self.page_size) + pure_surplus = surplus_pages[~torch.isin(surplus_pages, buffer_pages)] + if pure_surplus.numel() > 0: + self.hisparse_attn_allocator.is_not_in_free_group = True + self.hisparse_attn_allocator.free(pure_surplus * self.page_size) + else: + page_residual_length = len(hisparse_indices) % self.page_size + if page_residual_length != 0: + hisparse_indices = torch.cat( + [ + hisparse_indices, + torch.arange( + hisparse_indices[-1] + 1, + hisparse_indices[-1] + + self.page_size + - page_residual_length + + 1, + device=self.device, + ), + ] + ) + extra_indices = self.hisparse_attn_allocator.alloc( + need_size - len(hisparse_indices) + ) + assert ( + extra_indices is not None + ), "Hisparse allocation failed in alloc_device_buffer" + buffer_indices = torch.cat([hisparse_indices, extra_indices]) + return buffer_indices + + def free_hisparse_indices(self, buffer_indices: torch.Tensor): + self.hisparse_attn_allocator.is_not_in_free_group = True + self.hisparse_attn_allocator.free(buffer_indices[buffer_indices > 0]) + + def get_last_loc_compressed(self, last_locs: torch.Tensor): + return (last_locs - 3) // self.compress_ratio + + def get_last_loc_hisparse_device(self, last_locs: torch.Tensor): + return self.hisparse_kvcache._translate_loc_from_compressed_to_hisparse_device( + self.get_last_loc_compressed(last_locs) + ) + + def alloc_extend( + self, + prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + extend_num_tokens: int, + ): + assert self.page_size > 1 + + num_new_pages_logical = get_num_new_pages( + seq_lens=seq_lens_cpu, page_size=self.page_size, prefix_lens=prefix_lens_cpu + ) + num_new_pages_hisparse = get_num_new_pages( + seq_lens=seq_lens_cpu // self.compress_ratio, + page_size=self.page_size, + prefix_lens=prefix_lens_cpu // self.compress_ratio, + ) + if ( + num_new_pages_logical + > self.logical_attn_allocator.available_size() // self.page_size + ): + return None + if ( + num_new_pages_hisparse + > self.hisparse_attn_allocator.available_size() // self.page_size + ): + return None + + logical_indices = self.logical_attn_allocator.alloc_extend( + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + extend_num_tokens, + ) + assert logical_indices is not None, "Logical allocation failed in alloc_extend" + + compressed_logical_indices = ( + self.hisparse_kvcache.translate_loc_from_full_to_compressed(logical_indices) + ) + hisparse_last_loc = self.get_last_loc_hisparse_device(last_loc) + hisparse_indices = self.hisparse_attn_allocator.alloc_extend( + prefix_lens // self.compress_ratio, + prefix_lens_cpu // self.compress_ratio, + seq_lens // self.compress_ratio, + seq_lens_cpu // self.compress_ratio, + hisparse_last_loc, + len(compressed_logical_indices), + ) + assert ( + hisparse_indices is not None + ), "Hisparse allocation failed in alloc_extend" + + self.full_to_hisparse_device_index_mapping[compressed_logical_indices] = ( + hisparse_indices.to(torch.int64) + ) + return logical_indices + + def alloc_decode( + self, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + ): + return self.logical_attn_allocator.alloc_decode( + seq_lens, seq_lens_cpu, last_loc + ) + + def free_compressed(self, compressed_indices: torch.Tensor): + hisparse_indices = ( + self.hisparse_kvcache.translate_loc_from_compressed_to_hisparse_device( + compressed_indices + ) + ) + hisparse_indices = hisparse_indices[hisparse_indices > 0] + self.free_hisparse_indices(hisparse_indices) + self.full_to_hisparse_device_index_mapping[compressed_indices] = 0 + + def free_hisparse(self, free_indices: torch.Tensor): + compressed_indices = ( + self.hisparse_kvcache.translate_loc_from_full_to_compressed(free_indices) + ) + self.free_compressed(compressed_indices) + + def clear(self): + self.logical_attn_allocator.clear() + self.hisparse_attn_allocator.clear() + + self.full_to_hisparse_device_index_mapping[:-1].fill_(0) + self.is_not_in_free_group = True + self.free_group = [] + + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + + if self.is_not_in_free_group: + self.logical_attn_allocator.free(free_index) + else: + self.free_group.append(free_index) + assert ( + self.logical_attn_allocator.available_size() + <= self.logical_attn_allocator.size + ) + assert ( + self.hisparse_attn_allocator.available_size() + <= self.hisparse_attn_allocator.size + ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 0afbb15fd7e8..40367a920921 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -137,7 +137,7 @@ def __init__( (size, max_context_len), dtype=torch.int32, device=device ) - self.free_slots = list(range(size)) + self.free_slots = list(range(1, size)) def write(self, indices, values): self.req_to_token[indices] = values @@ -161,7 +161,7 @@ def free(self, free_index: Union[int, List[int]]): self.free_slots.extend(free_index) def clear(self): - self.free_slots = list(range(self.size)) + self.free_slots = list(range(1, self.size)) class MambaPool: diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 0faf201cbd48..c44162575166 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -1,6 +1,6 @@ import logging import weakref -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch @@ -10,8 +10,10 @@ PagedTokenToKVPoolAllocator, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool from sglang.srt.mem_cache.utils import maybe_init_custom_mem_pool +from sglang.srt.utils.common import get_num_new_pages logger = logging.getLogger(__name__) GB = 1024 * 1024 * 1024 @@ -230,29 +232,32 @@ def __init__( page_size: int, dtype: torch.dtype, device: str, - kvcache: SWAKVPool, + kvcache: Union[SWAKVPool, DeepSeekV4TokenToKVPool], need_sort: bool, ): - assert isinstance(kvcache, SWAKVPool) + assert isinstance(kvcache, (SWAKVPool, DeepSeekV4TokenToKVPool)) self._size_full = size self._size_swa = size_swa self.dtype = dtype self.device = device self.page_size = page_size + full_kv_pool = getattr(kvcache, "full_kv_pool", None) + swa_kv_pool = getattr(kvcache, "swa_kv_pool", None) + if page_size == 1: self.full_attn_allocator = TokenToKVPoolAllocator( size, dtype, device, - kvcache.full_kv_pool, + full_kv_pool, need_sort, ) self.swa_attn_allocator = TokenToKVPoolAllocator( size_swa, dtype, device, - kvcache.swa_kv_pool, + swa_kv_pool, need_sort, ) else: @@ -261,7 +266,7 @@ def __init__( page_size, dtype, device, - kvcache.full_kv_pool, + full_kv_pool, need_sort, ) self.swa_attn_allocator = PagedTokenToKVPoolAllocator( @@ -269,7 +274,7 @@ def __init__( page_size, dtype, device, - kvcache.swa_kv_pool, + swa_kv_pool, need_sort, ) # Note: append one more item of value -1 in the end so -1 maps to -1. @@ -360,10 +365,13 @@ def alloc_extend( extend_num_tokens: int, ): assert self.page_size > 1 - num_tokens = extend_num_tokens + len(seq_lens) * self.page_size - if num_tokens > self.full_attn_allocator.available_size(): + + num_new_pages = get_num_new_pages( + seq_lens=seq_lens_cpu, page_size=self.page_size, prefix_lens=prefix_lens_cpu + ) + if num_new_pages > self.full_attn_allocator.available_size() // self.page_size: return None - if num_tokens > self.swa_attn_allocator.available_size(): + if num_new_pages > self.swa_attn_allocator.available_size() // self.page_size: return None swa_last_loc = self.translate_loc_from_full_to_swa(last_loc) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 4b07b841f2aa..9feb2c8b0998 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -28,6 +28,7 @@ import torch from numpy import float64 +from sglang.srt.environ import envs from sglang.srt.mem_cache.base_prefix_cache import ( BasePrefixCache, EvictParams, @@ -513,7 +514,12 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) - self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) + self.dec_lock_ref( + req.last_node, + req.swa_uuid_for_lock, + skip_swa=req.swa_prefix_lock_released, + ) + req.swa_prefix_lock_released = False def cache_unfinished_req(self, req: Req, chunked=False) -> None: """Cache request when it is unfinished.""" @@ -587,7 +593,12 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: req.cache_protected_len = len(new_indices) - self.dec_lock_ref(req.last_node, req.swa_uuid_for_lock) + self.dec_lock_ref( + req.last_node, + req.swa_uuid_for_lock, + skip_swa=req.swa_prefix_lock_released, + ) + req.swa_prefix_lock_released = False swa_uuid_for_lock = self.inc_lock_ref(new_last_node) # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later @@ -635,12 +646,15 @@ def evict(self, params: EvictParams) -> EvictResult: # 1. free node kv indices, evict full and swa tokens self.token_to_kv_pool_allocator.free(x.value) full_num_evicted += len(x.value) - swa_num_evicted += len(x.value) + # Tombstoned leaves had their SWA freed earlier in `dec_swa_lock_only` + if not x.swa_tombstone: + swa_num_evicted += len(x.value) # 2. get the next leaf, update the lru lists x_next = self.full_lru_list.get_prev_leaf_no_lock(x) self.full_lru_list.remove_node(x) - self.swa_lru_list.remove_node(x) + if not x.swa_tombstone: + self.swa_lru_list.remove_node(x) # 3. delete the leaf node self._delete_leaf(x) @@ -677,6 +691,18 @@ def evict(self, params: EvictParams) -> EvictResult: # 3. tombstone the node self._tombstone_internal_node(x) + elif x.full_lock_ref > 0: + # Leaf still holds a full-side lock (can happen when the + # SWA leaf-lock early-release optimization revived a + # tombstoned leaf. Treat it like an internal tombstone. + self.token_to_kv_pool_allocator.free_swa(x.value) + swa_num_evicted += len(x.value) + + x_next = self.swa_lru_list.get_prev_no_lock(x) + self.swa_lru_list.remove_node(x) + + self.swa_evictable_size_ -= len(x.value) + x.swa_tombstone = True else: assert ( x.full_lock_ref == 0 @@ -745,17 +771,25 @@ def inc_lock_ref(self, node: TreeNode) -> Optional[int]: node = node.parent return swa_uuid_for_lock - def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None): + def dec_lock_ref( + self, + node: TreeNode, + swa_uuid_for_lock: Optional[int] = None, + skip_swa: bool = False, + ): """ Decrement the lock reference count for the node. It unlocks the full_lock_ref for nodes between the [last node, root), exclusive. It unlocks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive. If swa_uuid_for_lock is None, it unlocks to the root, exclusive. + + If skip_swa is True, only the full_lock_ref is decremented; the SWA lock is + assumed to have been released already (e.g. via `dec_swa_lock_only`). """ if self.disable: return - dec_lock_swa = True + dec_lock_swa = not skip_swa while node != self.root_node: assert ( node.full_lock_ref > 0 @@ -782,6 +816,61 @@ def dec_lock_ref(self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None): node = node.parent + def dec_swa_lock_only( + self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None + ): + """ + Decrement only the swa_lock_ref (and swa_protected_size_) along the chain + [node, swa_uuid_for_lock], inclusive. The full_lock_ref is left untouched + so the caller's full-cache protection is preserved. + + Used to early-release the SWA portion of a request's tree lock once the + request's decode position has advanced past the sliding window, so the + protected window can be reclaimed. + + For internal nodes, the standard protected -> evictable transition is + applied (node stays in swa_lru_list and may be evicted by SWA LRU later). + For leaf nodes, since `swa_lru_list` cannot contain a leaf with + `full_lock_ref > 0` (SWA-eviction would also delete the still-referenced + leaf), we instead free the SWA pool slots immediately and mark the leaf + as `swa_tombstone=True`. The full kv stays alive until the full-side + lock drops; future prefix-matches stop before this tombstoned leaf. + + Caller must ensure this is invoked at most once per (node, swa_uuid_for_lock) + pair (track via e.g. `Req.swa_prefix_lock_released`). When the request + finally releases its full lock via `dec_lock_ref`, pass `skip_swa=True` + to avoid touching SWA state again. + """ + if self.disable: + return + + while node != self.root_node: + assert ( + not node.swa_tombstone + ), f"dec_swa_lock_only on swa_tombstone node, {node.id=}" + assert ( + node.swa_lock_ref > 0 + ), f"dec_swa_lock_only on node with {node.swa_lock_ref=}, {node.id=}" + + if node.swa_lock_ref == 1: + self.swa_protected_size_ -= len(node.value) + if len(node.children) == 0: + # Leaf: free SWA pool slots and tombstone, and remove from + # swa_lru_list so SWA-eviction won't pick this tombstoned + # leaf (which still holds full_lock_ref > 0). The full kv + # stays alive until the request releases its full lock. + self.token_to_kv_pool_allocator.free_swa(node.value) + self.swa_lru_list.remove_node(node) + node.swa_tombstone = True + else: + # Internal: standard protected -> evictable. + self.swa_evictable_size_ += len(node.value) + node.swa_lock_ref -= 1 + + if swa_uuid_for_lock and node.swa_uuid == swa_uuid_for_lock: + break + node = node.parent + def sanity_check(self): self.full_lru_list.sanity_check(self) self.swa_lru_list.sanity_check(self) @@ -846,9 +935,13 @@ def _match_prefix_helper( match_len_since_tombstone = float("inf") best_value_len = 0 best_last_node = node + enable_compact = envs.SGLANG_OPT_SWA_RADIX_CACHE_COMPACT.get() while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] + if enable_compact: + self._compact_single_child_chain(child) + if child.swa_tombstone: # update best_value_len and best_last_node if needed if match_len_since_tombstone >= self.sliding_window_size: @@ -897,6 +990,75 @@ def _match_prefix_helper( return value[:best_value_len], best_last_node + def _compact_single_child_chain(self, node: TreeNode) -> None: + while len(node.children) == 1: + child = next(iter(node.children.values())) + if len(child.children) == 0: + break + sum_gc_full_lock_ref = sum( + gc.full_lock_ref for gc in child.children.values() + ) + if child.full_lock_ref > sum_gc_full_lock_ref: + break + if ( + child.swa_tombstone != node.swa_tombstone + or child.full_lock_ref != node.full_lock_ref + or child.swa_lock_ref != node.swa_lock_ref + ): + break + + node.key = RadixKey( + node.key.token_ids + child.key.token_ids, node.key.extra_key + ) + node.value = torch.cat([node.value, child.value]) + node.children = child.children + for grandchild in node.children.values(): + grandchild.parent = node + + if child.swa_uuid is not None: + node.swa_uuid = child.swa_uuid + + self.full_lru_list.remove_node(child) + if not child.swa_tombstone: + self.swa_lru_list.remove_node(child) + + def _maybe_split_leaf_for_swa_lock(self, leaf: TreeNode) -> TreeNode: + """``inc_lock_ref`` protects ``len(leaf.value)`` SWA tokens for the + leaf even though SWA only actually needs the last + ``sliding_window_size`` tokens. With chunked prefill, leaves can be + thousands of tokens long, which inflates ``swa_protected_size_`` by + ~``chunked_prefill_size / sliding_window_size`` and causes premature + SWA pool exhaustion / retract thrashing. + """ + if ( + leaf is self.root_node + or leaf.swa_lock_ref > 0 + or leaf.swa_tombstone + or len(leaf.value) == 0 + ): + return leaf + + # Smallest page-aligned size that still covers the sliding window. + tail_size = ( + (self.sliding_window_size + self.page_size - 1) + // self.page_size + * self.page_size + ) + if len(leaf.value) <= tail_size: + return leaf + + split_at = len(leaf.value) - tail_size + + if split_at <= 0 or split_at >= len(leaf.value): + return leaf + if self.page_size > 1 and ( + split_at % self.page_size != 0 or len(leaf.value) % self.page_size != 0 + ): + return leaf + + self._split_node(leaf.key, leaf, split_at) + return leaf + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: # new_node -> child new_node = TreeNode() @@ -1018,6 +1180,14 @@ def _insert_helper( child_key = self.get_child_key_fn(key) if len(key): + logger.debug( + f"Has Additional Node: len(key)={len(key)}, total_prefix_length={total_prefix_length}, swa_evicted_seqlen={swa_evicted_seqlen}, len(value)={len(value)}" + ) + + if swa_evicted_seqlen == total_prefix_length + len(key): + self.token_to_kv_pool_allocator.free(value) + return total_prefix_length + if ( swa_evicted_seqlen > total_prefix_length and swa_evicted_seqlen < total_prefix_length + len(key) @@ -1032,7 +1202,13 @@ def _insert_helper( key = key[swa_tombstone_len:] value = value[swa_tombstone_len:] - self._add_new_node(node, key, value, swa_tombstone=False) + new_leaf = self._add_new_node(node, key, value, swa_tombstone=False) + + if envs.SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT.get(): + # Cap the leaf at one (page-aligned) sliding window so a future + # inc_lock_ref only protects `sliding_window_size` tokens of SWA pool. + self._maybe_split_leaf_for_swa_lock(new_leaf) + return total_prefix_length def _add_new_node( @@ -1080,15 +1256,15 @@ def _iteratively_delete_tombstone_leaf( return node, full_num_evicted def _delete_leaf(self, node: TreeNode) -> None: - assert ( - not node.swa_tombstone - ), f"Invariant violated: leaf node is a tombstone, {node.id=}" assert len(node.children) == 0, f"leaf node has children, {node.id=}" key = self.get_child_key_fn(node.key) v = node.parent.children.pop(key, None) assert v == node, f"parent does not have child key, {key}" self.full_evictable_size_ -= len(node.key) - self.swa_evictable_size_ -= len(node.key) + # Tombstoned leaves were never (re-)added to swa_lru_list and were + # already removed from swa_evictable_size_ when they were tombstoned. + if not node.swa_tombstone: + self.swa_evictable_size_ -= len(node.key) def _tombstone_internal_node(self, node: TreeNode) -> None: assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b0b2ede6dbde..be642033fe46 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -678,6 +678,12 @@ def capture_one_batch_size( ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) + if ( + getattr(self.model_runner, "hisparse_coordinator", None) is not None + and self.capture_forward_mode.is_decode() + ): + forward_batch.hisparse_coordinator = self.model_runner.hisparse_coordinator + if lora_ids is not None: self.model_runner.lora_manager.prepare_lora_batch(forward_batch) @@ -726,11 +732,13 @@ def run_once(): self.device_module.synchronize() self.model_runner.tp_group.barrier() run_once() + attn_backend.on_after_cuda_graph_warmup_pass() if get_global_graph_memory_pool() is None: set_global_graph_memory_pool(self.device_module.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) + out = self._capture_graph( graph, get_global_graph_memory_pool(), stream, run_once ) @@ -832,6 +840,8 @@ def replay_prepare( self.capture_forward_mode, forward_batch.spec_info, seq_lens_cpu=seq_lens_cpu, + out_cache_loc=forward_batch.out_cache_loc, + actual_forward_mode=forward_batch.forward_mode, ) # Store fields @@ -860,6 +870,7 @@ def replay( else: graph_key = self.bs self.graphs[graph_key].replay() + output = self.output_buffers[graph_key] if isinstance(output, LogitsProcessorOutput): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index efd9f07d3d1b..cba8cded5194 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner @@ -98,8 +99,8 @@ class ForwardMode(IntEnum): # Used in diffusion LLM inference DLLM_EXTEND = auto() - def is_prefill(self): - return self.is_extend() + def is_prefill(self, include_draft_extend_v2: bool = False): + return self.is_extend(include_draft_extend_v2=include_draft_extend_v2) def is_extend(self, include_draft_extend_v2: bool = False): return ( @@ -375,6 +376,8 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): # For hidden states before normal return_hidden_states_before_norm: bool = False + hisparse_coordinator: Optional[HiSparseCoordinator] = None + @classmethod def init_new( cls, diff --git a/python/sglang/srt/model_executor/input_buffers.py b/python/sglang/srt/model_executor/input_buffers.py index f4468a70c634..9823f7c8ca65 100644 --- a/python/sglang/srt/model_executor/input_buffers.py +++ b/python/sglang/srt/model_executor/input_buffers.py @@ -144,6 +144,8 @@ def populate_from_forward_batch( enable_num_token_non_padded_flag: bool, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Optional[torch.Tensor]: + self.req_pool_indices.zero_() + if bs != raw_bs: self.seq_lens.fill_(seq_len_fill_value) self.out_cache_loc.zero_() diff --git a/python/sglang/srt/model_executor/memory_profiler.py b/python/sglang/srt/model_executor/memory_profiler.py new file mode 100644 index 000000000000..08d1d80cfdee --- /dev/null +++ b/python/sglang/srt/model_executor/memory_profiler.py @@ -0,0 +1,172 @@ + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed.parallel_state import get_world_group +from sglang.srt.environ import envs +from sglang.srt.mem_cache.deepseekv4_memory_pool import get_compress_state_ring_size + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + + +@dataclass +class DSv4PoolSizes: + + full_max_total_num_tokens: int + swa_max_total_num_tokens: int + c4_max_total_num_tokens: int + c128_max_total_num_tokens: int + c4_state_pool_size: int + c128_state_pool_size: int + + +class DSv4MemoryCalculator: + + def __init__( + self, + model_config: ModelConfig, + page_size: int, + swa_ratio: float, + is_speculative: bool = False, + c4_shrink_factor: int = 1, + ): + self.qk_nope_head_dim = model_config.qk_nope_head_dim + self.qk_rope_head_dim = model_config.qk_rope_head_dim + self.indexer_head_dim = model_config.index_head_dim + self.compression_ratios = model_config.compress_ratios + self.swa_page_size = model_config.window_size + self.page_size = page_size + self.swa_ratio = swa_ratio + self.is_speculative = is_speculative + assert c4_shrink_factor >= 1 + self.c4_shrink_factor = c4_shrink_factor + + self.c4_ring_size = get_compress_state_ring_size(4, self.is_speculative) + self.c128_ring_size = get_compress_state_ring_size(128, self.is_speculative) + + self.num_layers_total = len(self.compression_ratios) + self.num_layers_ca4 = sum(1 for r in self.compression_ratios if r == 4) + self.num_layers_ca128 = sum(1 for r in self.compression_ratios if r == 128) + + self.bytes_per_full_token = self.get_bytes_per_full_token() + + def get_bytes_per_full_token(self) -> float: + kv_bytes = self.qk_nope_head_dim + self.qk_rope_head_dim * 2 + 8 + + quant_block_size = 128 + indexer_bytes = ( + self.indexer_head_dim + self.indexer_head_dim // quant_block_size * 4 + ) + + attn_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + state_dtype_size = 4 + c4_state_bytes = 2 * 2 * attn_head_dim * state_dtype_size + # Online c128 stores (max, sum, kv) per slot (3*head_dim) instead of + # raw (kv, score) (2*head_dim). Combined with ring_size=1 this still + # nets a large reduction (~3/256x) but the per-slot bytes go up. + c128_online = envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get() + c128_state_bytes = (3 if c128_online else 2 * 1) * attn_head_dim * state_dtype_size + c4_indexer_state_bytes = 2 * 2 * self.indexer_head_dim * state_dtype_size + + c4_state_ratio = self.c4_ring_size / self.swa_page_size + c128_state_ratio = self.c128_ring_size / self.swa_page_size + + c4_frac = 1 / (4 * self.c4_shrink_factor) + bytes_per_full_token = ( + self.swa_ratio * kv_bytes * self.num_layers_total + + c4_frac * kv_bytes * self.num_layers_ca4 + + 1 / 128 * kv_bytes * self.num_layers_ca128 + + 1 / 4 * indexer_bytes * self.num_layers_ca4 + + self.swa_ratio * c4_state_ratio * c4_state_bytes * self.num_layers_ca4 + + self.swa_ratio + * c128_state_ratio + * c128_state_bytes + * self.num_layers_ca128 + + self.swa_ratio + * c4_state_ratio + * c4_indexer_state_bytes + * self.num_layers_ca4 + ) + + return bytes_per_full_token + + def calculate_pool_sizes(self, available_bytes: int) -> DSv4PoolSizes: + full_token = int(available_bytes / self.bytes_per_full_token) + + full_token = full_token // self.page_size * self.page_size + + swa_tokens = int(full_token * self.swa_ratio) // self.page_size * self.page_size + + pool_sizes = DSv4PoolSizes( + full_max_total_num_tokens=full_token, + swa_max_total_num_tokens=swa_tokens, + c4_max_total_num_tokens=full_token // (4 * self.c4_shrink_factor), + c128_max_total_num_tokens=full_token // 128, + c4_state_pool_size=swa_tokens // self.swa_page_size * self.c4_ring_size, + c128_state_pool_size=swa_tokens // self.swa_page_size * self.c128_ring_size, + ) + + logger.info( + f"DSv4 memory calculation: " + f"bytes_per_full_token={self.bytes_per_full_token:.2f}, " + f"available_bytes={available_bytes / (1 << 30):.2f} GB, " + f"full_token={full_token}" + ) + + return pool_sizes + + def get_pool_sizes_by_profiling(self, mr: ModelRunner) -> DSv4PoolSizes: + available_bytes = profile_available_bytes( + device=mr.device, + gpu_id=mr.gpu_id, + total_gpu_memory=mr.total_gpu_memory, + mem_fraction_static=mr.mem_fraction_static, + distributed=get_world_group().world_size > 1, + cpu_group=get_world_group().cpu_group, + ) + + if self.is_speculative: + draft_layers = 1 + target_layers = self.num_layers_total + target_ratio = target_layers / (target_layers + draft_layers) + available_bytes = int(available_bytes * target_ratio) + + return self.calculate_pool_sizes(available_bytes) + + def get_pool_sizes_by_configuration(self, max_total_tokens: int) -> DSv4PoolSizes: + available_bytes = max_total_tokens * self.bytes_per_full_token + return self.calculate_pool_sizes(available_bytes) + + +def profile_available_bytes( + device: str, + gpu_id: int, + total_gpu_memory: float, + mem_fraction_static: float, + distributed: bool = False, + cpu_group=None, +) -> int: + from sglang.srt.utils.common import get_available_gpu_memory + + available_gpu_memory = get_available_gpu_memory( + device, gpu_id, distributed=distributed, cpu_group=cpu_group + ) + rest_memory = available_gpu_memory - total_gpu_memory * (1 - mem_fraction_static) + + available_bytes = int(rest_memory * (1 << 30)) + + logger.info( + f"Memory profiling: available_gpu_memory={available_gpu_memory:.2f} GB, " + f"total_gpu_memory={total_gpu_memory:.2f} GB, " + f"mem_fraction_static={mem_fraction_static:.2f}, " + f"rest_memory={rest_memory:.2f} GB" + ) + + return available_bytes diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 275f5164ee02..ccf96f7ba375 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -84,6 +84,11 @@ ATTENTION_BACKENDS, attn_backend_wrapper, ) +from sglang.srt.layers.attention.indexer_topk_capturer import ( + create_indexer_capturer, + get_global_indexer_capturer, + set_global_indexer_capturer, +) from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.dp_attention import ( @@ -310,12 +315,15 @@ def __init__( self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.is_hybrid_swa = model_config.is_hybrid_swa - self.is_hybrid_swa_compress = model_config.is_hybrid_swa_compress + self.is_hybrid_swa_compress = getattr( + model_config, "is_hybrid_swa_compress", False + ) self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size self.forward_pass_id = 0 self.init_new_workspace = False self.draft_model_idx = draft_model_idx + self.enable_hisparse = server_args.enable_hisparse self.remote_instance_transfer_engine = None self.remote_instance_transfer_engine_session_id = "" @@ -366,7 +374,7 @@ def __init__( self.init_threads_binding() # Get memory before model loading - min_per_gpu_memory = self.init_torch_distributed() + self.total_gpu_memory = self.init_torch_distributed() # Init forward stream for overlap schedule self.forward_stream = torch.get_device_module(self.device).Stream() @@ -387,7 +395,7 @@ def __init__( deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args) # Initialize the model runner - self.initialize(min_per_gpu_memory) + self.initialize() self.check_quantized_moe_compatibility() # Temporary cached values @@ -404,6 +412,9 @@ def __init__( self._model_update_group = {} self._weights_send_group = {} + if not hasattr(self, "hisparse_coordinator"): + self.hisparse_coordinator = None + def init_mindspore_runner(self): # Init the mindspore runner # for now, there is only some communication initialization work @@ -418,7 +429,7 @@ def init_mindspore_runner(self): port=self.dist_port, ) - def initialize(self, min_per_gpu_memory: float): + def initialize(self): server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( @@ -493,6 +504,8 @@ def initialize(self, min_per_gpu_memory: float): self.end_layer = getattr(self.model, "end_layer", model_num_layers) self.num_effective_layers = self.end_layer - self.start_layer + self.adjust_hybrid_swa_layers_for_pp() + # For LoopCoder models, each loop has its own layer_id, so we need to multiply by loop_num loop_num = getattr(self.model_config.hf_config, "loop_num", 1) if loop_num > 1: @@ -507,23 +520,6 @@ def initialize(self, min_per_gpu_memory: float): ) ), "PP is not compatible with MTP models." - # Consider PP, so use start_layer and end_layer. - full_attention_layer_ids = [ - layer_idx - for layer_idx in range(self.start_layer, self.end_layer + 1) - if hasattr(self.model_config, "full_attention_layer_ids") - and layer_idx in self.model_config.full_attention_layer_ids - ] - swa_attention_layer_ids = [ - layer_idx - for layer_idx in range(self.start_layer, self.end_layer + 1) - if hasattr(self.model_config, "swa_attention_layer_ids") - and layer_idx in self.model_config.swa_attention_layer_ids - ] - # Update back to model_config. - self.model_config.swa_attention_layer_ids = swa_attention_layer_ids - self.model_config.full_attention_layer_ids = full_attention_layer_ids - # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) # In layered loading, torchao may have been applied @@ -559,7 +555,7 @@ def initialize(self, min_per_gpu_memory: float): self.configure_kv_cache_dtype() # Init memory pool and attention backends - self.init_memory_pool(min_per_gpu_memory) + self.init_memory_pool() # Init max running requests self.max_running_requests = min( @@ -575,10 +571,33 @@ def initialize(self, min_per_gpu_memory: float): # Init routed experts capturer self.init_routed_experts_capturer() + self.init_indexer_capturer() + if self.device == "cuda": self.init_cublas() self.init_attention_backend() self.kernel_warmup() + if self.enable_hisparse: + from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator + + _hisparse_top_k = getattr( + self.model_config.hf_text_config, "index_topk", 512 + ) + self.hisparse_coordinator = HiSparseCoordinator( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + top_k=_hisparse_top_k, + device_buffer_size=_hisparse_top_k * 2, + device=self.device, + tp_group=( + self.attention_tp_group.cpu_group + if self.server_args.enable_dp_attention + else self.tp_group.cpu_group + ), + ) + self.hisparse_coordinator.set_decode_producer_stream( + self.forward_stream + ) self.init_device_graphs() elif self.device in ["npu", "cpu"]: self.init_attention_backend() @@ -601,6 +620,28 @@ def initialize(self, min_per_gpu_memory: float): self.prealloc_symmetric_memory_pool() + def adjust_hybrid_swa_layers_for_pp(self): + if not self.is_hybrid_swa: + return + + if self.model_config.is_swa_with_compressed_attention: + return + + full_attention_layer_ids = [ + layer_idx + for layer_idx in range(self.start_layer, self.end_layer + 1) + if hasattr(self.model_config, "full_attention_layer_ids") + and layer_idx in self.model_config.full_attention_layer_ids + ] + swa_attention_layer_ids = [ + layer_idx + for layer_idx in range(self.start_layer, self.end_layer + 1) + if hasattr(self.model_config, "swa_attention_layer_ids") + and layer_idx in self.model_config.swa_attention_layer_ids + ] + self.model_config.swa_attention_layer_ids = swa_attention_layer_ids + self.model_config.full_attention_layer_ids = full_attention_layer_ids + def init_routed_experts_capturer(self): if not self.server_args.disable_shared_experts_fusion and hasattr( self.model, "num_fused_shared_experts" @@ -620,6 +661,17 @@ def init_routed_experts_capturer(self): ) ) + def init_indexer_capturer(self): + set_global_indexer_capturer( + create_indexer_capturer( + enable=get_global_server_args().enable_return_indexer_topk, + model_config=self.model_config, + num_tokens=self.max_total_num_tokens + self.page_size, + max_running_requests=self.max_running_requests, + device=self.device, + ) + ) + def remote_instance_init_transfer_engine(self): try: from mooncake.engine import TransferEngine @@ -692,7 +744,10 @@ def check_quantized_moe_compatibility(self): f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})." ) - if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0: + if ( + not envs.SGLANG_SHARED_EXPERT_TP1.get() + and (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0 + ): raise ValueError( f"For quantized MoE models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 " f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by ep_size ({self.moe_ep_size}). " @@ -1510,7 +1565,7 @@ def mamba2_config(self): def max_token_pool_size(self): """Return the max token pool size considering hybrid swa settings.""" if self.is_hybrid_swa: - return min(self.swa_max_total_num_tokens, self.max_total_num_tokens) + return self.full_max_total_num_tokens else: return self.max_total_num_tokens @@ -2285,6 +2340,12 @@ def forward( cuda_graph_batch=getattr(self.graph_runner, "bs", None), ) + get_global_indexer_capturer().on_forward_end( + forward_batch=forward_batch, + can_run_graph=output.can_run_graph, + cuda_graph_batch=getattr(self.graph_runner, "bs", None), + ) + if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() @@ -2309,6 +2370,14 @@ def _forward_raw( and self.graph_runner.can_run(forward_batch) ) + if ( + forward_batch.forward_mode.is_decode() + and self.hisparse_coordinator is not None + ): + forward_batch.hisparse_coordinator = self.hisparse_coordinator + self.hisparse_coordinator.wait_for_pending_backup() + self.hisparse_coordinator.num_real_reqs.fill_(forward_batch.batch_size) + if can_run_graph: ret = self.graph_runner.replay( forward_batch, @@ -2334,6 +2403,8 @@ def _forward_raw( server_args=self.server_args, ) + if self.hisparse_coordinator is not None: + forward_batch.hisparse_coordinator = self.hisparse_coordinator if forward_batch.forward_mode.is_decode(): ret = self.forward_decode( forward_batch, @@ -2411,6 +2482,8 @@ def sample( else forward_batch.seq_lens - 1 ), ) + + return next_token_ids def compute_logprobs_only( diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 6cfa91e87c9a..fdf1a330b92f 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -5,13 +5,23 @@ import torch -from sglang.srt.configs.model_config import get_nsa_index_head_dim, is_deepseek_nsa +from sglang.srt.configs.model_config import ( + get_nsa_index_head_dim, + is_deepseek_compressed, + is_deepseek_nsa, +) from sglang.srt.distributed.parallel_state import get_world_group +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.mem_cache.allocator import ( PagedTokenToKVPoolAllocator, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.deepseekv4_memory_pool import ( + DeepSeekV4IndexerPool, + DeepSeekV4TokenToKVPool, +) +from sglang.srt.mem_cache.hisparse_memory_pool import HiSparseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, HybridLinearKVPool, @@ -46,9 +56,32 @@ class ModelRunnerKVCacheMixin: def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int: kv_size = torch._utils._element_size(self.kv_cache_dtype) - if self.use_mla_backend: + if is_deepseek_compressed(self.model_config.hf_config): + assert kv_size == 1, kv_size # uint8 + cell_size = ( - (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) + ( + self.model_config.qk_nope_head_dim + + self.model_config.qk_rope_head_dim * 2 + ) + * num_layers + * kv_size + ) + index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config) + indexer_size_per_token = ( + index_head_dim + + index_head_dim // DeepSeekV4IndexerPool.quant_block_size * 4 + ) + element_size = torch._utils._element_size( + DeepSeekV4IndexerPool.index_k_with_scale_buffer_dtype + ) + cell_size += indexer_size_per_token * num_layers * element_size + elif self.use_mla_backend: + cell_size = ( + ( + self.model_config.qk_nope_head_dim + + self.model_config.qk_rope_head_dim + ) * num_layers * kv_size ) @@ -108,7 +141,7 @@ def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int: ) return cell_size - def profile_max_num_token(self: ModelRunner, total_gpu_memory: int): + def profile_max_num_token(self: ModelRunner): available_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, @@ -135,7 +168,7 @@ def profile_max_num_token(self: ModelRunner, total_gpu_memory: int): cell_size = self.get_cell_size_per_token(num_layers) - rest_memory = available_gpu_memory - total_gpu_memory * ( + rest_memory = available_gpu_memory - self.total_gpu_memory * ( 1 - self.mem_fraction_static ) if self.mambaish_config is not None: @@ -206,57 +239,140 @@ def handle_max_mamba_cache(self: ModelRunner, total_rest_memory): def set_num_tokens_hybrid_swa(self: ModelRunner): page_size = self.server_args.page_size - assert self.sliding_window_size is not None and self.sliding_window_size > 0 full_layers_num = len(self.model_config.full_attention_layer_ids) swa_layers_num = len(self.model_config.swa_attention_layer_ids) - assert swa_layers_num > 0, "Hybrid SWA model must have at least one SWA layer" - def align_page_size(x: int) -> int: + def align_to_page(x: int) -> int: return (x // page_size) * page_size if full_layers_num == 0: - # all layers are SWA - self.swa_max_total_num_tokens = align_page_size(self.max_total_num_tokens) + self.swa_max_total_num_tokens = align_to_page(self.max_total_num_tokens) self.full_max_total_num_tokens = 0 self.max_total_num_tokens = self.swa_max_total_num_tokens logger.info( - f"Use sliding window memory pool (all SWA). swa_layer_tokens={self.swa_max_total_num_tokens}" + f"Use sliding window memory pool (all SWA). " + f"swa_layer_tokens={self.swa_max_total_num_tokens}" ) return - # Algorithm: - # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens. - # - Find total # of tokens available across layers. - # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio. + total_tokens = self.max_total_num_tokens * self.model_config.num_hidden_layers - swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio + ratio = self.server_args.swa_full_tokens_ratio + denominator = full_layers_num + ratio * swa_layers_num + assert denominator > 0, ( + f"Invalid denominator={denominator}: " + f"ratio={ratio}, swa_layers={swa_layers_num}, full_layers={full_layers_num}" + ) - # Solve the equations: - # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens - # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens - denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num - assert ( - denominator > 0 - ), f"Invalid denominator={denominator} for swa_full_tokens_ratio={swa_full_tokens_ratio} and swa_layers_num={swa_layers_num} and full_layers_num={full_layers_num}" self.full_max_total_num_tokens = int(total_tokens / denominator) - self.swa_max_total_num_tokens = int( - self.full_max_total_num_tokens * swa_full_tokens_ratio + self.swa_max_total_num_tokens = int(self.full_max_total_num_tokens * ratio) + + self.full_max_total_num_tokens = align_to_page(self.full_max_total_num_tokens) + self.swa_max_total_num_tokens = align_to_page(self.swa_max_total_num_tokens) + + self.max_total_num_tokens = self.full_max_total_num_tokens + + logger.info( + f"Use sliding window memory pool. " + f"full_layer_tokens={self.full_max_total_num_tokens}, " + f"swa_layer_tokens={self.swa_max_total_num_tokens}" ) - self.full_max_total_num_tokens = align_page_size(self.full_max_total_num_tokens) - self.swa_max_total_num_tokens = align_page_size(self.swa_max_total_num_tokens) + def set_num_tokens_hybrid_swa_compress(self: ModelRunner): + from sglang.srt.model_executor.memory_profiler import DSv4MemoryCalculator + + self.state_dtype = torch.float32 + logger.info(f"DSv4 compressed attention: kv_cache_dtype={self.kv_cache_dtype}") + logger.info(f"DSv4 compressed attention: state_dtype={self.state_dtype}") + + page_size = self.server_args.page_size + assert ( + page_size % 128 == 0 + ), "page_size must be multiple of 128 for compressed attention" + + # Online c128 keeps a single in-progress (max, sum, kv) state per index + # and assumes a strict forward-only schedule. Speculative decode (MTP) + # would need rollback / replay of that state across draft and verify, + # which the online path doesn't support yet. + if envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + assert ( + self.spec_algorithm.is_none() + ), "SGLANG_OPT_USE_ONLINE_COMPRESS does not support speculative decode (MTP) yet" + logger.info("DSv4 compressed attention: online c128 enabled (ring_size=1)") + + if not self.spec_algorithm.is_none() and self.is_draft_worker: + config = getattr(self.server_args, "_draft_pool_config", None) + assert ( + config is not None + ), "Draft worker requires target's pool config but _draft_pool_config is not set." + self.full_max_total_num_tokens = config["full_max_total_num_tokens"] + self.swa_max_total_num_tokens = config["swa_max_total_num_tokens"] + self.c4_max_total_num_tokens = 0 + self.c128_max_total_num_tokens = 0 + self.c4_state_pool_size = 0 + self.c128_state_pool_size = 0 + + logger.info( + f"DSv4 pool sizes (DRAFT): using TARGET's pool sizes - " + f"full={self.full_max_total_num_tokens}, " + f"swa={self.swa_max_total_num_tokens}" + ) + return + c4_shrink = ( + envs.SGLANG_OPT_HISPARSE_C4_SHRINK.get() if self.enable_hisparse else 1 + ) + if c4_shrink > 1: + logger.info( + f"HiSparse c4 pool shrink factor = {c4_shrink} " + f"(set via SGLANG_OPT_HISPARSE_C4_SHRINK)" + ) + calculator = DSv4MemoryCalculator( + model_config=self.model_config, + page_size=page_size, + swa_ratio=self.server_args.swa_full_tokens_ratio, + is_speculative=self.server_args.speculative_algorithm is not None, + c4_shrink_factor=c4_shrink, + ) + + pool_sizes = calculator.get_pool_sizes_by_profiling(self) + if ( + self.server_args.max_total_tokens is not None + and pool_sizes.full_max_total_num_tokens > self.max_total_num_tokens + ): + pool_sizes = calculator.get_pool_sizes_by_configuration( + max_total_tokens=self.max_total_num_tokens + ) + + self.full_max_total_num_tokens = pool_sizes.full_max_total_num_tokens + self.swa_max_total_num_tokens = pool_sizes.swa_max_total_num_tokens + self.c4_max_total_num_tokens = pool_sizes.c4_max_total_num_tokens + self.c128_max_total_num_tokens = pool_sizes.c128_max_total_num_tokens + self.c4_state_pool_size = pool_sizes.c4_state_pool_size + self.c128_state_pool_size = pool_sizes.c128_state_pool_size self.max_total_num_tokens = self.full_max_total_num_tokens + if not self.spec_algorithm.is_none() and not self.is_draft_worker: + self.server_args._draft_pool_config = { + "full_max_total_num_tokens": self.full_max_total_num_tokens, + "swa_max_total_num_tokens": self.swa_max_total_num_tokens, + } + logger.info( - f"Use sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}" + f"DSv4 pool sizes: " + f"full={self.full_max_total_num_tokens}, " + f"swa={self.swa_max_total_num_tokens}, " + f"c4={self.c4_max_total_num_tokens}, " + f"c128={self.c128_max_total_num_tokens}, " + f"c4_state={self.c4_state_pool_size}, " + f"c128_state={self.c128_state_pool_size}" ) - def init_memory_pool(self: ModelRunner, total_gpu_memory: int): + def init_memory_pool(self: ModelRunner): max_num_reqs = self.server_args.max_running_requests - max_total_tokens = self.server_args.max_total_tokens - self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + max_total_tokens_configured = self.server_args.max_total_tokens + self.max_total_num_tokens = self.profile_max_num_token() if max_num_reqs is None: max_num_reqs = min( @@ -292,14 +408,16 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): max_num_reqs, self.server_args.max_running_requests // self.dp_size ) - if max_total_tokens is not None: - if max_total_tokens > self.max_total_num_tokens: + if max_total_tokens_configured is not None: + if max_total_tokens_configured > self.max_total_num_tokens: logging.warning( - f"max_total_tokens={max_total_tokens} is larger than the profiled value " + f"max_total_tokens={max_total_tokens_configured} is larger than the profiled value " f"{self.max_total_num_tokens}. " f"Use the profiled value instead." ) - self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) + self.max_total_num_tokens = min( + self.max_total_num_tokens, max_total_tokens_configured + ) self.max_total_num_tokens = ( self.max_total_num_tokens @@ -322,7 +440,11 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): # create token size for hybrid cache if self.is_hybrid_swa: - self.set_num_tokens_hybrid_swa() + assert self.sliding_window_size is not None and self.sliding_window_size > 0 + if self.model_config.is_swa_with_compressed_attention: + self.set_num_tokens_hybrid_swa_compress() + else: + self.set_num_tokens_hybrid_swa() if not self.spec_algorithm.is_none() and not self.is_draft_worker: # Draft worker should use SWA adjusted max_total_num_tokens for cache size, otherwise it may cause oob in kv cache store @@ -399,7 +521,46 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): # Initialize token_to_kv_pool is_nsa_model = is_deepseek_nsa(self.model_config.hf_config) - if self.server_args.attention_backend == "ascend": + is_v4_model = is_deepseek_compressed(self.model_config.hf_config) + if is_v4_model: + + swa_page_size = self.page_size + assert swa_page_size == 256, "In paged swa mode, page_size must be 256." + + if self.is_draft_worker: + from sglang.srt.models.deepseek_v4_nextn import ( + COMPRESS_RATIO_NEXTN_LAYER, + ) + + compression_ratios = [ + COMPRESS_RATIO_NEXTN_LAYER + ] * self.num_effective_layers + else: + compression_ratios = self.model_config.compress_ratios + self.token_to_kv_pool = DeepSeekV4TokenToKVPool( + max_num_reqs=self.server_args.max_running_requests, + swa_size=self.swa_max_total_num_tokens, + c4_size=self.c4_max_total_num_tokens, + c128_size=self.c128_max_total_num_tokens, + c4_state_pool_size=self.c4_state_pool_size, + c128_state_pool_size=self.c128_state_pool_size, + page_size=self.page_size, + swa_page_size=swa_page_size, + dtype=self.kv_cache_dtype, + state_dtype=self.state_dtype, + qk_nope_head_dim=self.model_config.qk_nope_head_dim, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + indexer_head_dim=self.model_config.index_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + compression_ratios=compression_ratios, + start_layer=self.start_layer, + end_layer=self.end_layer, + enable_hisparse=self.enable_hisparse, + ) + + elif self.server_args.attention_backend == "ascend": if self.use_mla_backend: from sglang.srt.hardware_backend.npu.memory_pool_npu import ( NPUMLATokenToKVPool, @@ -638,6 +799,10 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): kvcache=self.token_to_kv_pool, need_sort=need_sort, ) + if self.enable_hisparse: + self.token_to_kv_pool_allocator = HiSparseTokenToKVPoolAllocator( + self.token_to_kv_pool_allocator + ) else: assert self.is_draft_worker diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index cb6aa0d78114..92895ad35529 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -495,8 +495,10 @@ def _get_weights_iterator( hf_weights_files, ) elif use_safetensors: - weight_loader_disable_mmap = ( - get_global_server_args().weight_loader_disable_mmap + server_args = get_global_server_args() + weight_loader_disable_mmap = server_args.weight_loader_disable_mmap + weight_loader_drop_cache_after_load = ( + server_args.weight_loader_drop_cache_after_load ) if self.load_config.load_format == LoadFormat.FASTSAFETENSORS: @@ -510,10 +512,13 @@ def _get_weights_iterator( "num_threads", self.DEFAULT_NUM_THREADS ), disable_mmap=weight_loader_disable_mmap, + drop_cache_after_load=weight_loader_drop_cache_after_load, ) else: weights_iterator = safetensors_weights_iterator( - hf_weights_files, disable_mmap=weight_loader_disable_mmap + hf_weights_files, + disable_mmap=weight_loader_disable_mmap, + drop_cache_after_load=weight_loader_drop_cache_after_load, ) else: diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py index f6cabe6dba18..8821c3155b1d 100644 --- a/python/sglang/srt/model_loader/utils.py +++ b/python/sglang/srt/model_loader/utils.py @@ -122,7 +122,10 @@ def post_load_weights(model: nn.Module, model_config: ModelConfig): # 2. Post-processing of weights, including assigning specific member variables. # For `dummy_init`, only the second stage is required. if hasattr(model, "post_load_weights"): - if model_config.hf_config.architectures[0] == "DeepseekV3ForCausalLMNextN": + if model_config.hf_config.architectures[0] in ( + "DeepseekV3ForCausalLMNextN", + "DeepseekV4ForCausalLMNextN", + ): model.post_load_weights(is_nextn=True) else: model.post_load_weights() diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 1bfe0facd914..881e0e77cf59 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -712,11 +712,30 @@ def safetensors_encrypted_weights_iterator( raise NotImplementedError() +def _drop_file_cache_after_load(path: str) -> None: + """Release of checkpoint pages after weights have been copied out. Used to avoid CPU OOM in RL. """ + posix_fadvise = getattr(os, "posix_fadvise", None) + dontneed = getattr(os, "POSIX_FADV_DONTNEED", None) + if posix_fadvise is None or dontneed is None: + return + + fd = None + try: + fd = os.open(path, os.O_RDONLY) + posix_fadvise(fd, 0, 0, dontneed) + except OSError as e: + logger.debug("Failed to drop file cache for %s: %s", path, e) + finally: + if fd is not None: + os.close(fd) + + def safetensors_weights_iterator( hf_weights_files: List[str], is_all_weights_sharded: bool = False, decryption_key: Optional[str] = None, disable_mmap: bool = False, + drop_cache_after_load: bool = False, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files. @@ -748,6 +767,8 @@ def safetensors_weights_iterator( with safetensors.safe_open(st_file, framework="pt", device="cpu") as f: for name in f.keys(): yield name, f.get_tensor(name) + if drop_cache_after_load: + _drop_file_cache_after_load(st_file) def fastsafetensors_weights_iterator( @@ -811,6 +832,7 @@ def multi_thread_safetensors_weights_iterator( decryption_key: Optional[str] = None, max_workers: int = 4, disable_mmap: bool = False, + drop_cache_after_load: bool = False, ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Multi-Thread iterate over the weights in the model safetensor files. @@ -838,7 +860,7 @@ def _load_file(st_file: str): with safetensors.safe_open(st_file, framework="pt", device="cpu") as f: result = {k: f.get_tensor(k) for k in f.keys()} - return result + return st_file, result with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files] @@ -855,9 +877,12 @@ def _load_file(st_file: str): futures_iter = concurrent.futures.as_completed(futures) for future in futures_iter: - state_dict = future.result() + st_file, state_dict = future.result() for name, param in state_dict.items(): yield name, param + del state_dict + if drop_cache_after_load: + _drop_file_cache_after_load(st_file) def _load_pt_file(bin_file: str) -> dict: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index cde44eb93d61..eb1b7dad37a9 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -20,7 +20,7 @@ import logging import os from contextlib import nullcontext -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -74,6 +74,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, + get_dp_global_num_tokens, is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm @@ -89,6 +90,7 @@ get_moe_runner_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, ) +from sglang.srt.layers.moe.deepseek_v4_topk import HashTopK from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.kt_ep_wrapper import KTEPWrapperMethod @@ -109,6 +111,7 @@ per_tensor_quant_mla_fp8, per_token_group_quant_mla_deep_gemm_masked_fp8, ) +from sglang.srt.layers.quantization.mxfp4_deepseek import DeepSeekMxfp4MoEMethod from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope_wrapper from sglang.srt.layers.utils import PPMissingLayer @@ -155,6 +158,10 @@ use_intel_amx_backend, ) +if TYPE_CHECKING: + from deep_gemm import SymmBuffer + + if _use_aiter_gfx95: from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( @@ -220,9 +227,11 @@ def __init__( prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, + swiglu_limit: Optional[float] = None, ) -> None: super().__init__() self.tp_size = tp_size + self.swiglu_limit = swiglu_limit self.gate_up_proj = MergedColumnParallelLinear( hidden_size, @@ -276,6 +285,12 @@ def forward( x = (x, None, y) gate_up, _ = self.gate_up_proj(x) + if self.swiglu_limit is not None: + _g, _u = gate_up.chunk(2, dim=-1) + _lim = float(self.swiglu_limit) + gate_up = torch.cat( + [_g.clamp(max=_lim), _u.clamp(min=-_lim, max=_lim)], dim=-1 + ) x = self.act_fn(gate_up) x, _ = self.down_proj( x, @@ -291,13 +306,16 @@ def __init__( quant_config, prefix: str = "", is_nextn: bool = False, + is_hash_moe: bool = False, + is_deepseek_v4: bool = False, ): super().__init__() self.is_nextn = is_nextn self.weight = nn.Parameter( torch.empty((config.n_routed_experts, config.hidden_size)) ) - if config.topk_method == "noaux_tc": + + if config.topk_method == "noaux_tc" and not is_hash_moe: correction_bias_dtype = ( torch.bfloat16 if quant_config is not None @@ -331,8 +349,8 @@ def forward( if get_global_server_args().enable_deterministic_inference: return F.linear(hidden_states, self.weight, None) - if forward_batch is not None and nsa_use_prefill_cp(forward_batch): - logits = F.linear(hidden_states, self.weight, None) + if False: + pass else: # NOTE: For some unknown reason, router_gemm seems degrade accept length. if ( @@ -352,11 +370,50 @@ def forward( hidden_states, self.weight, gemm_output_zero_allocator ) else: - logits = F.linear(hidden_states, self.weight, None) + from sglang.jit_kernel.deepseek_v4 import linear_bf16_fp32 + + logits = linear_bf16_fp32(hidden_states, self.weight) return logits +_MEGA_MOE_SYMM_BUFFER: dict = {} + + +def _get_mega_moe_symm_buffer( + group, + num_experts: int, + num_max_tokens_per_rank: int, + num_topk: int, + hidden: int, + intermediate_hidden: int, +) -> SymmBuffer: + import deep_gemm + + key = ( + id(group), + num_max_tokens_per_rank, + num_experts, + num_topk, + hidden, + intermediate_hidden, + ) + buf = _MEGA_MOE_SYMM_BUFFER.get(key) + if buf is None: + buf = deep_gemm.get_symm_buffer_for_mega_moe( + group, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch=True, + activation="swiglu", + ) + _MEGA_MOE_SYMM_BUFFER[key] = buf + return buf + + class DeepseekV2MoE(nn.Module): def __init__( @@ -367,6 +424,7 @@ def __init__( prefix: str = "", alt_stream: Optional[torch.cuda.Stream] = None, is_nextn: bool = False, + is_deepseek_v4: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -383,6 +441,12 @@ def __init__( self.alt_stream = alt_stream self.is_nextn = is_nextn + if envs.SGLANG_DSV4_MODE.get() == "2604": + n_hash_layers = config.num_hash_layers + else: + n_hash_layers = getattr(config, "n_hash_layers", 0) + self.is_hash = layer_id < n_hash_layers and not (is_deepseek_v4 and is_nextn) + if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -400,6 +464,8 @@ def __init__( quant_config=quant_config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn, + is_hash_moe=self.is_hash, + is_deepseek_v4=is_deepseek_v4, ) # scaling factor for fused shared experts on AMD-platform. @@ -424,55 +490,77 @@ def __init__( routing_method_type=getattr( config, "routing_method_type", RoutingMethodType.DeepSeekV3 ), + swiglu_limit=getattr(config, "swiglu_limit", None), prefix=add_prefix("experts", prefix), ) - self.topk = TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, - layer_id=self.layer_id, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - quant_config=quant_config, - routed_scaling_factor=self.routed_scaling_factor, - apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, - fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor, - # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized - # and requires the output format to be standard (except trtllm). We use quant_config to determine the output format. - output_format=( - TopKOutputFormat.STANDARD - if (quant_config is None) - and (not get_moe_runner_backend().is_flashinfer_trtllm()) - else None - ), - ) + self.use_grouped_topk = config.n_group > config.topk_group + + if self.is_hash and not (is_nextn and is_deepseek_v4): + self.topk = HashTopK( + topk=config.num_experts_per_tok + self.num_fused_shared_experts, + num_experts=config.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, + vocab_size=config.vocab_size, + scoring_func=config.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, + ) + else: + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + layer_id=self.layer_id, + renormalize=config.norm_topk_prob, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + scoring_func=config.scoring_func, + correction_bias=self.gate.e_score_correction_bias, + quant_config=quant_config, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, + fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor, + output_format=( + TopKOutputFormat.STANDARD + if (quant_config is None) + and (not get_moe_runner_backend().is_flashinfer_trtllm()) + else None + ), + ) self.shared_experts_is_int8 = False self.shared_experts_is_fp8 = False self.shared_experts_weight_block_size = None + self._shared_expert_tp1 = False if config.n_shared_experts is not None and self.num_fused_shared_experts == 0: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - # disable tp for shared experts when enable deepep moe, or with fp4 allgather + # disable tp for shared experts when enable deepep moe, or with fp4 allgather, + # or when DSV4 FP4 experts are used (shared experts remain FP8 whose scale + # shape may not be divisible by tp_size, e.g. scale [24,56] vs tp=16). + _shared_expert_use_tp1 = ( + get_moe_a2a_backend().is_deepep() + or get_moe_a2a_backend().is_mooncake() + or get_moe_a2a_backend().is_ascend_fuseep() + or get_moe_a2a_backend().is_flashinfer() + or should_use_flashinfer_cutlass_moe_fp4_allgather() + or envs.SGLANG_SHARED_EXPERT_TP1.get() + ) self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + swiglu_limit=getattr(config, "swiglu_limit", None), prefix=add_prefix("shared_experts", prefix), **( dict(tp_rank=0, tp_size=1) - if get_moe_a2a_backend().is_deepep() - or get_moe_a2a_backend().is_mooncake() - or get_moe_a2a_backend().is_ascend_fuseep() - or get_moe_a2a_backend().is_flashinfer() - or should_use_flashinfer_cutlass_moe_fp4_allgather() + if _shared_expert_use_tp1 else {} ), ) + self._shared_expert_tp1 = _shared_expert_use_tp1 is_packed_weight = hasattr( self.shared_experts.gate_up_proj.quant_method, "quant_config" ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in { @@ -535,6 +623,9 @@ def __init__( ) self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo() + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert hasattr(self, "shared_experts") + def get_moe_weights(self): return [ x.data @@ -552,12 +643,22 @@ def forward( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: + if self._should_use_mega_moe(hidden_states): + return self.forward_mega_moe( + hidden_states, + forward_batch, + input_ids_global=input_ids_global, + ) + if not self._enable_a2a_moe: from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if ( - self.alt_stream is not None + envs.SGLANG_OPT_ALLOW_SHARED_EXPERT_DUAL_STREAM.get() + and self.alt_stream is not None and self.num_fused_shared_experts == 0 and hidden_states.shape[0] > 0 and get_is_capture_mode() @@ -567,6 +668,8 @@ def forward( should_allreduce_fusion, use_reduce_scatter, gemm_output_zero_allocator, + input_ids, + input_ids_global=input_ids_global, ) else: return self.forward_normal( @@ -574,9 +677,13 @@ def forward( should_allreduce_fusion, use_reduce_scatter, gemm_output_zero_allocator, + input_ids, + input_ids_global=input_ids_global, ) else: - return self.forward_deepep(hidden_states, forward_batch) + return self.forward_deepep( + hidden_states, forward_batch, input_ids_global=input_ids_global + ) def forward_normal_dual_stream( self, @@ -584,6 +691,8 @@ def forward_normal_dual_stream( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() @@ -595,13 +704,28 @@ def forward_normal_dual_stream( with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) - topk_output = self.topk(hidden_states, router_logits) + topk_kwargs = {"input_ids": input_ids_global} if self.is_hash else {} + topk_output = self.topk(hidden_states, router_logits, **topk_kwargs) final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda or isinstance(self.experts.quant_method, KTEPWrapperMethod): final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - final_hidden_states += shared_output + + if ( + isinstance(self.experts.quant_method, DeepSeekMxfp4MoEMethod) + and envs.SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD.get() + ): + if not self._shared_expert_tp1: + final_hidden_states = shared_output.add_( + final_hidden_states, alpha=self.routed_scaling_factor + ) + else: + final_hidden_states.mul_(self.routed_scaling_factor) + else: + if not self._shared_expert_tp1: + final_hidden_states += shared_output + if ( self.tp_size > 1 and not should_allreduce_fusion @@ -609,6 +733,10 @@ def forward_normal_dual_stream( and not should_use_flashinfer_cutlass_moe_fp4_allgather() ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + # When shared expert uses TP1 (replicated weights), its output is already + # complete and must be added AFTER all-reduce to avoid being summed tp_size times. + if self._shared_expert_tp1: + final_hidden_states += shared_output return final_hidden_states def forward_normal( @@ -617,6 +745,8 @@ def forward_normal( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -632,7 +762,8 @@ def forward_normal( ) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) - topk_output = self.topk(hidden_states, router_logits) + topk_kwargs = {"input_ids": input_ids_global} if self.is_hash else {} + topk_output = self.topk(hidden_states, router_logits, **topk_kwargs) else: shared_output = None topk_output = self.topk.empty_topk_output(hidden_states.device) @@ -678,8 +809,21 @@ def _post_combine_hook( ): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor - if shared_output is not None: - final_hidden_states += shared_output + + if ( + isinstance(self.experts.quant_method, DeepSeekMxfp4MoEMethod) + and envs.SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD.get() + ): + if shared_output is not None and not self._shared_expert_tp1: + final_hidden_states = shared_output.add_( + final_hidden_states, alpha=self.routed_scaling_factor + ) + else: + final_hidden_states.mul_(self.routed_scaling_factor) + else: + if shared_output is not None and not self._shared_expert_tp1: + final_hidden_states += shared_output + if ( self.tp_size > 1 and not should_allreduce_fusion @@ -687,6 +831,10 @@ def _post_combine_hook( and not should_use_flashinfer_cutlass_moe_fp4_allgather() ): final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + # When shared expert uses TP1 (replicated weights), its output is already + # complete and must be added AFTER all-reduce to avoid being summed tp_size times. + if shared_output is not None and self._shared_expert_tp1: + final_hidden_states += shared_output return final_hidden_states def forward_cpu( @@ -751,6 +899,7 @@ def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: shared_output = None sbo_enabled_flag = self._fuse_shared_experts_inside_sbo and not self.is_nextn @@ -773,6 +922,7 @@ def forward_deepep( shared_event = self.alt_stream.record_event() else: shared_output = self._forward_shared_experts(hidden_states) + topk_kwargs = {"input_ids": input_ids_global} if self.is_hash else {} topk_output = self.topk( hidden_states, router_logits, @@ -780,6 +930,7 @@ def forward_deepep( expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ), + **topk_kwargs, ) else: topk_output = self.topk.empty_topk_output(hidden_states.device) @@ -956,6 +1107,192 @@ def _post_combine_hook( return final_hidden_states + def _should_use_mega_moe(self, hidden_states: torch.Tensor) -> bool: + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + if not envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): + return False + if not getattr(self.experts, "_mega_moe_weights_built", False): + return False + + if not envs.SGLANG_OPT_FIX_NEXTN_MEGA_MOE.get(): + if self.is_nextn: + return False + + if not envs.SGLANG_OPT_FIX_HASH_MEGA_MOE.get(): + if self.is_hash: + return False + + if get_is_capture_mode(): + return True + + global_num_tokens = get_dp_global_num_tokens() + if global_num_tokens: + max_tokens_per_rank = max(global_num_tokens) + else: + max_tokens_per_rank = hidden_states.shape[0] + cap = envs.SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK.get() + return max_tokens_per_rank <= cap + + def forward_mega_moe( + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + input_ids_global: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, + ) + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + num_tokens = hidden_states.shape[0] + + sbo_overlap_flag = ( + self.alt_stream is not None + and self.num_fused_shared_experts == 0 + and num_tokens > 0 + and get_is_capture_mode() + ) + deepseek_v4_moe_code_path_checker.observed += 1 + + if sbo_overlap_flag: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + shared_output = self._forward_shared_experts(hidden_states) + mega_stream_ctx = torch.cuda.stream(self.alt_stream) + else: + shared_output = self._forward_shared_experts(hidden_states) + mega_stream_ctx = nullcontext() + + with mega_stream_ctx: + y = self._run_mega_routed( + hidden_states, forward_batch, input_ids_global, num_tokens + ) + + if sbo_overlap_flag: + current_stream.wait_stream(self.alt_stream) + + if shared_output is not None: + y.add_(shared_output) + return y + + def _run_mega_routed( + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch], + input_ids_global: Optional[torch.Tensor], + num_tokens: int, + ) -> torch.Tensor: + import deep_gemm + + from sglang.srt.distributed.parallel_state import get_moe_ep_group + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8_ue8m0, + ) + + hidden_size = self.config.hidden_size + + if num_tokens > 0: + router_logits = self.gate(hidden_states, forward_batch=forward_batch) + topk_kwargs = {"input_ids": input_ids_global} if self.is_hash else {} + topk_output = self.topk( + hidden_states, + router_logits, + num_token_non_padded=( + forward_batch.num_token_non_padded + if forward_batch is not None + else None + ), + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), + **topk_kwargs, + ) + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + else: + topk_ids = None + topk_weights = None + + ep_group = get_moe_ep_group().device_group + num_experts = self.experts.num_experts + top_k = self.config.num_experts_per_tok + self.num_fused_shared_experts + intermediate_size = self.config.moe_intermediate_size + num_max_tokens_per_rank = ( + envs.SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK.get() + ) + assert num_tokens <= num_max_tokens_per_rank, ( + f"mega MoE: num_tokens={num_tokens} exceeds cap " + f"SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK=" + f"{num_max_tokens_per_rank}; raise the env var or shrink " + f"cuda_graph_max_bs / chunked_prefill_size accordingly" + ) + + buf = _get_mega_moe_symm_buffer( + ep_group, + num_experts=num_experts, + num_max_tokens_per_rank=num_max_tokens_per_rank, + num_topk=top_k, + hidden=hidden_size, + intermediate_hidden=intermediate_size, + ) + + padded_max = buf.topk_idx.shape[0] + if envs.SGLANG_OPT_MEGA_MOE_FUSED_PRE_DISPATCH.get(): + from sglang.jit_kernel.deepseek_v4 import mega_moe_pre_dispatch + + if num_tokens > 0: + topk_ids_in = topk_ids + topk_weights_in = topk_weights + else: + topk_ids_in = hidden_states.new_empty((0, top_k), dtype=torch.int32) + topk_weights_in = hidden_states.new_empty( + (0, top_k), dtype=torch.float32 + ) + mega_moe_pre_dispatch( + hidden_states, + topk_ids_in, + topk_weights_in, + buf.x, + buf.x_sf, + buf.topk_idx, + buf.topk_weights, + quant_group_size=32, + ) + else: + if num_tokens > 0: + x_fp8, x_sf = sglang_per_token_group_quant_fp8_ue8m0( + hidden_states, group_size=32 + ) + buf.x[:num_tokens].copy_(x_fp8) + buf.x_sf[:num_tokens].copy_(x_sf) + buf.topk_idx[:num_tokens].copy_(topk_ids) + buf.topk_weights[:num_tokens].copy_(topk_weights) + if num_tokens < padded_max: + buf.topk_idx[num_tokens:].fill_(-1) + buf.topk_weights[num_tokens:].zero_() + + y = torch.empty( + (num_tokens, hidden_size), + dtype=torch.bfloat16, + device=hidden_states.device, + ) + swiglu_limit = getattr(self.config, "swiglu_limit", None) + deep_gemm.fp8_fp4_mega_moe( + y, + self.experts.mega_l1_weights, + self.experts.mega_l2_weights, + buf, + recipe=(1, 1, 32), + activation="swiglu", + activation_clamp=swiglu_limit, + fast_math=True, + ) + + if not self.experts.should_fuse_routed_scaling_factor_in_topk: + y.mul_(self.routed_scaling_factor) + return y + def _forward_shared_experts( self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None ): @@ -2286,6 +2623,7 @@ def __init__( prefix=add_prefix("mlp", prefix), tp_rank=mlp_tp_rank, tp_size=mlp_tp_size, + swiglu_limit=getattr(config, "swiglu_limit", None), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py new file mode 100644 index 000000000000..a71b854352c0 --- /dev/null +++ b/python/sglang/srt/models/deepseek_v4.py @@ -0,0 +1,2098 @@ +from __future__ import annotations + +import concurrent.futures +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +import sglang.srt.models.deepseek_v2 as deepseek_v2 +from sglang.jit_kernel.deepseek_v4 import fused_rope, linear_bf16_fp32, rmsnorm_self +from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config +from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, +) +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size +from sglang.srt.environ import envs, is_large_dummy_model +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation +from sglang.srt.layers.attention.nsa.utils import ( + assert_tensor_identical_across_cp_ranks, + can_cp_split, + cp_all_gather_rerange_output, + cp_split_and_rebuild_data, + cp_split_and_rebuild_position, + is_nsa_enable_prefill_cp, + nsa_use_prefill_cp, + prepare_input_dp_with_cp_dsa, +) +from sglang.srt.layers.communicator import LayerScatterModes, get_attn_tp_context +from sglang.srt.layers.deepseek_v4_rope import apply_rotary_emb_triton +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + attn_tp_all_gather, + dp_gather_partial, + dp_scatter, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, + get_global_dp_buffer, + get_local_dp_buffer, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 +from sglang.srt.layers.rotary_embedding import get_rope_wrapper +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.mem_cache.compress_state import CompressStatePool +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool +from sglang.srt.mem_cache.memory_pool import RadixAttention +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.model_loader.utils import maybe_executor_submit, should_async_load +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.dbrx import ReplicatedLinear +from sglang.srt.models.deepseek_v2 import ParallelLMHead, _is_cuda, _is_hip, _is_npu +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ( + BumpAllocator, + LazyValue, + add_prefix, + log_info_on_rank0, + make_layers, + maybe_torch_compile, +) + +logger = logging.getLogger(__name__) + +from sglang.srt.environ import envs + +MOE_BIT_WISE_EQUAL_MODE = False +ATTN_BIT_WISE_EQUAL_MODE = False +COMPRESSOR_BIT_WISE_EQUAL_MODE = False +_FP8_WO_A_GEMM = envs.SGLANG_OPT_FP8_WO_A_GEMM.get() + + +if TYPE_CHECKING: + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DeepseekV4BackendRadix, + ) + from sglang.srt.layers.quantization import QuantizationConfig + from sglang.srt.layers.rotary_embedding import RotaryEmbedding + from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + PPProxyTensors, + ) + + +class DeepseekRefRMSNorm(nn.Module): + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + out = rms_normalize_triton(x, self.eps, self.weight) + return out + + +@maybe_torch_compile +def rms_normalize(x: torch.Tensor, eps: float) -> torch.Tensor: + x *= torch.rsqrt(x.square().mean(-1, keepdim=True) + eps) + return x + + +@triton.jit +def _rms_normalize_kernel( + x_ptr, + weight_ptr, + eps, + stride_row, + dim, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, +): + pid = tl.program_id(0) + + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < dim + + base = pid * stride_row + x = tl.load(x_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + + mean_sq = tl.sum(x * x, axis=0) / dim + rms_inv = tl.rsqrt(mean_sq + eps) + out = x * rms_inv + + if HAS_WEIGHT: + weight = tl.load(weight_ptr + offs, mask=mask, other=0.0) + out = out * weight + + tl.store(x_ptr + base + offs, out, mask=mask) + + +def rms_normalize_triton( + x: torch.Tensor, eps: float, weight: torch.Tensor = None +) -> torch.Tensor: + dim = x.shape[-1] + x_flat = x.view(-1, dim) + num_rows = x_flat.shape[0] + + BLOCK_SIZE = triton.next_power_of_2(dim) + grid = (num_rows,) + + _rms_normalize_kernel[grid]( + x_flat, + weight, + eps, + x_flat.stride(0), + dim, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=(weight is not None), + ) + return x + + +class Compressor(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + is_in_indexer: bool, + rotary_emb: RotaryEmbedding, + freqs_cis: torch.Tensor, + compress_ratio: Literal[0, 4, 128], + head_dim: int, + rotate: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.is_in_indexer = is_in_indexer + self.dim = config.hidden_size + self.head_dim = head_dim + self.rope_head_dim = getattr(config, "qk_rope_head_dim", 64) + self.nope_head_dim = head_dim - self.rope_head_dim + assert compress_ratio != 0, "compress_ratio should not be 0" + self.ratio = compress_ratio + self.overlap = self.ratio == 4 + self.rotate = rotate + self.coff = coff = 1 + self.overlap + + self.ape = nn.Parameter( + torch.empty(self.ratio, coff * self.head_dim, dtype=torch.float32) + ) + wkv_gate_dtype = torch.bfloat16 + self.wkv_gate = ReplicatedLinear( + self.dim, + 2 * coff * self.head_dim, + bias=False, + quant_config=None, + prefix=add_prefix("wkv_gate", prefix), + params_dtype=wkv_gate_dtype, + ) + self.norm = DeepseekRefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = rotary_emb + self.freqs_cis = freqs_cis + + self.ape_converted = False + + def apply_ape_hotfix(self): + assert not self.ape_converted + self.ape_converted = True + + is_model_2604 = envs.SGLANG_DSV4_MODE.get() == "2604" + if self.overlap and (envs.SGLANG_OPT_FIX_APE_2604.get() or not is_model_2604): + orders = [0, 1] if is_model_2604 else [1, 0] + ape = torch.chunk(self.ape.data, 2, dim=-1) + ape = torch.cat([ape[orders[0]], ape[orders[1]]], dim=0) + self.ape.data.copy_(ape.view(self.ratio, -1)) + + def _get_state_pool(self, forward_batch: ForwardBatch) -> CompressStatePool: + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + if self.is_in_indexer: + ret = token_to_kv_pool.get_indexer_compress_states(self.layer_id) + else: + ret = token_to_kv_pool.get_attention_compress_states(self.layer_id) + + assert isinstance(ret, CompressStatePool) + + return ret + + def overlap_transform(self, tensor: torch.Tensor, fill_value: Any) -> torch.Tensor: + assert tensor.dim() == 3 + assert tensor.shape[1:] == (self.ratio, 2 * self.head_dim) + + s, r, d = tensor.size(0), self.ratio, self.head_dim + new_tensor = tensor.new_full((s, 2 * r, d), fill_value) + new_tensor[:, r:] = tensor[:, :, d:] + new_tensor[1:, :r] = tensor[:-1, :, :d] + return new_tensor + + def overlap_transform_decode(self, tensor: torch.Tensor) -> torch.Tensor: + assert tensor.dim() == 3 + assert tensor.shape[1:] == (2 * self.ratio, 2 * self.head_dim) + r, d = self.ratio, self.head_dim + ret = torch.cat((tensor[:, :r, :d], tensor[:, r:, d:]), dim=1) + return ret + + @staticmethod + def compute_state_len(seq_len: int, ratio: int): + return seq_len % ratio + (ratio == 4) * ratio + + @staticmethod + def compute_state_len_indices(seq_len: int, ratio: int): + state_len = seq_len % ratio + (ratio == 4) * ratio + return torch.arange(seq_len - state_len, seq_len).clamp(min=-1) + + def compress_fused( + self, + kv_score: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + backend = forward_batch.attn_backend + if TYPE_CHECKING: + assert isinstance(backend, DeepseekV4BackendRadix) + kv_score_buffer = self._get_state_pool(forward_batch) + kv_score_buffer = kv_score_buffer.kv_score_buffer.kv_score + return backend.forward_compress( + kv_score_buffer=kv_score_buffer, + kv_score_input=kv_score, + ape=self.ape.view(-1, self.head_dim), + head_dim=self.head_dim, + norm=self.norm, + freqs_cis_cache=self.freqs_cis, + rotate=self.rotate, + compress_ratio=self.ratio, + forward_batch=forward_batch, + is_paged=True, + ) + + def forward(self, x: torch.Tensor, forward_batch: ForwardBatch) -> torch.Tensor: + if forward_batch.forward_mode.is_idle(): + assert x.shape[0] == 0 + return x.new_empty(0, self.head_dim) + + self.forward_mode = forward_batch.forward_mode + + kv_score = linear_bf16_fp32(x, self.wkv_gate.weight) + if nsa_use_prefill_cp(forward_batch): + kv_score = cp_all_gather_rerange_output( + kv_score, + get_attention_tp_size(), + forward_batch, + torch.cuda.current_stream(), + ) + return self.compress_fused(kv_score, forward_batch) + + +class C4Indexer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + rotary_emb: RotaryEmbedding, + freqs_cis: torch.Tensor, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + ): + super().__init__() + self.layer_id = layer_id + self.dim = config.hidden_size + self.n_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.index_topk = config.index_topk + self.q_lora_rank = config.q_lora_rank + self.softmax_scale = self.head_dim**-0.5 + self.n_local_heads = self.n_heads + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + params_dtype=torch.bfloat16, + prefix=add_prefix("wq_b", prefix), + ) + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=None, + params_dtype=torch.bfloat16, + prefix=add_prefix("weights_proj", prefix), + ) + self.compressor = Compressor( + config, + self.layer_id, + True, + rotary_emb, + freqs_cis, + compress_ratio=4, + head_dim=self.head_dim, + rotate=True, + prefix=add_prefix("compressor", prefix), + ) + self.rotary_emb = rotary_emb + self.freqs_cis = freqs_cis + self.weight_scale: float = self.softmax_scale * self.n_heads**-0.5 + self.alt_streams = alt_streams + + def compute_q(self, q_lora: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + q, _ = self.wq_b(q_lora) + q = q.view(-1, self.n_local_heads, self.head_dim) + fused_rope( + q[..., -self.rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + ) + q = rotate_activation(q) + return q + + def compute_weights(self, x: torch.Tensor, skip_scale=False) -> torch.Tensor: + out, _ = self.weights_proj(x) + if not skip_scale: + out = out * self.weight_scale + return out + + def forward( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + forward_batch: ForwardBatch, + enable_multi_stream: bool = False, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> None: + if TYPE_CHECKING: + assert isinstance(forward_batch.attn_backend, DeepseekV4BackendRadix) + return forward_batch.attn_backend.forward_c4_indexer( + x=x, + q_lora=q_lora, + forward_batch=forward_batch, + c4_indexer=self, + alt_streams=self.alt_streams, + enable_multi_stream=enable_multi_stream, + q_lora_ready=q_lora_ready, + ) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class MQALayer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + compress_ratio_override: Optional[int] = None, + ) -> None: + super().__init__() + self.tp_rank = attn_tp_rank = get_attention_tp_rank() + self.tp_size = attn_tp_size = get_attention_tp_size() + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_size = get_attention_tp_size() + self.tp_rank = attn_tp_rank = 0 + self.tp_size = attn_tp_size = 1 + self.layer_id = layer_id + self.dim = config.hidden_size + self.qk_rope_head_dim = config.qk_rope_head_dim + if envs.SGLANG_DSV4_MODE.get() == "2604": + self.qk_nope_head_dim = config.head_dim - config.qk_rope_head_dim + else: + self.qk_nope_head_dim = config.qk_nope_head_dim + self.head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim + self.n_heads = config.num_attention_heads + self.n_local_heads = self.n_heads // attn_tp_size + self.n_groups = config.o_groups + self.n_local_groups = self.n_groups // attn_tp_size + self.rope_head_dim = config.qk_rope_head_dim + self.softmax_scale = self.head_dim**-0.5 + self.hidden_size = config.hidden_size + self.q_lora_rank = config.q_lora_rank + self.o_lora_rank = config.o_lora_rank + self.eps = config.rms_norm_eps + compress_ratio = ( + compress_ratio_override + if compress_ratio_override is not None + else config.compress_ratios[layer_id] + ) + assert compress_ratio in [0, 4, 128] + self.compress_ratio: Literal[0, 4, 128] = compress_ratio + + if envs.SGLANG_DSV4_MODE.get() == "2604": + assert self.head_dim == config.head_dim + else: + assert self.head_dim == config.v_head_dim + assert config.num_key_value_heads == 1 + + rope_scaling = config.rope_scaling + if rope_scaling: + rope_scaling["rope_type"] = "deepseek_yarn" + + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get(): + assert ( + config.compress_rope_theta == 160000 + ), f"{config.compress_rope_theta=}" + rope_base = ( + config.compress_rope_theta if self.compress_ratio else config.rope_theta + ) + + self.rotary_emb = get_rope_wrapper( + head_size=self.rope_head_dim, + rotary_dim=self.rope_head_dim, + max_position=config.max_position_embeddings, + base=rope_base, + rope_scaling=rope_scaling, + is_neox_style=False, + device=get_global_server_args().device, + ) + + from sglang.srt.layers.deepseek_v4_rope import precompute_freqs_cis + + if envs.SGLANG_DSV4_MODE.get() == "2604": + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get(): + assert rope_scaling["factor"] == 16 + elif envs.SGLANG_DSV4_MODE.get() == "2601": + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get(): + assert rope_scaling["factor"] == 4 + else: + raise NotImplementedError + + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert self.compress_ratio in {0, 4, 128} + if self.compress_ratio: + original_seq_len = rope_scaling["original_max_position_embeddings"] + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get(): + assert original_seq_len == 65536 + else: + original_seq_len = 0 + else: + original_seq_len = rope_scaling["original_max_position_embeddings"] + + rope_scaling = config.rope_scaling + freqs_cis = precompute_freqs_cis( + dim=self.qk_rope_head_dim, + seqlen=config.max_position_embeddings, + original_seq_len=original_seq_len, + base=rope_base, + factor=rope_scaling["factor"], + beta_fast=rope_scaling["beta_fast"], + beta_slow=rope_scaling["beta_slow"], + ) + self.register_buffer("freqs_cis", freqs_cis, persistent=False) + self.freqs_cis: torch.Tensor + + if envs.SGLANG_OPT_USE_MULTI_STREAM_OVERLAP.get() and alt_streams is not None: + self.alt_streams = alt_streams[:3] + self.alt_streams_indexer = alt_streams[-2:] + else: + self.alt_streams = None + self.alt_streams_indexer = None + + from sglang.srt.utils import is_blackwell_supported + + self._multi_stream_bs_limit = 128 if is_blackwell_supported() else 64 + + self.compressor = None + self.indexer = None + if self.compress_ratio: + self.compressor = Compressor( + config, + layer_id=self.layer_id, + is_in_indexer=False, + rotary_emb=self.rotary_emb, + freqs_cis=freqs_cis, + compress_ratio=self.compress_ratio, + head_dim=self.head_dim, + rotate=False, + prefix=add_prefix("compressor", prefix), + ) + if self.compress_ratio == 4: + self.indexer = C4Indexer( + config, + rotary_emb=self.rotary_emb, + freqs_cis=freqs_cis, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("indexer", prefix), + alt_streams=self.alt_streams_indexer, + ) + + self.attn_sink = nn.Parameter(torch.empty(self.n_heads, dtype=torch.float32)) + self.fuse_wqa_wkv = envs.SGLANG_OPT_FUSE_WQA_WKV.get() + if self.fuse_wqa_wkv: + self.wqkv_a = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wqkv_a", prefix), + ) + else: + self.wq_a = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wq_a", prefix), + ) + self.wkv = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wkv", prefix), + ) + self.q_norm = RMSNorm(self.q_lora_rank, eps=self.eps) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wq_b", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + self.kv_norm = RMSNorm(self.head_dim, eps=self.eps) + self.wo_a = ColumnParallelLinear( + self.n_heads * self.head_dim // self.n_groups, + self.n_groups * self.o_lora_rank, + bias=False, + quant_config=quant_config if _FP8_WO_A_GEMM else None, + prefix=add_prefix("wo_a", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + **({} if _FP8_WO_A_GEMM else {"params_dtype": torch.bfloat16}), + ) + if _FP8_WO_A_GEMM: + assert hasattr( + self.wo_a, "weight_scale_inv" + ), "FP8 quant_config must create weight_scale_inv" + self.wo_a.weight_scale_inv.format_ue8m0 = True + self.wo_b = RowParallelLinear( + self.n_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=attn_tp_size > 1, + prefix=add_prefix("wo_b", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + + self.attn_mqa = RadixAttention( + self.n_local_heads, + self.head_dim, + self.softmax_scale, + num_kv_heads=1, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn_mqa", prefix), + ) + + self.overlap_store_cache = envs.SGLANG_OPT_USE_OVERLAP_STORE_CACHE.get() + self.use_jit_norm = envs.SGLANG_OPT_USE_JIT_NORM.get() + + def _compute_q_a( + self, + x: torch.Tensor, + qkv_a: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if qkv_a is not None: + q = qkv_a[..., : self.q_lora_rank] + else: + q, _ = self.wq_a(x) + q = self.q_norm(q) + q_lora = q + return q_lora + + def _compute_q_b( + self, + q: torch.Tensor, + positions: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + q, _ = self.wq_b(q) + q = q.view(-1, self.n_local_heads, self.head_dim) + if self.use_jit_norm: + q = rmsnorm_self(q, self.eps) + else: + q = rms_normalize_triton(q, self.eps) + if positions is not None: + fused_rope( + q[..., -self.qk_rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + ) + else: + apply_rotary_emb_triton(q[..., -self.qk_rope_head_dim :], self.freqs_cis) + return q + + def _compute_kv( + self, + x: torch.Tensor, + positions: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + qkv_a: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if qkv_a is not None: + kv = qkv_a[..., self.q_lora_rank :] + else: + kv, _ = self.wkv(x) + kv = self.kv_norm(kv) + if positions is not None: + fused_rope( + kv[..., -self.qk_rope_head_dim :].unsqueeze(1), + None, + self.freqs_cis, + positions=positions, + ) + else: + apply_rotary_emb_triton(kv[..., -self.qk_rope_head_dim :], self.freqs_cis) + return kv + + def _forward_prepare_multi_stream( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + attn_backend: DeepseekV4BackendRadix, + freqs_cis: Optional[torch.Tensor] = None, + q_out: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.alt_streams is not None + assert len(self.alt_streams) >= 3 + + current_stream = torch.cuda.current_stream() + stream_kv = self.alt_streams[0] + stream_compressor = self.alt_streams[1] + stream_indexer = self.alt_streams[2] + + stream_kv.wait_stream(current_stream) + stream_compressor.wait_stream(current_stream) + stream_indexer.wait_stream(current_stream) + + qkv_a: Optional[torch.Tensor] = None + qkv_a_ready: Optional[torch.cuda.Event] = None + if self.fuse_wqa_wkv: + qkv_a, _ = self.wqkv_a(x) + qkv_a_ready = current_stream.record_event() + + q_lora = self._compute_q_a(x, qkv_a=qkv_a) + q_lora_ready = current_stream.record_event() + + if self.indexer is not None: + with torch.cuda.stream(stream_indexer): + self.indexer( + x=x, + q_lora=q_lora, + forward_batch=forward_batch, + enable_multi_stream=True, + q_lora_ready=q_lora_ready, + ) + + with torch.cuda.stream(stream_kv): + if qkv_a_ready is not None: + stream_kv.wait_event(qkv_a_ready) + kv = self._compute_kv(x, positions, freqs_cis, qkv_a=qkv_a) + if self.overlap_store_cache: + attn_backend.store_cache( + layer_id=self.layer_id, + swa_k=kv, + forward_batch=forward_batch, + ) + + del qkv_a + + if self.compressor is not None: + with torch.cuda.stream(stream_compressor): + attn_backend.forward_core_compressor( + x, forward_batch, self.layer_id, self.compressor + ) + + q = self._compute_q_b(q_lora, positions, freqs_cis) + if q_out is not None: + q_out.copy_(q) + + current_stream.wait_stream(stream_kv) + current_stream.wait_stream(stream_compressor) + current_stream.wait_stream(stream_indexer) + + return q, kv + + def _forward_prepare( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + attn_backend: DeepseekV4BackendRadix, + freqs_cis: Optional[torch.Tensor] = None, + q_out: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.fuse_wqa_wkv: + qkv_a, _ = self.wqkv_a(x) + q = qkv_a[..., : self.q_lora_rank] + kv = qkv_a[..., self.q_lora_rank :] + del qkv_a + else: + kv, _ = self.wkv(x) + q, _ = self.wq_a(x) + q = self.q_norm(q) + q_lora = q + q, _ = self.wq_b(q) + q = q.view(-1, self.n_local_heads, self.head_dim) + if self.use_jit_norm: + q = rmsnorm_self(q, self.eps) + else: + q = rms_normalize_triton(q, self.eps) + + kv = self.kv_norm(kv) + + fused_rope( + q[..., -self.qk_rope_head_dim :], + kv[..., -self.qk_rope_head_dim :].unsqueeze(1), + self.freqs_cis, + positions=positions, + ) + + if self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch): + kv = cp_all_gather_rerange_output( + kv.contiguous(), + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + if envs.SGLANG_DEBUG_HACK_CP_CHECK_RANK_CONSISTENCY.get(): + assert_tensor_identical_across_cp_ranks( + kv, + tag=f"kv_after_allgather layer_id={self.layer_id}", + forward_batch=forward_batch, + ) + + if self.overlap_store_cache: + attn_backend.store_cache( + layer_id=self.layer_id, + swa_k=kv, + forward_batch=forward_batch, + ) + + if self.indexer is not None: + self.indexer(x=x, q_lora=q_lora, forward_batch=forward_batch) + if self.compressor is not None: + attn_backend.forward_core_compressor( + x, + forward_batch, + self.layer_id, + self.compressor, + ) + + if q_out is not None: + q_out.copy_(q) + return q, kv + + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + debug_return_kv: bool = False, + ) -> torch.Tensor: + if not get_attn_tp_context().input_scattered and x.shape[0] == 0: + assert ( + not self.wo_b.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return x + + attn_backend = forward_batch.attn_backend + if TYPE_CHECKING: + assert isinstance(attn_backend, DeepseekV4BackendRadix) + + freqs_cis = None + + enable_multi_stream = ( + envs.SGLANG_OPT_USE_MULTI_STREAM_OVERLAP.get() + and self.alt_streams is not None + and get_is_capture_mode() + and x.shape[0] <= self._multi_stream_bs_limit + and not (self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch)) + ) + + tp_slice, q_padded, q_out = slice(None), None, None + if self.tp_size > 1: + q_padded = x.new_empty(x.shape[0], self.n_heads, self.head_dim) + rank = self.tp_rank + tp_slice = slice(rank * self.n_local_heads, (rank + 1) * self.n_local_heads) + q_out = q_padded[:, tp_slice, :] + + if enable_multi_stream: + q, kv = self._forward_prepare_multi_stream( + x, positions, forward_batch, attn_backend, freqs_cis, q_out + ) + else: + q, kv = self._forward_prepare( + x, positions, forward_batch, attn_backend, freqs_cis, q_out + ) + + o = attn_backend.forward( + q=q_padded if q_padded is not None else q, + k=kv, + v=kv, + layer=self.attn_mqa, + forward_batch=forward_batch, + compress_ratio=self.compress_ratio, + attn_sink=self.attn_sink, + save_kv_cache=not self.overlap_store_cache, + ) + o = o[:, tp_slice, :] + fused_rope( + o[..., -self.qk_rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + inverse=True, + ) + + o = o.view(o.shape[0], self.n_local_groups, -1) + + if _FP8_WO_A_GEMM: + import deep_gemm + + T, G, D = o.shape + R = self.o_lora_rank + o_fp8, o_s = sglang_per_token_group_quant_fp8( + o.reshape(T * G, D).contiguous(), + group_size=128, + ) + output = torch.empty(T, G, R, device=o.device, dtype=torch.bfloat16) + deep_gemm.fp8_einsum( + "bhr,hdr->bhd", + (o_fp8.view(T, G, D), o_s.view(T, G, -1)), + (self.wo_a.weight.view(G, R, D), self.wo_a.weight_scale_inv.data), + output, + recipe=(1, 1, 128), + ) + o = output + else: + wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) + o = torch.einsum("tgd,grd->tgr", o, wo_a) + + o, _ = self.wo_b(o.flatten(1)) + + return o + + +class DeepseekV4DecoderLayer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + moe_quant_config_override: Optional[QuantizationConfig] = None, + is_nextn: bool = False, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + compress_ratio_override: Optional[int] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_id = layer_id + self.is_nextn = is_nextn + self.self_attn = MQALayer( + config=config, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + alt_streams=alt_streams, + compress_ratio_override=compress_ratio_override, + ) + self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn) + is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False) + is_next_layer_sparse = self._is_layer_sparse(layer_id + 1, is_nextn=False) + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=1 if is_nextn else config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + is_next_layer_sparse=is_next_layer_sparse, + ) + self.mlp = deepseek_v2.DeepseekV2MoE( + config=config, + quant_config=moe_quant_config_override or quant_config, + prefix=add_prefix("mlp", prefix), + layer_id=self.layer_id, + alt_stream=alt_streams[0] if alt_streams is not None else None, + is_nextn=is_nextn, + is_deepseek_v4=True, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.hc_mult = hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + mix_hc = (2 + hc_mult) * hc_mult + hc_dim = hc_mult * config.hidden_size + self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim, dtype=torch.float32)) + self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim, dtype=torch.float32)) + self.hc_attn_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float32)) + self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float32)) + self.hc_attn_scale = nn.Parameter(torch.empty(3, dtype=torch.float32)) + self.hc_ffn_scale = nn.Parameter(torch.empty(3, dtype=torch.float32)) + self.rms_norm_eps = config.rms_norm_eps + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + + def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: + if envs.SGLANG_DSV4_MODE.get() == "2604": + first_k_dense_replace = 0 + moe_layer_freq = 1 + else: + first_k_dense_replace = self.config.first_k_dense_replace + moe_layer_freq = self.config.moe_layer_freq + return is_nextn or ( + self.config.n_routed_experts is not None + and layer_id >= first_k_dense_replace + and layer_id % moe_layer_freq == 0 + ) + + def hc_pre( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + @maybe_torch_compile + def hc_pre_torch_impl(x, hc_fn): + x_flat = x.flatten(1).float() + rsqrt = torch.rsqrt( + x_flat.square().mean(-1, keepdim=True) + self.rms_norm_eps + ) + mixes = (F.linear(x_flat, hc_fn) * rsqrt).unsqueeze(1) + return x_flat, mixes + + shape, dtype = x.size(), x.dtype + + if x.shape[0] == 0: + y = torch.empty((0, shape[-1]), dtype=dtype, device=x.device) + post = torch.empty((0, self.hc_mult), dtype=dtype, device=x.device) + comb = torch.empty( + (0, self.hc_mult, self.hc_mult), dtype=dtype, device=x.device + ) + return y, post, comb + + if envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.get(): + from sglang.srt.layers.mhc import mhc_pre + + post, comb, y = mhc_pre( + residual=x, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=self.rms_norm_eps, + hc_pre_eps=self.hc_eps, + hc_sinkhorn_eps=self.hc_eps, + hc_post_mult_value=2.0, + sinkhorn_repeat=self.hc_sinkhorn_iters, + ) + return y, post.squeeze(-1), comb + + if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get(): + import deep_gemm + + x_flat = x.flatten(1).bfloat16() + + m, k = x_flat.shape + mix_hc = hc_fn.size(0) + d_out = torch.empty((m, mix_hc), dtype=torch.float, device=x.device) + s_out = torch.empty((m,), dtype=torch.float, device=x.device) + deep_gemm.tf32_hc_prenorm_gemm( + x_flat, hc_fn.float().contiguous(), d_out, s_out, num_splits=None + ) + rsqrt = torch.rsqrt(s_out / k + self.rms_norm_eps) + mixes = (d_out * rsqrt.unsqueeze(1)).unsqueeze(1) + else: + x_flat, mixes = hc_pre_torch_impl(x, hc_fn) + + from sglang.srt.layers.mhc import hc_split_sinkhorn + + pre, post, comb = hc_split_sinkhorn( + mixes, + hc_scale, + hc_base, + self.hc_mult, + self.hc_sinkhorn_iters, + self.hc_eps, + ) + y = (pre.squeeze(1).unsqueeze(-1) * x_flat.view(shape)).sum(dim=1) + return y.to(dtype), post.squeeze(1), comb.squeeze(1) + + def hc_post( + self, + x: torch.Tensor, + residual: torch.Tensor, + post: torch.Tensor, + comb: torch.Tensor, + ): + + if x.shape[0] == 0: + return torch.empty( + (0, self.hc_mult, x.shape[-1]), dtype=x.dtype, device=x.device + ) + + if envs.SGLANG_OPT_USE_TILELANG_MHC_POST.get(): + from sglang.srt.layers.mhc import mhc_post + + return mhc_post(x, residual, post, comb) + + assert residual.shape == (x.shape[0], self.hc_mult, x.shape[-1]) + assert post.shape == (x.shape[0], self.hc_mult) + assert comb.shape == (x.shape[0], self.hc_mult, self.hc_mult) + + @maybe_torch_compile + def hc_post_torch_impl(x, residual, post, comb): + return ( + post.unsqueeze(-1) * x.unsqueeze(1) + + (comb.unsqueeze(-1) * residual.unsqueeze(2)).sum(dim=1) + ).type_as(x) + + return hc_post_torch_impl(x, residual, post, comb) + + def forward( + self, + positions: torch.tensor, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + forward_batch: ForwardBatch, + input_ids_global: torch.Tensor, + ) -> torch.Tensor: + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert deepseek_v4_moe_code_path_checker.observed == 0 + + residual = hidden_states + hidden_states, post, comb = self.hc_pre( + hidden_states, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + x=hidden_states, + positions=positions, + forward_batch=forward_batch, + ) + + hidden_states = self.hc_post(hidden_states, residual, post, comb) + residual = hidden_states + hidden_states, post, comb = self.hc_pre( + hidden_states, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + _use_cp = self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch) + _use_tp_moe_gather = ( + not _use_cp + and get_attention_dp_size() > 1 + and get_moe_a2a_backend().is_none() + ) + _use_tp_attn_a2a_scatter = ( + not _use_cp + and envs.SGLANG_DSV4_FIX_TP_ATTN_A2A_SCATTER.get() + and get_attention_tp_size() > 1 + and not get_moe_a2a_backend().is_none() + ) + if _use_cp: + assert get_moe_a2a_backend().is_deepep(), ( + "CP requires DeepEP (moe_a2a_backend == deepep). " + "Only DeepEP is tested with CP's per-rank token split." + ) + cp_rank = get_attention_tp_rank() + cp_size = get_attention_tp_size() + input_ids = input_ids[cp_rank::cp_size].contiguous() + input_ids_global = input_ids + elif _use_tp_moe_gather: + hidden_states, local_hidden_states = get_global_dp_buffer(), hidden_states + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + _a2a_scatter_chunks: Optional[List[torch.Tensor]] = None + if _use_tp_attn_a2a_scatter: + s, r = get_attention_tp_size(), get_attention_tp_rank() + _a2a_scatter_chunks = list(hidden_states.tensor_split(s)) + hidden_states = _a2a_scatter_chunks[r].contiguous() + input_ids = input_ids.tensor_split(s)[r].contiguous() + input_ids_global = input_ids_global.tensor_split(s)[r].contiguous() + hidden_states = self.mlp( + hidden_states, + forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + if _use_tp_moe_gather: + hidden_states, global_hidden_states = get_local_dp_buffer(), hidden_states + dp_scatter(hidden_states, global_hidden_states, forward_batch) + if _use_tp_attn_a2a_scatter: + assert _a2a_scatter_chunks is not None + gathered = [torch.empty_like(t) for t in _a2a_scatter_chunks] + attn_tp_all_gather(gathered, hidden_states.contiguous()) + hidden_states = torch.cat(gathered) + + hidden_states = self.hc_post(hidden_states, residual, post, comb) + + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert deepseek_v4_moe_code_path_checker.observed == 1 + deepseek_v4_moe_code_path_checker.observed = 0 + + return hidden_states + + +class DeepseekV4Model(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: DeepSeekV4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + self.first_k_dense_replace = config.first_k_dense_replace + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + ) + self.rms_norm_eps = config.rms_norm_eps + self.alt_streams = ( + [torch.cuda.Stream() for _ in range(5)] if (_is_cuda or _is_hip) else None + ) + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: DeepseekV4DecoderLayer( + config=config, + layer_id=idx, + quant_config=quant_config, + prefix=prefix, + alt_streams=self.alt_streams, + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix=add_prefix("layers", prefix), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gemm_output_zero_allocator_size = 0 + self.layers_to_capture = [] + if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): + self.enable_a2a_moe = True + else: + self.enable_a2a_moe = False + + self.hc_eps = config.hc_eps + self.hc_mult = hc_mult = config.hc_mult + self.norm_eps = config.rms_norm_eps + hc_dim = hc_mult * config.hidden_size + self.hc_head_fn = nn.Parameter( + torch.empty(hc_mult, hc_dim, dtype=torch.float32) + ) + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) + + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_size = get_attention_tp_size() + + def hc_head( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor], + pp_proxy_tensors: Optional[PPProxyTensors], + ) -> torch.Tensor: + total_num_layers = self.end_layer - self.start_layer + device = input_embeds.device if input_embeds is not None else input_ids.device + zero_allocator = BumpAllocator( + buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), + dtype=torch.float32, + device=device, + ) + has_gemm_output_zero_allocator = hasattr( + self, "gemm_output_zero_allocator_size" + ) + gemm_output_zero_allocator = ( + BumpAllocator( + buffer_size=self.gemm_output_zero_allocator_size, + dtype=torch.float32, + device=device, + ) + if has_gemm_output_zero_allocator + and self.gemm_output_zero_allocator_size > 0 + else None + ) + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + + if get_attention_dp_size() > 1 and get_moe_a2a_backend().is_none(): + input_ids_global = torch.empty( + (_DpGatheredBufferWrapper._global_dp_buffer_len, 1), + dtype=input_ids.dtype, + device=input_ids.device, + ) + dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch) + input_ids_global = input_ids_global.squeeze(-1) + else: + input_ids_global = input_ids + + if nsa_use_prefill_cp(forward_batch): + _check_rank_consistency = ( + envs.SGLANG_DEBUG_HACK_CP_CHECK_RANK_CONSISTENCY.get() + ) + if _check_rank_consistency: + _pre_split_hidden_states = hidden_states.clone() + _pre_split_positions = positions.clone() + hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + positions = cp_split_and_rebuild_position(forward_batch, positions) + if _check_rank_consistency: + _gathered_hidden = cp_all_gather_rerange_output( + hidden_states, + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + assert torch.equal(_gathered_hidden, _pre_split_hidden_states), ( + "SGLANG_DEBUG_HACK_CP_CHECK_RANK_CONSISTENCY: " + "cp_split_and_rebuild_data ∘ cp_all_gather_rerange_output is not identity on hidden_states. " + "Round-robin split/gather helpers are inconsistent." + ) + _gathered_positions = cp_all_gather_rerange_output( + positions.unsqueeze(-1), + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ).squeeze(-1) + assert torch.equal(_gathered_positions, _pre_split_positions), ( + "SGLANG_DEBUG_HACK_CP_CHECK_RANK_CONSISTENCY: " + "cp_split_and_rebuild_position ∘ cp_all_gather_rerange_output is not identity on positions." + ) + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + + if nsa_use_prefill_cp(forward_batch): + hidden_states = cp_all_gather_rerange_output( + hidden_states, + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + + pre_hc_head = ( + hidden_states.flatten(1) + if envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + else None + ) + + hidden_states = self.hc_head( + hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base + ) + hidden_states = self.norm(hidden_states) + + if pre_hc_head is not None: + return hidden_states, pre_hc_head + return hidden_states + + +class DeepseekV4ForCausalLM(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + self.determine_num_fused_shared_experts() + self.model = DeepseekV4Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.pp_group = get_pp_group() + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = False + get_attn_tp_context().init_context(config.q_lora_rank, is_nsa=True) + + self._routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: layer.mlp.get_moe_weights() + for layer_id, layer in enumerate(self.model.layers) + if isinstance(layer.mlp, deepseek_v2.DeepseekV2MoE) + } + ) + + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_rank = get_attention_tp_rank() + self.cp_size = get_attention_tp_size() + + @property + def routed_experts_weights_of_layer(self): + return self._routed_experts_weights_of_layer.value + + def determine_num_fused_shared_experts(self): + self.num_fused_shared_experts = 0 + if get_global_server_args().disable_shared_experts_fusion: + return + + disable_reason = None + if self.config.n_routed_experts != 256 or self.config.n_shared_experts != 1: + disable_reason = "Config not support fused shared expert(s)." + elif (not _is_cuda or torch.cuda.get_device_capability("cuda") < (8, 0)) and ( + not _is_hip or torch.cuda.get_device_capability("cuda") < (9, 4) + ): + disable_reason = ( + "Only Deepseek V3/R1 on NV-platform with capability >= 80 " + "or AMD-platform with capability >= gfx942(MI30x) can use shared experts fusion optimization." + ) + elif get_moe_expert_parallel_world_size() > 1 and ( + not _is_hip or torch.cuda.get_device_capability("cuda") < (9, 4) + ): + disable_reason = "Only Deepseek V3/R1 on AMD-platform with capability >= gfx942(MI30x) can use shared experts fusion optimization under expert parallelism." + elif disable_reason is None and get_moe_a2a_backend().is_deepep(): + disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under deepep expert parallelism." + elif self.quant_config and self.quant_config.get_name() == "w4afp8": + disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts." + elif ( + envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get() + ): + disable_reason = "2604 routed experts use FP4 while shared experts remain FP8; fusion would incorrectly apply FP4 to shared experts." + + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + disable_reason = "2604B checkpoint requires different clamping for shared and routed experts" + + if disable_reason is not None: + get_global_server_args().disable_shared_experts_fusion = True + self.num_fused_shared_experts = 0 + log_info_on_rank0( + logger, + f"{disable_reason} Shared experts fusion optimization is disabled.", + ) + return + + self.num_fused_shared_experts = self.config.n_shared_experts + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + if self.nsa_enable_prefill_cp: + if can_cp_split(len(input_ids), self.cp_size, True, forward_batch): + forward_batch.nsa_cp_metadata = prepare_input_dp_with_cp_dsa( + len(input_ids), + self.cp_rank, + self.cp_size, + forward_batch.seq_lens_cpu.tolist(), + ) + + with get_attn_tp_context().maybe_input_scattered(forward_batch): + hidden_states = self.model.forward( + input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors + ) + aux_hidden_states = None + pre_hc_head = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if ( + envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + ): + hidden_states, pre_hc_head = hidden_states + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + hidden_states_before_norm=pre_hc_head, + ) + + def _setup_fp8_wo_a_scales(self, is_nextn: bool) -> None: + from deep_gemm import transform_sf_into_required_layout + + layers = self.model.layers + for layer in layers: + attn = layer.self_attn + G = attn.n_local_groups + R = attn.o_lora_rank + D = attn.wo_a.weight.shape[1] + + raw_scale = attn.wo_a.weight_scale_inv.data.view(G, R // 128, D // 128) + attn.wo_a.weight_scale_inv.data = transform_sf_into_required_layout( + raw_scale, + mn=R, + k=D, + recipe=(1, 128, 128), + num_groups=G, + is_sfa=False, + ) + + def post_load_weights(self, is_nextn=False, weight_names=None): + if _FP8_WO_A_GEMM: + self._setup_fp8_wo_a_scales(is_nextn) + + if is_nextn: + return + for layer in self.model.layers: + self_attn = layer.self_attn + if self_attn.compress_ratio != 0 and not self_attn.compressor.ape_converted: + self_attn.compressor.apply_ape_hotfix() + if ( + self_attn.compress_ratio == 4 + and not self_attn.indexer.compressor.ape_converted + ): + self_attn.indexer.compressor.apply_ape_hotfix() + + @staticmethod + def remap_weight_name_to_dpsk_hf_format( + name: str, is_nextn: bool = False, num_hidden_layers: Optional[int] = None + ) -> str: + if name == "embed.weight": + return "model.embed_tokens.weight" + if name == "head.weight": + return "lm_head.weight" + if name == "norm.weight": + return "model.norm.weight" + if name.startswith("hc_head_"): + return "model." + name + + if is_nextn and name.startswith("mtp."): + parts = name.split(".", 2) + if len(parts) >= 3: + rest = parts[2] + nextn_spec_prefixes = [ + "e_proj", + "h_proj", + "emb", + "enorm", + "hnorm", + "norm", + "head", + "hc_head", + ] + is_nextn_spec = any(rest.startswith(p) for p in nextn_spec_prefixes) + if is_nextn_spec: + if rest.startswith("emb.tok_emb"): + rest = rest.replace("emb.tok_emb", "embed_tokens") + elif rest == "norm.weight": + rest = "shared_head.norm.weight" + elif rest.startswith("head."): + rest = "shared_head.head.weight" + elif rest == "e_proj.scale": + rest = "e_proj.weight_scale_inv" + elif rest == "h_proj.scale": + rest = "h_proj.weight_scale_inv" + name = f"model.layers.{num_hidden_layers}." + rest + + if name.startswith("layers."): + name = "model." + name + name = name.replace(".attn.", ".self_attn.") + name = name.replace(".ffn.", ".mlp.") + name = name.replace(".attn_norm.", ".input_layernorm.") + name = name.replace(".ffn_norm.", ".post_attention_layernorm.") + + if not ATTN_BIT_WISE_EQUAL_MODE: + if "self_attn" in name and ( + "compressor" not in name or not COMPRESSOR_BIT_WISE_EQUAL_MODE + ): + name = name.replace(".scale", ".weight_scale_inv") + + if not MOE_BIT_WISE_EQUAL_MODE: + name = name.replace(".gate.tid2eid", ".topk.tid2eid") + name = name.replace(".gate.bias", ".gate.e_score_correction_bias") + name = name.replace(".w1.", ".gate_proj.") + name = name.replace(".w2.", ".down_proj.") + name = name.replace(".w3.", ".up_proj.") + if "mlp" in name: + name = name.replace(".scale", ".weight_scale_inv") + + return name + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + assert envs.SGLANG_DSV4_MODE.get() in ["2601", "2604"] + if envs.SGLANG_DSV4_MODE.get() == "2604": + assert envs.SGLANG_DSV4_2604_SUBMODE.get() in ["2604A", "2604B"] + else: + assert envs.SGLANG_DSV4_2604_SUBMODE.get() == "" + + if ( + envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + ): + _debug_assert_model_path_configs() + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and is_large_dummy_model(): + assert ( + envs.SGLANG_HACK_OVERRIDE_TOPK_IDS_RANDOM.get() + ), "dummy model must use SGLANG_HACK_OVERRIDE_TOPK_IDS_RANDOM" + + if MOE_BIT_WISE_EQUAL_MODE: + assert ( + self.num_fused_shared_experts == 0 + ), "use --disable-shared-experts-fusion for MoE bit-wise equal mode" + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + nextn_layer_id = ( + 0 + if self.config.num_hidden_layers == 1 + else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") + + if ( + envs.SGLANG_DSV4_MODE.get() == "2604" + and not envs.SGLANG_OPT_FP8_WO_A_GEMM.get() + ): + if envs.SGLANG_FIX_DSV4_BASE_MODEL_LOAD.get(): + weights = list(weights) + exists_wo_a_scale = any(n.endswith(".wo_a.scale") for n, t in weights) + if exists_wo_a_scale: + logger.info("Execute dequant fp8 wo_a") + weights = _dequant_fp8_wo_a(weights) + else: + logger.info("Skip dequant fp8 wo_a") + else: + # ----------------------------- legacy code ------------------------------ + if envs.SGLANG_DSV4_FP4_EXPERTS.get(): + weights = _dequant_fp8_wo_a(weights) + else: + weights = ((n, t) for n, t in weights if not n.endswith(".wo_a.scale")) + # ------------------------------------------------------------------------ + + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + + if self.quant_config and self.quant_config.get_name() == "w4afp8": + expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts + ) + + cache_compressor_weight = {} + COMPRESSOR_PART = ".compressor.w" + + fuse_wqa_wkv = envs.SGLANG_OPT_FUSE_WQA_WKV.get() + cache_wqkv_a_weight: dict[str, dict[str, torch.Tensor]] = {} + + def auto_weight_loader(module): + return getattr(module, "weight_loader", default_weight_loader) + + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names_out_of_layer = [ + "shared_head.norm", + "shared_head.head", + "embed_tokens", + ".e_proj", + "h_proj", + "enorm", + "hnorm", + "hc_head_base", + "hc_head_fn", + "hc_head_scale", + ] + + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + weight_names = [] + for name, loaded_weight in weights: + try: + use_async_loading = should_async_load(loaded_weight) + + name = self.remap_weight_name_to_dpsk_hf_format( + name, + is_nextn=is_nextn, + num_hidden_layers=self.config.num_hidden_layers, + ) + + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if ( + self.num_fused_shared_experts > 0 + and "mlp.shared_experts" in name + ): + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + + weight_names.append(name) + + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) + >= self.config.num_hidden_layers + ): + continue + + if name.startswith("mtp"): + continue + else: + if "shared_head.head" in name or "embed_tokens" in name: + continue + + if not name.startswith(nextn_layer_prefix): + continue + + in_decoder = True + for weight_name in nextn_spec_weight_names_out_of_layer: + if weight_name in name: + in_decoder = False + name = name.replace(nextn_layer_prefix, "model") + break + + if in_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict and name.startswith("mtp"): + break + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight, shard_id), + ) + loaded_params.add(name) + break + else: + for mapping in expert_params_mapping: + if MOE_BIT_WISE_EQUAL_MODE: + continue + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=( + param, + loaded_weight, + name, + ), + func_kwargs={ + "shard_id": shard_id, + "expert_id": expert_id, + }, + ) + loaded_params.add(name) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if ( + ".embed_tokens." in name + and not self.pp_group.is_first_rank + ): + continue + if ".norm." in name and not self.pp_group.is_last_rank: + continue + elif COMPRESSOR_PART in name: + is_kv = name.endswith(".wkv.weight") + is_wgate = name.endswith(".wgate.weight") + assert is_kv != is_wgate + key = name.rsplit(".", 2)[0] + assert key.endswith(".compressor") + if key not in cache_compressor_weight: + cache_compressor_weight[key] = ( + is_kv, + loaded_weight, + ) + else: + assert key in cache_compressor_weight + cached_is_kv, cached_weight = ( + cache_compressor_weight[key] + ) + assert cached_is_kv != is_kv + kv = loaded_weight if is_kv else cached_weight + wgate = loaded_weight if is_wgate else cached_weight + fused_weight = torch.cat([kv, wgate], dim=0) + param_name = key + ".wkv_gate.weight" + param = params_dict[param_name] + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, fused_weight), + ) + loaded_params.add(param_name) + cache_compressor_weight.pop(key) + elif fuse_wqa_wkv and ( + name.endswith(".wq_a.weight") + or name.endswith(".wq_a.weight_scale_inv") + or name.endswith(".wkv.weight") + or name.endswith(".wkv.weight_scale_inv") + ): + is_q = ".wq_a." in name + param_name = name.replace( + ".wq_a." if is_q else ".wkv.", ".wqkv_a." + ) + bucket = cache_wqkv_a_weight.setdefault(param_name, {}) + shard_key = "q" if is_q else "kv" + assert ( + shard_key not in bucket + ), f"duplicate shard {shard_key} for {param_name}" + bucket[shard_key] = loaded_weight + if len(bucket) == 2: + fused_weight = torch.cat( + [bucket["q"], bucket["kv"]], dim=0 + ) + param = params_dict[param_name] + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, fused_weight), + ) + loaded_params.add(param_name) + cache_wqkv_a_weight.pop(param_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace( + f"{scale[0]}_proj", "attn_mqa" + ) + break + if name not in params_dict: + if not name.startswith("mtp"): + logger.warning( + f"{name} not found in params_dict." + ) + continue + param = params_dict[name] + + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight), + ) + loaded_params.add(name) + except Exception as e: + e.add_note(f"{name=} {loaded_weight.shape=}") + raise + + for future in concurrent.futures.as_completed(futures): + future.result() + + assert len(cache_compressor_weight) == 0 + assert len(cache_wqkv_a_weight) == 0, cache_wqkv_a_weight.keys() + unloaded_params = params_dict.keys() - loaded_params + + skipped_checking_patterns = ["attn_mqa.k_scale", "attn_mqa.v_scale"] + if is_nextn: + skipped_checking_patterns.extend(["lm_head", "embed_tokens"]) + unloaded_params = { + p + for p in unloaded_params + if all( + skipped_checking_pattern not in p + for skipped_checking_pattern in skipped_checking_patterns + ) + } + if os.environ.get("SGLANG_SKIP_CHECKPOINT_LOAD_CHECK", "0") == "0": + if unloaded_params: + raise RuntimeError( + f"Some weights are not initialized from checkpoints: {unloaded_params}" + ) + + self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.n_routed_experts, + num_groups=None, + ) + + +EntryClass = [DeepseekV4ForCausalLM] + + +def _dequant_fp8(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + from einops import rearrange + + assert ( + weight.dtype == torch.float8_e4m3fn + ), f"expected fp8_e4m3fn, got {weight.dtype}" + assert scale.dtype in ( + torch.float8_e8m0fnu, + torch.float32, + ), f"expected fp8_e8m0fnu or float32, got {scale.dtype}" + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model(): + assert weight.shape == (8192, 4096), f"unexpected weight shape {weight.shape}" + assert scale.shape == (64, 32), f"unexpected scale shape {scale.shape}" + + weight_f32 = rearrange( + weight.float(), "(sn bn) (sk bk) -> sn bn sk bk", bn=128, bk=128 + ) + result = rearrange( + weight_f32 * scale.float()[:, None, :, None], "sn bn sk bk -> (sn bn) (sk bk)" + ) + if envs.SGLANG_DEBUG_SANITY_CHECK_CONFIG.get() and not is_large_dummy_model(): + assert result.shape == (8192, 4096) + + return result.to(torch.bfloat16) + + +def build_mega_moe_experts_weights(experts) -> None: + from deep_gemm import ( + transform_sf_into_required_layout, + transform_weights_for_mega_moe, + ) + from deep_gemm.mega import _interleave_l1_weights, _transpose_sf_for_utccp + + if getattr(experts, "_mega_moe_weights_built", False): + return + + w13 = experts.w13_weight.data + w13_sf_fp32 = experts.w13_weight_scale_inv.data + w2 = experts.w2_weight.data + w2_sf_fp32 = experts.w2_weight_scale_inv.data + + num_groups, n1, half_k1 = w13.shape + k1 = half_k1 * 2 + _, n2, half_k2 = w2.shape + k2 = half_k2 * 2 + + w13_sf = transform_sf_into_required_layout( + w13_sf_fp32, + mn=n1, + k=k1, + recipe=(1, 32), + num_groups=num_groups, + disable_ue8m0_cast=False, + ) + w2_sf = transform_sf_into_required_layout( + w2_sf_fp32, + mn=n2, + k=k2, + recipe=(1, 32), + num_groups=num_groups, + disable_ue8m0_cast=False, + ) + + if envs.SGLANG_OPT_FIX_MEGA_MOE_MEMORY.get(): + # Build the interleaved L1 weight + scale once; share the weight buffer + # between `w13_weight.data` (normal deep-ep path) and `mega_l1_weights[0]` + # (mega moe path). Mega moe additionally needs a UTCCP-transposed scale; + # the deep-ep path consumes the non-transposed interleaved scale and a + # swizzle-aware activation kernel. L2 weight is untouched by the mega + # transform, so the existing `w2_weight.data` is shared directly. + w13_interleaved, w13_sf_interleaved = _interleave_l1_weights((w13, w13_sf)) + w13_sf_utccp = _transpose_sf_for_utccp(w13_sf_interleaved) + w2_sf_utccp = _transpose_sf_for_utccp(w2_sf) + + experts.w13_weight.data = w13_interleaved + experts.w13_weight_scale_inv.data = w13_sf_interleaved + experts.w2_weight_scale_inv.data = w2_sf + experts.w13_weight_scale_inv.format_ue8m0 = True + experts.w2_weight_scale_inv.format_ue8m0 = True + + experts.mega_l1_weights = (experts.w13_weight.data, w13_sf_utccp) + experts.mega_l2_weights = (experts.w2_weight.data, w2_sf_utccp) + else: + l1_pair, l2_pair = transform_weights_for_mega_moe((w13, w13_sf), (w2, w2_sf)) + + experts.mega_l1_weights = l1_pair + experts.mega_l2_weights = l2_pair + + experts._mega_moe_weights_built = True + + +def _dequant_fp8_wo_a( + weights: Iterable[Tuple[str, torch.Tensor]], +) -> Iterable[Tuple[str, torch.Tensor]]: + weights_dict = dict(weights) + + for name in list(weights_dict.keys()): + if name not in weights_dict: + continue + if not name.endswith(".wo_a.weight"): + continue + scale_name = name.replace(".wo_a.weight", ".wo_a.scale") + assert scale_name in weights_dict + weight = weights_dict.pop(name) + scale = weights_dict.pop(scale_name) + yield name, _dequant_fp8(weight, scale) + + yield from weights_dict.items() + + +def _debug_assert_model_path_configs() -> None: + assert_ckpt_version = os.environ.get("SGLANG_HACK_ASSERT_CKPT_VERSION", "v1") + + model_path = Path(get_global_server_args().model_path) + ref_dir = ( + Path(__file__).resolve().parents[4] + / "deepseek_v4" + / "assembled_hf_config_0409" + / assert_ckpt_version + ) + for ref_file in ref_dir.iterdir(): + if ref_file.name in ["apply.py", "create.py", "README.md"]: + continue + user_file = model_path / ref_file.name + if not user_file.exists(): + raise AssertionError( + f"2604 mode: expected {ref_file.name} in model_path {model_path}, but not found" + ) + if user_file.read_bytes() != ref_file.read_bytes(): + raise AssertionError( + f"2604 mode: {ref_file.name} in model_path differs from reference.\n" + f" model_path: {user_file}\n" + f" reference: {ref_file}\n" + f" Please use the files generated by deepseek_v4/assembled_hf_config_0409/create.py" + ) + logger.info("2604 mode: all config files match reference (bytewise equal)") diff --git a/python/sglang/srt/models/deepseek_v4_nextn.py b/python/sglang/srt/models/deepseek_v4_nextn.py new file mode 100644 index 000000000000..7f3ab6df1dad --- /dev/null +++ b/python/sglang/srt/models/deepseek_v4_nextn.py @@ -0,0 +1,248 @@ + +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.environ import envs +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + dp_gather_partial, + get_attention_dp_size, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.utils import get_moe_a2a_backend +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.deepseek_v4 import DeepseekV4DecoderLayer, DeepseekV4ForCausalLM +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + +COMPRESS_RATIO_NEXTN_LAYER = 0 + + +class DeepseekV4ModelNextN(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rms_norm_eps = config.rms_norm_eps + + self.layers_to_capture = [] + if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): + self.enable_a2a_moe = True + else: + self.enable_a2a_moe = False + + self.hc_eps = config.hc_eps + self.hc_mult = hc_mult = config.hc_mult + hc_dim = hc_mult * config.hidden_size + self.hc_head_fn = nn.Parameter( + torch.empty(hc_mult, hc_dim, dtype=torch.float32) + ) + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) + + self.e_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("e_proj", prefix), + ) + self.h_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("h_proj", prefix), + ) + + layer_name = "decoder" + + self.decoder = DeepseekV4DecoderLayer( + config, + layer_id=0, + quant_config=quant_config, + is_nextn=True, + prefix=add_prefix(layer_name, prefix), + alt_streams=None, + compress_ratio_override=COMPRESS_RATIO_NEXTN_LAYER, + ) + + self.shared_head = nn.Module() + self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def hc_head( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.rms_norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if hidden_states.shape[0] > 0: + if ( + envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + ): + n_tokens = hidden_states.shape[0] + d = self.config.hidden_size + hc_flat = forward_batch.spec_info.hidden_states.view( + n_tokens * self.hc_mult, d + ) + h_proj_out, _ = self.h_proj(self.hnorm(hc_flat)) + h_proj_hidden_states = h_proj_out.view(n_tokens, self.hc_mult, d) + + e_proj_hidden_states, _ = self.e_proj(self.enorm(hidden_states)) + hidden_states = e_proj_hidden_states[:, None, :] + h_proj_hidden_states + else: + e_proj_hidden_states, _ = self.e_proj(self.enorm(hidden_states)) + h_proj_hidden_states, _ = self.h_proj( + self.hnorm(forward_batch.spec_info.hidden_states) + ) + hidden_states = e_proj_hidden_states + h_proj_hidden_states + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + else: + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + + if get_attention_dp_size() > 1 and get_moe_a2a_backend().is_none(): + input_ids_global = torch.empty( + (_DpGatheredBufferWrapper._global_dp_buffer_len, 1), + dtype=input_ids.dtype, + device=input_ids.device, + ) + dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch) + input_ids_global = input_ids_global.squeeze(-1) + else: + input_ids_global = input_ids + + hidden_states = self.decoder( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + + pre_hc_head = ( + hidden_states.flatten(1) + if envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + else None + ) + + hidden_states = self.hc_head( + hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base + ) + hidden_states = self.shared_head.norm(hidden_states) + + if pre_hc_head is not None: + return hidden_states, pre_hc_head + return hidden_states + + +class DeepseekV4ForCausalLMNextN(DeepseekV4ForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.determine_num_fused_shared_experts() + + self.model = DeepseekV4ModelNextN( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + result = self.model(input_ids, positions, forward_batch) + pre_hc_head = None + if ( + envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + ): + hidden_states, pre_hc_head = result + else: + hidden_states = result + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + hidden_states_before_norm=pre_hc_head, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights, is_nextn=True) + + +EntryClass = [DeepseekV4ForCausalLMNextN] diff --git a/python/sglang/srt/models/registry.py b/python/sglang/srt/models/registry.py index 066be3dc44b3..a28c15fa782b 100644 --- a/python/sglang/srt/models/registry.py +++ b/python/sglang/srt/models/registry.py @@ -104,7 +104,9 @@ def import_model_classes(package_name: str, strict: bool = False): except Exception as e: if strict: raise - logger.warning(f"Ignore import error when loading {name}: {e}") + logger.warning( + f"In import_model_classes: Ignore import error when loading {name}: {e}" + ) continue if hasattr(module, "EntryClass"): entry = module.EntryClass diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 8949ba5d75b4..6b5a469f982b 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -304,6 +304,7 @@ class ReasoningParser: DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = { "deepseek-r1": DeepSeekR1Detector, "deepseek-v3": Qwen3Detector, + "deepseek-v4": Qwen3Detector, "glm45": Qwen3Detector, "gpt-oss": GptOssDetector, "kimi": KimiDetector, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1a049ced4fcc..d48dab8757ab 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -122,6 +122,7 @@ "torch_native", "flex_attention", "nsa", + "compressed", # NVIDIA specific "cutlass_mla", "fa3", @@ -180,6 +181,7 @@ "flashinfer_mxfp4", "flashinfer_cutedsl", "cutlass", + "marlin", ] MOE_A2A_BACKEND_CHOICES = ["none", "deepep", "mooncake", "ascend_fuseep", "flashinfer"] @@ -519,6 +521,7 @@ class ServerArgs: hicache_storage_backend_extra_config: Optional[str] = None # Hierarchical sparse attention + enable_hisparse: bool = False hierarchical_sparse_attention_extra_config: Optional[str] = None # LMCache @@ -609,6 +612,7 @@ class ServerArgs: keep_mm_feature_on_device: bool = False enable_return_hidden_states: bool = False enable_return_routed_experts: bool = False + enable_return_indexer_topk: bool = False scheduler_recv_interval: int = 1 numa_node: Optional[List[int]] = None enable_deterministic_inference: bool = False @@ -657,6 +661,7 @@ class ServerArgs: # For model weight update and weight loading custom_weight_loader: Optional[List[str]] = None weight_loader_disable_mmap: bool = False + weight_loader_drop_cache_after_load: bool = False remote_instance_weight_loader_seed_instance_ip: Optional[str] = None remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None @@ -1172,6 +1177,48 @@ def _handle_model_specific_adjustments(self): ]: self.dtype = "bfloat16" + if model_arch in [ + "DeepseekV4ForCausalLM", + ]: + self.attention_backend = "compressed" + self.page_size = 256 + logger.info( + f"Use compressed attention backend for {model_arch}, setting page_size to 256." + ) + + if self.max_running_requests is None: + self.max_running_requests = 256 + logger.warning( + f"Setting max_running_requests to {self.max_running_requests} for {model_arch}." + ) + + if self.kv_cache_dtype == "auto": + self.kv_cache_dtype = "fp8_e4m3" + logger.warning( + f"Setting KV cache dtype to {self.kv_cache_dtype} for {model_arch}." + ) + assert self.kv_cache_dtype in [ + "fp8_e4m3" + ], f"{self.kv_cache_dtype} is not supported for {model_arch}" + + if self.speculative_algorithm is not None: + assert ( + self.speculative_algorithm == "EAGLE" + ), f"Only EAGLE speculative algorithm is supported for {model_arch}" + assert ( + self.speculative_eagle_topk == 1 + ), f"Only EAGLE speculative algorithm with topk == 1 is supported for {model_arch}" + + if not envs.SGLANG_ENABLE_SPEC_V2.get(): + envs.SGLANG_ENABLE_SPEC_V2.set(True) + logger.warning("Spec v2 is enabled for EAGLE speculative decoding.") + + if self.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio: + self.swa_full_tokens_ratio = 0.1 + logger.info( + f"Setting swa_full_tokens_ratio to {self.swa_full_tokens_ratio} for {model_arch}." + ) + if model_arch in [ "DeepseekV3ForCausalLM", "MistralLarge3ForCausalLM", @@ -1205,8 +1252,8 @@ def _handle_model_specific_adjustments(self): self.dp_size == 1 ), "For round-robin split mode, dp attention is not supported." assert ( - self.tp_size == 8 - ), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)" + self.tp_size <= 8 + ), "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." logger.warning( f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} " @@ -1220,9 +1267,10 @@ def _handle_model_specific_adjustments(self): ) if is_hip(): - self.page_size = 1 + self.page_size = 64 logger.warning( - "Setting page size to 1 for DeepSeek DSA on ROCm." + "Setting page size to 64 for DeepSeek DSA on torch implementation.\n" + "Need to be changed based on ROCm implementation.\n" ) else: # For CUDA GPU @@ -1301,6 +1349,29 @@ def _handle_model_specific_adjustments(self): "Use triton fused moe by default for bf16 nextn layer in deepseek fp4 checkpoint." ) + elif model_arch in [ + "DeepseekV4ForCausalLM", + ]: + if self.enable_nsa_prefill_context_parallel: + if self.nsa_prefill_cp_mode == "round-robin-split": + self.moe_dense_tp_size = 1 + assert ( + self.dp_size == 1 + ), "For round-robin split mode, dp attention is not supported." + assert ( + self.tp_size <= 8 + ), "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." + logger.warning( + f"Enable Context Parallel for DeepSeekV4, " + f"dp_size={self.dp_size}, moe_dense_tp_size={self.moe_dense_tp_size}, " + f"ep_size={self.ep_size}, tp_size={self.tp_size}" + ) + else: + raise ValueError( + f"DeepSeekV4 only supports round-robin-split CP mode, " + f"got {self.nsa_prefill_cp_mode}" + ) + elif model_arch in ["GptOssForCausalLM"]: # Set attention backend for GPT-OSS if self.is_attention_backend_not_set(): @@ -2231,6 +2302,7 @@ def _handle_speculative_decoding(self): if model_arch in [ "DeepseekV32ForCausalLM", "DeepseekV3ForCausalLM", + "DeepseekV4ForCausalLM", "Glm4MoeForCausalLM", "Glm4MoeLiteForCausalLM", "BailingMoeForCausalLM", @@ -4054,6 +4126,11 @@ def add_cli_args(parser: argparse.ArgumentParser): help="A dictionary in JSON string format, or a string starting with a leading '@' and a config file in JSON/YAML/TOML format, containing extra configuration for the storage backend.", ) + parser.add_argument( + "--enable-hisparse", + action="store_true", + help="Enable hierarchical sparse attention", + ) # Hierarchical sparse attention parser.add_argument( "--hierarchical-sparse-attention-extra-config", @@ -4469,6 +4546,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable returning routed experts of each layer with responses.", ) + parser.add_argument( + "--enable-return-indexer-topk", + action="store_true", + help="Enable returning indexer topk indices of layers with indexer with responses.", + ) parser.add_argument( "--scheduler-recv-interval", type=int, @@ -4677,6 +4759,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable mmap while loading weight using safetensors.", ) + parser.add_argument( + "--weight-loader-drop-cache-after-load", + action="store_true", + help="Call posix_fadvise(DONTNEED) on each safetensors shard after loading it.", + ) parser.add_argument( "--remote-instance-weight-loader-seed-instance-ip", type=str, diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py index 9c630da72fb1..e0a47d61c7cb 100644 --- a/python/sglang/srt/speculative/draft_utils.py +++ b/python/sglang/srt/speculative/draft_utils.py @@ -55,6 +55,7 @@ def create_decode_backend(self): "trtllm_mla": self._create_trtllm_mla_decode_backend, "nsa": self._create_nsa_decode_backend, "ascend": self._create_ascend_decode_backend, + "compressed": self._create_compressed_decode_backend, } return self._create_backend( @@ -79,6 +80,7 @@ def create_draft_extend_backend(self): "trtllm_mla": self._create_trtllm_mla_prefill_backend, "nsa": self._create_nsa_prefill_backend, "ascend": self._create_ascend_prefill_backend, + "compressed": self._create_compressed_prefill_backend, } backend_name = ( "decode_attention_backend" @@ -189,6 +191,15 @@ def _create_ascend_decode_backend(self): self.draft_model_runner, self.topk, self.speculative_num_steps ) + def _create_compressed_decode_backend(self): + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DeepseekV4MultiStepBackend, + ) + + return DeepseekV4MultiStepBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + def _create_flashinfer_prefill_backend(self): if not get_global_server_args().use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( @@ -247,3 +258,10 @@ def _create_flashmla_prefill_backend(self): "flashmla prefill backend is not yet supported for draft extend." ) return None + + def _create_compressed_prefill_backend(self): + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DeepseekV4BackendRadix, + ) + + return DeepseekV4BackendRadix(self.draft_model_runner, skip_prefill=False) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 5fe45086ca4a..2211d22f7e3c 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -102,7 +102,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) self.hidden_states = torch.zeros( - (self.max_bs, self.model_runner.model_config.hidden_size), + (self.max_bs, self.model_runner.model_config.spec_hidden_size), dtype=self.model_runner.dtype, ) @@ -166,6 +166,7 @@ def _capture_init(self, run_once_fn): torch.cuda.synchronize() self.model_runner.tp_group.barrier() run_once_fn() + self.model_runner.draft_attn_backend.on_after_cuda_graph_warmup_pass() def _capture_graph(self, graph, pool, stream, run_once_fn): with torch.cuda.graph(graph, pool=pool, stream=stream): diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index e1afdd84b547..38b3b4991620 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -121,7 +121,10 @@ def __init__(self, eagle_worker: EAGLEWorker): ) else: self.hidden_states = torch.zeros( - (self.max_num_token, self.model_runner.model_config.hidden_size), + ( + self.max_num_token, + self.model_runner.model_config.spec_hidden_size, + ), dtype=self.model_runner.dtype, ) self.seq_len_fill_value = ( diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index e22eeaee46cd..cf9ffb8855b7 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -235,7 +235,7 @@ def verify( return EagleVerifyOutput( draft_input=EagleDraftInput.create_idle_input( device=batch.device, - hidden_size=batch.model_config.hidden_size, + hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, @@ -597,7 +597,7 @@ def verify( else: draft_input = EagleDraftInput.create_idle_input( device=batch.device, - hidden_size=batch.model_config.hidden_size, + hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index b542e9615f2a..715d5ce2c3c8 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -267,7 +267,7 @@ def sample( (which contains spec decoding information). """ if batch.forward_mode.is_idle(): - predict = torch.empty(0, dtype=torch.long, device=batch.input_ids.device) + predict = torch.empty(0, dtype=torch.int32, device=batch.input_ids.device) accept_length = torch.empty( 0, dtype=torch.int32, device=batch.input_ids.device ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 0086e2aa700e..ea7c12cc84f6 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -515,7 +515,7 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch): def _draft_preprocess_idle(self, batch: ScheduleBatch): batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, - hidden_size=self.model_config.hidden_size, + hidden_size=self.model_config.spec_hidden_size, dtype=self.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, @@ -901,12 +901,12 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): if not input_is_idle and batch.spec_info.verified_id.numel() == 0: batch = batch.copy() batch.prepare_for_idle() - hidden_size = ( - self.model_config.hidden_size * 3 - if self.speculative_algorithm.is_eagle3() + hidden_size = self.model_config.spec_hidden_size + if ( + self.speculative_algorithm.is_eagle3() and self.eagle_use_aux_hidden_state - else self.model_config.hidden_size - ) + ): + hidden_size = self.model_config.hidden_size * 3 batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, hidden_size=hidden_size, diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 1c90a3041f62..c4d2e1f20200 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,6 +12,7 @@ from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import ( EAGLEDraftNpuGraphRunner, ) +from sglang.srt.layers.attention.deepseek_v4_backend_radix import DeepseekV4BackendRadix from sglang.srt.layers.attention.triton_backend import TritonMultiStepDraftBackend from sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLAMultiStepDraftBackend, @@ -279,6 +280,11 @@ def init_cuda_graphs(self): _is_cuda and isinstance(self.draft_attn_backend, TRTLLMMLAMultiStepDraftBackend) ) + or ( + _is_cuda + and isinstance(self.draft_extend_attn_backend, DeepseekV4BackendRadix) + and envs.SGLANG_OPT_V4_DRAFT_EXTEND_CUDA_GRAPH.get() + ) ): tic = time.perf_counter() before_mem = get_available_gpu_memory(self.device, self.gpu_id) @@ -659,7 +665,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): if model_worker_batch.spec_info is None: model_worker_batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, - hidden_size=self.target_worker.model_config.hidden_size, + hidden_size=self.target_worker.model_config.spec_hidden_size, dtype=self.target_worker.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index 44bb2f0de128..d5cb73e6354f 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -602,7 +602,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): if model_worker_batch.spec_info is None: model_worker_batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, - hidden_size=self.target_worker.model_config.hidden_size, + hidden_size=self.target_worker.model_config.spec_hidden_size, dtype=self.target_worker.model_config.dtype, topk=self.topk * self.speculative_num_steps, capture_hidden_mode=CaptureHiddenMode.LAST, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 8e39ee4cad66..d6cb5c65e90a 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -4086,3 +4086,11 @@ def bind_to_closest_numa_node_cuda(): if is_numa_available() and nvgpu_available(): node_id = get_current_device_numa_node_cuda() numa_bind_to_node(node_id) + + +def maybe_torch_compile(func): + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + if get_is_capture_mode(): + return torch.compile(func) + return func diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index f8743b416eaf..31c92bfb1b2f 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -20,11 +20,12 @@ import tempfile import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Type, Union import torch from huggingface_hub import snapshot_download +from sglang.srt.environ import envs from sglang.srt.utils import get_bool_env_var # Conditional import based on SGLANG_USE_MODELSCOPE environment variable @@ -69,7 +70,6 @@ from sglang.srt.connector import create_remote_connector from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils -from sglang.srt.utils.patch_tokenizer import patch_tokenizer _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ AfmoeConfig, @@ -163,8 +163,10 @@ def get_hf_text_config(config: PretrainedConfig): # Temporary hack for DeepSeek-V3.2 model -def _load_deepseek_v32_model( +def _load_deepseek_temp_model( model_path: str, + model_type: Literal["deepseek_v32", "deepseek_ref"], + architecture: Literal["DeepseekV3ForCausalLM", "DeepseekV4ForCausalLM"], trust_remote_code: bool = False, revision: Optional[str] = None, **kwargs, @@ -172,20 +174,60 @@ def _load_deepseek_v32_model( # first get the local path local_path = download_from_hf(model_path) # then load the config file in json - config_file = os.path.join(local_path, "config.json") + backup_mode = envs.SGLANG_APPLY_CONFIG_BACKUP.get() + if backup_mode == "auto": + real_config_file = os.path.join(local_path, "config.json") + if not os.path.exists(real_config_file): + raise RuntimeError( + f"SGLANG_APPLY_CONFIG_BACKUP=auto requires the checkpoint's " + f"config.json at {real_config_file} to read num_hidden_layers." + ) + with open(real_config_file, "r") as f: + num_hidden_layers = json.load(f).get("num_hidden_layers") + if not isinstance(num_hidden_layers, int): + raise RuntimeError( + f"SGLANG_APPLY_CONFIG_BACKUP=auto could not read a numeric " + f"num_hidden_layers from {real_config_file} (got {num_hidden_layers!r})." + ) + backup_mode = "small" if num_hidden_layers <= 50 else "large" + logger.warning( + f"SGLANG_APPLY_CONFIG_BACKUP=auto: checkpoint has " + f"num_hidden_layers={num_hidden_layers}, dispatching to {backup_mode!r}." + ) + if backup_mode != "none": + backup_file = { + "small": "config_backup_small.json", + "large": "config_backup_large.json", + }.get(backup_mode) + if backup_file is None: + raise ValueError( + f"SGLANG_APPLY_CONFIG_BACKUP={backup_mode!r} is not recognized; " + f"use 'none' (off), 'small', 'large', or 'auto'." + ) + config_file = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), + "configs", + backup_file, + ) + logger.warning( + f"SGLANG_APPLY_CONFIG_BACKUP={backup_mode}: using packaged {config_file} " + f"instead of the checkpoint's config.json at {local_path}." + ) + else: + config_file = os.path.join(local_path, "config.json") if not os.path.exists(config_file): - raise RuntimeError(f"Can't find config file in {local_path}.") + raise RuntimeError(f"Can't find config file at {config_file}.") with open(config_file, "r") as f: config_json = json.load(f) - config_json["architectures"] = ["DeepseekV3ForCausalLM"] + config_json["architectures"] = [architecture] config_json["model_type"] = "deepseek_v3" tmp_path = os.path.join(tempfile.gettempdir(), "_tmp_config_folder") os.makedirs(tmp_path, exist_ok=True) - unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}") + unique_path = os.path.join(tmp_path, f"{model_type}_{os.getpid()}") with open(unique_path, "w") as f: json.dump(config_json, f) @@ -271,17 +313,41 @@ def get_config( config = _load_mistral_large_3_for_causal_LM( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) + elif envs.SGLANG_APPLY_CONFIG_BACKUP.get() != "none": + config = _load_deepseek_temp_model( + model, + model_type="deepseek_ref", + architecture="DeepseekV4ForCausalLM", + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) else: try: config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) except ValueError as e: - if not "deepseek_v32" in str(e): + if "deepseek_ref" in str(e): + config = _load_deepseek_temp_model( + model, + model_type="deepseek_ref", + architecture="DeepseekV4ForCausalLM", + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + elif "deepseek_v32" in str(e): + config = _load_deepseek_temp_model( + model, + model_type="deepseek_v32", + architecture="DeepseekV3ForCausalLM", + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: raise e - config = _load_deepseek_v32_model( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs - ) if ( config.architectures is not None @@ -504,7 +570,7 @@ def get_tokenizer( ) attach_additional_stop_token_ids(tokenizer) - tokenizer = patch_tokenizer(tokenizer) + return tokenizer diff --git a/python/sglang/srt/utils/profile_utils.py b/python/sglang/srt/utils/profile_utils.py index 1f44af979e89..99d949ebbc49 100644 --- a/python/sglang/srt/utils/profile_utils.py +++ b/python/sglang/srt/utils/profile_utils.py @@ -8,6 +8,7 @@ import torch +from sglang.srt.environ import envs from sglang.srt.managers.io_struct import ProfileReqOutput from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import get_global_server_args @@ -269,6 +270,17 @@ def start(self): activity_map[a] for a in self.activities if a in activity_map ] + if ( + envs.SGLANG_HACK_WARMUP_KINETO.get() + and not _is_npu + and torch.profiler.ProfilerActivity.CUDA in torchprof_activities + ): + from sglang.srt.managers.scheduler_profiler_mixin import ( + _warmup_kineto_once, + ) + + _warmup_kineto_once() + self.torch_profiler = torch.profiler.profile( activities=torchprof_activities, with_stack=self.with_stack if self.with_stack is not None else True, diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index ca5c2e08e24c..199273bfec3b 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -104,12 +104,15 @@ def _random_like(t: torch.Tensor): shape = t.shape dtype = t.dtype - if dtype.is_floating_point: + if dtype.is_floating_point or "float" in str(dtype): return torch.rand(shape, device=device, dtype=torch.float32).to(dtype) if dtype == torch.bool: return torch.rand(shape, device=device) > 0.5 + if dtype.is_complex: + return torch.randn(shape, device=device, dtype=dtype) + info = torch.iinfo(dtype) return torch.randint( low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype @@ -121,7 +124,16 @@ def _postprocess_tensors( ) -> Iterable[Tuple[str, bool, torch.Tensor]]: from sglang.srt.debug_utils.dumper import get_tensor_info - skip_compare_names = [] + skip_compare_names = [ + name + for name in raw + if any(pattern in name for pattern in ["attn_mqa.k_scale", "attn_mqa.v_scale"]) + ] + skip_compare_names += [ + name + for name in raw + if any(pattern in name for pattern in ["freqs_cis", "cos_sin_cache"]) + ] # dequant fp8 quant_names = [ @@ -130,19 +142,24 @@ def _postprocess_tensors( # Match: `something.weight`, `something.experts.w2_weight` if name.endswith("weight") and name.replace("weight", "weight_scale_inv") in raw ] + quant_scale_names = [ + name.replace("weight", "weight_scale_inv") for name in quant_names + ] skip_compare_names += quant_names + skip_compare_names += quant_scale_names for name in quant_names: w_q = raw[name] w_s = raw[name.replace("weight", "weight_scale_inv")] try: - # TODO this is only needed for Blackwell - w_s_inverse_transformed = inverse_transform_scale_ue8m0( - w_s, mn=w_q.shape[-2] - ) + if w_s.dtype == torch.int32: + w_s_for_dequant = inverse_transform_scale_ue8m0(w_s, mn=w_q.shape[-2]) + else: + w_s_for_dequant = w_s + w_dequant = block_quant_dequant( w_q, - w_s_inverse_transformed, + w_s_for_dequant, # TODO do not hardcode block_size=[128, 128], dtype=torch.bfloat16, diff --git a/python/sglang/test/bench_one_batch_server_internal.py b/python/sglang/test/bench_one_batch_server_internal.py index c7dc7b4f1a4a..1ba82d73aa72 100644 --- a/python/sglang/test/bench_one_batch_server_internal.py +++ b/python/sglang/test/bench_one_batch_server_internal.py @@ -531,27 +531,31 @@ def run_one_case( def should_skip_due_to_token_capacity( batch_size, input_len, output_len, skip_token_capacity_threshold ): - if batch_size * (input_len + output_len) > skip_token_capacity_threshold: - print( - "=" * 8 - + f"Skip benchmark {batch_size=} * ({input_len=} + {output_len=}) = {batch_size * (input_len + output_len)} > {skip_token_capacity_threshold=} due to kv cache limit." - + "=" * 8 - ) - return True + # NOTE HACK return False + # if batch_size * (input_len + output_len) > skip_token_capacity_threshold: + # print( + # "=" * 8 + # + f"Skip benchmark {batch_size=} * ({input_len=} + {output_len=}) = {batch_size * (input_len + output_len)} > {skip_token_capacity_threshold=} due to kv cache limit." + # + "=" * 8 + # ) + # return True + # return False def should_skip_due_to_max_running_requests( batch_size, skip_max_running_requests_threshold ): - if batch_size > skip_max_running_requests_threshold: - print( - "=" * 8 - + f"Skip benchmark {batch_size=} > {skip_max_running_requests_threshold=} due to max running requests limit." - + "=" * 8 - ) - return True + # NOTE HACK return False + # if batch_size > skip_max_running_requests_threshold: + # print( + # "=" * 8 + # + f"Skip benchmark {batch_size=} > {skip_max_running_requests_threshold=} due to max running requests limit." + # + "=" * 8 + # ) + # return True + # return False def get_report_summary( diff --git a/scripts/bench_gpqa_aime25.py b/scripts/bench_gpqa_aime25.py new file mode 100644 index 000000000000..9769dd8e0ae7 --- /dev/null +++ b/scripts/bench_gpqa_aime25.py @@ -0,0 +1,243 @@ +# This script should be used inside the container. Before testing anything, please +# 1. install typer +# 2. set the following environment variables: +# - HOST: the host to connect to (default 127.0.0.1) +# - PORT: the port to connect to (default 30010) +# - HF_TOKEN: needed for `setup-ns` +# 3. checkout to the commit you want to test + +# Caution!!! +# This script assumes that thinking mode can be controlled from SGLang side. (with an environ or argument) +# e.g. ++chat_template_kwargs.thinking=true is not included in the nemo skills command + +# Test GPQA: +# python bench_gpqa_aime25.py setup-ns +# python bench_gpqa_aime25.py run-gpqa --num-repeats 16 --temperature 1.0 --max-tokens 400000 --max-concurrency 512 + +# Test AIME25: +# python bench_gpqa_aime25.py setup-ns +# python bench_gpqa_aime25.py run-aime25 --num-repeats 16 --temperature 1.0 --max-tokens 400000 --max-concurrency 512\ +# python bench_gpqa_aime25.py regrade-aime25 # Post process that bypasses box limitation + + +import os +import random +import subprocess +import time +from typing import Annotated + +import typer + +app = typer.Typer() + +# Some manually set configs: +MODEL_PATH = "deepseek-ai/DeepSeek-V4-Pro" +HOST = os.environ.get("HOST", "127.0.0.1") +PORT = int(os.environ.get("PORT", "30000")) +LOG_DIR = "/sgl-workspace/logs" + +NS_VENV = "/sgl-workspace/ns-venv" + +info_msg = f""" +Using configurations: +MODEL_PATH: {MODEL_PATH} +HOST: {HOST} +PORT: {PORT} +LOG_DIR: {LOG_DIR} +""" + +# input(info_msg + "\nPress Enter to continue...") +print(info_msg) + + +def _venv_cmd(venv_dir: str, cmd: str) -> str: + return f"source {venv_dir}/bin/activate && {cmd}" + + +def get_timestamp(): + return time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) + + +def get_random_int(): + return random.randint(0, 10000) + + +@app.command() +def setup_ns(): + HF_TOKEN = os.getenv("HF_TOKEN", None) + if HF_TOKEN is None: + raise ValueError("Please set HF_TOKEN for nemo skill setup") + exec_command(f"uv venv {NS_VENV}") + exec_command( + _venv_cmd( + NS_VENV, + "uv pip install git+https://github.com/NVIDIA-NeMo/Skills.git@d77caab 'tree_sitter_language_pack<1.0' --reinstall-package blinker", + ) + ) + exec_command(_venv_cmd(NS_VENV, f"HF_TOKEN={HF_TOKEN} ns prepare_data aime25")) + # User might be asked for access of GPQA dataset. Just click in the hugging face website to grant access. + exec_command( + _venv_cmd(NS_VENV, f"HF_TOKEN={HF_TOKEN} ns prepare_data gpqa --split diamond") + ) + + +@app.command() +def run_gpqa( + num_repeats: Annotated[int, typer.Option()] = 16, + temperature: Annotated[float, typer.Option()] = 1.0, + max_tokens: Annotated[int, typer.Option()] = 60000, + max_concurrency: Annotated[int, typer.Option()] = 64, +): + if not os.path.exists(f"{LOG_DIR}/gpqa_logs"): + exec_command(f"mkdir -p {LOG_DIR}/gpqa_logs") + + random_seed = get_random_int() + gpqa_log_folder = f"{LOG_DIR}/gpqa_logs/{get_timestamp()}_{random_seed}" + exec_command(f"mkdir -p {gpqa_log_folder}") + print(f"Running GPQA, log folder: {gpqa_log_folder}") + + exec_command( + _venv_cmd( + NS_VENV, + f"nohup ns eval " + f"--server_type=openai " + f"--model={MODEL_PATH} " + f"--server_address=http://{HOST}:{PORT}/v1 " + f"--benchmarks=gpqa:{num_repeats} " + f"--output_dir={gpqa_log_folder} " + f"++inference.tokens_to_generate={max_tokens} " + f"++max_concurrent_requests={max_concurrency} " + f"++inference.temperature={temperature} " + f"++inference.top_p=1.0 " + f"++inference.timeout=25000000 " + f"--starting_seed {random_seed} " + f"> {gpqa_log_folder}/output.log 2>&1 &", + ) + ) + + +@app.command() +def run_aime25( + num_repeats: Annotated[int, typer.Option()] = 16, + temperature: Annotated[float, typer.Option()] = 1.0, + max_tokens: Annotated[int, typer.Option()] = 60000, + max_concurrency: Annotated[int, typer.Option()] = 64, +): + if not os.path.exists(f"{LOG_DIR}/aime25_logs"): + exec_command(f"mkdir -p {LOG_DIR}/aime25_logs") + + random_seed = get_random_int() + aime25_log_folder = f"{LOG_DIR}/aime25_logs/{get_timestamp()}_{random_seed}" + exec_command(f"mkdir -p {aime25_log_folder}") + print(f"Running AIME25, log folder: {aime25_log_folder}") + + exec_command( + _venv_cmd( + NS_VENV, + f"nohup ns eval " + f"--server_type=openai " + f"--model={MODEL_PATH} " + f"--server_address=http://{HOST}:{PORT}/v1 " + f"--benchmarks=aime25:{num_repeats} " + f"--output_dir={aime25_log_folder} " + f"++inference.tokens_to_generate={max_tokens} " + f"++max_concurrent_requests={max_concurrency} " + f"++inference.temperature={temperature} " + f"++inference.top_p=1.0 " + f"++inference.timeout=25000000 " + f"--starting_seed {random_seed} " + f"> {aime25_log_folder}/output.log 2>&1 &", + ) + ) + + +@app.command() +def exec_command(cmd: str, capture_output: bool = False) -> str | None: + print(f"EXEC: {cmd}", flush=True) + return subprocess.run( + ["bash", "-c", cmd], + shell=False, + check=True, + capture_output=capture_output, + **(dict(text=True) if capture_output else {}), + ) + + +# --------------------------------------------------------------------------- +# Post-eval relaxed grade for AIME25 +# --------------------------------------------------------------------------- +# Why this exists: nemo-skills' default extractor uses `\boxed{}` only, and the +# DeepSeek-V4-Pro generations sometimes finish with prose like +# "**Answer:** 336" or "the final answer is 821" instead of `\boxed{}`. Those +# get scored as no-answer. We can't pass a relaxed regex via CLI because +# nemo-run rebuilds the inner shell command and unquotes / strips backslashes, +# so the parens land in bash bare and break it. Instead, run this command +# *after* `run-aime25` finishes to re-extract via a fallback regex *only* when +# the boxed extractor returned None, then regenerate metrics.json via +# `ns summarize_results`. +# +# Usage: +# python bench_gpqa_aime25.py regrade-aime25 +# # e.g. /sgl-workspace/logs/aime25_logs/20260427005929_3235 + + +@app.command() +def regrade_aime25( + log_folder: Annotated[ + str, + typer.Argument(help="The aime25_logs/_ dir from run-aime25."), + ], +): + import glob + import json + import sys + + sys.path.insert(0, f"{NS_VENV}/lib/python3.12/site-packages") + from nemo_skills.evaluation.math_grader import extract_answer, math_equal + + FALLBACK_REGEX = ( + r"(?:\*\*Answer\*\*[^0-9\-]{0,30}" + r"|(?i:final answer)[^0-9\-]{0,30}" + r"|(?i:answer)\s*(?:is|=|:)[^0-9\-]{0,30})(-?\d+)" + ) + + eval_dir = f"{log_folder}/eval-results/aime25" + files = sorted(glob.glob(f"{eval_dir}/output-rs*.jsonl")) + if not files: + raise typer.Exit(f"No output-rs*.jsonl files in {eval_dir}") + print(f"Re-extracted {len(files)} files in {eval_dir}") + + total = recovered = 0 + for f in files: + lines_out = [] + changed = False + with open(f) as fp: + for line in fp: + r = json.loads(line) + total += 1 + if r.get("predicted_answer") is None: + new_pred = extract_answer( + r["generation"], + extract_from_boxed=False, + extract_regex=FALLBACK_REGEX, + ) + if new_pred is not None: + r["predicted_answer"] = new_pred + r["symbolic_correct"] = bool( + math_equal(r["expected_answer"], new_pred) + ) + recovered += 1 + changed = True + lines_out.append(json.dumps(r)) + if changed: + with open(f, "w") as fp: + fp.write("\n".join(lines_out) + "\n") + print(f"Re-extracted {recovered} / {total} previously-no-answer records.") + + # Regenerate metrics.json via ns summarize_results (runs inside the venv). + exec_command(_venv_cmd(NS_VENV, f"ns summarize_results {log_folder}")) + print(f"Updated metrics: {eval_dir}/metrics.json") + + +if __name__ == "__main__": + app() diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 81a020d437e2..984ea74377a5 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -52,24 +52,6 @@ FetchContent_Declare( ) FetchContent_Populate(repo-cutlass) -# DeepGEMM -FetchContent_Declare( - repo-deepgemm - GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM - GIT_TAG 54f99a8af537b3c6eb4819b69907ccbe2b600792 - GIT_SHALLOW OFF -) -FetchContent_Populate(repo-deepgemm) - -# fmt -FetchContent_Declare( - repo-fmt - GIT_REPOSITORY https://github.com/fmtlib/fmt - GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28 - GIT_SHALLOW OFF -) -FetchContent_Populate(repo-fmt) - # Triton kernel FetchContent_Declare( repo-triton @@ -88,15 +70,6 @@ FetchContent_Declare( ) FetchContent_Populate(repo-flashinfer) -# flash-attention -FetchContent_Declare( - repo-flash-attention - GIT_REPOSITORY https://github.com/sgl-project/sgl-attn - GIT_TAG f866ec34002250e74c8bbcbcffa0e1ae71300b2d - GIT_SHALLOW OFF -) -FetchContent_Populate(repo-flash-attention) - # mscclpp FetchContent_Declare( repo-mscclpp @@ -225,6 +198,18 @@ if (ENABLE_BELOW_SM90) endif() +if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4") + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-gencode=arch=compute_90a,code=sm_90a" + ) +endif() + +if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4) + list(APPEND SGL_KERNEL_CUDA_FLAGS + "-DENABLE_NVFP4=1" + ) +endif() + if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100a,code=sm_100a" @@ -251,18 +236,6 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) endif() endif() -if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4") - set(SGL_KERNEL_ENABLE_FA3 ON) - list(APPEND SGL_KERNEL_CUDA_FLAGS - "-gencode=arch=compute_90a,code=sm_90a" - ) -endif() - -if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_FP4) - list(APPEND SGL_KERNEL_CUDA_FLAGS - "-DENABLE_NVFP4=1" - ) -endif() # All source files # NOTE: Please sort the filenames alphabetically @@ -347,11 +320,6 @@ set(SOURCES "${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform_cuda.cu" "${repo-fast-hadamard-transform_SOURCE_DIR}/csrc/fast_hadamard_transform.cpp" - "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu" - "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu" - "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu" - "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu" - "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp" ) set(INCLUDES @@ -362,7 +330,6 @@ set(INCLUDES ${repo-mscclpp_SOURCE_DIR}/include ${repo-cutlass_SOURCE_DIR}/examples/77_blackwell_fmha ${repo-cutlass_SOURCE_DIR}/examples/common - ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src ) # =========================== Common SM90 Build ============================= # @@ -379,20 +346,6 @@ set_target_properties(common_ops_sm90_build PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm90" ) -# =========================== Common SM100+ Build ============================= # -# Build SM100+ library with precise math (same namespace, different directory) -Python_add_library(common_ops_sm100_build MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) - -target_compile_options(common_ops_sm100_build PRIVATE - $<$:${SGL_KERNEL_CUDA_FLAGS}> -) -target_include_directories(common_ops_sm100_build PRIVATE ${INCLUDES}) -# Set output name and separate build directory to avoid conflicts -set_target_properties(common_ops_sm100_build PROPERTIES - OUTPUT_NAME "common_ops" - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/sm100" -) - find_package(Python3 COMPONENTS Interpreter REQUIRED) execute_process( COMMAND ${Python3_EXECUTABLE} -c "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))" @@ -419,217 +372,11 @@ add_subdirectory( ) target_link_libraries(common_ops_sm90_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) -target_link_libraries(common_ops_sm100_build PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt mscclpp_static) -# sparse flash attention -target_compile_definitions(common_ops_sm90_build PRIVATE - FLASHATTENTION_DISABLE_BACKWARD - FLASHATTENTION_DISABLE_DROPOUT - FLASHATTENTION_DISABLE_UNEVEN_K -) -target_compile_definitions(common_ops_sm100_build PRIVATE - FLASHATTENTION_DISABLE_BACKWARD - FLASHATTENTION_DISABLE_DROPOUT - FLASHATTENTION_DISABLE_UNEVEN_K -) - -# Install to different subdirectories -# CMake will find the built libraries in their respective LIBRARY_OUTPUT_DIRECTORY locations -# and install them to the specified destinations install(TARGETS common_ops_sm90_build LIBRARY DESTINATION sgl_kernel/sm90) -install(TARGETS common_ops_sm100_build LIBRARY DESTINATION sgl_kernel/sm100) - -# ============================ Optional Install: FA3 ============================= # -# set flash-attention sources file -# Now FA3 support sm80/sm86/sm90 -if (SGL_KERNEL_ENABLE_FA3) - set(SGL_FLASH_KERNEL_CUDA_FLAGS - "-DNDEBUG" - "-DOPERATOR_NAMESPACE=sgl-kernel" - "-O3" - "-Xcompiler" - "-fPIC" - "-gencode=arch=compute_90a,code=sm_90a" - "-std=c++17" - "-DCUTE_USE_PACKED_TUPLE=1" - "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" - "-DCUTLASS_VERSIONS_GENERATED" - "-DCUTLASS_TEST_LEVEL=0" - "-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1" - "-DCUTLASS_DEBUG_TRACE_LEVEL=0" - "--expt-relaxed-constexpr" - "--expt-extended-lambda" - "--use_fast_math" - "-Xcompiler=-Wconversion" - "-Xcompiler=-fno-strict-aliasing" - ) - - if (ENABLE_BELOW_SM90) - list(APPEND SGL_FLASH_KERNEL_CUDA_FLAGS - "-gencode=arch=compute_80,code=sm_80" - "-gencode=arch=compute_86,code=sm_86" - ) - # SM8X Logic - file(GLOB FA3_SM8X_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdim*_sm80.cu") - endif() - - file(GLOB FA3_BF16_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_bf16*_sm90.cu") - file(GLOB FA3_BF16_GEN_SRCS_ - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_bf16*_sm90.cu") - list(APPEND FA3_BF16_GEN_SRCS ${FA3_BF16_GEN_SRCS_}) - - # FP16 source files - file(GLOB FA3_FP16_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_fp16*_sm90.cu") - file(GLOB FA3_FP16_GEN_SRCS_ - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_fp16*_sm90.cu") - list(APPEND FA3_FP16_GEN_SRCS ${FA3_FP16_GEN_SRCS_}) - - # FP8 source files - file(GLOB FA3_FP8_GEN_SRCS - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimall_e4m3*_sm90.cu") - file(GLOB FA3_FP8_GEN_SRCS_ - "${repo-flash-attention_SOURCE_DIR}/hopper/instantiations/flash_fwd_hdimdiff_e4m3*_sm90.cu") - list(APPEND FA3_FP8_GEN_SRCS ${FA3_FP8_GEN_SRCS_}) - - set(FA3_GEN_SRCS ${FA3_BF16_GEN_SRCS} ${FA3_FP16_GEN_SRCS} ${FA3_FP8_GEN_SRCS} ${FA3_SM8X_GEN_SRCS}) - - set(FLASH_SOURCES - "csrc/flash_extension.cc" - "${repo-flash-attention_SOURCE_DIR}/hopper/flash_prepare_scheduler.cu" - "${repo-flash-attention_SOURCE_DIR}/hopper/flash_api.cpp" - "${repo-flash-attention_SOURCE_DIR}/hopper/flash_fwd_combine.cu" - "${FA3_GEN_SRCS}" - ) - - Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES}) - - target_compile_options(flash_ops PRIVATE $<$:${SGL_FLASH_KERNEL_CUDA_FLAGS}>) - target_include_directories(flash_ops PRIVATE - ${repo-cutlass_SOURCE_DIR}/include - ${repo-cutlass_SOURCE_DIR}/tools/util/include - ${repo-flash-attention_SOURCE_DIR}/hopper - ) - target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) - - install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel") - set(FLASH_OPS_COMPILE_DEFS - FLASHATTENTION_DISABLE_BACKWARD - FLASHATTENTION_DISABLE_DROPOUT - FLASHATTENTION_DISABLE_UNEVEN_K - FLASHATTENTION_VARLEN_ONLY - ) - - if(NOT ENABLE_BELOW_SM90) - list(APPEND FLASH_OPS_COMPILE_DEFS FLASHATTENTION_DISABLE_SM8x) - endif() - target_compile_definitions(flash_ops PRIVATE ${FLASH_OPS_COMPILE_DEFS}) -endif() - -# Build spatial_ops as a separate, optional extension for green contexts -set(SPATIAL_SOURCES - "csrc/spatial/greenctx_stream.cu" - "csrc/spatial_extension.cc" -) - -Python_add_library(spatial_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SPATIAL_SOURCES}) -target_compile_options(spatial_ops PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) -target_link_libraries(spatial_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) -install(TARGETS spatial_ops LIBRARY DESTINATION sgl_kernel) - -# ============================ Extra Install: FLashMLA ============================= # -include(${CMAKE_CURRENT_LIST_DIR}/cmake/flashmla.cmake) - -# ============================ Extra Install: DeepGEMM (JIT) ============================= # -# Create a separate library for DeepGEMM's Python API. -# This keeps its compilation isolated from the main common_ops. -set(DEEPGEMM_SOURCES - "${repo-deepgemm_SOURCE_DIR}/csrc/python_api.cpp" -) - -Python_add_library(deep_gemm_cpp MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${DEEPGEMM_SOURCES}) - -# Link against necessary libraries, including nvrtc for JIT compilation. -target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} c10 cuda nvrtc mscclpp_static) - -# Add include directories needed by DeepGEMM. -target_include_directories(deep_gemm_cpp PRIVATE - ${repo-deepgemm_SOURCE_DIR}/deep_gemm/include - ${repo-cutlass_SOURCE_DIR}/include - ${repo-fmt_SOURCE_DIR}/include -) - -# Apply the same compile options as common_ops. -target_compile_options(deep_gemm_cpp PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) - -# Create an empty __init__.py to make `deepgemm` a Python package. -file(WRITE ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py "") -install( - FILES ${CMAKE_CURRENT_BINARY_DIR}/deepgemm_pkg_init.py - DESTINATION deep_gemm - RENAME __init__.py -) - -# Install the compiled DeepGEMM API library. -install(TARGETS deep_gemm_cpp LIBRARY DESTINATION deep_gemm) - -# Install the source files required by DeepGEMM for runtime JIT compilation. -install( - DIRECTORY ${repo-deepgemm_SOURCE_DIR}/deep_gemm/ - DESTINATION deep_gemm -) - -install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cute/" - DESTINATION "deep_gemm/include/cute") - -install(DIRECTORY "${repo-cutlass_SOURCE_DIR}/include/cutlass/" - DESTINATION "deep_gemm/include/cutlass") # ============================ Extra Install: triton kernels ============================= # install(DIRECTORY "${repo-triton_SOURCE_DIR}/python/triton_kernels/triton_kernels/" DESTINATION "triton_kernels" PATTERN ".git*" EXCLUDE PATTERN "__pycache__" EXCLUDE) - -# ============================ Extra Install: FA4 ============================= # -# TODO: find a better install condition. -if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) - - set(FLASH_ATTN_CUTE_SRC "${repo-flash-attention_SOURCE_DIR}/flash_attn/cute") - set(FLASH_ATTN_CUTE_DST "${CMAKE_CURRENT_BINARY_DIR}/flash_attn_origin/cute") - - file(MAKE_DIRECTORY "${FLASH_ATTN_CUTE_DST}") - - file(COPY "${FLASH_ATTN_CUTE_SRC}/" - DESTINATION "${FLASH_ATTN_CUTE_DST}" - PATTERN ".git*" EXCLUDE - PATTERN "__pycache__" EXCLUDE) - - file(GLOB_RECURSE FLASH_ATTN_CUTE_DST_PY - "${FLASH_ATTN_CUTE_DST}/*.py") - - foreach(FILE_PATH IN LISTS FLASH_ATTN_CUTE_DST_PY) - file(READ "${FILE_PATH}" FILE_CONTENT) - - set(MODIFIED_CONTENT "${FILE_CONTENT}") - - # The main goal is to avoid using "flash_attn" so that other libraries (such as transformers) do not mistakenly assume that "flash_attn" is already installed. - - string(REPLACE "flash_attn.cute" - "flash_attn_origin.cute" - MODIFIED_CONTENT "${MODIFIED_CONTENT}") - - if (NOT FILE_CONTENT STREQUAL MODIFIED_CONTENT) - file(WRITE "${FILE_PATH}" "${MODIFIED_CONTENT}") - message(STATUS " - [FA4 Patch] Patched: ${FILE_PATH}") - endif() - endforeach() - - install(DIRECTORY "${FLASH_ATTN_CUTE_DST}/" - DESTINATION "flash_attn_origin/cute" - PATTERN ".git*" EXCLUDE - PATTERN "__pycache__" EXCLUDE) - -endif() diff --git a/sgl-kernel/cmake/flashmla.cmake b/sgl-kernel/cmake/flashmla.cmake index c17266af243f..8b7c1d5f0ca2 100644 --- a/sgl-kernel/cmake/flashmla.cmake +++ b/sgl-kernel/cmake/flashmla.cmake @@ -30,6 +30,11 @@ if(${CUDA_VERSION} VERSION_GREATER 12.8) "-gencode=arch=compute_100a,code=sm_100a" ) endif() +if(${CUDA_VERSION} VERSION_GREATER_EQUAL 13.0) +list(APPEND FLASHMLA_CUDA_FLAGS + "-gencode=arch=compute_103a,code=sm_103a" +) +endif() set(FlashMLA_SOURCES diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index b38eab218dff..8274896795a5 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -455,28 +455,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()"); m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits); - /* - * From Sparse Flash Attention - */ - m.def( - "fwd_sparse(Tensor! q, Tensor k, Tensor v, " - "Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, " - "Tensor!? out, Tensor? alibi_slopes, " - "float p_dropout, float softmax_scale, bool is_causal, " - "float softcap, bool return_softmax, Generator? gen)" - "-> Tensor[]"); - m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse); - - m.def( - "varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, " - "Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, " - "Tensor!? out, Tensor cu_seqlens_q, " - "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, " - "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " - "bool is_causal, float softcap, bool return_softmax, " - "Generator? gen) -> Tensor[]"); - m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse); - // Sparse Attention utils m.def( "convert_vertical_slash_indexes(" diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h index 1ba99787ba60..01da81c95395 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -28,6 +28,8 @@ #include "gemm/marlin/marlin_dtypes.cuh" #include "scalar_type.hpp" +#include + #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert( \ std::is_same::value || std::is_same::value, \ @@ -357,6 +359,7 @@ __global__ void Marlin( constexpr bool has_zp = w_type == sglang::kU4 || w_type == sglang::kU8; constexpr bool is_int_type = w_type == sglang::kU4 || w_type == sglang::kU8 || w_type == sglang::kU4B8 || w_type == sglang::kU8B128; + constexpr bool is_8bit_scale = s_type.size_bits() == 8; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = w_type == sglang::kFE4M3fn || w_type == sglang::kFE2M1f && s_type == sglang::kFE4M3fn || @@ -371,7 +374,7 @@ __global__ void Marlin( static_assert(thread_m_blocks == 1 || !m_block_size_8); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == sglang::kFE2M1f ? 16 : 8); + const int scales_expert_stride = prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); const int b_bias_expert_stride = prob_n / 8; @@ -442,52 +445,75 @@ __global__ void Marlin( locks_off = (iters * blockIdx.x) / k_tiles - 1; } + int prob_m_top_k = prob_m * top_k; // read moe block data given block_id // block_sorted_ids / block_num_valid_tokens / block_topk_weights auto read_moe_block_data = [&](int block_id) { block_num_valid_tokens = moe_block_size; + + cp_async4_pred( + sh_block_sorted_ids_int4 + threadIdx.x, + reinterpret_cast(sorted_token_ids_ptr) + + (block_id * moe_block_size / 4 + threadIdx.x), + threadIdx.x < moe_block_size / 4); + + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + if (threadIdx.x >= threads - 32) { + constexpr int size_per_thread = div_ceil(moe_block_size, 32); + int lane_id = threadIdx.x - (threads - 32); + + int local_count = 0; #pragma unroll - for (int i = 0; i < moe_block_size / 4; i++) { - int4 sorted_token_ids_int4 = - reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; - int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); -#pragma unroll - for (int j = 0; j < 4; j++) { - if (sorted_token_ids[j] >= prob_m * top_k) { - block_num_valid_tokens = i * 4 + j; - break; + for (int i = 0; i < size_per_thread; i++) { + int j = lane_id * size_per_thread + i; + if (j < moe_block_size) { + int idx = sh_block_sorted_ids[j]; + if (idx < prob_m_top_k) local_count++; } } - if (block_num_valid_tokens != moe_block_size) break; - } - __syncthreads(); - int tid4 = threadIdx.x / 4; - if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { - sh_block_sorted_ids_int4[tid4] = - reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + if constexpr (moe_block_size >= 16) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16); + if constexpr (moe_block_size >= 8) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8); + if constexpr (moe_block_size >= 4) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4); + if constexpr (moe_block_size >= 2) + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2); + + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1); + block_num_valid_tokens = local_count; +#else + block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); +#endif -#pragma unroll - for (int i = 0; i < 4; i++) - sh_rd_block_sorted_ids[tid4 * 4 + i] = sh_block_sorted_ids[tid4 * 4 + i] / top_k; + if (lane_id == 0) reinterpret_cast(sh_new)[0] = block_num_valid_tokens; + } + + if (threadIdx.x < moe_block_size) { + int idx = sh_block_sorted_ids[threadIdx.x]; + sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k; if (mul_topk_weights) { -#pragma unroll - for (int i = 0; i < 4; i++) { - int idx = tid4 * 4 + i; - // idx = idx < block_num_valid_tokens ? idx : 0; - if (idx < block_num_valid_tokens) { - if constexpr (w_type == sglang::kFE2M1f && s_type == sglang::kFE4M3fn) { - sh_block_topk_weights[idx] = - __hmul2(global_scale, Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = - Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); - } - } + idx = idx < prob_m_top_k ? idx : 0; + scalar_t topk_weight_tmp = Dtype::float2num(topk_weights_ptr[idx]); + if constexpr (w_type == sglang::kFE2M1f && s_type == sglang::kFE4M3fn) { + sh_block_topk_weights[threadIdx.x] = + __hmul2(global_scale, Dtype::num2num2(topk_weight_tmp)); + } else { + sh_block_topk_weights[threadIdx.x] = Dtype::num2num2(topk_weight_tmp); } } } + + __syncthreads(); + + block_num_valid_tokens = reinterpret_cast(sh_new)[0]; __syncthreads(); }; @@ -629,11 +655,10 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == sglang::kFE2M1f ? 2 : 1) - : 1; + int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -684,13 +709,15 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == sglang::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; @@ -708,15 +735,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == sglang::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -910,43 +929,21 @@ __global__ void Marlin( } else { if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -1024,35 +1021,33 @@ __global__ void Marlin( } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } } } else { auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == sglang::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != sglang::kFE2M1f.id()) { - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + if constexpr (!is_8bit_scale) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; } else { reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2]; + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } @@ -1246,17 +1241,27 @@ __global__ void Marlin( } } - // Commented out FP4/FP8 scale dequantization since we don't generate - // kFE2M1f kernels to reduce compilation time - // if constexpr (w_type == sglang::kFE2M1f) { - // int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - // int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + // FP4/FP8 scale dequantization (E4M3 for NVFP4 and E8M0 for MXFP4). + // Aligns with vLLM's single-shot marlin_template.h path: convert the raw + // FP8 scale bytes packed into frag_s back to real bf16/half values before + // they multiply the dequantized weights. // - // dequant_fp8_scales( - // s_quant_0, reinterpret_cast(&frag_s[k2])); - // dequant_fp8_scales( - // s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - // } + // The half2 + kFE8M0fnu specialization of dequant_fp8_scales is not + // defined on purpose (fp16 cannot safely represent 2^e for large |e|, so + // fp16+MXFP4 kernels are intentionally not instantiated in + // generate_kernels.py). Guard that combination at compile time so the + // unused code path is eliminated and no unresolved extern is emitted. + if constexpr ((s_type == sglang::kFE4M3fn || s_type == sglang::kFE8M0fnu) && + !(std::is_same::value && + s_type == sglang::kFE8M0fnu)) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. @@ -1885,8 +1890,18 @@ __global__ void Marlin( slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } start_pipes(); } diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index 57334663ad48..20f7de059f37 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -274,6 +274,46 @@ int get_kernel_cache_size( return total_size; } +sglang::ScalarType infer_scale_type( + const sglang::ScalarType& q_type, + const torch::Tensor& b_scales, + at::ScalarType compute_dtype) { + if (q_type == sglang::kFE2M1f) { + if (b_scales.scalar_type() == at::ScalarType::Float8_e4m3fn) { + return sglang::kFE4M3fn; + } + if (b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + return sglang::kFE8M0fnu; + } + TORCH_CHECK( + false, + "When q_type=float4_e2m1f, b_scales must use float8_e4m3fn (NVFP4) " + "or float8_e8m0fnu (MXFP4). Got scalar_type=", + b_scales.scalar_type()); + return sglang::kFloat16; + } + + if (b_scales.scalar_type() == at::ScalarType::Half) { + return sglang::kFloat16; + } + if (b_scales.scalar_type() == at::ScalarType::BFloat16) { + return sglang::kBFloat16; + } + + if (compute_dtype == at::ScalarType::Half) { + return sglang::kFloat16; + } + if (compute_dtype == at::ScalarType::BFloat16) { + return sglang::kBFloat16; + } + + TORCH_CHECK( + false, + "Unsupported scale dtype for Marlin MoE: ", + b_scales.scalar_type()); + return sglang::kFloat16; +} + bool is_valid_config( thread_config_t const& th_config, bool m_block_size_8, @@ -325,27 +365,25 @@ bool is_valid_config( return cache_size + 512 <= max_shared_mem; } -#define _GET_IF( \ - W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ - else if ( \ - q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = W_TYPE == sglang::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? sglang::kFE4M3fn : sglang::kFE8M0fnu) \ - : (std::is_same::value ? sglang::kFloat16 : sglang::kBFloat16); \ - kernel = Marlin< \ - scalar_t, \ - W_TYPE.id(), \ - S_TYPE.id(), \ - NUM_THREADS, \ - THREAD_M_BLOCKS, \ - THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, \ - pipe_stages, \ - GROUP_BLOCKS, \ - IS_ZP_FLOAT>; \ +#define _GET_IF( \ + W_TYPE, S_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && s_type == S_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + S_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -354,123 +392,124 @@ bool is_valid_config( // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) // FP4: cases for nvfp4(e2m1) (group_blocks == 1) -#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) - -#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - -#define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) - -#define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - -#define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) - -#define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - -#define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - -#define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) +#define COMMON_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE, S_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE, S_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) + +#define NVFP4_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define NVFP4_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define NVFP4_GET_IF(W_TYPE, S_TYPE) \ + NVFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + NVFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + NVFP4_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + NVFP4_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) + +#define MXFP4_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + +#define MXFP4_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) + +#define MXFP4_GET_IF(W_TYPE, S_TYPE) \ + MXFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + MXFP4_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + MXFP4_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + MXFP4_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) // We currently have 4-bit models only with group_blocks == 4 -#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) +#define FZP_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) -#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) +#define FZP_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) -#define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) +#define FZP_GET_IF(W_TYPE, S_TYPE) \ + FZP_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) // We currently have 4-bit models only with group_blocks == 4 -#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) +#define ACT_GET_IF_M1(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) -#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) +#define ACT_GET_IF_M234(W_TYPE, S_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, S_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, S_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) -#define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) +#define ACT_GET_IF(W_TYPE, S_TYPE) \ + ACT_GET_IF_M1(W_TYPE, S_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, S_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, S_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, S_TYPE, 8, 4, 128) template MarlinFuncPtr get_marlin_kernel( const sglang::ScalarType q_type, + const sglang::ScalarType s_type, int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, @@ -480,25 +519,28 @@ MarlinFuncPtr get_marlin_kernel( int group_blocks, int num_threads, bool is_zp_float) { - int num_bits = q_type.size_bits(); auto kernel = MarlinDefault; - if (false) { - } - - COMMON_GET_IF(sglang::kU4) - COMMON_GET_IF(sglang::kU4B8) - COMMON_GET_IF(sglang::kU8B128) - - NVFP4_GET_IF(sglang::kFE2M1f) - - BIGGROUP_GET_IF(sglang::kFE4M3fn) - - ACT_GET_IF(sglang::kU4B8) - ACT_GET_IF(sglang::kU8B128) - if (std::is_same::value) { + if constexpr (std::is_same::value) { if (false) { } - MXFP4_GET_IF(sglang::kFE2M1f) + COMMON_GET_IF(sglang::kU4, sglang::kFloat16) + COMMON_GET_IF(sglang::kU4B8, sglang::kFloat16) + COMMON_GET_IF(sglang::kU8B128, sglang::kFloat16) + NVFP4_GET_IF(sglang::kFE2M1f, sglang::kFE4M3fn) + BIGGROUP_GET_IF(sglang::kFE4M3fn, sglang::kFloat16) + ACT_GET_IF(sglang::kU4B8, sglang::kFloat16) + ACT_GET_IF(sglang::kU8B128, sglang::kFloat16) + } else { + if (false) { + } + COMMON_GET_IF(sglang::kU4, sglang::kBFloat16) + COMMON_GET_IF(sglang::kU4B8, sglang::kBFloat16) + COMMON_GET_IF(sglang::kU8B128, sglang::kBFloat16) + NVFP4_GET_IF(sglang::kFE2M1f, sglang::kFE4M3fn) + BIGGROUP_GET_IF(sglang::kFE4M3fn, sglang::kBFloat16) + ACT_GET_IF(sglang::kU4B8, sglang::kBFloat16) + ACT_GET_IF(sglang::kU8B128, sglang::kBFloat16) + MXFP4_GET_IF(sglang::kFE2M1f, sglang::kFE8M0fnu) } return kernel; @@ -507,6 +549,7 @@ MarlinFuncPtr get_marlin_kernel( template exec_config_t determine_exec_config( const sglang::ScalarType& q_type, + const sglang::ScalarType& s_type, int prob_m, int prob_n, int prob_k, @@ -567,6 +610,7 @@ exec_config_t determine_exec_config( auto kernel = get_marlin_kernel( q_type, + s_type, thread_m_blocks, th_config.thread_n / 16, th_config.thread_k / 16, @@ -624,6 +668,7 @@ void marlin_mm( int prob_k, void* workspace, sglang::ScalarType const& q_type, + sglang::ScalarType const& s_type, bool has_bias, bool has_act_order, bool is_k_full, @@ -677,6 +722,17 @@ void marlin_mm( } } + if (q_type == sglang::kFE2M1f) { + TORCH_CHECK( + (group_size == 16 && s_type == sglang::kFE4M3fn) || + (group_size == 32 && s_type == sglang::kFE8M0fnu), + "float4_e2m1f expects s_type=float8_e4m3fn for group_size=16 (NVFP4) " + "or s_type=float8_e8m0fnu for group_size=32 (MXFP4). Got group_size=", + group_size, + ", s_type=", + s_type.str()); + } + int num_bits = q_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; @@ -742,6 +798,7 @@ void marlin_mm( // Auto config exec_cfg = determine_exec_config( q_type, + s_type, prob_m, prob_n, prob_k, @@ -812,6 +869,7 @@ void marlin_mm( auto kernel = get_marlin_kernel( q_type, + s_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, @@ -845,7 +903,9 @@ void marlin_mm( ", thread_k_blocks = ", thread_k_blocks, ", num_bits = ", - num_bits); + num_bits, + ", s_type = ", + s_type.str()); } cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); @@ -889,6 +949,9 @@ torch::Tensor moe_wna16_marlin_gemm( bool use_fp32_reduce, bool is_zp_float) { sglang::ScalarType const b_q_type = sglang::ScalarType::from_id(b_q_type_id); + sglang::ScalarType const s_type = + MARLIN_NAMESPACE_NAME::infer_scale_type( + b_q_type, b_scales, a.scalar_type()); int pack_factor = 32 / b_q_type.size_bits(); if (moe_block_size != 8) { @@ -1123,13 +1186,10 @@ torch::Tensor moe_wna16_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; - if (b_q_type == sglang::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, "float4_e2m1f only supports group_size == 16 (NVFP4) ", "and group_size == 32 (MXFP4)"); + if (s_type == sglang::kFE4M3fn) { + scales_ptr = b_scales.data_ptr(); + } else if (s_type == sglang::kFE8M0fnu) { + scales_ptr = b_scales.data_ptr(); } else { scales_ptr = b_scales.data_ptr(); } @@ -1159,6 +1219,7 @@ torch::Tensor moe_wna16_marlin_gemm( size_k, workspace.data_ptr(), b_q_type, + s_type, has_bias, has_act_order, is_k_full, @@ -1175,13 +1236,10 @@ torch::Tensor moe_wna16_marlin_gemm( is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; - if (b_q_type == sglang::kFE2M1f) { - if (group_size == 16) - scales_ptr = b_scales.data_ptr(); - else if (group_size == 32) - scales_ptr = b_scales.data_ptr(); - else - TORCH_CHECK(false, "float4_e2m1f only supports group_size == 16 (NVFP4) ", "and group_size == 32 (MXFP4)"); + if (s_type == sglang::kFE4M3fn) { + scales_ptr = b_scales.data_ptr(); + } else if (s_type == sglang::kFE8M0fnu) { + scales_ptr = b_scales.data_ptr(); } else { scales_ptr = b_scales.data_ptr(); } @@ -1211,6 +1269,7 @@ torch::Tensor moe_wna16_marlin_gemm( size_k, workspace.data_ptr(), b_q_type, + s_type, has_bias, has_act_order, is_k_full, diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index 7263a492cffe..c7d3a721fdde 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -5,6 +5,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.deepseekv4_detector import DeepSeekV4Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -1612,6 +1613,393 @@ def test_streaming_no_parameters_with_whitespace(self): params = json.loads(tool_calls_by_index[0]["parameters"]) self.assertEqual(params, {}) + def test_self_closing_zero_arg_invoke(self): + """V32 inherits the same regex; verify self-closing parses to empty + params here too (V32 model rarely emits this shape, but the parser + must agree with V4 since V4 inherits from V32).""" + submit_tool = Tool( + type="function", + function=Function( + name="submit", + parameters={"type": "object", "properties": {}}, + ), + ) + text = ( + '<|DSML|function_calls>\n<|DSML|invoke name="submit"/>\n' + "" + ) + result = self.detector.detect_and_parse(text, [submit_tool]) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "submit") + self.assertEqual(json.loads(result.calls[0].parameters), {}) + + +class TestDeepSeekV4Detector(unittest.TestCase): + """DeepSeek V4 DSML tool-call tests. + + Mirrors TestDeepSeekV32Detector but targets the V4 outer block name + ``<|DSML|tool_calls>`` instead of ``<|DSML|function_calls>``. The V4 + reference encoder only emits XML-parameter form, so the V32 JSON-body + tests have no V4 analogue and are intentionally omitted. + """ + + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="search", + description="Searches for information related to query and displays topn results.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string", + }, + "topn": { + "type": "integer", + "description": "Number of top results to display", + "default": 10, + }, + "source": { + "type": "string", + "description": "Source to search within", + "enum": ["web", "news"], + "default": "web", + }, + }, + "required": ["query"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_favorite_tourist_spot", + description="Return the favorite tourist spot for a given city.", + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + ), + ] + self.detector = DeepSeekV4Detector() + from transformers import AutoTokenizer + + # V3.2 tokenizer works for the chunk-split streaming test: it already + # has the DSML special tokens and decodes the test strings losslessly. + self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3.2") + self.interval = 1 + + def test_detect_and_parse_xml_format(self): + """Test parsing standard XML format (DSML)""" + text = """I'll help you with information about San Francisco and get its favorite tourist spot for you.\n\n + <|DSML|tool_calls>\n + <|DSML|invoke name="get_favorite_tourist_spot">\n + <|DSML|parameter name="city" string="true">San Francisco\n + \n + <|DSML|invoke name="search"> + <|DSML|parameter name="query" string="true">WebNav benchmark + <|DSML|parameter name="topn" string="false">10 + <|DSML|parameter name="source" string="true">web + + + """ + result = self.detector.detect_and_parse(text, self.tools) + + self.assertIn("I'll help you with information", result.normal_text) + self.assertEqual(len(result.calls), 2) + + call1 = result.calls[0] + self.assertEqual(call1.name, "get_favorite_tourist_spot") + params1 = json.loads(call1.parameters) + self.assertEqual(params1["city"], "San Francisco") + + call2 = result.calls[1] + self.assertEqual(call2.name, "search") + params2 = json.loads(call2.parameters) + self.assertEqual(params2["query"], "WebNav benchmark") + self.assertEqual(params2["topn"], 10) + self.assertEqual(params2["source"], "web") + + def test_streaming_xml_format(self): + """Test streaming parsing of XML format""" + text = """<|DSML|tool_calls> + <|DSML|invoke name="get_favorite_tourist_spot"> + <|DSML|parameter name="city" string="true">San Francisco + <|DSML|parameter name="another_city" string="true">London + <|DSML|parameter name="topn" string="false">10 + <|DSML|parameter name="obj" string="false">{"name": "John", "age": 30} + + """ + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_favorite_tourist_spot") + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["city"], "San Francisco") + self.assertEqual(params["another_city"], "London") + self.assertEqual(params["topn"], 10) + self.assertEqual(params["obj"]["name"], "John") + self.assertEqual(params["obj"]["age"], 30) + + def test_detect_and_parse_no_parameters(self): + """Test parsing function calls with no parameters (non-streaming)""" + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """Let me get the current date for you. + +<|DSML|tool_calls> +<|DSML|invoke name="get_date"> + +""" + + result = self.detector.detect_and_parse(text, tools_with_no_param) + + self.assertIn("Let me get the current date", result.normal_text) + self.assertEqual(len(result.calls), 1) + + call = result.calls[0] + self.assertEqual(call.name, "get_date") + params = json.loads(call.parameters) + self.assertEqual(params, {}) + + def test_streaming_no_parameters(self): + """Test streaming parsing of function calls with no parameters.""" + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """<|DSML|tool_calls> +<|DSML|invoke name="get_date"> + +""" + + self.detector = DeepSeekV4Detector() + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual( + len(tool_calls_by_index), 1, "Should have exactly one tool call" + ) + self.assertEqual(tool_calls_by_index[0]["name"], "get_date") + + params_str = tool_calls_by_index[0]["parameters"].strip() + params = json.loads(params_str) + self.assertEqual(params, {}) + + def test_streaming_no_parameters_with_whitespace(self): + """Test streaming parsing when invoke content has only whitespace (newlines).""" + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """<|DSML|tool_calls> +<|DSML|invoke name="get_date"> + + +""" + + self.detector = DeepSeekV4Detector() + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual( + len(tool_calls_by_index), 1, "Should have exactly one tool call" + ) + self.assertEqual(tool_calls_by_index[0]["name"], "get_date") + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params, {}) + + def test_self_closing_zero_arg_invoke(self): + """V4 emits `<|DSML|invoke name="x"/>` for zero-arg tools; the + detector must parse it as a complete tool call with empty params + instead of leaking the raw markup back into normal_text.""" + submit_tool = Tool( + type="function", + function=Function( + name="submit", + description="Submit the final answer.", + parameters={"type": "object", "properties": {}}, + ), + ) + + text = ( + "Final answer.\n" + '<|DSML|tool_calls>\n<|DSML|invoke name="submit"/>\n' + "" + ) + result = self.detector.detect_and_parse(text, [submit_tool]) + self.assertEqual(len(result.calls), 1) + self.assertEqual(result.calls[0].name, "submit") + self.assertEqual(json.loads(result.calls[0].parameters), {}) + self.assertNotIn("DSML", result.normal_text) + + def test_self_closing_mixed_with_long_form(self): + """Mix of long-form (with params) and self-closing tags in one block.""" + submit_tool = Tool( + type="function", + function=Function( + name="submit", + parameters={"type": "object", "properties": {}}, + ), + ) + text = ( + "<|DSML|tool_calls>\n" + '<|DSML|invoke name="get_favorite_tourist_spot">\n' + '<|DSML|parameter name="city" string="true">SF\n' + "\n" + '<|DSML|invoke name="submit"/>\n' + "" + ) + result = self.detector.detect_and_parse(text, self.tools + [submit_tool]) + self.assertEqual(len(result.calls), 2) + self.assertEqual(result.calls[0].name, "get_favorite_tourist_spot") + self.assertEqual(json.loads(result.calls[0].parameters), {"city": "SF"}) + self.assertEqual(result.calls[1].name, "submit") + self.assertEqual(json.loads(result.calls[1].parameters), {}) + + def test_streaming_self_closing_invoke(self): + """Self-closing invoke must terminate cleanly even when `/>` arrives + after the `name=` attribute crosses chunk boundaries.""" + submit_tool = Tool( + type="function", + function=Function( + name="submit", + parameters={"type": "object", "properties": {}}, + ), + ) + # Build the prompt and feed it through the tokenizer to exercise the + # same chunk shapes the runtime sees. + text = ( + "<|DSML|tool_calls>\n" + '<|DSML|invoke name="submit"/>\n' + "" + ) + self.detector = DeepSeekV4Detector() + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunks = [ + self.tokenizer.decode(input_ids[i : i + self.interval]) + for i in range(0, len(input_ids), self.interval) + ] + + tool_calls_by_index = {} + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, [submit_tool]) + for call in result.calls: + if call.tool_index is None: + continue + slot = tool_calls_by_index.setdefault( + call.tool_index, {"name": "", "parameters": ""} + ) + if call.name: + slot["name"] = call.name + if call.parameters: + slot["parameters"] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "submit") + self.assertEqual(json.loads(tool_calls_by_index[0]["parameters"]), {}) + class TestQwen3CoderDetector(unittest.TestCase): """Test suite for Qwen3CoderDetector.""" diff --git a/test/registered/openai_server/basic/test_protocol.py b/test/registered/openai_server/basic/test_protocol.py index 47bf563816ef..02158b81cff9 100644 --- a/test/registered/openai_server/basic/test_protocol.py +++ b/test/registered/openai_server/basic/test_protocol.py @@ -192,6 +192,37 @@ def test_chat_completion_reasoning_effort(self): self.assertEqual(request.reasoning_effort, "high") self.assertEqual(request.chat_template_kwargs, {"thinking": True}) + def test_chat_completion_reasoning_effort_max(self): + """`max` is an sglang extension on chat completion's top-level + `reasoning_effort` only; the Responses-API-style nested + `reasoning.effort` path stays aligned with OpenAI's three levels.""" + from pydantic import ValidationError + + messages = [{"role": "user", "content": "Hello"}] + request = ChatCompletionRequest( + model="test-model", + messages=messages, + reasoning_effort="max", + ) + self.assertEqual(request.reasoning_effort, "max") + + # Unknown values still rejected. + with self.assertRaises(ValidationError): + ChatCompletionRequest( + model="test-model", + messages=messages, + reasoning_effort="ultra", + ) + + # Nested reasoning.effort=max is NOT promoted by normalize_reasoning_inputs: + # the Responses API path keeps the OpenAI low/medium/high contract. + request = ChatCompletionRequest( + model="test-model", + messages=messages, + reasoning={"effort": "max"}, + ) + self.assertNotEqual(request.reasoning_effort, "max") + def test_chat_completion_json_format(self): """Test chat completion json format""" transcript = "Good morning! It's 7:00 AM, and I'm just waking up. Today is going to be a busy day, " diff --git a/test/registered/openai_server/basic/test_serving_chat.py b/test/registered/openai_server/basic/test_serving_chat.py index d81f2efb051f..debbc66873f4 100644 --- a/test/registered/openai_server/basic/test_serving_chat.py +++ b/test/registered/openai_server/basic/test_serving_chat.py @@ -37,7 +37,7 @@ def __init__(self): tool_call_parser="hermes", reasoning_parser=None, ) - # Mock hf_config for _use_dpsk_v32_encoding check + # Mock hf_config for _resolve_chat_encoding_spec check mock_hf_config = Mock() mock_hf_config.architectures = ["LlamaForCausalLM"] self.model_config.hf_config = mock_hf_config @@ -614,18 +614,201 @@ def test_dpsk_v32_encoding_path(self): tokenizer_manager.tokenizer.chat_template = None serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) - self.assertTrue(serving_chat.use_dpsk_v32_encoding) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv32") # Case 2: Chat template exists -> should NOT use dpsk encoding tokenizer_manager.tokenizer.chat_template = "some template" serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) - self.assertFalse(serving_chat.use_dpsk_v32_encoding) + self.assertIsNone(serving_chat.chat_encoding_spec) # Case 3: Not DeepSeek V3.2 architecture -> should NOT use dpsk encoding tokenizer_manager.tokenizer.chat_template = None mock_hf_config.architectures = ["LlamaForCausalLM"] serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) - self.assertFalse(serving_chat.use_dpsk_v32_encoding) + self.assertIsNone(serving_chat.chat_encoding_spec) + + # Case 4: DeepseekV4 arch -> always dsv4, even with chat_template + # (release ships a stale V3 jinja we deliberately override). + mock_hf_config.architectures = ["DeepseekV4ForCausalLM"] + tokenizer_manager.tokenizer.chat_template = "stale v3 jinja" + serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv4") + + tokenizer_manager.tokenizer.chat_template = None + serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv4") + + # ------------- dsv4 task + latest_reminder ------------- + def test_dsv4_task_field_schema(self): + """Top-level `task` accepts the 6 DS task tokens and rejects others.""" + for valid in ("action", "query", "authority", "domain", "title", "read_url"): + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "hi"}], + task=valid, + ) + self.assertEqual(req.task, valid) + + # None / unset is fine + self.assertIsNone(self.basic_req.task) + + # Bogus value rejected at validation time + from pydantic import ValidationError + + with self.assertRaises(ValidationError): + ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "hi"}], + task="bogus", + ) + + def test_latest_reminder_role_accepted(self): + """`latest_reminder` is a first-class message role on generic param.""" + from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionMessageGenericParam, + ) + + msg = ChatCompletionMessageGenericParam( + role="latest_reminder", content="Be terse." + ) + self.assertEqual(msg.role, "latest_reminder") + + # Full request with reminder before user parses cleanly. + req = ChatCompletionRequest( + model="x", + messages=[ + {"role": "latest_reminder", "content": "Be terse."}, + {"role": "user", "content": "Hi"}, + ], + ) + self.assertEqual(req.messages[0].role, "latest_reminder") + self.assertEqual(req.messages[1].role, "user") + + def test_attach_task_to_last_user_message(self): + """Helper attaches task to the nearest user/developer message.""" + from sglang.srt.entrypoints.openai import encoding_dsv4 + + messages = [{"role": "user", "content": "Hi"}] + encoding_dsv4.attach_task_to_last_user_message(messages, "domain") + self.assertEqual(messages[0]["task"], "domain") + + # Prefers the LAST user message across a multi-turn conversation. + messages = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "second"}, + ] + encoding_dsv4.attach_task_to_last_user_message(messages, "query") + self.assertNotIn("task", messages[0]) + self.assertEqual(messages[2]["task"], "query") + + # `developer` role is treated like `user` (matches encoder semantics). + messages = [{"role": "developer", "content": "dev"}] + encoding_dsv4.attach_task_to_last_user_message(messages, "authority") + self.assertEqual(messages[0]["task"], "authority") + + # No user/developer present -> raises. + with self.assertRaises(ValueError): + encoding_dsv4.attach_task_to_last_user_message( + [{"role": "system", "content": "s"}], "domain" + ) + + def test_dsv4_content_parts_list_normalized(self): + """OpenAI list-of-parts content flattens to text before reaching the encoder.""" + from sglang.srt.entrypoints.openai import encoding_dsv4 + from sglang.srt.parser.jinja_template_utils import ( + process_content_for_template_format, + ) + + req = ChatCompletionRequest( + model="x", + messages=[ + { + "role": "user", + "content": [{"type": "text", "text": "say hi"}], + } + ], + ) + messages = [m.model_dump() for m in req.messages] + # Mirror the boundary normalization _process_messages does for any + # non-None chat_encoding_spec. + for i, msg in enumerate(messages): + if isinstance(msg.get("content"), list): + messages[i] = process_content_for_template_format( + msg, "string", [], [], [], [] + ) + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|User|>say hi", out) + + # Multiple text parts concat with single space; non-text parts dropped. + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "x"}}, + ], + } + ] + for i, msg in enumerate(messages): + if isinstance(msg.get("content"), list): + messages[i] = process_content_for_template_format( + msg, "string", [], [], [], [] + ) + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|User|>describe", out) + self.assertNotIn("image_url", out) + + def test_dsv4_task_and_reminder_encode_end_to_end(self): + """Task + latest_reminder plumb through to the dsv4 encoder correctly.""" + from sglang.srt.entrypoints.openai import encoding_dsv4 + + # 1) task='domain' in chat mode -> `<|domain|>` appended, no Assistant + # prefix (this is a single-shot classification, not a chat turn). + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "What is SGLang?"}], + task="domain", + ) + messages = [m.model_dump() for m in req.messages] + encoding_dsv4.attach_task_to_last_user_message(messages, req.task) + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|domain|>", out) + self.assertTrue(out.rstrip().endswith("<|domain|>")) + self.assertNotIn("<|Assistant|>", out) + + # 2) task='action' in thinking mode -> Assistant + + <|action|> + # (action is the one task that still runs a reasoning pass). + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi"}], + task="action", + ) + messages = [m.model_dump() for m in req.messages] + encoding_dsv4.attach_task_to_last_user_message(messages, req.task) + out = encoding_dsv4.encode_messages(messages, thinking_mode="thinking") + self.assertIn("<|Assistant|>", out) + self.assertIn("", out) + self.assertTrue(out.rstrip().endswith("<|action|>")) + + # 3) latest_reminder preceding user -> reminder renders before user, + # Assistant prefix still comes after user. + req = ChatCompletionRequest( + model="x", + messages=[ + {"role": "latest_reminder", "content": "Be terse."}, + {"role": "user", "content": "Hello"}, + ], + ) + messages = [m.model_dump() for m in req.messages] + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|latest_reminder|>Be terse.", out) + self.assertIn("<|User|>Hello", out) + self.assertLess( + out.index("<|latest_reminder|>"), + out.index("<|User|>"), + ) + self.assertIn("<|Assistant|>", out) if __name__ == "__main__":