diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index aeed0b9c756f..9bcf60cd26bf 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -1419,6 +1419,110 @@ jobs: if: always() run: bash scripts/ci/cuda/ci_cleanup_venv.sh + stage-c-test-dsv4-4-gpu-b200: + needs: [check-changes, call-gate, wait-for-stage-b, sgl-kernel-build-wheels] + if: | + always() && + ( + (inputs.target_stage == 'stage-c-test-dsv4-4-gpu-b200') || + ( + !inputs.target_stage && + ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) && + ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true')) + ) + ) + runs-on: ${{ needs.check-changes.outputs.b200_runner }} + timeout-minutes: 240 + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }} + + - uses: ./.github/actions/check-stage-health + + - uses: ./.github/actions/check-maintenance + + - name: Download artifacts + if: needs.check-changes.outputs.sgl_kernel == 'true' + uses: actions/download-artifact@v6 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-python3.10-cuda* + + - name: Install dependencies + timeout-minutes: 30 + run: | + CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_flash_mla.sh + + - name: Run test + timeout-minutes: 30 + env: + CONTINUE_ON_ERROR_FLAG: ${{ needs.check-changes.outputs.continue_on_error == 'true' && '--continue-on-error' || '' }} + run: | + cd test + python3 run_suite.py --hw cuda --suite stage-c-test-dsv4-4-gpu-b200 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG + + - uses: ./.github/actions/upload-cuda-coredumps + if: failure() + + - name: Cleanup venv + if: always() + run: bash scripts/ci/cuda/ci_cleanup_venv.sh + + stage-c-test-dsv4-8-gpu-h200: + needs: [check-changes, call-gate, wait-for-stage-b, sgl-kernel-build-wheels] + if: | + always() && + ( + (inputs.target_stage == 'stage-c-test-dsv4-8-gpu-h200') || + ( + !inputs.target_stage && + ((github.event_name == 'schedule' || inputs.test_parallel_dispatch == true) || (!failure() && !cancelled())) && + ((needs.check-changes.outputs.main_package == 'true') || (needs.check-changes.outputs.sgl_kernel == 'true')) + ) + ) + runs-on: 8-gpu-h200 + timeout-minutes: 240 + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.pr_head_sha || inputs.git_ref || github.sha }} + + - uses: ./.github/actions/check-stage-health + + - uses: ./.github/actions/check-maintenance + + - name: Download artifacts + if: needs.check-changes.outputs.sgl_kernel == 'true' + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-python3.10-cuda* + + - name: Install dependencies + timeout-minutes: 30 + run: | + CUSTOM_BUILD_SGL_KERNEL=${{needs.check-changes.outputs.sgl_kernel}} bash scripts/ci/cuda/ci_install_flash_mla.sh + + - name: Run test + timeout-minutes: 30 + env: + CONTINUE_ON_ERROR_FLAG: ${{ needs.check-changes.outputs.continue_on_error == 'true' && '--continue-on-error' || '' }} + run: | + cd test + python3 run_suite.py --hw cuda --suite stage-c-test-dsv4-8-gpu-h200 --timeout-per-file 1800 $CONTINUE_ON_ERROR_FLAG + + - uses: ./.github/actions/upload-cuda-coredumps + if: failure() + + - name: Cleanup venv + if: always() + run: bash scripts/ci/cuda/ci_cleanup_venv.sh + # NOTE: GB200 stage temporarily disabled — no company-owned GB200 runner available yet. # Re-enable when a 4-gpu-gb200 runner is provisioned. # stage-c-test-4-gpu-gb200: @@ -1500,6 +1604,8 @@ jobs: stage-c-test-deepep-4-gpu-h100, stage-c-test-deepep-8-gpu-h200, stage-c-test-4-gpu-b200, + stage-c-test-dsv4-4-gpu-b200, + stage-c-test-dsv4-8-gpu-h200, # stage-c-test-4-gpu-gb200, # Temporarily disabled — no GB200 runner ] if: always() diff --git a/docker/Dockerfile b/docker/Dockerfile index fe7c2680a7ef..2e57ed442e20 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -13,6 +13,7 @@ ARG DEEPEP_COMMIT=9af0e0d0e74f3577af1979c9b9e1ac2cad0104ee ARG BUILD_AND_DOWNLOAD_PARALLEL=8 ARG SGL_KERNEL_VERSION=0.4.2.post1 ARG SGL_VERSION +ARG SGL_DEEP_GEMM_VERSION=0.0.1 ARG USE_LATEST_SGLANG=0 ARG GDRCOPY_VERSION=2.5.1 ARG PIP_DEFAULT_INDEX @@ -244,6 +245,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ | xargs -r python3 -m pip uninstall -y && \ python3 -m pip install --index-url https://download.pytorch.org/whl/cu${CUINDEX} \ torch torchvision torchaudio --force-reinstall; \ + python3 -m pip install https://github.com/sgl-project/whl/releases/download/v${SGL_DEEP_GEMM_VERSION}/sgl_deep_gemm-${SGL_DEEP_GEMM_VERSION}+cu129-py3-none-manylinux2014_$(uname -m).whl --force-reinstall; \ fi \ && cd /sgl-workspace \ && rm -rf /tmp/sglang_deps \ @@ -539,12 +541,20 @@ RUN --mount=type=cache,target=/root/.cache/pip \ # the `nixl` import path) but unconditionally requires nixl-cu12, so we install # it with --no-deps and pair it with the matching nixl-cu12 / nixl-cu13 binary # to avoid shipping wrong-CUDA libs on cu13 images. +# The upstream flash-mla packages are required for running deepseek-v4 models RUN --mount=type=cache,target=/root/.cache/pip if [ "${CUDA_VERSION%%.*}" = "12" ]; then \ python3 -m pip install nixl nixl-cu12 --no-deps ; \ python3 -m pip install cuda-python==12.9 ; \ + cd /sgl-workspace && git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla \ + && cd flash-mla && git submodule update --init --recursive \ + && pip install --no-build-isolation -v . ; \ elif [ "${CUDA_VERSION%%.*}" = "13" ]; then \ python3 -m pip install nixl nixl-cu13 --no-deps ; \ python3 -m pip install cuda-python==13.2.0 ; \ + cd /sgl-workspace && git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla \ + && ln -s /usr/local/cuda/include/cccl/cuda /usr/local/cuda/include/cuda \ + && cd flash-mla && git submodule update --init --recursive \ + && pip install --no-build-isolation -v . ; \ fi # Add yank script diff --git a/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx b/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx index 0e01e0a35902..e99e29499792 100644 --- a/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx +++ b/docs_new/cookbook/autoregressive/DeepSeek/DeepSeek-V4.mdx @@ -159,10 +159,6 @@ PD-Disagg recipes on H200 may require `docker run --privileged --ulimit memlock= can discover the IB HCAs; without IB exposure mooncake silently falls back to TCP, which can lead to garbled KV transfer on large checkpoints. -**Base model usage** - -In order to use base models, please enable `SGLANG_FIX_DSV4_BASE_MODEL_LOAD=1` and use latest code, before the next round of testing matrix is finished. - **GB300 PD-Disagg cross-pod MNNVL** On some GB300 clusters with cross-pod KV transfer over NVLink, mooncake may diff --git a/python/pyproject.toml b/python/pyproject.toml index 19eab8a8b837..477b2fe8e1f3 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ dependencies = [ "IPython", "aiohttp", - "apache-tvm-ffi>=0.1.5,<0.2", + "apache-tvm-ffi==0.1.9", "anthropic>=0.20.0", "blobfile==3.0.0", "build", @@ -63,6 +63,7 @@ dependencies = [ "sglang-kernel==0.4.2.post1", "soundfile==0.13.1", "tiktoken", + "tilelang==0.1.8", "timm==1.0.16", "torch_memory_saver>=0.0.9.post1", "torch==2.11.0", diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh new file mode 100644 index 000000000000..3a89e8114ce5 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh @@ -0,0 +1,522 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan128 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) + +struct Compress128DecodeParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const int32_t* __restrict__ load_indices; + /** \brief The following part is plan info. */ + const Plan128* __restrict__ compress_plan; + const Plan128* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos, + const uint32_t lane_id) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 2; + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv | [1] = score | + Storage kv_score[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t window_len, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + const auto element_size = head_dim * 2; + const auto score_offset = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + const InFloat* src; + __builtin_assume(j < 128); + if (j < window_len) { + src = kv_score_buf + j * element_size; + } else { + /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = j - 127; + src = kv_score_src + k * element_size; + } + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, batch_size // decode info + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % kNumSplit; // split id + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: the write must be visible to the subsequent c128_forward, + /// so only the last warp can write to HBM + /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127` + if (warp_id == kNumWarps - 1) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id); + } + if (seq_len % 128 == 0) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, load_indices, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + uint32_t global_id; + if constexpr (kWrite) { + // for write kernel, we use global warp_id to dispatch work + global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads; + } else { + // for compress kernel, we use block id to dispatch work + global_id = blockIdx.x; // block id + } + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const auto indices_ptr = kWrite ? indices : load_indices; + + const int64_t split_offset = global_sid * kTileDim; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + const int32_t index = indices_ptr[global_bid]; + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + PDLWaitPrimary(); + + // only responsible for the compress part + if constexpr (kWrite) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id); + } else { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + template + static constexpr auto prefill_kernel = flash_c128_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWriteBlockSize = 128; + static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + // might be needed for prefill write + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}) // [read_positions] + .with_dtype() + .with_device(device_) + .verify(load_indices); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + + constexpr auto kBlockSize_C = kBlockSize; + constexpr auto kBlockSize_W = kWriteBlockSize; + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c128_online.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_online.cuh new file mode 100644 index 000000000000..b497470606cf --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_online.cuh @@ -0,0 +1,726 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace device::compress { + +/// \brief Plan entry for online compress 128 prefill. +/// Each entry describes a contiguous segment of tokens that lies inside a +/// single 128-chunk. Multiple segments can map to the same batch id when the +/// extend tokens span chunk boundaries. +/// +/// **Layout compatibility:** the field order/types match `PrefillPlan` so that +/// downstream kernels (e.g. `fused_norm_rope` in `CompressExtend` mode) can +/// consume the compress_plan tensor as-if it were a `PrefillPlan` tensor -- +/// they only read `ragged_id` and `position`, both of which carry identical +/// semantics here (the LAST token of the segment in q-ragged and global +/// coordinates respectively). +/// +/// Note that `window_len` here means "number of real tokens in this segment" +/// (1..128), which differs from `PrefillPlan::window_len`. Downstream kernels +/// that share the tensor MUST NOT read it under that name. +struct alignas(16) OnlinePrefillPlan { + /// \brief Ragged-q position of the LAST token in this segment. + /// Equal to `segment_start_ragged + window_len - 1`. + uint32_t ragged_id; + /// \brief Index into the `indices` / `load_indices` arrays. + uint32_t batch_id; + /// \brief Global position of the LAST token in this segment. + /// For compress plans, `position % 128 == 127` (chunk-closing); for write + /// plans, `position % 128 < 127`. + uint32_t position; + /// \brief Number of real tokens in this segment (1..128). + /// The first segment token sits at `position - window_len + 1` (global) and + /// at `ragged_id - window_len + 1` (ragged). + uint32_t window_len; +}; + +static_assert(alignof(OnlinePrefillPlan) == alignof(PrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == sizeof(PrefillPlan)); + +} // namespace device::compress + +namespace host::compress { + +using device::compress::OnlinePrefillPlan; +using OnlinePrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kOnlinePrefillPlanDim = 16; + +static_assert(alignof(OnlinePrefillPlan) == sizeof(OnlinePrefillPlan)); +static_assert(sizeof(OnlinePrefillPlan) == kOnlinePrefillPlanDim * sizeof(OnlinePrefillPlanTensorDtype)); + +} // namespace host::compress + +namespace { + +using OnlinePlan = device::compress::OnlinePrefillPlan; +using IndiceT = int32_t; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlineDecodeParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +/// \brief Need to reduce register usage to increase occupancy +struct Compress128OnlinePrefillParams { + /** \brief Shape: `[num_indices, 1, head_dim * 3 (max, sum, kv) ]` \n */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ load_indices; + /// \brief Plan for segments that close a chunk (write to `kv_compressed_output`). + /// Shape: `[num_compress, 16]` (uint8). + const OnlinePlan* __restrict__ compress_plan; + /// \brief Plan for the trailing partial segment of each batch (write back to + /// `kv_score_buffer`). Shape: `[num_write, 16]` (uint8). + const OnlinePlan* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +// 4 elements per thread, kHeadDim / 4 threads per block +template +__global__ void flash_c128_online_decode(const __grid_constant__ Compress128OnlineDecodeParams params) { + using namespace device; + constexpr uint32_t kVecSize = 4; + constexpr uint32_t kBlockSize = kHeadDim / kVecSize; + using Vec = AlignedVector; + const auto gmem = tile::Memory::cta(kBlockSize); + const auto batch_id = blockIdx.x; + const auto index = params.indices[batch_id]; + const auto seq_len = params.seq_lens[batch_id]; + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kHeadDim * 3); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_src = kv_score_input + batch_id * (kHeadDim * 2); + + /// NOTE: kv_score_buffer layout is [max, sum, kv] (slot 0 / 1 / 2). Reads, + /// writes, and the prefill kernel must all agree on this order. + const auto max_score_vec = gmem.load(kv_buf, 0); + const auto sum_score_vec = gmem.load(kv_buf, 1); + const auto old_kv_vec = gmem.load(kv_buf, 2); + + /// NOTE: kv_score_input layout is | kv | score | (head_dim each), matching + /// the offline c128 kernel and the online prefill kernel. + const auto new_kv_vec = gmem.load(kv_src, 0); + const auto new_score_raw_vec = gmem.load(kv_src, 1); + + /// NOTE: the new token sits at global position `seq_len - 1`, so its + /// position inside the 128-chunk is `(seq_len - 1) % 128`. The previous + /// `seq_len % 128` was off by one (`bias[127]` vs `bias[0]`, etc.). + const auto pos_in_chunk = (seq_len - 1) % 128; + const auto bias_vec = gmem.load(params.score_bias, pos_in_chunk); + + Vec out_kv_vec; + Vec out_max_vec; + Vec out_sum_vec; + if (pos_in_chunk != 0) { + // Mid-chunk: combine prior partial state with the new token via online softmax. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto old_max = max_score_vec[i]; + const auto old_kv = old_kv_vec[i]; + const auto new_score = new_score_raw_vec[i] + bias_vec[i]; + const auto new_kv = new_kv_vec[i]; + const auto new_max = fmax(old_max, new_score); + const auto old_sum = sum_score_vec[i] * expf(old_max - new_max); + const auto new_exp = expf(new_score - new_max); + const auto new_sum = old_sum + new_exp; + out_kv_vec[i] = (old_kv * old_sum + new_kv * new_exp) / new_sum; + out_max_vec[i] = new_max; + out_sum_vec[i] = new_sum; + } + } else { + // First token of a new 128-chunk: initialize state with this token alone. +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + out_kv_vec[i] = new_kv_vec[i]; + out_max_vec[i] = new_score_raw_vec[i] + bias_vec[i]; + out_sum_vec[i] = 1.0f; // exp(score - max) with max == score + } + } + + if (pos_in_chunk == 127) { + // Chunk just closed: emit the compressed kv. No need to update the buffer + // -- the next chunk's first token will overwrite it. + const auto kv_out = static_cast(params.kv_compressed_output) + batch_id * kHeadDim; + gmem.store(kv_out, out_kv_vec); + } else { + // Otherwise persist the running [max, sum, kv] state for the next step. + gmem.store(kv_buf, out_max_vec, 0); + gmem.store(kv_buf, out_sum_vec, 1); + gmem.store(kv_buf, out_kv_vec, 2); + } +} + +constexpr int32_t kTileElements = 2; // split (along head-dim) +/// \brief Each warp will handle this many elements (split along softmax-128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kPrefillBlockSize = device::kWarpThreads * kNumWarps; +using PrefillStorage = device::AlignedVector; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_prefill_forward( + const PrefillStorage (&kv)[kElementsPerWarp], + const PrefillStorage (&score)[kElementsPerWarp], + float* kv_out, + float* max_out, + float* sum_out, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = score[j][i]; + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kPrefillBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kPrefillBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = global_product; + if constexpr (kNeedData) { + max_out[local_elem_id] = global_val_max; + sum_out[local_elem_id] = global_exp_sum; + } + } + if constexpr (kNeedData) __syncthreads(); +} + +/// \brief Sentinel score for padded positions in a 128-segment. +/// Must be finite so that `score - max` never produces NaN even when an +/// entire warp has only padded positions. +constexpr float kPadScore = -FLT_MAX; + +/// \brief Online compress 128 prefill. Two passes share this body: +/// - `kWrite=false` (compress pass): handles segments that close a chunk. +/// May load prior partial state from the buffer, but never writes to it, +/// so concurrent blocks can read the same slot without racing. +/// - `kWrite=true` (write pass): handles the trailing partial segment of each +/// batch. Each batch contributes at most one such plan, so concurrent blocks +/// touch disjoint buffer slots. +/// +/// The two passes MUST run as separate kernel launches (in stream order) so +/// that all reads in pass 1 finish before any writes in pass 2 start. +template +__global__ __launch_bounds__(kPrefillBlockSize, 2) // + void flash_c128_online_prefill(const __grid_constant__ Compress128OnlinePrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + /// NOTE: the compiler folds the if-else at compile time. + const auto num_plans = kWrite ? params.num_write : params.num_compress; + const auto plan_ptr = kWrite ? params.write_plan : params.compress_plan; + const uint32_t global_id = blockIdx.x; + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + if (global_pid >= num_plans) return; + const auto [ragged_id, batch_id, position, window_len] = plan_ptr[global_pid]; + if (ragged_id == 0xFFFFFFFFu) [[unlikely]] + return; + + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + const int32_t split_offset = global_sid * kTileDim; // int32 is enough + + const auto kv_score_buffer = static_cast(params.kv_score_buffer); + const auto kv_score_input = static_cast(params.kv_score_input); + const auto kv_compressed_output = static_cast(params.kv_compressed_output); + const auto score_bias_base = static_cast(params.score_bias); + + constexpr int64_t kElementSize = kHeadDim * 2; // | kv | score | + const uint32_t chunk_offset = (position % 128u) + 1u - window_len; + const uint32_t window_end = chunk_offset + window_len; // exclusive, in [1, 128] + const int32_t segment_start = ragged_id - (position % 128u); // can be negative, but safe + const int32_t load_index = chunk_offset != 0 ? params.load_indices[batch_id] : -1; + const int32_t store_index = kWrite ? params.indices[batch_id] : -1; + + PDLWaitPrimary(); + + // 2 * 8 = 16 register per elem. in theory we should consume 48 register here + PrefillStorage kv[kElementsPerWarp]; + PrefillStorage score[kElementsPerWarp]; + PrefillStorage bias[kElementsPerWarp]; + const auto warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + if (j >= chunk_offset && j < window_end) { + const auto kv_src_ptr = kv_score_input + (segment_start + j) * kElementSize + split_offset; + const auto score_src_ptr = kv_src_ptr + kHeadDim; + const auto bias_src_ptr = score_bias_base + j * kHeadDim + split_offset; + kv[i].load(kv_src_ptr, lane_id); + score[i].load(score_src_ptr, lane_id); + bias[i].load(bias_src_ptr, lane_id); + } + } + +#pragma unroll + for (uint32_t i = 0; i < kElementsPerWarp; ++i) { + const uint32_t j = i + warp_offset; + const bool is_valid = (j >= chunk_offset && j < window_end); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + score[i][ii] = is_valid ? score[i][ii] + bias[i][ii] : kPadScore; + /// NOTE: must zero out kv on padded slots -- `c128_prefill_forward` + /// computes `kv * exp_score` where `exp_score = expf(-FLT_MAX - max) ??? 0`, + /// and IEEE-754 makes `NaN * 0 = NaN` / `+-inf * 0 = NaN`. An + /// uninitialized register can hold a NaN/inf bit pattern, so without + /// this reset a single padded warp can poison the whole softmax. + kv[i][ii] = is_valid ? kv[i][ii] : 0.0f; + } + } + + __shared__ alignas(16) float seg_kv[kTileDim]; + __shared__ alignas(16) float seg_max[kTileDim]; + __shared__ alignas(16) float seg_sum[kTileDim]; + + c128_prefill_forward(kv, score, seg_kv, seg_max, seg_sum, warp_id, lane_id); + + PDLTriggerSecondary(); + + if (warp_id == 0) { + PrefillStorage out_kv_vec, out_max_vec, out_sum_vec; + out_kv_vec.load(seg_kv, lane_id); + out_max_vec.load(seg_max, lane_id); + out_sum_vec.load(seg_sum, lane_id); + if (chunk_offset != 0) { + /// NOTE: load (max, sum, kv) of the in-progress chunk for this index. + /// `load_indices` may differ from `indices` when the prior partial state + /// lives on a different slot than the slot we ultimately write to. + const auto buf_load = kv_score_buffer + load_index * (kHeadDim * 3) + split_offset; + PrefillStorage buf_max_vec, buf_sum_vec, buf_kv_vec; + buf_max_vec.load(buf_load + 0 * kHeadDim, lane_id); + buf_sum_vec.load(buf_load + 1 * kHeadDim, lane_id); + buf_kv_vec.load(buf_load + 2 * kHeadDim, lane_id); +#pragma unroll + for (uint32_t ii = 0; ii < kTileElements; ++ii) { + const float m1 = buf_max_vec[ii]; + const float s1 = buf_sum_vec[ii]; + const float k1 = buf_kv_vec[ii]; + const float m2 = out_max_vec[ii]; + const float s2 = out_sum_vec[ii]; + const float k2 = out_kv_vec[ii]; + const float new_max = fmaxf(m1, m2); + const float new_s1 = s1 * expf(m1 - new_max); + const float new_s2 = s2 * expf(m2 - new_max); + const float new_sum = new_s1 + new_s2; + const float new_kv = (k1 * new_s1 + k2 * new_s2) / new_sum; + out_max_vec[ii] = new_max; + out_sum_vec[ii] = new_sum; + out_kv_vec[ii] = new_kv; + } + } + + if constexpr (kWrite) { + const auto buf_store = kv_score_buffer + store_index * (kHeadDim * 3) + split_offset; + reinterpret_cast(buf_store + 0 * kHeadDim)[lane_id] = out_max_vec; + reinterpret_cast(buf_store + 1 * kHeadDim)[lane_id] = out_sum_vec; + reinterpret_cast(buf_store + 2 * kHeadDim)[lane_id] = out_kv_vec; + } else { + const auto out_ptr = kv_compressed_output + ragged_id * kHeadDim + split_offset; + reinterpret_cast(out_ptr)[lane_id] = out_kv_vec; + } + } +} + +template +struct FlashCompress128OnlineKernel { + static constexpr auto decode_kernel = flash_c128_online_decode; + template + static constexpr auto prefill_kernel = flash_c128_online_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kDecodeBlockSize = kHeadDim / 4; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}).with_dtype().with_device(device).verify(indices); + TensorMatcher({B}).with_dtype().with_device(device).verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128OnlineDecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + LaunchKernel(batch_size, kDecodeBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + using host::compress::kOnlinePrefillPlanDim; + using host::compress::OnlinePrefillPlanTensorDtype; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 1, kHeadDim * 3}) // kv score buffer (max, sum, kv) ??? 2D + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, kOnlinePrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, kOnlinePrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + /// NOTE: `extra` is `load_indices`. When the previous partial state lives + /// on a slot different from the destination slot (e.g. paged buffers), the + /// caller must supply this; otherwise it defaults to `indices`. + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}).with_dtype().with_device(device_).verify(load_indices); + + const auto device = device_.unwrap(); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128OnlinePrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + + /// NOTE: pass 1 reads the buffer (for the first segment of each batch + /// that started mid-chunk) and writes only to `kv_compressed_output`. + /// Pass 2 then writes the trailing partial state of each batch back to + /// the buffer. Stream serialization between the two launches enforces + /// read-before-write on shared buffer slots. + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = num_w * kNumSplit) { + LaunchKernel(num_w_blocks, kPrefillBlockSize, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace + +namespace host::compress { + +using OnlinePlanResult = tvm::ffi::Tuple; + +struct OnlinePrefillCompressParams { + OnlinePrefillPlan* __restrict__ compress_plan; + OnlinePrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; +}; + +/// \brief Build the compress + write plans for online compress 128 prefill. +/// +/// Each batch's `[prefix_len, prefix_len + extend_len)` range is split at +/// 128-aligned boundaries. Every resulting segment falls into one of: +/// - **compress**: closes a 128-chunk (`chunk_offset + window_len == 128`). +/// These plans only read the buffer (when starting mid-chunk) and write the +/// compressed kv to `kv_compressed_output`. +/// - **write**: trailing partial of the batch (`chunk_offset + window_len < 128`). +/// May read the buffer and always writes the new partial state back to it. +/// Each batch produces at most one such plan. +/// +/// The two plans MUST be dispatched as separate kernel launches in stream +/// order so that pass-1 reads of a buffer slot complete before any pass-2 +/// write of the same slot. +inline OnlinePlanResult plan_online_prefill_host(const OnlinePrefillCompressParams& params, const bool use_cuda_graph) { + const auto& [compress_plan, write_plan, seq_lens, extend_lens, batch_size, num_tokens] = params; + + uint32_t counter = 0; + uint32_t compress_count = 0; + uint32_t write_count = 0; + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = static_cast(seq_lens[i]); + const uint32_t extend_len = static_cast(extend_lens[i]); + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t end_pos = prefix_len + extend_len; + /// NOTE: split the extend range into per-128-chunk segments. Each segment + /// stays inside one chunk, so the kernel can decide load/store from + /// `chunk_offset` and `window_len` alone. + uint32_t pos = prefix_len; + while (pos < end_pos) { + const uint32_t chunk_start = (pos / 128u) * 128u; + const uint32_t seg_end = std::min(end_pos, chunk_start + 128u); // exclusive + const uint32_t seg_len = seg_end - pos; + const uint32_t chunk_off = pos - chunk_start; + /// NOTE: store last-token coordinates so that downstream consumers + /// (e.g. `fused_norm_rope`) can read `ragged_id` and `position` with the + /// same semantics as `PrefillPlan`. The segment start is recoverable as + /// `ragged_id - window_len + 1` and `position - window_len + 1`. + const uint32_t last_pos = seg_end - 1; + const uint32_t last_ragged = counter + (last_pos - prefix_len); + const auto plan = OnlinePrefillPlan{ + .ragged_id = last_ragged, + .batch_id = i, + .position = last_pos, + .window_len = seg_len, + }; + if (chunk_off + seg_len == 128u) { + // full chunk, must be complete, maybe read the buffer, no write + RuntimeCheck(compress_count < num_tokens); + compress_plan[compress_count++] = plan; + } else { + // last chunk, must be incomplete, maybe read the buffer, must write + RuntimeCheck(write_count < num_tokens); + write_plan[write_count++] = plan; + } + pos = seg_end; + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return OnlinePlanResult{compress_count, write_count}; + /// NOTE: pad both plans with sentinel entries so cuda-graph runs always see + /// the same number of blocks. The kernel skips plans whose `ragged_id` is -1. + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = OnlinePrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_count, num_tokens)) { + compress_plan[i] = kInvalidPlan; + } + for (const auto i : irange(write_count, num_tokens)) { + write_plan[i] = kInvalidPlan; + } + return OnlinePlanResult{num_tokens, num_tokens}; +} + +inline OnlinePlanResult plan_online_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + /// NOTE: only host (CPU/cuda-host) planning is implemented for now. The + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kOnlinePrefillPlanDim}) // + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + const auto params = OnlinePrefillCompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + }; + return plan_online_prefill_host(params, use_cuda_graph); +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_online_prefill = host::compress::plan_online_prefill; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c128_v2.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_v2.cuh new file mode 100644 index 000000000000..498f2601eeaa --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c128_v2.cuh @@ -0,0 +1,543 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan128 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) + +struct Compress128DecodeParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const int32_t* __restrict__ load_indices; + /** \brief The following part is plan info. */ + + const Plan128* __restrict__ compress_plan; + const Plan128* __restrict__ write_plan; + + uint32_t num_compress; + uint32_t num_write; + + uint32_t num_q_tokens; + uint32_t batch_size; + uint32_t num_indices; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos, + const uint32_t lane_id) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 2; + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv | [1] = score | + Storage kv_score[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t window_len, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + const auto element_size = head_dim * 2; + const auto score_offset = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + const InFloat* src; + __builtin_assume(j < 128); + if (j < window_len) { + src = kv_score_buf + j * element_size; + } else { + /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = j - 127; + src = kv_score_src + k * element_size; + } + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, batch_size // decode info + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % kNumSplit; // split id + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: the write must be visible to the subsequent c128_forward, + /// so only the last warp can write to HBM + /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127` + if (warp_id == kNumWarps - 1) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id); + } + if (seq_len % 128 == 0) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, load_indices, compress_plan, write_plan, num_compress, num_write, // prefill plan + _num_q_tokens, _batch_size, _num_indices + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + uint32_t global_id; + if constexpr (kWrite) { + // for write kernel, we use global warp_id to dispatch work + global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads; + } else { + // for compress kernel, we use block id to dispatch work + global_id = blockIdx.x; // block id + } + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const auto indices_ptr = kWrite ? indices : load_indices; + + const int64_t split_offset = global_sid * kTileDim; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + if (ragged_id >= _num_q_tokens) [[unlikely]] + return; + if (global_bid >= _batch_size) [[unlikely]] + return; + + const int32_t index = indices_ptr[global_bid]; + + if (index < 0 || static_cast(index) >= _num_indices) [[unlikely]] + return; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + PDLWaitPrimary(); + + // only responsible for the compress part + if constexpr (kWrite) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id); + } else { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + template + static constexpr auto prefill_kernel = flash_c128_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWriteBlockSize = 128; + static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto K = SymbolicSize{"num_indices"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({K, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + // might be needed for prefill write + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}) // [read_positions] + .with_dtype() + .with_device(device_) + .verify(load_indices); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto num_indices = static_cast(K.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + .num_q_tokens = num_q_tokens, + .batch_size = batch_size, + .num_indices = num_indices, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + + constexpr auto kBlockSize_C = kBlockSize; + constexpr auto kBlockSize_W = kWriteBlockSize; + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh new file mode 100644 index 000000000000..145ab1fb081e --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh @@ -0,0 +1,549 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan4 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int kTileElements = 4; + +/// \brief Need to improve register usage to reduce latency +#define C4_KERNEL __global__ __launch_bounds__(128, 4) + +enum class PageMode { + RingBuffer = 8, + Page4Align = 4, +}; + +struct alignas(16) C4IndexBundle { + int32_t load_first_page; + int32_t load_second_page; + int32_t write_first_page; + int32_t last_position; +}; + +struct Compress4DecodeParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \brief Shape: `[batch_size, 1]` */ + const int32_t* __restrict__ extra; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress4PrefillParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, 4]` */ + const C4IndexBundle* __restrict__ extra; + /** \brief The following part is plan info. */ + + const Plan4* __restrict__ compress_plan; + const Plan4* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +template +SGL_DEVICE void c4_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 4; + const auto gmem = tile::Memory::warp(); + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv overlap | [1] = kv | [2] = score overlap | [3] = score | + Storage kv_score[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c4_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t seq_len, + const int32_t window_len, + [[maybe_unused]] const InFloat* kv_score_overlap_buf = nullptr) { + using namespace device; + + const auto element_size = head_dim * 4; + const auto score_offset = head_dim * 2; + const auto overlap_stride = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory::warp(); + StorageIn kv[8]; + StorageIn score[8]; + StorageIn bias[8]; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + bias[i] = gmem_in.load(score_bias + i * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const bool is_overlap = i < 4; + const InFloat* src; + if (i < window_len) { + /// NOTE: `seq_len` must be a multiple of 4 here + if constexpr (kPaged) { + const auto kv_score_ptr = is_overlap ? kv_score_overlap_buf : kv_score_buf; + const int32_t k = i % 4; + src = kv_score_ptr + k * element_size; + } else { + const int32_t k = (seq_len + i) % 8; + src = kv_score_buf + k * element_size; + } + } else { + /// NOTE: k in [-7, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = i - 7; + src = kv_score_src + k * element_size; + } + src += (is_overlap ? 0 : overlap_stride); + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + if (seq_len == 4) { + [[unlikely]]; + constexpr float kFloatNegInf = -1e9f; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv[i].fill(cast(0.0f)); + score[i].fill(cast(kFloatNegInf)); + } + } + + /// NOTE: part 2: safe online softmax + weighted sum + using StorageOut = AlignedVector; + const auto gmem_out = tile::Memory::warp(); + StorageOut result; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[8]; + +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + result[i] = cast(sum_product / sum_exp_value); + } + + gmem_out.store(kv_out, result); +} + +template +C4_KERNEL void flash_c4_decode(const __grid_constant__ Compress4DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, extra, batch_size // decode info + ] = params; + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_bid = global_wid / kNumSplit; // batch id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + page_size - 1` + if constexpr (kMode == PageMode::Page4Align) { + const auto index_prev = extra[global_bid]; + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 3) % 4); + if (seq_len % 4 == 0) { + const auto kv_overlap = kv_buf + (index_prev - index) * (kElementSize * 4); + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, 8, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 7) % 8); + if (seq_len % 4 == 0) { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, /*window_size=*/8); + } + } + + PDLTriggerSecondary(); +} + +template +C4_KERNEL void flash_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, extra, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / kNumSplit; // plan id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + const auto seq_len = position + 1; + const int32_t index = indices[global_bid]; + + PDLWaitPrimary(); + + if constexpr (kMode == PageMode::Page4Align) { + const auto write_second_page = index; + const auto [load_first_page, load_second_page, write_first_page, last_pos] = extra[global_bid]; + if constexpr (kWrite) { + int32_t index; + if (position < static_cast(last_pos)) { + index = write_first_page; + } else { + index = write_second_page; + } + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 4); + } else { + int32_t index_overlap, index_normal; + if (window_len <= 4) { + index_overlap = load_second_page; + index_normal = load_second_page; // not used + } else { + index_overlap = load_first_page; + index_normal = load_second_page; + } + const auto kv_buf = kv_score_buffer + index_normal * (kElementSize * 4) + split_offset; + const auto kv_overlap = kv_score_buffer + index_overlap * (kElementSize * 4) + split_offset; + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + if constexpr (kWrite) { + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 8); + } else { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len); + } + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress4Kernel { + template + static constexpr auto decode_kernel = flash_c4_decode; + template + static constexpr auto prefill_kernel = flash_c4_prefill; + template + static constexpr auto prefill_c_kernel = prefill_kernel; + template + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr uint32_t kBlockSize = 128; + static constexpr uint32_t kTileDim = kTileElements * device::kWarpThreads; + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWarpsPerBlock = kBlockSize / device::kWarpThreads; + + using Self = FlashCompress4Kernel; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional extra) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device_) + .verify(seq_lens); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress4DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extra = static_cast(extra_ptr), + .batch_size = batch_size, + }; + const auto kernel = extra_ptr != nullptr ? decode_kernel // + : decode_kernel; + const uint32_t num_blocks = div_ceil(batch_size * kNumSplit, kWarpsPerBlock); + LaunchKernel(num_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra, /*is_prefill=*/true); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress4PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .extra = static_cast(extra_ptr), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + if (const auto num_c_blocks = div_ceil(num_c * kNumSplit, kWarpsPerBlock)) { + const auto c_kernel = extra_ptr != nullptr ? prefill_c_kernel // + : prefill_c_kernel; + LaunchKernel(num_c_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerBlock)) { + const auto w_kernel = extra_ptr != nullptr ? prefill_w_kernel // + : prefill_w_kernel; + LaunchKernel(num_w_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(w_kernel, params); + } + } + + // some auxiliary functions + private: + static const void* _get_extra_pointer( + host::SymbolicSize& B, // batch_size + host::SymbolicDevice& device, + const tvm::ffi::Optional& extra, + bool is_prefill = false) { + // only have value when using page-aligned mode + if (!extra.has_value()) return nullptr; + const auto& extra_tensor = extra.value(); + /// NOTE: the metadata layout is different for prefill and decode: + /// for prefill, last 4 are: + /// load overlap | load normal | write overlap | last written page + /// for decode, last 1 is the write (also load) overlap + host::TensorMatcher({B, is_prefill ? 4 : 1}) // extra tensor + .with_dtype() + .with_device(device) + .verify(extra_tensor); + const auto data_ptr = extra_tensor.data_ptr(); + host::RuntimeCheck(data_ptr != nullptr, "extra tensor data ptr is null"); + if (is_prefill) { + static_assert(alignof(C4IndexBundle) == 16); + host::RuntimeCheck(std::bit_cast(data_ptr) % 16 == 0, "extra tensor is not properly aligned"); + } + return data_ptr; + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh new file mode 100644 index 000000000000..46acaa9c46b3 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh @@ -0,0 +1,208 @@ +#include +#include + +#include + +#include + +namespace host::compress { + +using PlanResult = tvm::ffi::Tuple; + +struct CompressParams { + PrefillPlan* __restrict__ compress_plan; + PrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; + uint32_t compress_ratio; + bool is_overlap; +}; + +inline constexpr uint32_t kBlockSize = 1024; + +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) inline + +PLAN_KERNEL void plan_prefill_cuda(const __grid_constant__ CompressParams params) { + const auto &[ + compress_plan, write_plan, seq_lens, extend_lens, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + __shared__ uint32_t compress_counter; + __shared__ uint32_t write_counter; + + uint32_t batch_id = 0; + uint32_t counter = 0; + uint32_t extend_len = extend_lens[0]; + + const auto tid = threadIdx.x; + if (tid == 0) { + compress_counter = 0; + write_counter = 0; + } + __syncthreads(); + + for (uint32_t i = tid; i < num_tokens; i += blockDim.x) { + const uint32_t ragged_id = i; + uint32_t j = ragged_id - counter; + while (j >= extend_len) { + j -= extend_len; + batch_id += 1; + if (batch_id >= batch_size) [[unlikely]] + break; + counter += extend_len; + extend_len = extend_lens[batch_id]; + } + if (batch_id >= batch_size) [[unlikely]] + break; + const uint32_t seq_len = seq_lens[batch_id]; + const uint32_t extend_len = extend_lens[batch_id]; + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t ratio = compress_ratio * (1 + is_overlap); + const uint32_t window_len = j + 1 < ratio ? ratio - (j + 1) : 0; + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = ragged_id, + .batch_id = batch_id, + .position = position, + .window_len = window_len, + }; + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + if ((position + 1) % compress_ratio == 0) { + const auto write_pos = atomicAdd(&compress_counter, 1); + compress_plan[write_pos] = plan; + } + if (position >= start_write_pos) { + const auto write_pos = atomicAdd(&write_counter, 1); + write_plan[write_pos] = plan; + } + } + __syncthreads(); + constexpr auto kInvalid = static_cast(-1); + const auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + const auto compress_count = compress_counter; + const auto write_count = write_counter; + for (uint32_t i = compress_count + tid; i < num_tokens; i += blockDim.x) { + compress_plan[i] = kInvalidPlan; + } + for (uint32_t i = write_count + tid; i < num_tokens; i += blockDim.x) { + write_plan[i] = kInvalidPlan; + } +} + +inline PlanResult plan_prefill_host(const CompressParams& params, const bool use_cuda_graph) { + const auto &[ + compress_ptr, write_ptr, seq_lens_ptr, extend_lens_ptr, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + uint32_t counter = 0; + uint32_t compress_counter = 0; + uint32_t write_counter = 0; + const auto ratio = compress_ratio * (1 + is_overlap); + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = seq_lens_ptr[i]; + const uint32_t extend_len = extend_lens_ptr[i]; + const uint32_t prefix_len = seq_len - extend_len; + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + /// NOTE: `start_write_pos` must be a multiple of `compress_ratio` + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + /// NOTE: to avoid unsigned integer underflow, don't use `pos - compress_ratio` + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + /// NOTE: `position` is within [prefix_len, seq_len) + for (const auto j : irange(extend_len)) { + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = counter + j, + .batch_id = i, + .position = position, + .window_len = ratio - std::min(j + 1, ratio), + }; + RuntimeCheck(plan.is_valid(compress_ratio, is_overlap), "Internal error!"); + if ((position + 1) % compress_ratio == 0) { + compress_ptr[compress_counter++] = plan; + } + if (position >= start_write_pos) { + write_ptr[write_counter++] = plan; + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return PlanResult{compress_counter, write_counter}; + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_counter, num_tokens)) { + compress_ptr[i] = kInvalidPlan; + } + for (const auto i : irange(write_counter, num_tokens)) { + write_ptr[i] = kInvalidPlan; + } + return PlanResult{num_tokens, num_tokens}; +} + +inline PlanResult plan_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const uint32_t compress_ratio, + const bool is_overlap, // for overlap transform, we have to keep 1 more extra window + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + const bool is_cuda = [&] { + if (extend_lens.device().device_type == kDLCUDA) { + device.set_options(); + return true; + } else { + device.set_options(); + return false; + } + }(); + TensorMatcher({N}) // extend_lens and seq_lens + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kPrefillPlanDim}) // compress_plan and write_plan + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + + const auto params = CompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + .compress_ratio = compress_ratio, + .is_overlap = is_overlap, + }; + + if (!is_cuda) return plan_prefill_host(params, use_cuda_graph); + /// NOTE: cuda kernel plan is naturally compatible with cuda graph + LaunchKernel(1, kBlockSize, device.unwrap())(plan_prefill_cuda, params); + return PlanResult{params.num_tokens, params.num_tokens}; +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_prefill = host::compress::plan_prefill; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh new file mode 100644 index 000000000000..d3953578b925 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh @@ -0,0 +1,254 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace { + +using Plan = device::compress::PrefillPlan; + +/// \brief common block size for memory-bound kernel +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedNormRopeParams { + void* __restrict__ input; + const void* __restrict__ weight; + float eps; + uint32_t num_works; + const void* __restrict__ handle; + const float* __restrict__ freqs_cis; + uint32_t compress_ratio; +}; + +enum class ForwardMode { + CompressExtend = 0, + CompressDecode = 1, + DefaultForward = 2, +}; + +template +__global__ void fused_norm_rope(const __grid_constant__ FusedNormRopeParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kMaxVecSize = 16 / sizeof(DType); + constexpr int64_t kVecSize = std::min(kMaxVecSize, kHeadDim / kWarpThreads); + constexpr int64_t kLocalSize = kHeadDim / (kWarpThreads * kVecSize); + constexpr int64_t kRopeVecSize = kRopeDim / (kWarpThreads * 2); + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + static_assert(kLocalSize * kVecSize * kWarpThreads == kHeadDim); + static_assert(kRopeDim % (kWarpThreads * 2) == 0); + static_assert(kRopeDim % (kVecSize * kLocalSize) == 0); + static_assert(kRopeSize <= kWarpThreads); + static_assert(kRopeVecSize == 1, "only support rope dim = 64"); + + const auto& [ + _input, _weight, eps, num_works, // norm + handle, freqs_cis, compress_ratio // rope + ] = params; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + + if (work_id >= num_works) return; + + DType* input; + int32_t position; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(handle)[work_id]; + input = static_cast(_input) + plan.ragged_id * kHeadDim; + position = plan.position + 1 - compress_ratio; + if (plan.ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + } else if constexpr (kMode == CompressDecode) { + input = static_cast(_input) + work_id * kHeadDim; + const auto seq_len = static_cast(handle)[work_id]; + if (seq_len % compress_ratio != 0) return; + position = seq_len - compress_ratio; + } else if constexpr (kMode == DefaultForward) { + input = static_cast(_input) + work_id * kHeadDim; + position = static_cast(handle)[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + + using Storage = AlignedVector; + __shared__ Storage s_rope_input[kNumWarps][kRopeSize]; + + // prefetch freq + const auto mem_freq = tile::Memory::warp(); + const auto freq = mem_freq.load(freqs_cis + position * kRopeDim); + + PDLWaitPrimary(); + + // part 1: norm + { + const auto gmem = tile::Memory::warp(); + Storage input_vec[kLocalSize]; + Storage weight_vec[kLocalSize]; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + input_vec[i] = gmem.load(input, i); + } + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + weight_vec[i] = gmem.load(_weight, i); + } + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + sum_of_squares += fp32_input * fp32_input; + } + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + eps); + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + const auto fp32_weight = cast(weight_vec[i][j]); + input_vec[i][j] = cast(fp32_input * norm_factor * fp32_weight); + } + } + + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + if (i == kLocalSize - 1 && is_rope_lane) { + const auto rope_id = lane_id - (kWarpThreads - kRopeSize); + s_rope_input[warp_id][rope_id] = input_vec[i]; + } else { + gmem.store(input, input_vec[i], i); + } + } + + __syncwarp(); + } + + // part 2: rope + { + // mem elem = DType x 2 + using DTypex2_t = packed_t; + const auto mem_elem = tile::Memory::warp(); + const auto elem = mem_elem.load(s_rope_input[warp_id]); + const auto [x_real, x_imag] = cast(elem); + const auto [freq_real, freq_imag] = freq; + const fp32x2_t output = { + x_real * freq_real - x_imag * freq_imag, + x_real * freq_imag + x_imag * freq_real, + }; + mem_elem.store(input + (kHeadDim - kRopeDim), cast(output)); + } + + PDLTriggerSecondary(); +} + +template +struct FusedNormRopeKernel { + template + static constexpr auto fused_kernel = fused_norm_rope; + + static void forward( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView handle, + const tvm::ffi::TensorView freqs_cis, + int32_t _mode, + float eps, + uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + const auto mode = static_cast(_mode); + + auto B = SymbolicSize{"num_q_tokens"}; + auto N = SymbolicSize{"num_compress_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, kHeadDim}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({kHeadDim}) // weight + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({-1, kRopeDim}) // freqs_cis + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + switch (mode) { + case CompressExtend: + TensorMatcher({N, compress::kPrefillPlanDim}) // plan + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case CompressDecode: + TensorMatcher({N}) // seq_len + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case DefaultForward: + TensorMatcher({N}) // position + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio == 0); + break; + default: + Panic("unsupported forward mode: ", static_cast(mode)); + } + + // launch kernel + const auto num_compress_tokens = static_cast(N.unwrap()); + if (num_compress_tokens == 0) return; + const auto params = FusedNormRopeParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .eps = eps, + .num_works = num_compress_tokens, + .handle = handle.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .compress_ratio = compress_ratio, + }; + const auto num_blocks = div_ceil(num_compress_tokens, kNumWarps); + using KernelType = std::decay_t)>; + static constexpr KernelType kernel_table[3] = { + [static_cast(CompressExtend)] = fused_kernel, + [static_cast(CompressDecode)] = fused_kernel, + [static_cast(DefaultForward)] = fused_kernel, + }; + const auto kernel = kernel_table[static_cast(mode)]; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh new file mode 100644 index 000000000000..90dec3c1178d --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh @@ -0,0 +1,214 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +[[maybe_unused]] +SGL_DEVICE float act_sqrt_softplus(float x) { + const float softplus = fmaxf(x, 0.0f) + log1pf(expf(-fabsf(x))); + return sqrtf(softplus); +} + +struct MoEHashTopKParams { + const float* __restrict__ router_logits; + const int64_t* __restrict__ input_id; + const int32_t* __restrict__ tid2eid; + int32_t* __restrict__ topk_ids; + float* __restrict__ topk_weights; + uint32_t num_tokens; + uint32_t topk; + uint32_t num_routed_experts; + uint32_t num_shared_experts; + float routed_scaling_factor; +}; + +template +__global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ params) { + using namespace device; + const auto& [ + router_logits, input_id, tid2eid, topk_ids, topk_weights, // pointers + num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor] = + params; + + const uint32_t topk_fused = topk + num_shared_experts; + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / kWarpThreads; + const uint32_t lane_id = tid % kWarpThreads; + if (warp_id >= num_tokens) return; + // we can safely prefetch the token id + const auto token_id = input_id[warp_id]; + + PDLWaitPrimary(); + + float routed_weight = 0.0f; + int32_t expert_id = 0; + if (lane_id < topk) { + expert_id = tid2eid[token_id * topk + lane_id]; + routed_weight = Fn(router_logits[warp_id * num_routed_experts + expert_id]); + } + + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk_fused) { + const bool is_shared = lane_id >= topk; + const auto output_offset = warp_id * topk_fused + lane_id; + topk_ids[output_offset] = is_shared ? num_routed_experts + lane_id - topk : expert_id; + topk_weights[output_offset] = is_shared ? 1.0f / routed_scaling_factor : routed_weight / routed_sum; + } + + PDLTriggerSecondary(); +} + +struct TopKParams { + int32_t* __restrict__ topk_ids; + // Exactly one is active: ntn_ptr == nullptr means use ntn_value. + const int32_t* __restrict__ ntn_ptr; + int32_t ntn_value; + int64_t stride; + uint32_t topk; + uint32_t num_tokens; +}; + +__global__ void mask_topk_ids_padded_region(const TopKParams __grid_constant__ params) { + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / device::kWarpThreads; + const uint32_t lane_id = tid % device::kWarpThreads; + if (warp_id >= params.num_tokens || lane_id >= params.topk) return; + device::PDLWaitPrimary(); + const uint32_t num = (params.ntn_ptr != nullptr) // + ? static_cast(params.ntn_ptr[0]) + : static_cast(params.ntn_value); + if (warp_id >= num) params.topk_ids[warp_id * params.stride + lane_id] = -1; + device::PDLTriggerSecondary(); +} + +template +struct HashTopKKernel { + static constexpr auto kernel = moe_hash_topk_fused; + + static void + run(const tvm::ffi::TensorView router_logits, + const tvm::ffi::TensorView input_id, + const tvm::ffi::TensorView tid2eid, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView topk_ids, + float routed_scaling_factor) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto E = SymbolicSize{"num_routed_experts"}; + auto K = SymbolicSize{"topk_fused"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, E}) // + .with_dtype() + .with_device(device) + .verify(router_logits); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(input_id); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device) + .verify(tid2eid); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_ids); + + const auto num_tokens = static_cast(N.unwrap()); + const auto topk_fused = static_cast(K.unwrap()); + const auto topk = static_cast(tid2eid.size(1)); + const auto shared_experts = topk_fused - topk; + RuntimeCheck(topk <= topk_fused, "HashTopKKernel requires topk <= topk_fused"); + RuntimeCheck(topk_fused <= device::kWarpThreads, "HashTopKKernel requires topk_fused <= warp size"); + + const auto params = MoEHashTopKParams{ + .router_logits = static_cast(router_logits.data_ptr()), + .input_id = static_cast(input_id.data_ptr()), + .tid2eid = static_cast(tid2eid.data_ptr()), + .topk_ids = static_cast(topk_ids.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .num_tokens = num_tokens, + .topk = topk, + .num_routed_experts = static_cast(E.unwrap()), + .num_shared_experts = shared_experts, + .routed_scaling_factor = routed_scaling_factor, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +// TODO this may not be related to *hash* topk, thus may move +struct MaskKernel { + static constexpr auto kernel = mask_topk_ids_padded_region; + + static void run(tvm::ffi::TensorView topk_ids, tvm::ffi::TensorView num_token_non_padded) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto K = SymbolicSize{"topk"}; + auto D = SymbolicSize{"stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N, K}) // + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(topk_ids); + RuntimeCheck(num_token_non_padded.numel() == 1, "num_token_non_padded should be a scalar"); + RuntimeCheck(K.unwrap() <= device::kWarpThreads, "MaskKernel requires topk <= warp size"); + const int32_t* ntn_ptr = nullptr; + int32_t ntn_value = 0; + const auto ntn_dev = num_token_non_padded.device().device_type; + if (ntn_dev == kDLCUDA) { + RuntimeCheck(is_type(num_token_non_padded.dtype()), "num_token_non_padded on CUDA must be int32"); + ntn_ptr = static_cast(num_token_non_padded.data_ptr()); + } else if (ntn_dev == kDLCPU) { + if (is_type(num_token_non_padded.dtype())) { + ntn_value = *static_cast(num_token_non_padded.data_ptr()); + } else if (is_type(num_token_non_padded.dtype())) { + ntn_value = static_cast(*static_cast(num_token_non_padded.data_ptr())); + } else { + RuntimeCheck(false, "num_token_non_padded on CPU must be int32 or int64"); + } + } else { + RuntimeCheck(false, "num_token_non_padded must be on CPU or CUDA"); + } + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = TopKParams{ + .topk_ids = static_cast(topk_ids.data_ptr()), + .ntn_ptr = ntn_ptr, + .ntn_value = ntn_value, + .stride = static_cast(D.unwrap()), + .topk = static_cast(K.unwrap()), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(true)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh new file mode 100644 index 000000000000..aefec24372a8 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh @@ -0,0 +1,82 @@ +#include +#include + +#include + +#include + +#include +#include + +#include + +namespace { + +/// NOTE: for offload to cpu kernel, we use persistent kernel +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kBlockQuota = 4; + +#define OFFLOAD_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct OffloadParams { + void** gpu_caches; + void** cpu_caches; + const int64_t* gpu_indices; + const int64_t* cpu_indices; + uint32_t num_items; + uint32_t num_layers; +}; + +OFFLOAD_KERNEL void offload_to_cpu(const __grid_constant__ OffloadParams params) { + using namespace device::hisparse; + const auto [gpu_caches, cpu_caches, gpu_indices, cpu_indices, num_items, num_layers] = params; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + constexpr auto kNumWarps = (kBlockSize / 32) * kBlockQuota; + for (auto i = global_tid / 32; i < num_items; i += kNumWarps) { + const int32_t gpu_index = gpu_indices[i]; + const int32_t cpu_index = cpu_indices[i]; + for (auto j = 0u; j < num_layers; ++j) { + const auto gpu_cache = gpu_caches[j]; + const auto cpu_cache = cpu_caches[j]; + transfer_item( + /*dst_cache=*/cpu_cache, + /*src_cache=*/gpu_cache, + /*dst_index=*/cpu_index, + /*src_index=*/gpu_index); + } + } +} + +[[maybe_unused]] +void hisparse_transfer( + tvm::ffi::TensorView gpu_ptrs, + tvm::ffi::TensorView cpu_ptrs, + tvm::ffi::TensorView gpu_indices, + tvm::ffi::TensorView cpu_indices) { + using namespace host; + auto N = SymbolicSize{"num_items"}; + auto L = SymbolicSize{"num_layers"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({L}) // 1D cache pointers + .with_dtype() + .with_device(device_) + .verify(gpu_ptrs) + .verify(cpu_ptrs); + TensorMatcher({N}) // 1D indices + .with_dtype() + .with_device(device_) + .verify(gpu_indices) + .verify(cpu_indices); + const auto params = OffloadParams{ + .gpu_caches = static_cast(gpu_ptrs.data_ptr()), + .cpu_caches = static_cast(cpu_ptrs.data_ptr()), + .gpu_indices = static_cast(gpu_indices.data_ptr()), + .cpu_indices = static_cast(cpu_indices.data_ptr()), + .num_items = static_cast(N.unwrap()), + .num_layers = static_cast(L.unwrap()), + }; + LaunchKernel(kBlockQuota, kBlockSize, device_.unwrap())(offload_to_cpu, params); +} + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh new file mode 100644 index 000000000000..7d5f97824b06 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/mega_moe_pre_dispatch.cuh @@ -0,0 +1,219 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct MegaMoEPreDispatchParams { + const bf16_t* __restrict__ x; // [num_tokens, hidden] + const int32_t* __restrict__ topk_idx; // [num_tokens, top_k] + const float* __restrict__ topk_weights; // [num_tokens, top_k] + + fp8_e4m3_t* __restrict__ buf_x; // [padded_max, hidden] + int32_t* __restrict__ buf_x_sf; // contiguous int32 [P, G/4]; see layout comment + int64_t* __restrict__ buf_topk_idx; // [padded_max, top_k] + float* __restrict__ buf_topk_weights; // [padded_max, top_k] + + uint32_t num_tokens; + uint32_t padded_max; + uint32_t hidden; + uint32_t num_groups; // hidden / group_size + uint32_t top_k; +}; + +// kGroupSize must match sglang_per_token_group_quant_fp8_ue8m0(group_size=). +template +__global__ __launch_bounds__(1024, 2) void // + mega_moe_pre_dispatch_kernel(const MegaMoEPreDispatchParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kVecElems = 8; // 8 bf16 = 16B load per thread + static_assert(kGroupSize % kVecElems == 0, "group_size must be a multiple of 8"); + constexpr uint32_t kThreadsPerGroup = kGroupSize / kVecElems; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + + PDLWaitPrimary(); + if (bid < params.num_tokens) { + // ---- Quantize path: one CTA per valid token ---- + + const uint32_t token_id = bid; + const auto token_in = params.x + static_cast(token_id) * params.hidden; + const auto token_out = params.buf_x + static_cast(token_id) * params.hidden; + + InputVec in_vec; + in_vec.load(token_in, tid); + + float local_max = 0.0f; + float vals[kVecElems]; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + const auto [v0, v1] = cast(in_vec[i]); + vals[2 * i + 0] = v0; + vals[2 * i + 1] = v1; + local_max = fmaxf(local_max, fmaxf(fabsf(v0), fabsf(v1))); + } + + // Absmax across the kThreadsPerGroup threads that cover one group. + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + const float raw_scale = absmax / math::FP8_E4M3_MAX; + const uint32_t ue8m0_exp = cast_to_ue8m0(raw_scale); + // 2^-ue8m0_exp as fp32 (equivalent to 1 / __uint_as_float(ue8m0 << 23)). + const float inv_scale = __uint_as_float((127u + 127u - ue8m0_exp) << 23); + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < kVecElems / 2; ++i) { + out_vec[i] = pack_fp8(vals[2 * i + 0] * inv_scale, vals[2 * i + 1] * inv_scale); + } + out_vec.store(token_out, tid); + + // One thread per group writes its UE8M0 byte into the contiguous + // row-major int32-packed layout: byte address = t*num_groups + g + // (see layout comment at the top of the file). + const uint32_t group_id = tid / kThreadsPerGroup; + const uint32_t within_group_id = tid % kThreadsPerGroup; + if (within_group_id == 0 && group_id < params.num_groups) { + const uint32_t byte_off = token_id * params.num_groups + group_id; + reinterpret_cast(params.buf_x_sf)[byte_off] = static_cast(ue8m0_exp); + } + + // Copy this token's topk row (no alignment assumptions; top_k is small). + if (tid < params.top_k) { + const uint32_t off = token_id * params.top_k + tid; + params.buf_topk_idx[off] = params.topk_idx[off]; + params.buf_topk_weights[off] = params.topk_weights[off]; + } + } else { + // ---- Pad path: trailing blocks fill [num_tokens, padded_max) with (-1, 0) ---- + const uint32_t copy_bid = bid - params.num_tokens; + const uint32_t pad_base = params.num_tokens * params.top_k; + const uint32_t slot = pad_base + copy_bid * blockDim.x + tid; + const uint32_t total_slots = params.padded_max * params.top_k; + + if (slot < total_slots) { + params.buf_topk_idx[slot] = -1; + params.buf_topk_weights[slot] = 0.0f; + } + } + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct MegaMoEPreDispatchKernel { + static_assert(kGroupSize == 32 || kGroupSize == 64 || kGroupSize == 128, "unsupported group_size"); + static constexpr auto kernel = mega_moe_pre_dispatch_kernel(kGroupSize), kUsePDL>; + + static void + run(const tvm::ffi::TensorView x, + const tvm::ffi::TensorView topk_idx, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView buf_x, + const tvm::ffi::TensorView buf_x_sf, + const tvm::ffi::TensorView buf_topk_idx, + const tvm::ffi::TensorView buf_topk_weights) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto P = SymbolicSize{"padded_max"}; + auto H = SymbolicSize{"hidden"}; + auto K = SymbolicSize{"top_k"}; + auto G4 = SymbolicSize{"num_groups_div_4"}; + device.set_options(); + + TensorMatcher({M, H}) // input x + .with_dtype() + .with_device(device) + .verify(x); + TensorMatcher({M, K}) // topk_idx + .with_dtype() + .with_device(device) + .verify(topk_idx); + TensorMatcher({M, K}) // topk_weights + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({P, H}) // buf.x + .with_dtype() + .with_device(device) + .verify(buf_x); + // buf.x_sf is the contiguous row-major int32 view from DeepGEMM's mega + // symm buffer (DeepGEMM/csrc/apis/mega.hpp): shape (P, G/4), strides + // (G/4, 1). No explicit strides required -> TensorMatcher enforces + // is_contiguous(). + TensorMatcher({P, G4}) // buf_x_sf + .with_dtype() + .with_device(device) + .verify(buf_x_sf); + TensorMatcher({P, K}) // buf.topk_idx + .with_dtype() + .with_device(device) + .verify(buf_topk_idx); + TensorMatcher({P, K}) // buf.topk_weights + .with_dtype() + .with_device(device) + .verify(buf_topk_weights); + + const auto num_tokens = static_cast(M.unwrap()); + const auto padded_max = static_cast(P.unwrap()); + const auto hidden = static_cast(H.unwrap()); + const auto top_k = static_cast(K.unwrap()); + const auto num_groups_div_4 = static_cast(G4.unwrap()); + + RuntimeCheck(num_tokens <= padded_max, "num_tokens must not exceed padded_max"); + RuntimeCheck(hidden % kGroupSize == 0, "hidden must be a multiple of group_size"); + const auto num_groups = hidden / static_cast(kGroupSize); + RuntimeCheck(num_groups == num_groups_div_4 * 4u, "num_groups must be a multiple of 4"); + RuntimeCheck(hidden % 8u == 0, "hidden must be a multiple of 8 (16B bf16 loads)"); + const auto num_threads = hidden / 8u; + RuntimeCheck(num_threads <= 1024, "hidden too large for single-block-per-row quant"); + RuntimeCheck(num_threads >= top_k, "top_k must fit into one quant CTA"); + + const auto pad_slots = (padded_max - num_tokens) * top_k; + const uint32_t num_pad_blocks = pad_slots == 0 ? 0u : ((pad_slots + num_threads - 1u) / num_threads); + const auto num_total_blocks = num_tokens + num_pad_blocks; + + const auto params = MegaMoEPreDispatchParams{ + .x = static_cast(x.data_ptr()), + .topk_idx = static_cast(topk_idx.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .buf_x = static_cast(buf_x.data_ptr()), + .buf_x_sf = static_cast(buf_x_sf.data_ptr()), + .buf_topk_idx = static_cast(buf_topk_idx.data_ptr()), + .buf_topk_weights = static_cast(buf_topk_weights.data_ptr()), + .num_tokens = num_tokens, + .padded_max = padded_max, + .hidden = hidden, + .num_groups = num_groups, + .top_k = top_k, + }; + + if (num_total_blocks == 0) return; + LaunchKernel(num_total_blocks, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh new file mode 100644 index 000000000000..38be97555853 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh @@ -0,0 +1,119 @@ +#include +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kBlockSize = 1024; +constexpr uint32_t kSplitKV = 256; // const for both SM90 and SM100 + +struct MetadataParams { + /// NOTE: batch_size > 0 + uint32_t batch_size; + uint32_t num_sm; + const uint32_t* __restrict__ context_lens; + uint32_t* __restrict__ schedule_metadata; + bool use_smem = true; +}; + +__global__ __launch_bounds__(kBlockSize, 1) // + void smxx_paged_mqa_logits_metadata(const MetadataParams params) { + using namespace device; + extern __shared__ uint32_t s_length[]; + static constexpr auto kNumWarps = kBlockSize / kWarpThreads; + static_assert(kNumWarps == kWarpThreads); + + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + __shared__ uint32_t s_warp_sum[kNumWarps]; + + uint32_t local_sum = 0; + for (uint32_t i = tx; i < params.batch_size; i += kBlockSize) { + const auto length = params.context_lens[i]; + local_sum += (length + kSplitKV - 1) / kSplitKV; + if (params.use_smem) s_length[i] = length; + } + + s_warp_sum[warp_id] = warp::reduce_sum(local_sum); + __syncthreads(); + + const auto global_sum = warp::reduce_sum(s_warp_sum[lane_id]); + if (lane_id != 0) return; + + const auto length_ptr = params.use_smem ? s_length : params.context_lens; + + const auto avg = global_sum / params.num_sm; + const auto ret = global_sum % params.num_sm; + uint32_t q = 0; + uint32_t num_work = (length_ptr[0] + kSplitKV - 1) / kSplitKV; + uint32_t sum_work = num_work; + for (auto i = warp_id; i <= params.num_sm; i += kNumWarps) { + const auto target = i * avg + min(i, ret); + while (sum_work <= target) { + if (++q >= params.batch_size) break; + num_work = (length_ptr[q] + kSplitKV - 1) / kSplitKV; + sum_work += num_work; + } + if (q >= params.batch_size) { + params.schedule_metadata[2 * i + 0] = params.batch_size; + params.schedule_metadata[2 * i + 1] = 0; + } else { + // sum > target && (sum - length) <= target + params.schedule_metadata[2 * i + 0] = q; + params.schedule_metadata[2 * i + 1] = target - (sum_work - num_work); + } + } +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +struct IndexerMetadataKernel { + static constexpr auto kMaxBatchSizeInSmem = 16384 * 2; // 128 KB smeme + static void run(tvm::ffi::TensorView seq_lens, tvm::ffi::TensorView metadata) { + using namespace host; + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_sm"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({M, 2}) // + .with_dtype() + .with_device(device) + .verify(metadata); + const auto batch_size = static_cast(N.unwrap()); + const auto num_sm = static_cast(M.unwrap()) - 1; + RuntimeCheck(num_sm <= 1024); + const auto use_smem = batch_size <= kMaxBatchSizeInSmem; + const auto params = MetadataParams{ + .batch_size = batch_size, + .num_sm = num_sm, + .context_lens = static_cast(seq_lens.data_ptr()), + .schedule_metadata = static_cast(metadata.data_ptr()), + .use_smem = use_smem, + }; + constexpr auto kernel = smxx_paged_mqa_logits_metadata; + setup_kernel_smem_once(); + const auto smem = use_smem ? (batch_size + 1) * sizeof(uint32_t) : 0; + LaunchKernel(1, kBlockSize, device.unwrap(), smem)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/rmsnorm.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/rmsnorm.cuh new file mode 100644 index 000000000000..f9407ec84db0 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/rmsnorm.cuh @@ -0,0 +1,133 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct RMSNormSelfParams { + const void* __restrict__ input; + void* __restrict__ output; + int64_t stride_batch_bytes; + int64_t stride_head_bytes; + uint32_t batch_size; + uint32_t num_head; + float eps; +}; + +template +__global__ __launch_bounds__(kBlockSize, 20) // + void rmsnorm_self(const __grid_constant__ RMSNormSelfParams params) { + using namespace device; + constexpr int64_t kVecSize = 16 / sizeof(DType); + constexpr uint32_t kNumLoop = kHeadDim / (kVecSize * kWarpThreads); + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + using DType2 = packed_t; + using Vec = AlignedVector; + + const auto warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads; + const auto batch_id = warp_id / params.num_head; + const auto head_id = warp_id % params.num_head; + const auto gmem = tile::Memory::warp(); + if (batch_id >= params.batch_size) return; + const auto input_ptr = pointer::offset( // + params.input, + batch_id * params.stride_batch_bytes, + head_id * params.stride_head_bytes); + // use contiguous layout + const auto output_ptr = pointer::offset( // + params.output, + warp_id * kHeadDim * sizeof(DType)); + PDLWaitPrimary(); // wait for primary kernel + + Vec inputs[kNumLoop]; +#pragma unroll + for (uint32_t i = 0; i < kNumLoop; ++i) { + inputs[i] = gmem.load(input_ptr, i); + } + + // compute sum of squares + float local_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < kNumLoop; ++i) { +#pragma unroll + for (uint32_t j = 0; j < kVecSize / 2; ++j) { + const auto [x, y] = cast(inputs[i][j]); + local_sum += x * x + y * y; + } + } + + const auto sum_of_squares = warp::reduce_sum(local_sum); + const auto factor = math::rsqrt(sum_of_squares / kHeadDim + params.eps); + + // weight must be identity (null, not used) +#pragma unroll + for (uint32_t i = 0; i < kNumLoop; ++i) { +#pragma unroll + for (uint32_t j = 0; j < kVecSize / 2; ++j) { + const auto [x, y] = cast(inputs[i][j]); + inputs[i][j] = cast(fp32x2_t{x * factor, y * factor}); + } + gmem.store(output_ptr, inputs[i], i); + } + + PDLTriggerSecondary(); // launch secondary kernel +} + +template +struct RMSNormKernel { + static constexpr auto kernel_self = rmsnorm_self; + + static void run_self(tvm::ffi::TensorView input, tvm::ffi::TensorView output, float eps) { + using namespace host; + + auto N = SymbolicSize{"batch_size"}; + auto H = SymbolicSize{"num_heads"}; + auto Dn = SymbolicSize{"stride_head"}; + auto Dh = SymbolicSize{"stride_batch"}; + constexpr auto D = kHeadDim; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, H, D}) // input + .with_strides({Dh, Dn, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({N, H, D}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(output); + + const auto batch_size = static_cast(N.unwrap()); + const auto num_head = static_cast(H.unwrap()); + const auto stride_head_bytes = static_cast(Dn.unwrap() * sizeof(DType)); + const auto stride_batch_bytes = static_cast(Dh.unwrap() * sizeof(DType)); + const auto params = RMSNormSelfParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .stride_batch_bytes = stride_batch_bytes, + .stride_head_bytes = stride_head_bytes, + .batch_size = batch_size, + .num_head = num_head, + .eps = eps, + }; + if (batch_size == 0 || num_head == 0) return; + const auto needed_warps = batch_size * num_head; + const auto num_blocks = div_ceil(needed_warps, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel_self, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh new file mode 100644 index 000000000000..2239d3972d64 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh @@ -0,0 +1,169 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +namespace { + +using DType = bf16_t; +constexpr int64_t kRopeDim = 64; +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedQKRopeParams { + void* __restrict__ q; + void* __restrict__ k; + const float* __restrict__ freqs_cis; + const void* __restrict__ positions; + int64_t q_stride_batch; + int64_t k_stride_batch; + int64_t q_stride_head; + int64_t k_stride_head; + uint32_t num_q_heads; + uint32_t num_k_heads; + uint32_t batch_size; +}; + +template +__global__ __launch_bounds__(kBlockSize, 16) // + void deepseek_rope_kernel(const __grid_constant__ FusedQKRopeParams param) { + using namespace device; + using DType2 = packed_t; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto global_warp_id = blockIdx.x * kNumWarps + warp_id; + + const auto& [ + q, k, freqs_cis, positions, // + q_stride_batch, k_stride_batch, q_stride_head, k_stride_head, // + num_q_heads, num_k_heads, batch_size + ] = param; + + const auto num_total_heads = num_q_heads + num_k_heads; + const auto head_id = global_warp_id % num_total_heads; + const auto batch_id = global_warp_id / num_total_heads; + if (batch_id >= batch_size) return; + + const auto position = static_cast(positions)[batch_id]; + const auto is_q = head_id < num_q_heads; + const auto local_head = is_q ? head_id : (head_id - num_q_heads); + const auto stride_batch = is_q ? q_stride_batch : k_stride_batch; + const auto stride_head = is_q ? q_stride_head : k_stride_head; + const auto base_ptr = is_q ? q : k; + const auto input = static_cast(pointer::offset(base_ptr, batch_id * stride_batch, local_head * stride_head)); + + const auto freq_ptr = reinterpret_cast(freqs_cis + position * kRopeDim); + const auto [f_real, f_imag] = freq_ptr[lane_id]; + PDLWaitPrimary(); + + const auto data = input[lane_id]; + const auto [x_real, x_imag] = cast(data); + fp32x2_t output; + if constexpr (kInverse) { + // (a + bi) * (c - di) = (ac + bd) + (bc - ad)i + output = { + x_real * f_real + x_imag * f_imag, + x_imag * f_real - x_real * f_imag, + }; + } else { + // (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + output = { + x_real * f_real - x_imag * f_imag, + x_real * f_imag + x_imag * f_real, + }; + } + input[lane_id] = cast(output); + + PDLTriggerSecondary(); +} + +template +struct FusedQKRopeKernel { + // 4 kernel variants: {forward, inverse} x {int32, int64} + static constexpr auto kernel_fwd_i32 = deepseek_rope_kernel; + static constexpr auto kernel_fwd_i64 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i32 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i64 = deepseek_rope_kernel; + + static void forward( + const tvm::ffi::TensorView q, + const tvm::ffi::Optional k, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + bool inverse) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto Q = SymbolicSize{"num_q_heads"}; + auto K = SymbolicSize{"num_k_heads"}; + constexpr auto D = kRopeDim; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, Q, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(q); + if (k.has_value()) { + TensorMatcher({B, K, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(k.value()); + } else { + K.set_value(0); + } + TensorMatcher({-1, D}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + const bool pos_i32 = pos_dtype.is_type(); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto num_q_heads = static_cast(Q.unwrap()); + const auto num_k_heads = static_cast(K.unwrap()); + const auto num_total_heads = num_q_heads + num_k_heads; + const auto total_warps = batch_size * num_total_heads; + const auto num_blocks = div_ceil(total_warps, kNumWarps); + + const auto elem_size = static_cast(sizeof(DType)); + const auto params = FusedQKRopeParams{ + .q = q.data_ptr(), + .k = k ? k.value().data_ptr() : nullptr, + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .q_stride_batch = q.stride(0) * elem_size, + .k_stride_batch = k ? k.value().stride(0) * elem_size : 0, + .q_stride_head = q.stride(1) * elem_size, + .k_stride_head = k ? k.value().stride(1) * elem_size : 0, + .num_q_heads = num_q_heads, + .num_k_heads = num_k_heads, + .batch_size = batch_size, + }; + + // dispatch: {inverse} x {pos_i32} + using KernelType = decltype(kernel_fwd_i32); + const KernelType kernel = + inverse ? (pos_i32 ? kernel_inv_i32 : kernel_inv_i64) : (pos_i32 ? kernel_fwd_i32 : kernel_fwd_i64); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh new file mode 100644 index 000000000000..be0e759445f9 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh @@ -0,0 +1,540 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct SiluMulQuantVarlenParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + const int32_t* __restrict__ masked_m; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t num_experts; +}; + +constexpr uint32_t kMaxExperts = 256; + +struct alignas(16) CTAWork { + uint32_t expert_id; + uint32_t expert_token_id; + bool valid; +}; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +template +SGL_DEVICE fp32x2_t silu_and_mul(DType2 gate, DType2 up, float limit) { + using namespace device; + // refer to as implementation. TL;DR: must clamp in bf16 + // https://github.com/deepseek-ai/DeepGEMM/blob/7f2a703ed51ac1f7af07f5e1453b2d3267d37d50/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh#L984-L997 + if constexpr (kApplySwigluLimit) { + static_assert(std::is_same_v); + gate = __hmin2(gate, {limit, limit}); + up = __hmax2(up, {-limit, -limit}); + up = __hmin2(up, {limit, limit}); + } + const auto [g0, g1] = cast(gate); + const auto [u0, u1] = cast(up); + const auto silu0 = g0 / (1.0f + __expf(-g0)); + const auto silu1 = g1 / (1.0f + __expf(-g1)); + const float val0 = silu0 * u0; + const float val1 = silu1 * u1; + if constexpr (kPrecise) { // I don't know if we should enable this? + return {val0, val1}; + } else { + return cast(cast(fp32x2_t{val0, val1})); + } +} + +[[maybe_unused]] +SGL_DEVICE CTAWork get_work(const SiluMulQuantVarlenParams& params) { + // Preconditions: + // 1. blockDim.x >= params.num_experts + // 2. params.num_experts <= kMaxExperts + using namespace device; + static_assert(kWarpThreads == 32); + + static __shared__ uint32_t s_warp_sum[32]; + static __shared__ CTAWork result; + + result.valid = false; + + const uint32_t tx = threadIdx.x; + const uint32_t lane_id = tx % kWarpThreads; + const uint32_t warp_id = tx / kWarpThreads; + + const uint32_t val = tx < params.num_experts ? params.masked_m[tx] : 0u; + + // Per-warp inclusive scan of masked_m. + const uint32_t warp_inclusive = warp_inclusive_sum(lane_id, val); + const uint32_t warp_exclusive = warp_inclusive - val; + + // Write each warp total. + if (lane_id == kWarpThreads - 1) s_warp_sum[warp_id] = warp_inclusive; + __syncthreads(); + const auto tmp_val = lane_id < warp_id ? s_warp_sum[lane_id] : 0u; + const auto prefix_exclusive = warp::reduce_sum(tmp_val) + warp_exclusive; + const auto bx = blockIdx.x; + if (prefix_exclusive <= bx && bx < prefix_exclusive + val) { + result = {tx, bx - prefix_exclusive, true}; + } + __syncthreads(); + return result; +} + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_varlen_kernel(const SiluMulQuantVarlenParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + // each thread will handle 8 elements + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto [expert_id, token_id, valid] = get_work(params); + + if (!valid) return; + + const auto work_id = threadIdx.x / kWorkThreads; + + const auto offset = expert_id * params.num_tokens + token_id; + const auto input = params.input + offset * params.hidden_dim * 2; + const auto output = params.output + offset * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + const auto base = reinterpret_cast(params.output_scale); + // Physical layout is [E, G//4, N] int32. Each int32 packs 4 consecutive + // group scales for the same token, so the byte address is: + // expert_offset + (group/4)*N*4 + token*4 + group%4 + return base + expert_id * num_groups * params.num_tokens + (work_id / 4u) * (params.num_tokens * 4u) + + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + offset * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + // gran=8 interleaved: every 16-element chunk on the N axis is + // [gate[0..7], up[0..7]]. Each thread handles 8 consecutive output + // elements, so its gate chunk lives at vec index 2*threadIdx.x and its + // up chunk at 2*threadIdx.x+1. + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +struct SiluAndMulClampParams { + const void* __restrict__ input; + void* __restrict__ output; + float swiglu_limit; +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_clamp_kernel(const SiluAndMulClampParams __grid_constant__ params) { + using namespace device; + static_assert(sizeof(DType) == 2, "only fp16/bf16 supported"); + using DType2 = packed_t; + constexpr auto kVecSize = 16 / sizeof(DType); + static_assert(kVecSize % 2 == 0 && kVecSize > 0); + using Vec = AlignedVector; + const auto bid = blockIdx.x; + const auto tile = tile::Memory::cta(); + const float limit = params.swiglu_limit; + + PDLWaitPrimary(); + const auto gate = tile.load(params.input, bid * 2 + 0); + const auto up = tile.load(params.input, bid * 2 + 1); + Vec out; + +#pragma unroll + for (uint32_t i = 0; i < kVecSize / 2; ++i) { + out[i] = cast(silu_and_mul(cast(gate[i]), cast(up[i]), limit)); + } + + tile.store(params.output, out, bid); + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct SiluAndMulMaskedPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_varlen_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_varlen_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const tvm::ffi::TensorView masked_m, + const uint32_t topk, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto E = SymbolicSize{"num_experts"}; + auto T = SymbolicSize{"num_tokens_padded"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({E, T, D}) // input + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({E, T, N}) // output + .with_dtype() + .with_device(device) + .verify(output); + if (!transposed) { + TensorMatcher({E, T, G}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + auto G_ = SymbolicSize{"G // 4"}; + TensorMatcher({E, G_, T}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + G.set_value(G_.unwrap() * 4); + } + TensorMatcher({E}) // + .with_dtype() + .with_device(device) + .verify(masked_m); + + const auto num_experts = static_cast(E.unwrap()); + const auto num_tokens = static_cast(T.unwrap()); + const auto num_groups = static_cast(G.unwrap()); + const auto hidden_dim = N.unwrap(); + + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum (256)"); + RuntimeCheck(num_groups * kGroupSize == hidden_dim, "invalid num_groups"); + + const auto params = SiluMulQuantVarlenParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .masked_m = static_cast(masked_m.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .num_experts = num_experts, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + RuntimeCheck(num_threads >= num_experts); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens * topk, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct SiluAndMulClampKernel { + static constexpr auto kernel = silu_mul_clamp_kernel; + + static void run(const tvm::ffi::TensorView input, const tvm::ffi::TensorView output, const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"gate_up_dim"}; // 2 * out_dim + auto H = SymbolicSize{"out_dim"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate || up) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, H}) // output + .with_dtype() + .with_device(device) + .verify(output); + RuntimeCheck(D.unwrap() == 2 * H.unwrap(), "input last dim must be 2 * output last dim"); + + constexpr uint32_t kVecSize = 16 / sizeof(DType); + const auto out_dim = static_cast(H.unwrap()); + const auto num_tokens = static_cast(M.unwrap()); + RuntimeCheck(out_dim % kVecSize == 0, "out_dim must be divisible by vector size"); + const auto num_threads = out_dim / kVecSize; + RuntimeCheck(num_threads <= 1024, "out_dim too large for single-block-per-row launch"); + + const auto params = SiluAndMulClampParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .swiglu_limit = static_cast(swiglu_limit), + }; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +struct SiluMulQuantContigParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t scale_row_stride_int32; // only used when kTransposed=true +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_contig_kernel(const SiluMulQuantContigParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto token_id = blockIdx.x; + const auto work_id = threadIdx.x / kWorkThreads; + + const auto input = params.input + token_id * params.hidden_dim * 2; + const auto output = params.output + token_id * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + // Physical layout is (G//4_pad, M_pad) int32; each int32 packs 4 + // consecutive UE8M0 exponents for the same token. Byte address: + // (work_id / 4) * M_pad * 4 + token * 4 + (work_id % 4). + const auto base = reinterpret_cast(params.output_scale); + return base + (work_id / 4u) * (params.scale_row_stride_int32 * 4u) + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + token_id * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + if constexpr (kSwizzle) { + gate_vec.load(input, threadIdx.x * 2); + up_vec.load(input, threadIdx.x * 2 + 1); + } else { + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + } + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const auto [x, y] = silu_and_mul(gate_vec[i], up_vec[i], params.swiglu_limit); + results[2 * i + 0] = x; + results[2 * i + 1] = y; + local_max = fmaxf(local_max, fmaxf(fabsf(x), fabsf(y))); + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +template +struct SiluAndMulContigPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = + silu_mul_quant_contig_kernel; + static constexpr auto kernel_transposed = + silu_mul_quant_contig_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate/up, natural or gran=8 interleaved on last dim) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, N}) // fp8 output + .with_dtype() + .with_device(device) + .verify(output); + + const auto hidden_dim = N.unwrap(); + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + const auto num_groups = static_cast(hidden_dim / kGroupSize); + + uint32_t scale_row_stride_int32 = 0; + if (!transposed) { + G.set_value(num_groups); + TensorMatcher({M, G}) // (M, G) fp32 natural row-major + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + RuntimeCheck(num_groups % 4 == 0, "transposed layout requires num_groups % 4 == 0"); + auto G_ = SymbolicSize{"G // 4"}; + G_.set_value(num_groups / 4); + auto M_pad = SymbolicSize{"M padded"}; + TensorMatcher({M, G_}) // `.transpose(-1,-2)[:M,:]` view of (G//4_pad, M_pad) int32 + .with_strides({int64_t{1}, M_pad}) // col-major transposed + .with_dtype() + .with_device(device) + .verify(output_scale); + scale_row_stride_int32 = static_cast(M_pad.unwrap()); + } + + const auto num_tokens = static_cast(M.unwrap()); + + const auto params = SiluMulQuantContigParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .scale_row_stride_int32 = scale_row_stride_int32, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh new file mode 100644 index 000000000000..3e2bd92589b7 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant_tmp.cuh @@ -0,0 +1,371 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct SiluMulQuantParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + const int32_t* __restrict__ masked_m; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t num_experts; +}; + +constexpr uint32_t kMaxExperts = 256; + +struct alignas(16) CTAWork { + uint32_t expert_id; + uint32_t expert_token_id; + bool valid; +}; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +[[maybe_unused]] +SGL_DEVICE CTAWork get_work(const SiluMulQuantParams& params) { + // Preconditions: + // 1. blockDim.x >= params.num_experts + // 2. params.num_experts <= kMaxExperts + using namespace device; + static_assert(kWarpThreads == 32); + + static __shared__ uint32_t s_warp_sum[32]; + static __shared__ CTAWork result; + + result.valid = false; + + const uint32_t tx = threadIdx.x; + const uint32_t lane_id = tx % kWarpThreads; + const uint32_t warp_id = tx / kWarpThreads; + + const uint32_t val = tx < params.num_experts ? params.masked_m[tx] : 0u; + + // Per-warp inclusive scan of masked_m. + const uint32_t warp_inclusive = warp_inclusive_sum(lane_id, val); + const uint32_t warp_exclusive = warp_inclusive - val; + + // Write each warp total. + if (lane_id == kWarpThreads - 1) s_warp_sum[warp_id] = warp_inclusive; + __syncthreads(); + const auto tmp_val = lane_id < warp_id ? s_warp_sum[lane_id] : 0u; + const auto prefix_exclusive = warp::reduce_sum(tmp_val) + warp_exclusive; + const auto bx = blockIdx.x; + if (prefix_exclusive <= bx && bx < prefix_exclusive + val) { + result = {tx, bx - prefix_exclusive, true}; + } + __syncthreads(); + return result; +} + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_kernel(const SiluMulQuantParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + // each thread will handle 8 elements + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto [expert_id, token_id, valid] = get_work(params); + + if (!valid) return; + + const auto work_id = threadIdx.x / kWorkThreads; + + const auto offset = expert_id * params.num_tokens + token_id; + const auto input = params.input + offset * params.hidden_dim * 2; + const auto output = params.output + offset * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + const auto base = reinterpret_cast(params.output_scale); + // Physical layout is [E, G//4, N] int32. Each int32 packs 4 consecutive + // group scales for the same token, so the byte address is: + // expert_offset + (group/4)*N*4 + token*4 + group%4 + return base + expert_id * num_groups * params.num_tokens + (work_id / 4u) * (params.num_tokens * 4u) + + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + offset * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + if constexpr (kApplySwigluLimit) { + // Fused fp32 path: bf16 load ??? fp32 clamp ??? fp32 silu ??? fp32 mul ??? fp32 result. + // Avoids the silu???bf16???mul???fp32 round-trip of the non-fused path since we already + // have gate/up in fp32 registers after clamp. + const float limit = params.swiglu_limit; + + const auto [g0_raw, g1_raw] = cast(gate_vec[i]); + const float g0 = fminf(g0_raw, limit); + const float g1 = fminf(g1_raw, limit); + + const float silu0 = g0 / (1.0f + expf(-g0)); + const float silu1 = g1 / (1.0f + expf(-g1)); + + const auto [u0_raw, u1_raw] = cast(up_vec[i]); + const float u0 = fmaxf(fminf(u0_raw, limit), -limit); + const float u1 = fmaxf(fminf(u1_raw, limit), -limit); + + const float val0 = u0 * silu0; + const float val1 = u1 * silu1; + results[2 * i + 0] = val0; + results[2 * i + 1] = val1; + local_max = fmaxf(local_max, fmaxf(fabsf(val0), fabsf(val1))); + } else { + // original code path ??? must stay byte-equal to pre-fusion kernel. + const auto [g0, g1] = cast(gate_vec[i]); + + float silu0 = g0 / (1.0f + expf(-g0)); + float silu1 = g1 / (1.0f + expf(-g1)); + + bf16x2_t silu_d = cast(fp32x2_t{silu0, silu1}); + auto [val0, val1] = cast(up_vec[i] * silu_d); + results[2 * i + 0] = val0; + results[2 * i + 1] = val1; + local_max = fmaxf(local_max, fmaxf(fabsf(val0), fabsf(val1))); + } + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +struct SiluAndMulClampParams { + const void* __restrict__ input; + void* __restrict__ output; + float swiglu_limit; +}; + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_clamp_kernel(const SiluAndMulClampParams __grid_constant__ params) { + using namespace device; + static_assert(sizeof(DType) == 2, "only fp16/bf16 supported"); + using DType2 = packed_t; + constexpr auto kVecSize = 16 / sizeof(DType); + static_assert(kVecSize % 2 == 0 && kVecSize > 0); + using Vec = AlignedVector; + const auto bid = blockIdx.x; + const auto tile = tile::Memory::cta(); + const float limit = params.swiglu_limit; + + PDLWaitPrimary(); + const auto gate = tile.load(params.input, bid * 2 + 0); + const auto up = tile.load(params.input, bid * 2 + 1); + Vec out; + +#pragma unroll + for (uint32_t i = 0; i < kVecSize / 2; ++i) { + const auto [g0_raw, g1_raw] = cast(gate[i]); + const float g0 = fminf(g0_raw, limit); + const float g1 = fminf(g1_raw, limit); + const float silu0 = g0 / (1.0f + expf(-g0)); + const float silu1 = g1 / (1.0f + expf(-g1)); + const auto [u0_raw, u1_raw] = cast(up[i]); + const float u0 = fmaxf(fminf(u0_raw, limit), -limit); + const float u1 = fmaxf(fminf(u1_raw, limit), -limit); + const float val0 = u0 * silu0; + const float val1 = u1 * silu1; + out[i] = cast(fp32x2_t{val0, val1}); + } + + tile.store(params.output, out, bid); + PDLTriggerSecondary(); +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct SiluAndMulMaskedPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = silu_mul_quant_kernel; + static constexpr auto kernel_transposed = silu_mul_quant_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const tvm::ffi::TensorView masked_m, + const uint32_t topk, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto E = SymbolicSize{"num_experts"}; + auto T = SymbolicSize{"num_tokens_padded"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({E, T, D}) // input + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({E, T, N}) // output + .with_dtype() + .with_device(device) + .verify(output); + if (!transposed) { + TensorMatcher({E, T, G}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + auto G_ = SymbolicSize{"G // 4"}; + TensorMatcher({E, G_, T}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + G.set_value(G_.unwrap() * 4); + } + TensorMatcher({E}) // + .with_dtype() + .with_device(device) + .verify(masked_m); + + const auto num_experts = static_cast(E.unwrap()); + const auto num_tokens = static_cast(T.unwrap()); + const auto num_groups = static_cast(G.unwrap()); + const auto hidden_dim = N.unwrap(); + + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum (256)"); + RuntimeCheck(num_groups * kGroupSize == hidden_dim, "invalid num_groups"); + + const auto params = SiluMulQuantParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .masked_m = static_cast(masked_m.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .num_experts = num_experts, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + RuntimeCheck(num_threads >= num_experts); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens * topk, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct SiluAndMulClampKernel { + static constexpr auto kernel = silu_mul_clamp_kernel; + + static void run(const tvm::ffi::TensorView input, const tvm::ffi::TensorView output, const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto M = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"gate_up_dim"}; // 2 * out_dim + auto H = SymbolicSize{"out_dim"}; + device.set_options(); + + TensorMatcher({M, D}) // input (gate || up) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({M, H}) // output + .with_dtype() + .with_device(device) + .verify(output); + RuntimeCheck(D.unwrap() == 2 * H.unwrap(), "input last dim must be 2 * output last dim"); + + constexpr uint32_t kVecSize = 16 / sizeof(DType); + const auto out_dim = static_cast(H.unwrap()); + const auto num_tokens = static_cast(M.unwrap()); + RuntimeCheck(out_dim % kVecSize == 0, "out_dim must be divisible by vector size"); + const auto num_threads = out_dim / kVecSize; + RuntimeCheck(num_threads <= 1024, "out_dim too large for single-block-per-row launch"); + + const auto params = SiluAndMulClampParams{ + .input = input.data_ptr(), + .output = output.data_ptr(), + .swiglu_limit = static_cast(swiglu_limit), + }; + LaunchKernel(num_tokens, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh new file mode 100644 index 000000000000..49f6f5596377 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh @@ -0,0 +1,205 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include +#include + +namespace { + +using deepseek_v4::fp8::cast_to_ue8m0; +using deepseek_v4::fp8::inv_scale_ue8m0; +using deepseek_v4::fp8::pack_fp8; + +struct FusedStoreCacheParam { + const void* __restrict__ input; + void* __restrict__ cache; + const void* __restrict__ indices; + uint32_t num_tokens; +}; + +template +__global__ void fused_store_flashmla_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 584 = 576 + 8 + constexpr int64_t kPageBytes = host::div_ceil(584 << kPageBits, 576) * 576; + + // each warp handles 64 elements, 8 warps, each block handles 1 row + const auto& [input, cache, indices, num_tokens] = param; + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + const uint32_t wid = tid / 32; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[bid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + const auto elems = static_cast(input)[tid + bid * 256]; + if (wid != 7) { + const auto [x, y] = cast(elems); + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576); + const auto scale_ptr = pointer::offset(page_ptr, 576 << kPageBits, offset * 8); + static_cast(value_ptr)[tid] = result; + static_cast(scale_ptr)[wid] = scale_ue8m0; + } else { + const auto result = cast(elems); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576, 448); + static_cast(value_ptr)[tid - 7 * 32] = result; + } + + PDLTriggerSecondary(); +} + +template +__global__ void fused_store_indexer_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 132 = 128 + 4 + constexpr int64_t kPageBytes = 132 << kPageBits; + + // each warp handles 128 elements, 1 warp, each block handles multiple rows + const auto& [input, cache, indices, num_tokens] = param; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto global_wid = global_tid / 32; + const auto lane_id = threadIdx.x % 32; + + if (global_wid >= num_tokens) return; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[global_wid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + using InStorage = AlignedVector; + using OutStorage = AlignedVector; + const auto elems = static_cast(input)[global_tid]; + const auto [x0, x1] = cast(elems[0]); + const auto [y0, y1] = cast(elems[1]); + const auto local_max = fmaxf(fmaxf(fabs(x0), fabs(x1)), fmaxf(fabs(y0), fabs(y1))); + const auto abs_max = warp::reduce_max(local_max); + // use normal fp32 scale + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 128); + const auto scale_ptr = pointer::offset(page_ptr, 128 << kPageBits, offset * 4); + OutStorage result; + result[0] = pack_fp8(x0 * inv_scale, x1 * inv_scale); + result[1] = pack_fp8(y0 * inv_scale, y1 * inv_scale); + static_cast(value_ptr)[lane_id] = result; + static_cast(scale_ptr)[0] = scale; + + PDLTriggerSecondary(); +} + +template +struct FusedStoreCacheFlashMLAKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = host::div_ceil(584 * kPageSize, 576) * 576; + static constexpr auto kernel = fused_store_flashmla_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 512}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 256; + const auto num_blocks = num_tokens; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct FusedStoreCacheIndexerKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = 132 * kPageSize; + static constexpr auto kernel = fused_store_indexer_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 128}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128; + const auto num_blocks = div_ceil(num_tokens * 32, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh new file mode 100644 index 000000000000..ef2be43c07e2 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh @@ -0,0 +1,336 @@ +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kTopK = 512; +constexpr uint32_t kTopKBlockSize = 512; +constexpr uint32_t kSMEM = 16 * 1024 * sizeof(uint32_t); // 64KB (bytes) + +struct TopK512Params { + const float* __restrict__ scores; + const int32_t* __restrict__ seq_lens; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int32_t* __restrict__ raw_indices; // optional: output raw abs position indices before page transform + const int64_t score_stride; + const int64_t page_table_stride; + uint32_t page_bits; +}; + +SGL_DEVICE uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +SGL_DEVICE uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + int32_t* __restrict__ raw_indices, // optional: output raw abs position indices + const uint32_t length, + const uint32_t page_bits) { + static_assert(kTopK <= kTopKBlockSize); + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + if (raw_indices != nullptr) { + raw_indices[tx] = tx; + } + } else if (kTopK == kTopKBlockSize || tx < kTopK) { + indices[tx] = -1; // fill invalid indices to -1 + if (raw_indices != nullptr) { + raw_indices[tx] = -1; + } + } +} + +[[maybe_unused]] +SGL_DEVICE void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, const uint32_t length) { + constexpr uint32_t RADIX = 256; + constexpr uint32_t BLOCK_SIZE = kTopKBlockSize; + constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t)); + + alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32]; + alignas(128) __shared__ uint32_t s_counter; + alignas(128) __shared__ uint32_t s_threshold_bin_id; + alignas(128) __shared__ uint32_t s_num_input[2]; + alignas(128) __shared__ int32_t s_last_remain; + + extern __shared__ uint32_t s_input_idx[][kSMEM / (2 * sizeof(int32_t))]; + + const uint32_t tx = threadIdx.x; + uint32_t remain_topk = kTopK; + auto& s_histogram = _s_histogram_buf[0]; + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int32_t i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = _s_histogram_buf[k][tx]; + if (tx + j < RADIX) { + value += _s_histogram_buf[k][tx + j]; + } + _s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + if (remain_topk == 0) { + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const uint32_t bin = convert_to_uint8(input[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw_input = input[idx]; + const uint32_t bin = convert_to_uint8(raw_input); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (pos < SMEM_INPUT_SIZE) { + [[likely]] s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto raw_num_input = s_num_input[r_idx]; + const auto num_input = raw_num_input < SMEM_INPUT_SIZE ? raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = remain_topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + + if (remain_topk == 0) { + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + output[kTopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < SMEM_INPUT_SIZE) { + /// NOTE: (dark) fuse the histogram computation here + [[likely]] s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void topk_512_transform(const __grid_constant__ TopK512Params params) { + const auto &[ + scores, seq_lens, page_table, page_indices, raw_indices, // pointers + score_stride, page_table_stride, page_bits // sizes + ] = params; + const uint32_t work_id = blockIdx.x; + + /// NOTE: dangerous prefetch seq_len before PDL wait + const uint32_t seq_len = seq_lens[work_id]; + const auto score_ptr = scores + work_id * score_stride; + const auto page_ptr = page_table + work_id * page_table_stride; + const auto indices_ptr = page_indices + work_id * kTopK; + const auto raw_indices_ptr = raw_indices != nullptr ? raw_indices + work_id * kTopK : nullptr; + + device::PDLWaitPrimary(); + + if (seq_len <= kTopK) { + naive_transform(score_ptr, page_ptr, indices_ptr, raw_indices_ptr, seq_len, page_bits); + } else { + __shared__ int32_t s_topk_indices[kTopK]; + radix_topk(score_ptr, s_topk_indices, seq_len); + static_assert(kTopK <= kTopKBlockSize); + const auto tx = threadIdx.x; + if (kTopK == kTopKBlockSize || tx < kTopK) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], page_bits); + if (raw_indices_ptr != nullptr) { + raw_indices_ptr[tx] = s_topk_indices[tx]; + } + } + } + + device::PDLTriggerSecondary(); +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +template +struct TopK512Kernel { + static constexpr auto kernel = topk_512_transform; + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional raw_indices) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, 512}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + int32_t* raw_indices_ptr = nullptr; + if (raw_indices.has_value()) { + TensorMatcher({B, 512}) // optional raw indices output, must be contiguous + .with_dtype() + .with_device(device) + .verify(raw_indices.value()); + raw_indices_ptr = static_cast(raw_indices.value().data_ptr()); + } + + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopK512Params{ + .scores = static_cast(scores.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .raw_indices = raw_indices_ptr, + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + constexpr auto kSMEM_ = kSMEM + sizeof(int32_t); // align up a little + setup_kernel_smem_once(); + LaunchKernel(batch_size, kTopKBlockSize, device.unwrap(), kSMEM_).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk_1024.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_1024.cuh new file mode 100644 index 000000000000..6774734ec187 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_1024.cuh @@ -0,0 +1,336 @@ +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kTopK = 1024; +constexpr uint32_t kTopKBlockSize = 1024; +constexpr uint32_t kSMEM = 16 * 1024 * sizeof(uint32_t); // 64KB (bytes) + +struct TopK1024Params { + const float* __restrict__ scores; + const int32_t* __restrict__ seq_lens; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int32_t* __restrict__ raw_indices; // optional: output raw abs position indices before page transform + const int64_t score_stride; + const int64_t page_table_stride; + uint32_t page_bits; +}; + +SGL_DEVICE uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +SGL_DEVICE uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + int32_t* __restrict__ raw_indices, // optional: output raw abs position indices + const uint32_t length, + const uint32_t page_bits) { + static_assert(kTopK <= kTopKBlockSize); + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + if (raw_indices != nullptr) { + raw_indices[tx] = tx; + } + } else if (kTopK == kTopKBlockSize || tx < kTopK) { + indices[tx] = -1; // fill invalid indices to -1 + if (raw_indices != nullptr) { + raw_indices[tx] = -1; + } + } +} + +[[maybe_unused]] +SGL_DEVICE void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, const uint32_t length) { + constexpr uint32_t RADIX = 256; + constexpr uint32_t BLOCK_SIZE = kTopKBlockSize; + constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t)); + + alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32]; + alignas(128) __shared__ uint32_t s_counter; + alignas(128) __shared__ uint32_t s_threshold_bin_id; + alignas(128) __shared__ uint32_t s_num_input[2]; + alignas(128) __shared__ int32_t s_last_remain; + + extern __shared__ uint32_t s_input_idx[][kSMEM / (2 * sizeof(int32_t))]; + + const uint32_t tx = threadIdx.x; + uint32_t remain_topk = kTopK; + auto& s_histogram = _s_histogram_buf[0]; + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int32_t i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = _s_histogram_buf[k][tx]; + if (tx + j < RADIX) { + value += _s_histogram_buf[k][tx + j]; + } + _s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + if (remain_topk == 0) { + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const uint32_t bin = convert_to_uint8(input[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw_input = input[idx]; + const uint32_t bin = convert_to_uint8(raw_input); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (pos < SMEM_INPUT_SIZE) { + [[likely]] s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto raw_num_input = s_num_input[r_idx]; + const auto num_input = raw_num_input < SMEM_INPUT_SIZE ? raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = remain_topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + + if (remain_topk == 0) { + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + output[kTopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < SMEM_INPUT_SIZE) { + /// NOTE: (dark) fuse the histogram computation here + [[likely]] s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void topk_1024_transform(const __grid_constant__ TopK1024Params params) { + const auto &[ + scores, seq_lens, page_table, page_indices, raw_indices, // pointers + score_stride, page_table_stride, page_bits // sizes + ] = params; + const uint32_t work_id = blockIdx.x; + + /// NOTE: dangerous prefetch seq_len before PDL wait + const uint32_t seq_len = seq_lens[work_id]; + const auto score_ptr = scores + work_id * score_stride; + const auto page_ptr = page_table + work_id * page_table_stride; + const auto indices_ptr = page_indices + work_id * kTopK; + const auto raw_indices_ptr = raw_indices != nullptr ? raw_indices + work_id * kTopK : nullptr; + + device::PDLWaitPrimary(); + + if (seq_len <= kTopK) { + naive_transform(score_ptr, page_ptr, indices_ptr, raw_indices_ptr, seq_len, page_bits); + } else { + __shared__ int32_t s_topk_indices[kTopK]; + radix_topk(score_ptr, s_topk_indices, seq_len); + static_assert(kTopK <= kTopKBlockSize); + const auto tx = threadIdx.x; + if (kTopK == kTopKBlockSize || tx < kTopK) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], page_bits); + if (raw_indices_ptr != nullptr) { + raw_indices_ptr[tx] = s_topk_indices[tx]; + } + } + } + + device::PDLTriggerSecondary(); +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +template +struct TopK1024Kernel { + static constexpr auto kernel = topk_1024_transform; + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional raw_indices) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, 1024}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + int32_t* raw_indices_ptr = nullptr; + if (raw_indices.has_value()) { + TensorMatcher({B, 1024}) // optional raw indices output, must be contiguous + .with_dtype() + .with_device(device) + .verify(raw_indices.value()); + raw_indices_ptr = static_cast(raw_indices.value().data_ptr()); + } + + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopK1024Params{ + .scores = static_cast(scores.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .raw_indices = raw_indices_ptr, + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + constexpr auto kSMEM_ = kSMEM + sizeof(int32_t); // align up a little + setup_kernel_smem_once(); + LaunchKernel(batch_size, kTopKBlockSize, device.unwrap(), kSMEM_).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh new file mode 100644 index 000000000000..8c4a526575ea --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh @@ -0,0 +1,493 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace { + +#ifndef SGL_TOPK +#define SGL_TOPK 512 +#endif + +inline constexpr uint32_t K = SGL_TOPK; + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +namespace impl = device::top512; +using Large = impl::ClusterTopK; +using Medium = impl::StreamingTopK; +using Small = impl::RegisterTopK; + +using Metadata = Large::Metadata; +constexpr uint32_t kBlockSize = impl::kBlockSize; +constexpr uint32_t kNumClusters = 15; // based on hardware limits +constexpr uint32_t kClusterSize = Large::kClusterSize; +constexpr uint32_t kMax2PassLength = Small::kMax2PassLength; +constexpr uint32_t kMaxSupportedLength = Large::kMaxLength; + +/// Common metadata lives at metadata[0] (first row of the [batch_size+1, 4] tensor). +/// Per-item metadata starts at metadata[1..batch_size]. The plan kernel writes both. +struct alignas(16) GlobalMetadata { + uint32_t cluster_threshold; // decided per-batch in plan kernel + uint32_t num_cluster_items; // N = number of items routed to the cluster path + uint32_t reserved[2]; +}; +static_assert(sizeof(GlobalMetadata) == sizeof(Metadata), "layout: row 0 must occupy one Metadata-sized slot"); + +// optimize occupancy for prefill +#define SMALL_TOPK_KERNEL __global__ __launch_bounds__(kBlockSize, 2) +// cluster at y dim +#define LARGE_CLUSTER __cluster_dims__(1, kClusterSize, 1) +// stage-1 is persistent cluster, and shared memory usage is huge (can not 2) +#define LARGE_TOPK_STAGE_1 __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// stage-2 is non-persistent non-cluster, with less shared memory and higher occupancy +#define LARGE_TOPK_STAGE_2 __global__ __launch_bounds__(kBlockSize, 2) +// fused into 1 stage when batch-size <= kNumPersistentClusters +#define FUSED_COMBINE_KERNEL __global__ __launch_bounds__(kBlockSize, 1) LARGE_CLUSTER +// plan runs once as a single block before the combine kernels +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct TopKParams { + const uint32_t* __restrict__ seq_lens; + const float* __restrict__ scores; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int64_t score_stride; + int64_t page_table_stride; + uint8_t* __restrict__ workspace; // [batch, kWorkspaceBytes] -- internally allocated + /// Pointer to the full metadata tensor: metadata[0] is GlobalMetadata, metadata[1..] + /// are per-item entries (at most kNumClusters * rounds of them). + const Metadata* __restrict__ metadata = nullptr; + int64_t workspace_stride; // bytes per batch + uint32_t batch_size; + uint32_t page_bits; + + SGL_DEVICE const float* get_scores(const uint32_t batch_id) const { + return scores + batch_id * score_stride; + } + SGL_DEVICE impl::TransformParams get_transform(const uint32_t batch_id, int32_t* indices) const { + return { + .page_table = page_table + batch_id * page_table_stride, + .indices_in = indices, + .indices_out = page_indices + batch_id * K, + .page_bits = page_bits, + }; + } + SGL_DEVICE const GlobalMetadata& get_global_metadata() const { + return *reinterpret_cast(metadata); + } + SGL_DEVICE const Metadata& get_item_metadata(uint32_t work_id) const { + return metadata[1 + work_id]; // +1 to skip the GlobalMetadata row + } +}; + +SGL_DEVICE uint2 partition_work(uint32_t length, uint32_t rank) { + constexpr uint32_t kTMAAlign = 4; + const auto total_units = (length + kTMAAlign - 1) / kTMAAlign; + const auto base = total_units / kClusterSize; + const auto extra = total_units % kClusterSize; + const auto local_units = base + (rank < extra ? 1u : 0u); + const auto offset_units = rank * base + min(rank, extra); + const auto offset = offset_units * kTMAAlign; + const auto finish = min(offset + local_units * kTMAAlign, length); + return {offset, finish - offset}; +} + +/// Persistent scheduler. A single block: +/// 1. Decides a cluster_threshold from the real seq_lens distribution (or +/// uses the caller-supplied `static_cluster_threshold` when non-zero). +/// 2. Writes that threshold + N into metadata[0] (the GlobalMetadata row). +/// 3. Compacts items with seq_len > threshold into metadata[1..N+1), laid out +/// to match the persistent consumer's round-robin stride (kNumClusters). +/// Entries for clusters that get no work are zero-filled. +PLAN_KERNEL void topk_plan( + const uint32_t* __restrict__ seq_lens, + Metadata* __restrict__ metadata, + const uint32_t batch_size, + const uint32_t static_cluster_threshold) { + // Candidate thresholds, strictly increasing. Picked to give the auto-heuristic + // reasonable granularity without needing a full sort. Must all be >= kMax2PassLength. + + struct Pair { + uint32_t threshold; + uint32_t max_batch_size; + }; + /// NOTE: only tuned on B200 + constexpr Pair kCandidates[] = { + {32768, 30}, + {40960, 45}, + {49152, 45}, + {65536, 60}, + {98304, 60}, + {131072, 75}, + {196608, 90}, + {262144, 105}, + }; + constexpr uint32_t kNumCandidates = std::size(kCandidates); + constexpr uint32_t kMinBatchSize = kCandidates[0].max_batch_size; + static_assert(kCandidates[0].threshold == kMax2PassLength); + static_assert(kCandidates[kNumCandidates - 1].threshold == kMaxSupportedLength); + + __shared__ uint32_t s_count; // final N after compaction + __shared__ uint32_t s_counts[kNumCandidates]; + __shared__ uint32_t s_threshold; + + const auto tx = threadIdx.x; + if (tx == 0) s_count = 0; + if (tx < kNumCandidates) s_counts[tx] = 0; + __syncthreads(); + + // --- Phase 1: decide threshold ------------------------------------------ + if (static_cluster_threshold > 0) { + if (tx == 0) s_threshold = static_cluster_threshold; + } else if (batch_size <= kMinBatchSize) { + if (tx == 0) s_threshold = kMax2PassLength; // always prefer cluster + } else { + // Count items above each candidate threshold. Monotonically non-increasing in T. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + assert(sl <= kMaxSupportedLength); + uint32_t count = 0; +#pragma unroll + for (uint32_t j = 0; j < kNumCandidates; ++j) { + count += (sl > kCandidates[j].threshold ? 1 : 0); + } + if (count > 0) { + atomicAdd(&s_counts[count - 1], 1); + } + } + __syncthreads(); + if (tx == 0) { + uint32_t accum = 0; + uint32_t chosen = kMaxSupportedLength; +#pragma unroll + for (uint32_t i = 0; i < kNumCandidates; ++i) { + const auto j = kNumCandidates - 1 - i; + accum += s_counts[j]; + /// NOTE: `accum` increasing, while `max_batch_size` decreasing + if (accum > kCandidates[j].max_batch_size) break; + chosen = kCandidates[j].threshold; + } + s_threshold = chosen; + } + } + __syncthreads(); + // sanity check: below 2 pass threshold, must fits in small path + const auto cluster_threshold = max(s_threshold, kMax2PassLength); + + // --- Phase 2: compact items with seq_len > threshold into metadata[1..] - + // Per-item rows live at metadata[1 + pos]; metadata[0] is the GlobalMetadata row. + for (uint32_t i = tx; i < batch_size; i += kBlockSize) { + const uint32_t sl = seq_lens[i]; + if (sl > cluster_threshold) { + const auto pos = atomicAdd(&s_count, 1); + metadata[1 + pos] = {i, sl, false}; + } + } + __syncthreads(); + const auto N = s_count; + + // --- Phase 3: has_next + sentinels + GlobalMetadata --------------------- + for (uint32_t i = tx; i < N; i += kBlockSize) { + if (i + kNumClusters < N) metadata[1 + i].has_next = true; + } + // Zero-fill the first kNumClusters sentinel slots that got no valid entry. + if (tx < kNumClusters && tx >= N) metadata[1 + tx] = {0, 0, false}; + // Write global metadata (row 0). + if (tx == 0) { + auto* g = reinterpret_cast(metadata); + *g = { + .cluster_threshold = cluster_threshold, + .num_cluster_items = N, + .reserved = {0, 0}, + }; + } +} + +SMALL_TOPK_KERNEL void // short context +topk_short_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + // trivial case + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + device::PDLTriggerSecondary(); + Small::transform(transform); + } +} + +LARGE_TOPK_STAGE_1 void // long context, middle to large batch size +topk_combine_preprocess(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + uint32_t work_id = blockIdx.x; + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + uint32_t length; + uint32_t offset; + const auto cluster_rank = blockIdx.y; + + const auto prefetch_metadata = [&] { + const auto metadata = params.get_item_metadata(work_id); + batch_id = metadata.batch_id; + seq_len = metadata.seq_len; + has_next = metadata.has_next; + work_id += kNumClusters; // advance to the next item for this cluster + }; + const auto launch_prologue = [&] { + const auto partition = partition_work(seq_len, cluster_rank); + offset = partition.x; + length = partition.y; + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + }; + + device::PDLWaitPrimary(); + device::PDLTriggerSecondary(); + + prefetch_metadata(); + if (seq_len == 0) return; + Large::stage1_init(smem); + launch_prologue(); + while (true) { + const auto this_length = length; + const auto this_offset = offset; + const auto need_prefetch = has_next; + const auto transform = params.get_transform(batch_id, s_topk_indices); + const auto ws = params.workspace + batch_id * params.workspace_stride; + if (need_prefetch) prefetch_metadata(); + Large::stage1(s_topk_indices, this_length, smem, /*reuse=*/true); + if (need_prefetch) launch_prologue(); + Large::stage1_epilogue(transform, this_offset, ws, smem); + if (!need_prefetch) break; + } +} + +LARGE_TOPK_STAGE_2 void // long context, middle to large batch size +topk_combine_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto cluster_threshold = params.get_global_metadata().cluster_threshold; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= kMax2PassLength) { + if (seq_len <= Small::kMax1PassLength) { + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } else { + __syncwarp(); + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem); + } + Small::transform(transform); + } else if (seq_len <= cluster_threshold) { + Medium::run(params.get_scores(batch_id), seq_len, s_topk_indices, smem); + Medium::transform(transform, smem); + } else { + const auto ws = params.workspace + batch_id * params.workspace_stride; + device::PDLWaitPrimary(); + Large::transform(transform, ws, smem); + } +} + +FUSED_COMBINE_KERNEL void // long context, small batch size +topk_fused_transform(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + __shared__ int32_t s_topk_indices[K]; + const auto batch_id = blockIdx.x; + const auto cluster_rank = blockIdx.y; + const auto seq_len = params.seq_lens[batch_id]; + const auto transform = params.get_transform(batch_id, s_topk_indices); + if (seq_len <= K) { + if (cluster_rank != 0) return; // only first rank work + impl::trivial_transform(transform, seq_len, K); + } else if (seq_len <= Small::kMax1PassLength) { + if (cluster_rank != 0) return; // only first rank work + Small::run(params.get_scores(batch_id), s_topk_indices, seq_len, smem, /*use_pdl=*/true); + Small::transform(transform); + } else { + const auto [offset, length] = partition_work(seq_len, cluster_rank); + const auto ws = params.workspace + batch_id * params.workspace_stride; + Large::stage1_init(smem); + device::PDLWaitPrimary(); + Large::stage1_prologue(params.get_scores(batch_id) + offset, length, smem); + Large::stage1(s_topk_indices, length, smem); + Large::stage1_epilogue(transform, offset, ws, smem); + cooperative_groups::this_cluster().sync(); + if (cluster_rank != 0) return; // only first rank do the stage-2 + Large::transform(transform, ws, smem); + } +} + +struct CombinedTopKKernel { + static constexpr auto kStage1SMEM = sizeof(Large::Smem) + 128; + static constexpr auto kStage2SMEM = std::max(sizeof(Small::Smem), sizeof(Medium::Smem)) + 128; + + static void plan( // + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView metadata, + const uint32_t static_cluster_threshold) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto batch_size = static_cast(B.unwrap()); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1); + if (batch_size <= kNumClusters) return; // metadata unused in fused path + + const auto device = device_.unwrap(); + constexpr auto kernel = topk_plan; + LaunchKernel(1, kBlockSize, device)( // + kernel, + static_cast(seq_lens.data_ptr()), + static_cast(metadata.data_ptr()), + batch_size, + static_cluster_threshold); + } + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::TensorView workspace, + const tvm::ffi::TensorView metadata) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto Bp1 = SymbolicSize{"batch_size_plus_1"}; + auto L = SymbolicSize{"max_seq_len"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto W = SymbolicSize{"workspace_stride"}; + constexpr auto D = Large::kWorkspaceInts; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, L}) // + .with_strides({S, 1}) + .with_dtype() + .with_device(device_) + .verify(scores); + TensorMatcher({B}) // + .with_dtype() + .with_device(device_) + .verify(seq_lens); + TensorMatcher({B, -1}) // + .with_strides({P, 1}) + .with_dtype() + .with_device(device_) + .verify(page_table); + TensorMatcher({B, K}) // + .with_dtype() + .with_device(device_) + .verify(page_indices); + TensorMatcher({B, D}) // + .with_strides({W, 1}) + .with_dtype() + .with_device(device_) + .verify(workspace); + TensorMatcher({Bp1, 4}) // + .with_dtype() + .with_device(device_) + .verify(metadata); + + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto max_seq_len = static_cast(L.unwrap()); + const auto device = device_.unwrap(); + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + RuntimeCheck(S.unwrap() % 4 == 0, "score_stride must be a multiple of 4 (TMA 16-byte alignment)"); + RuntimeCheck(Bp1.unwrap() == B.unwrap() + 1, "invalid metadata shape"); + + // NOTE: this should be fixed later + // RuntimeCheck(max_seq_len <= kMaxSupportedLength, max_seq_len, " exceeds the maximum supported length"); + + const auto params = TopKParams{ + .seq_lens = static_cast(seq_lens.data_ptr()), + .scores = static_cast(scores.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .workspace = static_cast(workspace.data_ptr()), + .metadata = static_cast(metadata.data_ptr()), + .workspace_stride = W.unwrap() * static_cast(sizeof(int32_t)), + .batch_size = batch_size, + .page_bits = page_bits, + }; + + if (max_seq_len <= Small::kMax1PassLength) { + // All items fit in the short path -- no stage-1 needed + constexpr auto kernel = topk_short_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel, params); + } else { + // Some items may be large -- launch stage-1 + main + if (batch_size <= kNumClusters) { + // can fuse into 1 stage + constexpr auto kernel = topk_fused_transform; + constexpr auto kSMEM = std::max(kStage1SMEM, kStage2SMEM); + setup_kernel_smem_once(); + LaunchKernel({batch_size, kClusterSize}, kBlockSize, device, kSMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel, params); + } else { + // stage 1 + stage 2 + constexpr auto kernel_stage_1 = topk_combine_preprocess; + setup_kernel_smem_once(); + const auto num_clusters = std::min(batch_size, kNumClusters); + LaunchKernel({num_clusters, kClusterSize}, kBlockSize, device, kStage1SMEM) + .enable_cluster({1, kClusterSize}) + .enable_pdl(true)(kernel_stage_1, params); + constexpr auto kernel_stage_2 = topk_combine_transform; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device, kStage2SMEM) // + .enable_pdl(true)(kernel_stage_2, params); + } + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h index bf7dcb202301..a06054990db0 100644 --- a/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h +++ b/python/sglang/jit_kernel/csrc/gemm/marlin_moe/marlin_template.h @@ -24,6 +24,7 @@ #include "../marlin/dequant.h" #include "../marlin/marlin.cuh" #include "../marlin/marlin_dtypes.cuh" +#include #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert( \ @@ -355,6 +356,7 @@ __global__ void Marlin( constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; constexpr bool is_int_type = w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + constexpr bool is_8bit_scale = s_type.size_bits() == 8; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = w_type == host::kFE4M3fn || w_type == host::kFE2M1f && s_type == host::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || @@ -368,7 +370,7 @@ __global__ void Marlin( static_assert(thread_m_blocks == 1 || !m_block_size_8); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / (w_type == host::kFE2M1f ? 16 : 8); + const int scales_expert_stride = prob_n * prob_k / group_size / (is_8bit_scale ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); const int b_bias_expert_stride = prob_n / 8; @@ -439,52 +441,69 @@ __global__ void Marlin( locks_off = (iters * blockIdx.x) / k_tiles - 1; } + int prob_m_top_k = prob_m * top_k; // read moe block data given block_id // block_sorted_ids / block_num_valid_tokens / block_topk_weights auto read_moe_block_data = [&](int block_id) { block_num_valid_tokens = moe_block_size; + + cp_async4_pred( + sh_block_sorted_ids_int4 + threadIdx.x, + reinterpret_cast(sorted_token_ids_ptr) + (block_id * moe_block_size / 4 + threadIdx.x), + threadIdx.x < moe_block_size / 4); + + cp_async_fence(); + cp_async_wait<0>(); + + __syncthreads(); + + if (threadIdx.x >= threads - 32) { + constexpr int size_per_thread = div_ceil(moe_block_size, 32); + int lane_id = threadIdx.x - (threads - 32); + + int local_count = 0; #pragma unroll - for (int i = 0; i < moe_block_size / 4; i++) { - int4 sorted_token_ids_int4 = - reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + i]; - int* sorted_token_ids = reinterpret_cast(&sorted_token_ids_int4); -#pragma unroll - for (int j = 0; j < 4; j++) { - if (sorted_token_ids[j] >= prob_m * top_k) { - block_num_valid_tokens = i * 4 + j; - break; + for (int i = 0; i < size_per_thread; i++) { + int j = lane_id * size_per_thread + i; + if (j < moe_block_size) { + int idx = sh_block_sorted_ids[j]; + if (idx < prob_m_top_k) local_count++; } } - if (block_num_valid_tokens != moe_block_size) break; - } - __syncthreads(); - int tid4 = threadIdx.x / 4; - if (threadIdx.x % 4 == 0 && threadIdx.x < block_num_valid_tokens) { - sh_block_sorted_ids_int4[tid4] = - reinterpret_cast(sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + if constexpr (moe_block_size >= 16) local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 16); + if constexpr (moe_block_size >= 8) local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 8); + if constexpr (moe_block_size >= 4) local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 4); + if constexpr (moe_block_size >= 2) local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 2); -#pragma unroll - for (int i = 0; i < 4; i++) - sh_rd_block_sorted_ids[tid4 * 4 + i] = sh_block_sorted_ids[tid4 * 4 + i] / top_k; + local_count += __shfl_down_sync(0xFFFFFFFF, local_count, 1); + block_num_valid_tokens = local_count; +#else + block_num_valid_tokens = __reduce_add_sync(0xffffffff, local_count); +#endif + + if (lane_id == 0) reinterpret_cast(sh_new)[0] = block_num_valid_tokens; + } + + if (threadIdx.x < moe_block_size) { + int idx = sh_block_sorted_ids[threadIdx.x]; + sh_rd_block_sorted_ids[threadIdx.x] = idx / top_k; if (mul_topk_weights) { -#pragma unroll - for (int i = 0; i < 4; i++) { - int idx = tid4 * 4 + i; - // idx = idx < block_num_valid_tokens ? idx : 0; - if (idx < block_num_valid_tokens) { - if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) { - sh_block_topk_weights[idx] = - __hmul2(global_scale, Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]]))); - } else { - sh_block_topk_weights[idx] = - Dtype::num2num2(Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); - } - } + idx = idx < prob_m_top_k ? idx : 0; + scalar_t topk_weight_tmp = Dtype::float2num(topk_weights_ptr[idx]); + if constexpr (w_type == host::kFE2M1f && s_type == host::kFE4M3fn) { + sh_block_topk_weights[threadIdx.x] = __hmul2(global_scale, Dtype::num2num2(topk_weight_tmp)); + } else { + sh_block_topk_weights[threadIdx.x] = Dtype::num2num2(topk_weight_tmp); } } } + + __syncthreads(); + + block_num_valid_tokens = reinterpret_cast(sh_new)[0]; __syncthreads(); }; @@ -626,11 +645,10 @@ __global__ void Marlin( constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) - : 1; + int s_gl_stride = prob_n / (is_8bit_scale ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (is_8bit_scale ? 16 : 8); + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks ? thread_k_blocks / group_blocks : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -681,13 +699,15 @@ __global__ void Marlin( if constexpr (!has_act_order) { if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + - s_sh_stride * slice_col + threadIdx.x; + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; } } auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; // Zero-points int zp_gl_rd; @@ -705,15 +725,7 @@ __global__ void Marlin( // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - int warp_row = warp_id / n_warps; - - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - s_sh_rd = s_sh_rd * 2 + (warp_row / group_blocks) % 2; - - } else if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; @@ -907,43 +919,21 @@ __global__ void Marlin( } else { if constexpr (group_blocks != -1) { int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta * s_tb_groups; } } if constexpr (has_zp && group_blocks != -1) { int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; } } } @@ -1021,35 +1011,32 @@ __global__ void Marlin( } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - if (k % b_sh_wr_iters == 0) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } } } else { auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; - int cur_k = warp_row * 16; cur_k += k_iter_size * (k % b_sh_wr_iters); - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + int cur_group_id = k_blocks / group_blocks; int4* sh_s_stage = sh_s + s_sh_stage * pipe; - if constexpr (w_type_id != host::kFE2M1f.id()) { + if constexpr (!is_8bit_scale) { reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } else if constexpr (group_blocks == 1 || thread_k_blocks > 4) { - reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } else { reinterpret_cast(&frag_s[k % 2])[0] = - reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride) + k % 2]; + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; } } } @@ -1243,17 +1230,16 @@ __global__ void Marlin( } } - // Commented out FP4/FP8 scale dequantization since we don't generate - // kFE2M1f kernels to reduce compilation time - // if constexpr (w_type == host::kFE2M1f) { - // int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; - // int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; - // - // dequant_fp8_scales( - // s_quant_0, reinterpret_cast(&frag_s[k2])); - // dequant_fp8_scales( - // s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); - // } + // FP4/FP8 scale dequantization (E4M3 for NVFP4 and E8M0 for MXFP4). + if constexpr ( + (s_type == host::kFE4M3fn || s_type == host::kFE8M0fnu) && + !(std::is_same::value && s_type == host::kFE8M0fnu)) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. @@ -1882,8 +1868,20 @@ __global__ void Marlin( slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = + s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = + zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / zp_sh_stride) + + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } } start_pipes(); } diff --git a/python/sglang/jit_kernel/csrc/hisparse.cuh b/python/sglang/jit_kernel/csrc/hisparse.cuh index 998ce2e25d0d..15da350a4e24 100644 --- a/python/sglang/jit_kernel/csrc/hisparse.cuh +++ b/python/sglang/jit_kernel/csrc/hisparse.cuh @@ -3,6 +3,8 @@ #include +#include + #include #include @@ -81,10 +83,21 @@ struct SmemLayout { }; // Each block processes one request -// req_pool_indices are int64_t (pool indices can be large), seq_lens can be int32_t or int64_t +// req_pool_indices and seq_lens can each be int32_t or int64_t // Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token] // newest_slot is at HOT_BUFFER_SIZE (first position of extra page) -template +// +// IsDsv4Layout selects the miss-copy addressing: +// false -> generic byte-stride: device + host both linear, stride = item_size_bytes +// true -> DSv4 page-padded device + linear host (kvcacheio.cuh hardcoded constants) +template < + int BLOCK_SIZE, + int NUM_TOP_K, + int HOT_BUFFER_SIZE, + bool IsMLA, + bool IsDsv4Layout, + typename SeqLensT, + typename ReqPoolIndicesT> __global__ void load_cache_to_device_buffer_kernel( const int32_t* __restrict__ top_k_tokens, int32_t* __restrict__ device_buffer_tokens, @@ -95,7 +108,7 @@ __global__ void load_cache_to_device_buffer_kernel( void* __restrict__ device_buffer_k, void* __restrict__ device_buffer_v, int32_t* __restrict__ top_k_device_locs, - const int64_t* __restrict__ req_pool_indices, + const ReqPoolIndicesT* __restrict__ req_pool_indices, const SeqLensT* __restrict__ seq_lens, int16_t* __restrict__ lru_slots, const int32_t* __restrict__ num_real_reqs, @@ -106,6 +119,7 @@ __global__ void load_cache_to_device_buffer_kernel( int64_t top_k_device_locs_stride, int64_t page_size, int64_t item_size_bytes) { + static_assert(!IsDsv4Layout || IsMLA, "DSv4 page-padded layout is K-only (MLA)."); // todo hisparse: support page wise sparsity constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE; constexpr int NUM_TOKEN_CHUNKS = (NUM_TOP_K + WARP_SIZE - 1) / WARP_SIZE; @@ -157,16 +171,16 @@ __global__ void load_cache_to_device_buffer_kernel( int32_t* s_chunk_offset = s_top_k_tokens + NUM_TOP_K; // Prefix-sum offsets for evictable counting int32_t* s_evict_chunk_offset = s_chunk_offset + (NUM_BUFFER_CHUNKS + 1); - // Open-addressing hash table: top-k token_id → top-k index (keys) + // Open-addressing hash table: top-k token_id -> top-k index (keys) int32_t* s_hash_keys = s_evict_chunk_offset + (NUM_BUFFER_CHUNKS + 1); // Scalar counters int32_t& s_total_hits = s_hash_keys[HASH_SIZE]; int32_t& s_newest_hit = s_hash_keys[HASH_SIZE + 1]; int16_t* smem_i16 = reinterpret_cast(smem_i32 + Layout::TOTAL_INT32); - // Compacted slot ordering: [hits fwd→ ... ←evictables bwd] + // Compacted slot ordering: [hits fwd-> ... <-evictables bwd] int16_t* s_lru_slots_out = smem_i16; - // Open-addressing hash table: top-k token_id → top-k index (values) + // Open-addressing hash table: top-k token_id -> top-k index (values) int16_t* s_hash_vals = s_lru_slots_out + HOT_BUFFER_SIZE; // Initialize shared memory: counters, hash table, prefix-sum offsets. @@ -362,19 +376,30 @@ __global__ void load_cache_to_device_buffer_kernel( const int64_t src_loc = req_host_cache_locs[miss_token]; const int64_t dst_loc = static_cast(req_device_buffer_locs[evict_slot]); - const auto src_k = static_cast(host_cache_k) + src_loc * item_size_bytes; - auto dst_k = static_cast(device_buffer_k) + dst_loc * item_size_bytes; - transfer_item_warp(lane_id, src_k, dst_k, item_size_bytes); - - if constexpr (!IsMLA) { - const auto src_v = static_cast(host_cache_v) + src_loc * item_size_bytes; - auto dst_v = static_cast(device_buffer_v) + dst_loc * item_size_bytes; - transfer_item_warp(lane_id, src_v, dst_v, item_size_bytes); + if constexpr (IsDsv4Layout) { + // DSv4 path: page-padded device layout + linear host layout, K-only. + // Uses kvcacheio.cuh's hardcoded constants (kGPUPageSize=64, kCPUItemBytes=584). + device::hisparse::transfer_item( + /*dst_cache=*/device_buffer_k, + /*src_cache=*/const_cast(host_cache_k), + /*dst_index=*/static_cast(dst_loc), + /*src_index=*/static_cast(src_loc)); + } else { + // Generic path: device + host both linear, stride = item_size_bytes. + const auto src_k = static_cast(host_cache_k) + src_loc * item_size_bytes; + auto dst_k = static_cast(device_buffer_k) + dst_loc * item_size_bytes; + transfer_item_warp(lane_id, src_k, dst_k, item_size_bytes); + + if constexpr (!IsMLA) { + const auto src_v = static_cast(host_cache_v) + src_loc * item_size_bytes; + auto dst_v = static_cast(device_buffer_v) + dst_loc * item_size_bytes; + transfer_item_warp(lane_id, src_v, dst_v, item_size_bytes); + } } } } -template +template void load_cache_to_device_buffer( tvm::ffi::TensorView top_k_tokens, tvm::ffi::TensorView device_buffer_tokens, @@ -401,9 +426,9 @@ void load_cache_to_device_buffer( const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0]; const auto device = LaunchKernel::resolve_device(top_k_tokens.device()); - // Generic lambda: both int32 and int64 kernel variants are compiled; - // the correct one is selected at runtime based on seq_lens dtype. - auto launch = [&](auto kernel_fn, const auto* seq_lens_ptr) { + // Generic lambda: int32/int64 kernel variants are compiled for both + // seq_lens and req_pool_indices; the correct combo is selected at runtime. + auto launch = [&](auto kernel_fn, const auto* seq_lens_ptr, const auto* req_pool_indices_ptr) { constexpr size_t smem_bytes = SmemLayout::BYTES; if constexpr (smem_bytes > 48u * 1024u) { cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); @@ -419,7 +444,7 @@ void load_cache_to_device_buffer( device_buffer_k.data_ptr(), (IsMLA || device_buffer_v.ndim() == 0) ? (void*)nullptr : device_buffer_v.data_ptr(), static_cast(top_k_device_locs.data_ptr()), - static_cast(req_pool_indices.data_ptr()), + req_pool_indices_ptr, seq_lens_ptr, static_cast(lru_slots.data_ptr()), static_cast(num_real_reqs.data_ptr()), @@ -432,15 +457,59 @@ void load_cache_to_device_buffer( item_size_bytes); }; - const auto dtype = seq_lens.dtype(); - if (dtype.code == kDLInt && dtype.bits == 64) { + const auto seq_dtype = seq_lens.dtype(); + const auto rpi_dtype = req_pool_indices.dtype(); + const bool seq_is_i64 = (seq_dtype.code == kDLInt && seq_dtype.bits == 64); + const bool rpi_is_i64 = (rpi_dtype.code == kDLInt && rpi_dtype.bits == 64); + + if (seq_is_i64 && rpi_is_i64) { + launch( + load_cache_to_device_buffer_kernel< + BLOCK_SIZE, + NUM_TOP_K, + HOT_BUFFER_SIZE, + IsMLA, + IsDsv4Layout, + int64_t, + int64_t>, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); + } else if (seq_is_i64 && !rpi_is_i64) { + launch( + load_cache_to_device_buffer_kernel< + BLOCK_SIZE, + NUM_TOP_K, + HOT_BUFFER_SIZE, + IsMLA, + IsDsv4Layout, + int64_t, + int32_t>, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); + } else if (!seq_is_i64 && rpi_is_i64) { launch( - load_cache_to_device_buffer_kernel, - static_cast(seq_lens.data_ptr())); + load_cache_to_device_buffer_kernel< + BLOCK_SIZE, + NUM_TOP_K, + HOT_BUFFER_SIZE, + IsMLA, + IsDsv4Layout, + int32_t, + int64_t>, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); } else { launch( - load_cache_to_device_buffer_kernel, - static_cast(seq_lens.data_ptr())); + load_cache_to_device_buffer_kernel< + BLOCK_SIZE, + NUM_TOP_K, + HOT_BUFFER_SIZE, + IsMLA, + IsDsv4Layout, + int32_t, + int32_t>, + static_cast(seq_lens.data_ptr()), + static_cast(req_pool_indices.data_ptr())); } } diff --git a/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh b/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh new file mode 100644 index 000000000000..6476a3be232f --- /dev/null +++ b/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh @@ -0,0 +1,363 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +constexpr uint32_t kWarpSize = 32; +constexpr uint32_t kWarpsPerCTA = 6; +constexpr uint32_t kSmallTokenThreshold = 512; +constexpr uint32_t kMaxExperts = 512; +constexpr uint32_t kMaxTopK = 16; + +enum class ScoringFunc : uint32_t { + kSigmoid = 0, + kSqrtSoftplus = 1, +}; + +struct MoEFusedGateParams { + const float* __restrict__ input; + const float* __restrict__ bias; + float* __restrict__ output; + int32_t* __restrict__ indices; + uint32_t num_rows; + uint32_t num_experts; + uint32_t topk; + uint32_t num_fused_shared_experts; + bool renormalize; + float routed_scaling_factor; + bool apply_routed_scaling_factor_on_output; +}; + +template +__device__ __forceinline__ float compute_score(float x) { + if constexpr (kScoringFunc == ScoringFunc::kSigmoid) { + // sigmoid(x) = 1 / (1 + exp(-x)) + return 1.0f / (1.0f + expf(-x)); + } else { + // sqrt(softplus(x)) = sqrt(log(1 + exp(x))) + float softplus = log1pf(expf(x)); + return sqrtf(softplus); + } +} + +template +__global__ void moe_fused_gate_kernel_small_token(const MoEFusedGateParams __grid_constant__ params) { + const auto& [input, bias, output, indices, num_rows, num_experts, topk, num_fused_shared_experts, renormalize, routed_scaling_factor, apply_routed_scaling_factor_on_output] = + params; + + uint32_t row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + // number of routed experts to select (excluding fused shared experts) + const uint32_t topk_routed = topk - num_fused_shared_experts; + + uint32_t tid = threadIdx.x; + uint32_t warp_id = tid / kWarpSize; + uint32_t lane_id = tid % kWarpSize; + + extern __shared__ float shared_mem[]; + float* shared_scores = shared_mem; + float* shared_original_scores = shared_mem + num_experts; + + // For warp-level reduction + __shared__ float warp_maxs[kWarpsPerToken]; + __shared__ int warp_experts[kWarpsPerToken]; + __shared__ int selected_experts[kMaxTopK]; + + for (uint32_t e = tid; e < num_experts; e += blockDim.x) { + float input_val = input[row_idx * num_experts + e]; + float bias_val = bias[e]; + float score_val = compute_score(input_val); + float biased_val = score_val + bias_val; + shared_scores[e] = biased_val; + shared_original_scores[e] = score_val; + } + + __syncthreads(); + + // only select topk_routed experts (excluding shared experts) + for (uint32_t k = 0; k < topk_routed; k++) { + float my_val = -FLT_MAX; + int my_expert = -1; + for (uint32_t e = tid; e < num_experts; e += blockDim.x) { + if (shared_scores[e] > my_val) { + my_val = shared_scores[e]; + my_expert = e; + } + } + + float warp_max_val = my_val; + int warp_max_expert = my_expert; + +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, warp_max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, warp_max_expert, offset); + if (other_val > warp_max_val) { + warp_max_val = other_val; + warp_max_expert = other_expert; + } + } + + if (lane_id == 0 && warp_id < kWarpsPerToken) { + warp_maxs[warp_id] = warp_max_val; + warp_experts[warp_id] = warp_max_expert; + } + + __syncthreads(); + + if (warp_id == 0) { + float final_max = (lane_id < kWarpsPerToken) ? warp_maxs[lane_id] : -FLT_MAX; + int final_expert = (lane_id < kWarpsPerToken) ? warp_experts[lane_id] : -1; + +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, final_max, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, final_expert, offset); + if (other_val > final_max) { + final_max = other_val; + final_expert = other_expert; + } + } + + if (lane_id == 0) { + selected_experts[k] = final_expert; + } + } + + __syncthreads(); + + int selected = selected_experts[k]; + if (selected >= 0 && tid == 0) { + shared_scores[selected] = -FLT_MAX; + } + + __syncthreads(); + } + + static_assert(kMaxTopK <= device::kWarpThreads); + if (tid >= device::kWarpThreads) return; + + // only use the first warp to perform write to global operation + float routed_weight = 0.0f; + int32_t selected_expert = 0; + if (tid < topk_routed) { + int expert_id = selected_experts[tid]; + float score = shared_original_scores[expert_id]; + if (expert_id >= 0 && expert_id < static_cast(num_experts)) { + routed_weight = score; + selected_expert = expert_id; + } + } + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (tid < topk) { + const bool is_shared = tid >= topk_routed; + const auto output_offset = row_idx * topk + tid; + const auto weight = is_shared ? (routed_sum / routed_scaling_factor) : routed_weight; + const auto expert_id = is_shared ? (num_experts + tid - topk_routed) : selected_expert; + const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f; + const auto norm = renormalize && routed_sum > 0.0f ? routed_sum : 1.0f; + output[output_offset] = weight / norm * scale; + indices[output_offset] = expert_id; + } +} + +template +__global__ void moe_fused_gate_kernel(const MoEFusedGateParams __grid_constant__ params) { + const auto& [input, bias, output, indices, num_rows, num_experts, topk, num_fused_shared_experts, renormalize, routed_scaling_factor, apply_routed_scaling_factor_on_output] = + params; + + uint32_t row_idx = blockIdx.x * kWarpsPerCTA + threadIdx.y; + if (row_idx >= num_rows) return; + + // number of routed experts to select (excluding fused shared experts) + const uint32_t topk_routed = topk - num_fused_shared_experts; + + uint32_t lane_id = threadIdx.x; + uint32_t warp_id = threadIdx.y; + + extern __shared__ float shared_mem[]; + float* shared_scores = shared_mem + warp_id * num_experts * 2; + float* shared_original_scores = shared_scores + num_experts; + __shared__ int selected_experts[kWarpsPerCTA][kMaxTopK]; + int* warp_selected_experts = selected_experts[warp_id]; + + for (uint32_t e = lane_id; e < num_experts; e += kWarpSize) { + float input_val = input[row_idx * num_experts + e]; + float bias_val = bias[e]; + float score_val = compute_score(input_val); + float biased_val = score_val + bias_val; + shared_scores[e] = biased_val; + shared_original_scores[e] = score_val; + } + + __syncwarp(); + + // only select topk_routed experts + for (uint32_t k = 0; k < topk_routed; k++) { + float max_val = -FLT_MAX; + int max_expert = -1; + + for (uint32_t expert = lane_id; expert < num_experts; expert += kWarpSize) { + if (shared_scores[expert] > max_val) { + max_val = shared_scores[expert]; + max_expert = expert; + } + } + + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, max_expert, offset); + + if (other_val > max_val || (other_val == max_val && other_expert < max_expert)) { + max_val = other_val; + max_expert = other_expert; + } + } + + if (lane_id == 0) { + warp_selected_experts[k] = max_expert; + if (max_expert != -1) { + shared_scores[max_expert] = -FLT_MAX; + } + } + + __syncwarp(); + } + + static_assert(kMaxTopK <= device::kWarpThreads); + + float routed_weight = 0.0f; + int32_t selected_expert = 0; + if (lane_id < topk_routed) { + int expert_id = warp_selected_experts[lane_id]; + if (expert_id >= 0 && expert_id < static_cast(num_experts)) { + routed_weight = shared_original_scores[expert_id]; + selected_expert = expert_id; + } + } + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk) { + const bool is_shared = lane_id >= topk_routed; + const auto output_idx = row_idx * topk + lane_id; + const auto weight = is_shared ? (routed_sum / routed_scaling_factor) : routed_weight; + const auto expert_id = is_shared ? (num_experts + lane_id - topk_routed) : selected_expert; + const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f; + const auto norm = renormalize && routed_sum > 0.0f ? routed_sum : 1.0f; + output[output_idx] = weight / norm * scale; + indices[output_idx] = expert_id; + } +} + +template +void dispatch_small_token_kernel( + uint32_t num_rows, + uint32_t threads_per_block, + uint32_t warps_per_token, + DLDevice device, + size_t smem_per_row, + const MoEFusedGateParams& params) { + using namespace host; + if (warps_per_token <= 8) { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<8, kScoringFunc>, params); + } else if (warps_per_token <= 12) { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<12, kScoringFunc>, params); + } else { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<16, kScoringFunc>, params); + } +} + +struct MoEFusedGateKernel { + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView bias, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView indices, + uint32_t topk, + uint32_t scoring_func, // 0 = sigmoid, 1 = sqrtsoftplus + uint32_t num_fused_shared_experts, + bool renormalize, + float routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { + using namespace host; + + auto N = SymbolicSize{"num_rows"}; + auto E = SymbolicSize{"num_experts"}; + auto K = SymbolicSize{"topk"}; + auto device = SymbolicDevice{}; + K.set_value(topk); + device.set_options(); + + TensorMatcher({N, E}).with_dtype().with_device(device).verify(input); + TensorMatcher({E}).with_dtype().with_device(device).verify(bias); + TensorMatcher({N, K}).with_dtype().with_device(device).verify(output); + TensorMatcher({N, K}).with_dtype().with_device(device).verify(indices); + + const auto num_rows = static_cast(N.unwrap()); + const auto num_experts = static_cast(E.unwrap()); + + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum supported value"); + RuntimeCheck(scoring_func <= 1, "scoring_func must be 0 (sigmoid) or 1 (sqrtsoftplus)"); + RuntimeCheck(topk > num_fused_shared_experts, "topk must be greater than num_fused_shared_experts"); + + const auto params = MoEFusedGateParams{ + .input = static_cast(input.data_ptr()), + .bias = static_cast(bias.data_ptr()), + .output = static_cast(output.data_ptr()), + .indices = static_cast(indices.data_ptr()), + .num_rows = num_rows, + .num_experts = num_experts, + .topk = topk, + .num_fused_shared_experts = num_fused_shared_experts, + .renormalize = renormalize, + .routed_scaling_factor = routed_scaling_factor, + .apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output, + }; + + const size_t smem_per_row = 2 * num_experts * sizeof(float); + + bool use_small_token_kernel = num_rows <= kSmallTokenThreshold; + + if (use_small_token_kernel) { + // 1 token per block + uint32_t warps_per_token = div_ceil(num_experts, kWarpSize); + warps_per_token = std::min(warps_per_token, 16u); + uint32_t threads_per_block = warps_per_token * kWarpSize; + + if (scoring_func == 0) { + dispatch_small_token_kernel( + num_rows, threads_per_block, warps_per_token, device.unwrap(), smem_per_row, params); + } else { + dispatch_small_token_kernel( + num_rows, threads_per_block, warps_per_token, device.unwrap(), smem_per_row, params); + } + } else { + // multiple tokens per block + uint32_t num_blocks = div_ceil(num_rows, kWarpsPerCTA); + dim3 block_dim(kWarpSize, kWarpsPerCTA); + size_t large_smem = smem_per_row * kWarpsPerCTA; + + if (scoring_func == 0) { + LaunchKernel(num_blocks, block_dim, device.unwrap(), large_smem)( + moe_fused_gate_kernel, params); + } else { + LaunchKernel(num_blocks, block_dim, device.unwrap(), large_smem)( + moe_fused_gate_kernel, params); + } + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/deepseek_v4.py b/python/sglang/jit_kernel/deepseek_v4.py new file mode 100644 index 000000000000..72192b533a3d --- /dev/null +++ b/python/sglang/jit_kernel/deepseek_v4.py @@ -0,0 +1,908 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Tuple, Union + +import torch +import triton +import triton.language as tl + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) +from sglang.srt.environ import envs + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +def make_name(name: str) -> str: + return f"dpsk_v4_{name}" + + +@cache_once +def _jit_common_module() -> Module: + return load_jit( + make_name("common"), + cuda_files=["deepseek_v4/common.cuh"], + cuda_wrappers=[("plan_compress_prefill", "plan_compress_prefill")], + ) + + +@cache_once +def _jit_compress_128_online_plan_module() -> Module: + """Host-side plan generator for online compress 128 (no template args).""" + return load_jit( + make_name("compress_128_online_plan"), + cuda_files=["deepseek_v4/c128_online.cuh"], + cuda_wrappers=[ + ("plan_compress_online_prefill", "plan_compress_online_prefill"), + ], + ) + + +@cache_once +def _jit_compress_128_online_module(head_dim: int) -> Module: + """Online compress 128 kernel: ring_size=1, per-index (max, sum, kv) state.""" + args = make_cpp_args(head_dim, is_arch_support_pdl()) + kernel_class = f"FlashCompress128OnlineKernel<{args}>" + return load_jit( + make_name("compress_128_online"), + *args, + cuda_files=["deepseek_v4/c128_online.cuh"], + cuda_wrappers=[ + ("decode", f"{kernel_class}::run_decode"), + ("prefill", f"{kernel_class}::run_prefill"), + ], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_topk_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("topk"), + *args, + cuda_files=["deepseek_v4/topk.cuh"], + cuda_wrappers=[("topk_transform", f"TopK512Kernel<{args}>::transform")], + ) + + +@cache_once +def _jit_topk1024_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("topk1024"), + *args, + cuda_files=["deepseek_v4/topk_1024.cuh"], + cuda_wrappers=[("topk_transform", f"TopK1024Kernel<{args}>::transform")], + ) + + +@cache_once +def _jit_topk_v2_module(topk: int) -> Module: + return load_jit( + make_name("topk_v2"), + str(topk), + cuda_files=["deepseek_v4/topk_v2.cuh"], + cuda_wrappers=[ + ("topk_transform", "CombinedTopKKernel::transform"), + ("topk_plan", "CombinedTopKKernel::plan"), + ], + extra_cuda_cflags=[f"-DSGL_TOPK={topk}"], + ) + + +@cache_once +def _jit_mask_topk_module() -> Module: + return load_jit( + make_name("mask_topk"), + cuda_files=["deepseek_v4/hash_topk.cuh"], + cuda_wrappers=[("run", "MaskKernel::run")], + ) + + +@cache_once +def _jit_hash_topk_module() -> Module: + args = make_cpp_args("act_sqrt_softplus", is_arch_support_pdl()) + return load_jit( + make_name("hash_topk"), + *args, + cuda_files=["deepseek_v4/hash_topk.cuh"], + cuda_wrappers=[("hash_topk", f"HashTopKKernel<{args}>::run")], + ) + + +@cache_once +def _jit_compress_module( + head_dim: int, + dtype_in: torch.dtype, + dtype_out: torch.dtype, + ratio: Literal[4, 128], +) -> Module: + args = make_cpp_args(head_dim, dtype_in, dtype_out, is_arch_support_pdl()) + kernel_class = f"FlashCompress{ratio}Kernel<{args}>" + return load_jit( + make_name(f"compress_{ratio}"), + *args, + cuda_files=[f"deepseek_v4/c{ratio}.cuh"], + cuda_wrappers=[ + ("decode", f"{kernel_class}::run_decode"), + ("prefill", f"{kernel_class}::run_prefill"), + ], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_rmsnorm_head_module(head_dim: int, dtype: torch.dtype): + args = make_cpp_args(head_dim, dtype, is_arch_support_pdl()) + kernel_class = f"RMSNormKernel<{args}>" + return load_jit( + make_name("rmsnorm_head"), + *args, + cuda_files=["deepseek_v4/rmsnorm.cuh"], + cuda_wrappers=[("run_self", f"{kernel_class}::run_self")], + ) + + +@cache_once +def _jit_fused_rope_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("fused_rope"), + *args, + cuda_files=["deepseek_v4/rope.cuh"], + cuda_wrappers=[("forward", f"FusedQKRopeKernel<{args}>::forward")], + ) + + +@cache_once +def _jit_norm_rope_module( + dtype: torch.dtype, + head_dim: int, + rope_dim: int, +) -> Module: + args = make_cpp_args(dtype, head_dim, rope_dim, is_arch_support_pdl()) + return load_jit( + make_name("fused_norm_rope"), + *args, + cuda_files=["deepseek_v4/fused_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedNormRopeKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_fused_store_module( + name: Literal["flashmla", "indexer"], + input_dtype: torch.dtype, + index_dtype: torch.dtype, + page_size: int, +) -> Module: + args = make_cpp_args(input_dtype, index_dtype, page_size, is_arch_support_pdl()) + cname = "FlashMLA" if name == "flashmla" else "Indexer" + kernel_class = f"FusedStoreCache{cname}Kernel<{args}>" + return load_jit( + make_name("store_" + name), + *args, + cuda_files=["deepseek_v4/store.cuh"], + cuda_wrappers=[("run", f"{kernel_class}::run")], + ) + + +@cache_once +def _jit_metadata_module(): + return load_jit( + make_name("metadata"), + cuda_files=["deepseek_v4/paged_mqa_metadata.cuh"], + cuda_wrappers=[("run", "IndexerMetadataKernel::run")], + ) + + +@cache_once +def _jit_silu_mul_quant_varlen_module( + quant_group_size: int, + scale_ue8m0: bool, + swizzle: bool, + apply_swiglu_limit: bool, +) -> Module: + args = make_cpp_args( + quant_group_size, + scale_ue8m0, + swizzle, + is_arch_support_pdl(), + apply_swiglu_limit, + ) + return load_jit( + make_name("silu_mul_quant_varlen"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant.cuh"], + cuda_wrappers=[("run", f"SiluAndMulMaskedPostQuantKernel<{args}>::run")], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_silu_mul_quant_contig_module( + quant_group_size: int, + scale_ue8m0: bool, + swizzle: bool, + apply_swiglu_limit: bool, +) -> Module: + args = make_cpp_args( + quant_group_size, + scale_ue8m0, + swizzle, + is_arch_support_pdl(), + apply_swiglu_limit, + ) + return load_jit( + make_name("silu_mul_quant_contig"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant.cuh"], + cuda_wrappers=[("run", f"SiluAndMulContigPostQuantKernel<{args}>::run")], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_silu_and_mul_clamp_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype, is_arch_support_pdl()) + return load_jit( + make_name("silu_and_mul_clamp"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant.cuh"], + cuda_wrappers=[("run", f"SiluAndMulClampKernel<{args}>::run")], + extra_cuda_cflags=["-use_fast_math"], + ) + + +@cache_once +def _jit_mega_moe_pre_dispatch_module(quant_group_size: int) -> Module: + args = make_cpp_args(quant_group_size, is_arch_support_pdl()) + return load_jit( + make_name("mega_moe_pre_dispatch"), + *args, + cuda_files=["deepseek_v4/mega_moe_pre_dispatch.cuh"], + cuda_wrappers=[("run", f"MegaMoEPreDispatchKernel<{args}>::run")], + ) + + +@cache_once +def _jit_hisparse_transfer_module() -> Module: + return load_jit( + make_name("hisparse_transfer"), + cuda_files=["deepseek_v4/hisparse_transfer.cuh"], + cuda_wrappers=[("hisparse_transfer", "hisparse_transfer")], + ) + + +def hisparse_offload_to_host( + gpu_ptrs: torch.Tensor, + cpu_ptrs: torch.Tensor, + gpu_indices: torch.Tensor, + cpu_indices: torch.Tensor, +) -> None: + module = _jit_hisparse_transfer_module() + module.hisparse_transfer(gpu_ptrs, cpu_ptrs, gpu_indices, cpu_indices) + + +def topk_transform_512( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, +) -> None: + if out_page_indices.shape[1] == 512: + module = _jit_topk_module() + else: + module = _jit_topk1024_module() + module.topk_transform( + scores, seq_lens, page_tables, out_page_indices, page_size, out_raw_indices + ) + + +_WORKSPACE_INTS_PER_BATCH = 2 + 1024 * 2 +_PLAN_METADATA_INTS_PER_BATCH = 4 + + +def plan_topk_v2(seq_lens: torch.Tensor, static_threshold: int = 0) -> torch.Tensor: + module = _jit_topk_v2_module(512) # does not matter + bs = seq_lens.shape[0] + metadata = seq_lens.new_empty(bs + 1, _PLAN_METADATA_INTS_PER_BATCH) + module.topk_plan(seq_lens, metadata, static_threshold) + return metadata + + +def topk_transform_512_v2( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + metadata: torch.Tensor, +) -> None: + module = _jit_topk_v2_module(out_page_indices.shape[1]) + bs = scores.shape[0] + workspace = seq_lens.new_empty(bs, _WORKSPACE_INTS_PER_BATCH) + module.topk_transform( + scores, + seq_lens, + page_tables, + out_page_indices, + page_size, + workspace, + metadata, + ) + + +def hash_topk( + router_logits: torch.Tensor, + input_ids: torch.Tensor, + tid2eid: torch.Tensor, + num_fused_shared_experts: int = 0, + routed_scaling_factor: float = 1.0, + scoring_func: str = "sqrtsoftplus", +) -> Tuple[torch.Tensor, torch.Tensor]: + assert scoring_func == "sqrtsoftplus" + num_tokens = router_logits.size(0) + topk_routed = tid2eid.size(1) + topk_fused = topk_routed + num_fused_shared_experts + topk_ids = torch.empty( + (num_tokens, topk_fused), dtype=torch.int32, device=router_logits.device + ) + topk_weights = torch.empty( + (num_tokens, topk_fused), dtype=torch.float32, device=router_logits.device + ) + module = _jit_hash_topk_module() + module.hash_topk( + router_logits, + input_ids, + tid2eid, + topk_weights, + topk_ids, + routed_scaling_factor, + ) + return topk_weights, topk_ids + + +def mask_topk_ids(topk_ids: torch.Tensor, num_token_non_padded: torch.Tensor): + return _jit_mask_topk_module().run(topk_ids, num_token_non_padded) + + +class CompressorPrefillPlan(NamedTuple): + compress_ratio: int + compress_plan: torch.Tensor + write_plan: torch.Tensor + + def copy_(self, other: CompressorPrefillPlan) -> None: + assert self.compress_ratio == other.compress_ratio + self.compress_plan.copy_(other.compress_plan) + self.write_plan.copy_(other.write_plan) + + @staticmethod + def generate( + compress_ratio: Literal[4, 128], + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: torch.Tensor, + device: torch.device, + use_cuda_graph: bool = False, + ) -> CompressorPrefillPlan: + from sglang.srt.environ import envs + + # Online c128 keeps the same NamedTuple shape (compress_plan, write_plan) + # so call sites that splat `*plan[1:]` continue to work, but the C++ + # plan struct semantics differ (last-token coords + window_len). + if compress_ratio == 128 and envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + return CompressorPrefillPlan._generate_online( + num_q_tokens=num_q_tokens, + seq_lens=seq_lens, + extend_lens=extend_lens, + device=device, + use_cuda_graph=use_cuda_graph, + ) + assert seq_lens.device == extend_lens.device + seq_lens = seq_lens.to(torch.int64) + extend_lens = extend_lens.to(torch.int64) + plan_tensor = torch.empty( + (2, num_q_tokens, 16), + dtype=torch.uint8, + device=seq_lens.device, + pin_memory=seq_lens.is_cpu, + ) + module = _jit_common_module() + is_overlap = compress_ratio == 4 + plan_lens = module.plan_compress_prefill( + extend_lens, + seq_lens, + plan_tensor[0], + plan_tensor[1], + compress_ratio, + is_overlap, + use_cuda_graph, + ) + return CompressorPrefillPlan( + compress_ratio, + plan_tensor[0, : plan_lens[0]].to(device, non_blocking=True), + plan_tensor[1, : plan_lens[1]].to(device, non_blocking=True), + ) + + @staticmethod + def _generate_online( + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: torch.Tensor, + device: torch.device, + use_cuda_graph: bool, + ) -> CompressorPrefillPlan: + # Online plan host-side path: only CPU/cuda-host implemented today. + # Move inputs to CPU pinned memory then bounce the result to device. + seq_lens_cpu = seq_lens.detach().to(torch.int64).cpu() + extend_lens_cpu = extend_lens.detach().to(torch.int64).cpu() + plan_tensor = torch.empty( + (2, num_q_tokens, 16), + dtype=torch.uint8, + device="cpu", + pin_memory=True, + ) + module = _jit_compress_128_online_plan_module() + plan_lens = module.plan_compress_online_prefill( + extend_lens_cpu, + seq_lens_cpu, + plan_tensor[0], + plan_tensor[1], + use_cuda_graph, + ) + return CompressorPrefillPlan( + 128, + plan_tensor[0, : plan_lens[0]].to(device, non_blocking=True), + plan_tensor[1, : plan_lens[1]].to(device, non_blocking=True), + ) + + @property + def is_decode(self) -> bool: + return False + + +class CompressorDecodePlan(NamedTuple): + compress_ratio: int + seq_lens: torch.Tensor + + def copy_(self, other: CompressorDecodePlan) -> None: + assert self.compress_ratio == other.compress_ratio + self.seq_lens.copy_(other.seq_lens) + + @property + def is_decode(self) -> bool: + return True + + +def compress_plan( + compress_ratio: Literal[4, 128], + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: Optional[torch.Tensor], + device: torch.device, +) -> Union[CompressorDecodePlan, CompressorPrefillPlan]: + if extend_lens is not None: + return CompressorPrefillPlan.generate( + compress_ratio, + num_q_tokens, + seq_lens, + extend_lens, + device, + ) + else: + assert num_q_tokens == len(seq_lens) + seq_lens = seq_lens.to(device, non_blocking=True) + return CompressorDecodePlan(compress_ratio, seq_lens) + + +def compress_forward( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + plan: Union[CompressorDecodePlan, CompressorPrefillPlan, None] = None, + extra_data: Optional[torch.Tensor] = None, + *, + head_dim: int, + compress_ratio: Literal[4, 128], + out: Optional[torch.Tensor] = None, + seq_lens: Optional[torch.Tensor] = None, + extend_lens: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert head_dim % 128 == 0 + num_q_tokens = kv_score_input.shape[0] + if out is None: + out = kv_score_input.new_empty((num_q_tokens, head_dim)) + if plan is None: + assert seq_lens is not None + plan = compress_plan( + compress_ratio, + num_q_tokens, + seq_lens, + extend_lens, + kv_score_input.device, + ) + assert plan.compress_ratio == compress_ratio, "Mismatched compress ratio in plan!" + # Online c128: separate JIT module, fp32 state, no compile-time dtypes. + if compress_ratio == 128 and envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + online_module = _jit_compress_128_online_module(head_dim=head_dim) + F = online_module.decode if plan.is_decode else online_module.prefill + F(kv_score_buffer, kv_score_input, out, ape, indices, *plan[1:], extra_data) + return out + module = _jit_compress_module( + head_dim, + kv_score_input.dtype, + out.dtype, + compress_ratio, + ) + F = module.decode if plan.is_decode else module.prefill + F(kv_score_buffer, kv_score_input, out, ape, indices, *plan[1:], extra_data) + return out + + +def compress_fused_norm_rope_inplace( + kv: torch.Tensor, + weight: torch.Tensor, + eps: float, + freq_cis: torch.Tensor, + plan: Union[CompressorDecodePlan, CompressorPrefillPlan], +) -> None: + freq_cis = torch.view_as_real(freq_cis).flatten(-2) + module = _jit_norm_rope_module(kv.dtype, kv.shape[-1], freq_cis.shape[-1]) + module.forward( + kv, + weight, + plan[1], + freq_cis, + int(plan.is_decode), + eps, + plan.compress_ratio, + ) + + +def fused_rope( + q: torch.Tensor, + k: Optional[torch.Tensor], + freqs_cis: torch.Tensor, + positions: torch.Tensor, + inverse: bool = False, +) -> None: + freqs_real = torch.view_as_real(freqs_cis).flatten(-2).contiguous() + module = _jit_fused_rope_module() + module.forward(q, k, freqs_real, positions, inverse) + + +@triton.jit +def create_paged_compress_data_kernel( + req_pool_indices_ptr, + seq_lens_ptr, + extend_seq_lens_ptr, + req_to_token_ptr, + full_to_swa_index_mapping_ptr, + out_0_ptr, + out_1_ptr, + batch_size, + stride_req_to_token_0, + stride_req_to_token_1: tl.constexpr, + stride_out_1_0, + stride_out_1_1: tl.constexpr, + compress_ratio: tl.constexpr, + is_overlap: tl.constexpr, + swa_page_size: tl.constexpr, + ring_size: tl.constexpr, + BLOCK: tl.constexpr, +) -> None: + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < batch_size + + rid = tl.load(req_pool_indices_ptr + offs, mask=mask, other=0).to(tl.int32) + seq_len = tl.load(seq_lens_ptr + offs, mask=mask, other=0).to(tl.int32) + extend_len = tl.load(extend_seq_lens_ptr + offs, mask=mask, other=0).to(tl.int32) + prefix_len = seq_len - extend_len + + cr = compress_ratio + write_pos = ((seq_len - 1) // cr) * cr + load_pos = ((prefix_len - 1) // cr) * cr + write_overlap_pos = write_pos - cr + load_overlap_pos = load_pos - cr + v0 = tl.zeros([BLOCK], tl.int32) + v1 = tl.zeros([BLOCK], tl.int32) + v2 = tl.zeros([BLOCK], tl.int32) + v3 = tl.zeros([BLOCK], tl.int32) + + for i in tl.static_range(4): + if i == 0: + pos = load_pos + elif i == 1: + pos = write_pos + elif i == 2: + pos = load_overlap_pos + else: + pos = write_overlap_pos + pos = tl.maximum(pos, 0) + loc = tl.load( + req_to_token_ptr + + rid.to(tl.int64) * stride_req_to_token_0 + + pos.to(tl.int64) * stride_req_to_token_1, + mask=mask, + other=0, + ).to(tl.int32) + swa_loc = tl.load(full_to_swa_index_mapping_ptr + loc, mask=mask, other=0).to( + tl.int32 + ) + swa_page = swa_loc // swa_page_size + state_loc = swa_page * ring_size + (swa_loc % ring_size) + state_loc = state_loc // cr + if i == 0: + v0 = state_loc + elif i == 1: + v1 = state_loc + elif i == 2: + v2 = state_loc + else: + v3 = state_loc + + tl.store(out_0_ptr + offs, v1, mask=mask) + + if is_overlap: + base = out_1_ptr + offs * stride_out_1_0 + tl.store(base + 0 * stride_out_1_1, v2, mask=mask) + tl.store(base + 1 * stride_out_1_1, v0, mask=mask) + tl.store(base + 2 * stride_out_1_1, v3, mask=mask) + tl.store(base + 3 * stride_out_1_1, write_pos.to(tl.int32), mask=mask) + else: + base = out_1_ptr + offs * stride_out_1_0 + tl.store(base + 0 * stride_out_1_1, v0, mask=mask) + + +def triton_create_paged_compress_data( + *, + compress_ratio: int, + is_overlap: bool, + swa_page_size: int, + ring_size: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + req_to_token: torch.Tensor, + full_to_swa_index_mapping: torch.Tensor, + block: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = req_pool_indices.shape[0] + out_dim = 4 if is_overlap else 1 + device_args: dict = dict(device=req_pool_indices.device, dtype=torch.int32) + out_0 = torch.empty((batch_size,), **device_args) + out_1 = torch.empty((batch_size, out_dim), **device_args) + grid = (triton.cdiv(batch_size, block),) + create_paged_compress_data_kernel[grid]( + req_pool_indices, + seq_lens, + extend_seq_lens, + req_to_token, + full_to_swa_index_mapping, + out_0, + out_1, + batch_size=batch_size, + stride_req_to_token_0=req_to_token.stride(0), + stride_req_to_token_1=req_to_token.stride(1), + stride_out_1_0=out_1.stride(0), + stride_out_1_1=out_1.stride(1), + compress_ratio=compress_ratio, + is_overlap=1 if is_overlap else 0, + swa_page_size=swa_page_size, + ring_size=ring_size, + BLOCK=block, + ) + + if not is_overlap: + out_1.squeeze_(1) + return out_0, out_1 + + +def fused_store_cache( + input: torch.Tensor, + cache: torch.Tensor, + indices: torch.Tensor, + *, + page_size: int, + type: Literal["flashmla", "indexer"], +) -> None: + module = _jit_fused_store_module( + name=type, + input_dtype=input.dtype, + index_dtype=indices.dtype, + page_size=page_size, + ) + module.run(input, cache, indices) + + +def silu_and_mul_clamp( + input: torch.Tensor, + output: torch.Tensor, + swiglu_limit: float, +) -> None: + module = _jit_silu_and_mul_clamp_module(input.dtype) + module.run(input, output, float(swiglu_limit)) + + +def silu_and_mul_masked_post_quant( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + scale_ue8m0: bool = False, + topk: int = 8, + transposed: bool = False, + swiglu_limit: Optional[float] = None, + swizzle: bool = False, +) -> None: + apply_swiglu_limit = swiglu_limit is not None + module = _jit_silu_mul_quant_varlen_module( + quant_group_size, scale_ue8m0, swizzle, apply_swiglu_limit + ) + module.run( + input, + output, + output_scale, + masked_m, + topk, + transposed, + float(swiglu_limit) if apply_swiglu_limit else 0.0, + ) + + +def silu_and_mul_contig_post_quant( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + scale_ue8m0: bool = False, + transposed: bool = False, + swiglu_limit: Optional[float] = None, + swizzle: bool = False, +) -> None: + apply_swiglu_limit = swiglu_limit is not None + module = _jit_silu_mul_quant_contig_module( + quant_group_size, scale_ue8m0, swizzle, apply_swiglu_limit + ) + module.run( + input, + output, + output_scale, + transposed, + float(swiglu_limit) if apply_swiglu_limit else 0.0, + ) + + +def mega_moe_pre_dispatch( + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + buf_x: torch.Tensor, + buf_x_sf: torch.Tensor, + buf_topk_idx: torch.Tensor, + buf_topk_weights: torch.Tensor, + quant_group_size: int = 32, +) -> None: + module = _jit_mega_moe_pre_dispatch_module(quant_group_size) + module.run( + x, + topk_idx, + topk_weights, + buf_x, + buf_x_sf, + buf_topk_idx, + buf_topk_weights, + ) + + +def get_paged_mqa_logits_metadata(seq_lens: torch.Tensor, page_size: int, num_sm: int): + assert page_size == 64 + seq_lens = seq_lens.view(-1).to(torch.int32) + metadata = seq_lens.new_empty(num_sm + 1, 2) + module = _jit_metadata_module() + module.run(seq_lens, metadata) + return metadata + + +def rmsnorm_self(q: torch.Tensor, eps: float) -> torch.Tensor: + module = _jit_rmsnorm_head_module(q.shape[-1], q.dtype) + out = q.new_empty(q.shape) + module.run_self(q, out, eps) + return out + + +@cache_once +def _jit_torch_cublas_bf16_fp32() -> Any: + import torch.utils.cpp_extension + + source = """ +#include +#include +#include + +torch::Tensor linear_bf16_fp32( + torch::Tensor X, + torch::Tensor W) +{ + int batch = X.size(0); + int in_features = X.size(1); + int out_features = W.size(0); + + auto Y = torch::empty( + {batch, out_features}, + torch::dtype(torch::kFloat32).device(X.device())); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + + float alpha = 1.0f; + float beta = 0.0f; + + cublasGemmEx( + handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + out_features, + batch, + in_features, + &alpha, + W.data_ptr(), CUDA_R_16BF, in_features, + X.data_ptr(), CUDA_R_16BF, in_features, + &beta, + Y.data_ptr(), CUDA_R_32F, out_features, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP + ); + + return Y; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_bf16_fp32", &linear_bf16_fp32, "BF16xBF16 -> FP32 linear (no bias)"); +} +""" + module = torch.utils.cpp_extension.load_inline( + name="linear_bf16_fp32", + cpp_sources="", + cuda_sources=source, + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3"], + verbose=False, + ) + return module + + +def linear_bf16_fp32(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + from sglang.srt.environ import envs + + algo = envs.SGLANG_OPT_BF16_FP32_GEMM_ALGO.get() + return _dispatch_bf16_fp32_backend(x, y, algo=algo) + + +def _dispatch_bf16_fp32_backend( + x: torch.Tensor, y: torch.Tensor, *, algo: str +) -> torch.Tensor: + if algo == "cublas": + module = _jit_torch_cublas_bf16_fp32() + return module.linear_bf16_fp32(x, y) + elif algo == "deep_gemm": + import deep_gemm + + z = x.new_empty(x.size(0), y.size(0), dtype=torch.float32) + deep_gemm.bf16_gemm_nt(x, y, z) + return z + else: + return torch.nn.functional.linear(x.float(), y.float()) diff --git a/python/sglang/jit_kernel/hisparse.py b/python/sglang/jit_kernel/hisparse.py index 25db57b0b6ae..a6930ecb7f95 100644 --- a/python/sglang/jit_kernel/hisparse.py +++ b/python/sglang/jit_kernel/hisparse.py @@ -18,10 +18,13 @@ def _jit_sparse_module( num_top_k: int, hot_buffer_size: int, is_mla: bool = False, + is_dsv4_layout: bool = False, ) -> Module: - template_args = make_cpp_args(block_size, num_top_k, hot_buffer_size, is_mla) + template_args = make_cpp_args( + block_size, num_top_k, hot_buffer_size, is_mla, is_dsv4_layout + ) cache_args = make_cpp_args( - item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla + item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla, is_dsv4_layout ) return load_jit( "sparse_cache", @@ -36,7 +39,9 @@ def _jit_sparse_module( ) -def load_cache_to_device_buffer_mla( +def _load_cache_to_device_buffer_mla( + *, + is_dsv4_layout: bool, top_k_tokens: torch.Tensor, device_buffer_tokens: torch.Tensor, host_cache_locs: torch.Tensor, @@ -50,16 +55,21 @@ def load_cache_to_device_buffer_mla( item_size_bytes: int, num_top_k: int, hot_buffer_size: int, - page_size: int = 1, - block_size: int = 256, - num_real_reqs: torch.Tensor | None = None, + page_size: int, + block_size: int, + num_real_reqs: torch.Tensor | None, ) -> None: assert ( hot_buffer_size >= num_top_k ), f"hot_buffer_size ({hot_buffer_size}) must be >= num_top_k ({num_top_k})" module = _jit_sparse_module( - item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla=True + item_size_bytes, + block_size, + num_top_k, + hot_buffer_size, + is_mla=True, + is_dsv4_layout=is_dsv4_layout, ) empty = torch.empty(0) @@ -86,3 +96,83 @@ def load_cache_to_device_buffer_mla( page_size, item_size_bytes, ) + + +def load_cache_to_device_buffer_mla( + top_k_tokens: torch.Tensor, + device_buffer_tokens: torch.Tensor, + host_cache_locs: torch.Tensor, + device_buffer_locs: torch.Tensor, + host_cache: torch.Tensor, + device_buffer: torch.Tensor, + top_k_device_locs: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + lru_slots: torch.Tensor, + item_size_bytes: int, + num_top_k: int, + hot_buffer_size: int, + page_size: int = 1, + block_size: int = 256, + num_real_reqs: torch.Tensor | None = None, +) -> None: + """Generic MLA hisparse swap-in: device + host both linear (stride=item_size_bytes).""" + _load_cache_to_device_buffer_mla( + is_dsv4_layout=False, + top_k_tokens=top_k_tokens, + device_buffer_tokens=device_buffer_tokens, + host_cache_locs=host_cache_locs, + device_buffer_locs=device_buffer_locs, + host_cache=host_cache, + device_buffer=device_buffer, + top_k_device_locs=top_k_device_locs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + lru_slots=lru_slots, + item_size_bytes=item_size_bytes, + num_top_k=num_top_k, + hot_buffer_size=hot_buffer_size, + page_size=page_size, + block_size=block_size, + num_real_reqs=num_real_reqs, + ) + + +def load_cache_to_device_buffer_dsv4_mla( + top_k_tokens: torch.Tensor, + device_buffer_tokens: torch.Tensor, + host_cache_locs: torch.Tensor, + device_buffer_locs: torch.Tensor, + host_cache: torch.Tensor, + device_buffer: torch.Tensor, + top_k_device_locs: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + lru_slots: torch.Tensor, + item_size_bytes: int, + num_top_k: int, + hot_buffer_size: int, + page_size: int = 1, + block_size: int = 256, + num_real_reqs: torch.Tensor | None = None, +) -> None: + """DSv4 hisparse swap-in: page-padded device + linear host (kvcacheio.cuh layout).""" + _load_cache_to_device_buffer_mla( + is_dsv4_layout=True, + top_k_tokens=top_k_tokens, + device_buffer_tokens=device_buffer_tokens, + host_cache_locs=host_cache_locs, + device_buffer_locs=device_buffer_locs, + host_cache=host_cache, + device_buffer=device_buffer, + top_k_device_locs=top_k_device_locs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + lru_slots=lru_slots, + item_size_bytes=item_size_bytes, + num_top_k=num_top_k, + hot_buffer_size=hot_buffer_size, + page_size=page_size, + block_size=block_size, + num_real_reqs=num_real_reqs, + ) diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh new file mode 100644 index 000000000000..02b166d01c73 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace device::compress { + +struct alignas(16) PrefillPlan { + uint32_t ragged_id; + uint32_t batch_id; + uint32_t position; + uint32_t window_len; // must be in `[0, compress_ratio * (1 + is_overlap))` + + bool is_valid(const uint32_t ratio, const bool is_overlap) const { + const uint32_t max_window_len = ratio * (1 + is_overlap); + return window_len < max_window_len; + } +}; + +} // namespace device::compress + +namespace host::compress { + +using device::compress::PrefillPlan; +using PrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kPrefillPlanDim = 16; + +static_assert(alignof(PrefillPlan) == sizeof(PrefillPlan)); +static_assert(sizeof(PrefillPlan) == kPrefillPlanDim * sizeof(PrefillPlanTensorDtype)); + +} // namespace host::compress diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/fp8_utils.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/fp8_utils.cuh new file mode 100644 index 000000000000..4fdbb062c3cd --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/fp8_utils.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include + +#include +#include + +// Small helpers shared by the DeepSeek-V4 FP8/UE8M0 quantization kernels +// (silu_and_mul_masked_post_quant, store, mega_moe_pre_dispatch, ...). +// All functions are `SGL_DEVICE` (= `__forceinline__ __device__`) so +// including this header in multiple translation units is ODR-safe. + +namespace deepseek_v4::fp8 { + +// Round `x` to the nearest representable UE8M0 value. Returns the raw +// 8-bit biased exponent; the actual fp32 scale is `2^(exp - 127)` +// (i.e. `__uint_as_float(exp << 23)`). +SGL_DEVICE int32_t cast_to_ue8m0(float x) { + uint32_t u = __float_as_uint(x); + int32_t exp = int32_t((u >> 23) & 0xFF); + uint32_t mant = u & 0x7FFFFF; + return exp + (mant != 0); +} + +// 1 / 2^(exp - 127) as fp32. Equivalent to `1.0f / __uint_as_float(exp << 23)`. +SGL_DEVICE float inv_scale_ue8m0(int32_t exp) { + return __uint_as_float((127 + 127 - exp) << 23); +} + +// Clamp to [-FP8_E4M3_MAX, FP8_E4M3_MAX]. +SGL_DEVICE float fp8_e4m3_clip(float val) { + namespace math = device::math; + return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX); +} + +// Pack two fp32 values into a single fp8x2_e4m3 with clamping. +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}}; +} + +} // namespace deepseek_v4::fp8 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh new file mode 100644 index 000000000000..0a3acc47734a --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh @@ -0,0 +1,96 @@ +#include +#include + +#include + +#include + +namespace device::hisparse { + +/// NOTE: We call nope+rope as a "value" here. +/// GPU Cache layout: +/// VALUE 0, VALUE 1, ..., VALUE 63, +/// SCALE 0, SCALE 1, ..., SCALE 63, +/// [Padding to align to 576 bytes] +/// CPU Cache follow a trivial linear layout without any padding. +inline constexpr int64_t kGPUPageSize = 64; +inline constexpr int64_t kGPUPageBits = 6; // log2(kGPUPageSize) +inline constexpr int64_t kValueBytes = 576; +inline constexpr int64_t kScaleBytes = 8; +/// NOTE: FlashMLA requires each page to be aligned to 576 bytes +inline constexpr int64_t kCPUItemBytes = kValueBytes + kScaleBytes; +inline constexpr int64_t kGPUPageBytes = host::div_ceil(kCPUItemBytes * kGPUPageSize, 576) * 576; +inline constexpr int64_t kGPUScaleOffset = kValueBytes * kGPUPageSize; + +struct PointerInfo { + int64_t* value_ptr; + int64_t* scale_ptr; +}; + +SGL_DEVICE PointerInfo get_pointer_gpu(void* cache, int32_t index) { + using namespace device; + static_assert(1 << kGPUPageBits == kGPUPageSize); + const int32_t page_num = index >> kGPUPageBits; + const int32_t page_offset = index & (kGPUPageSize - 1); + const auto page_ptr = pointer::offset(cache, page_num * kGPUPageBytes); + const auto value_ptr = pointer::offset(page_ptr, page_offset * kValueBytes); + const auto scale_ptr = pointer::offset(page_ptr, kGPUScaleOffset + page_offset * kScaleBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +SGL_DEVICE PointerInfo get_pointer_cpu(void* cache, int32_t index) { + using namespace device; + const auto value_ptr = pointer::offset(cache, index * kCPUItemBytes); + const auto scale_ptr = pointer::offset(value_ptr, kValueBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +enum class TransferDirection { + DeviceToDevice = 0, + DeviceToHost = 1, + HostToDevice = 2, +}; + +template +SGL_DEVICE void transfer_item(void* dst_cache, void* src_cache, const int32_t dst_index, const int32_t src_index) { + constexpr bool is_dst_device = (direction != TransferDirection::DeviceToHost); + constexpr bool is_src_device = (direction != TransferDirection::HostToDevice); + constexpr auto dst_fn = is_dst_device ? get_pointer_gpu : get_pointer_cpu; + constexpr auto src_fn = is_src_device ? get_pointer_gpu : get_pointer_cpu; + + const auto [dst_value_ptr, dst_scale_ptr] = dst_fn(dst_cache, dst_index); + const auto [src_value_ptr, src_scale_ptr] = src_fn(src_cache, src_index); + + int64_t local_items[2]; + const int64_t* tail_src_ptr; + int64_t* tail_dst_ptr; + + const int32_t lane_id = threadIdx.x % 32; + + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + local_items[i] = src_value_ptr[j]; + } + + if (lane_id < 8) { // handle the tail element safely + const auto last_id = 64 + lane_id; + tail_src_ptr = src_value_ptr + last_id; + tail_dst_ptr = dst_value_ptr + last_id; + } else { // broadcast load/store is safe + tail_src_ptr = src_scale_ptr; + tail_dst_ptr = dst_scale_ptr; + } + + const auto tail_item = *tail_src_ptr; + + // store first 512 bytes of value + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + dst_value_ptr[j] = local_items[i]; + } + + // store the tail element + *tail_dst_ptr = tail_item; +} + +} // namespace device::hisparse diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/cluster.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/cluster.cuh new file mode 100644 index 000000000000..e58214c95148 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/cluster.cuh @@ -0,0 +1,257 @@ +#pragma once +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct ClusterTopK { + static constexpr uint32_t kClusterSize = 8; + static constexpr uint32_t kHistBits = 10; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 4; + static constexpr uint32_t kMaxLength = kClusterSize * kNumStages * kSizePerStage; + static constexpr uint32_t kStoreLane = kBlockSize - 1; + static constexpr uint32_t kAboveBits = 11; + + // --------------------------------------------------------------------------- + // Shared memory layouts + // --------------------------------------------------------------------------- + + struct Smem { + uint64_t barrier[kNumStages]; + uint32_t local_above_equal[kClusterSize]; + uint32_t prefix_above_equal; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + uint32_t histogram[kHistBins]; + alignas(128) float score_buffer[kNumStages][kSizePerStage]; + Tie tie_buffer[kMaxTies]; + }; + + struct alignas(16) Metadata { + uint32_t batch_id; + uint32_t seq_len; + bool has_next; + }; + + struct WorkSpace { + uint2 metadata; // {num_above, num_ties} + Tie ties[kMaxTies]; + }; + + static constexpr uint32_t kWorkspaceInts = sizeof(WorkSpace) / sizeof(uint32_t); + + // --------------------------------------------------------------------------- + // Stage 1: histogram + cluster reduce + find threshold + scatter + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_init(void* _smem) { + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto smem = static_cast(_smem); + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < kNumStages) ptx::mbarrier_init(&smem->barrier[tx], 1); + __syncthreads(); + } + + SGL_DEVICE static void stage1_prologue(const float* scores, uint32_t length, void* _smem) { + if (threadIdx.x == 0) { + const auto smem = static_cast(_smem); + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + const auto length_aligned = (length + 3u) & ~3u; // align to 4 for TMA +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + if (stage >= num_stages) break; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length_aligned - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[stage]; + ptx::tma_load(smem->score_buffer[stage], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + } + } + + SGL_DEVICE static void stage1(int32_t* indices, uint32_t length, void* _smem, bool reuse = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram, counters, and barriers +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; + const auto size = min(kSizePerStage, length - offset); + if (lane_id == 0) ptx::mbarrier_wait(&smem->barrier[stage], 0); + __syncwarp(); +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto idx = tx + i * kBlockSize; + if (idx >= size) break; + const auto score = smem->score_buffer[stage][idx]; + const auto bin = extract_coarse_bin(score); + atomicAdd(&smem->histogram[bin], 1); + } + } + + static_assert(kHistBins <= kBlockSize); + + // 2-shot all-reduce + { + auto cluster = cooperative_groups::this_cluster(); + cluster.sync(); + const auto cluster_rank = blockIdx.y; + const auto kLocalSize = kHistBins / kClusterSize; + const auto offset = kLocalSize * cluster_rank; + + const auto src_tx = tx / kClusterSize; + const auto src_rank = tx % kClusterSize; + + if (tx < kHistBins) { + const auto addr = &smem->histogram[offset + src_tx]; + const auto src_addr = cluster.map_shared_rank(addr, src_rank); + *src_addr = warp::reduce_sum(*src_addr); + } + cluster.sync(); + } + + // now each block holds the whole histogram, find the threshold bin + { + const auto value = tx < kHistBins ? smem->histogram[tx] : 0; + const auto warp_inc = warp_inclusive_sum(lane_id, value); + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + const auto tmp = smem->warp_sum[lane_id]; + // total_length = sum of all bins in the globally-reduced histogram + // (problem.length is block-local; after cluster reduction we need the global total) + const auto total_length = warp::reduce_sum(tmp); + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_inc; + const auto above = total_length - prefix_sum; + if (tx < kHistBins && above < K && above + value >= K) { + smem->counter_gt = smem->counter_eq = 0; + smem->match = { + .bin = tx, + .above_count = above, + .equal_count = value, + }; + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // write above and equal results to global memory +#pragma unroll + for (uint32_t stage = 0; stage < kNumStages; stage++) { + const auto offset = stage * kSizePerStage; + if (offset >= length) break; +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; ++i) { + const auto buf_idx = tx + i * kBlockSize; + const auto global_idx = offset + buf_idx; + if (global_idx >= length) break; + const auto score = smem->score_buffer[stage][buf_idx]; + const auto bin = extract_coarse_bin(score); + if (bin > thr_bin) { + indices[atomicAdd(&smem->counter_gt, 1)] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } + } + if (reuse) { + const auto num_stages = (length + kSizePerStage - 1) / kSizePerStage; + if (tx < kHistBins) smem->histogram[tx] = 0; + if (tx < num_stages) ptx::mbarrier_arrive(&smem->barrier[tx]); + } + __syncthreads(); + } + + // --------------------------------------------------------------------------- + // Stage 1 epilogue: cross-block prefix sum + page translate + tie store + // --------------------------------------------------------------------------- + + SGL_DEVICE static void stage1_epilogue(const TransformParams params, const uint32_t offset, void* _ws, void* _smem) { + auto cluster = cooperative_groups::this_cluster(); + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto local_above = smem->counter_gt; + const auto local_equal = smem->counter_eq; + const auto cluster_rank = blockIdx.y; + + constexpr uint32_t kAboveMask = (1 << kAboveBits) - 1; + static_assert(kAboveMask >= K); + + // Pack local counts -- NO alignment rounding (contiguous layout) + static_assert(kMaxTies <= kBlockSize); + const auto idx_above = tx < local_above ? params.indices_in[tx] : 0; + const auto tie_value = tx < local_equal ? smem->tie_buffer[tx] : Tie{0, 0.0f}; + + // push to remote shared memory, can reduce latency of reading remote + if (tx < kClusterSize) { + const auto value = (local_equal << kAboveBits) | local_above; + const auto dst_addr = cluster.map_shared_rank(smem->local_above_equal, tx); + dst_addr[cluster_rank] = value; + } + // after this last sync, only read local shared memory + // so that it is safe when peer rank has already exited the kernel + cluster.sync(); + if (tx < kClusterSize) { + const auto value = tx < cluster_rank ? smem->local_above_equal[tx] : 0; + const auto kActiveMask = (1u << kClusterSize) - 1; + smem->prefix_above_equal = warp::reduce_sum(value, kActiveMask); + } + __syncthreads(); + + const auto prefix_packed = smem->prefix_above_equal; + const auto prefix_above = prefix_packed & kAboveMask; + const auto prefix_equal = prefix_packed >> kAboveBits; + + // Page-translate above elements + if (tx < local_above) { + params.write(tx + prefix_above, idx_above + offset); + } + // Contiguous tie store via regular global writes (no TMA, no gaps) + const auto ws = static_cast(_ws); + if (tx < local_equal && tx + prefix_equal < kMaxTies) { + ws->ties[tx + prefix_equal] = {tie_value.idx + offset, tie_value.score}; + } + // Block 0 writes global metadata {num_above, num_ties} + if (cluster_rank == kClusterSize - 1 && tx == 0) { + const auto sum_above = prefix_above + local_above; + const auto sum_equal = prefix_equal + local_equal; + ws->metadata = make_uint2(sum_above, sum_equal); + } + } + + SGL_DEVICE static void transform(const TransformParams params, const void* _ws, void* _smem) { + const auto ws = static_cast(_ws); + const auto meta = &ws->metadata; + const auto [num_above, num_equal] = *meta; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(ws->ties, clamped_ties, num_above, K, params, _smem); + } +}; + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/common.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/common.cuh new file mode 100644 index 000000000000..d553032d799a --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/common.cuh @@ -0,0 +1,176 @@ +#pragma once +#include +#include +#include +#include + +#include + +namespace device::top512 { + +inline constexpr uint32_t kMaxTopK = 1024; +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kNumWarps = kBlockSize / kWarpThreads; +inline constexpr uint32_t kMaxTies = 1024; // == kBlockSize: 1 element per thread in stage2 +static constexpr uint32_t kRadixBins = 256; +static_assert(kMaxTopK <= kBlockSize && kMaxTies <= kBlockSize); + +// always use float4 to load from global memory +using Vec4 = AlignedVector; + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +struct TransformParams { + const int32_t* __restrict__ page_table; + const int32_t* __restrict__ indices_in; + int32_t* __restrict__ indices_out; + uint32_t page_bits; + + SGL_DEVICE void transform(const uint32_t idx) const { + indices_out[idx] = page_to_indices(page_table, indices_in[idx], page_bits); + } + SGL_DEVICE void write(const uint32_t dst, const uint32_t src) const { + indices_out[dst] = page_to_indices(page_table, src, page_bits); + } +}; + +struct alignas(16) MatchBin { + uint32_t bin; + uint32_t above_count; + uint32_t equal_count; +}; + +struct alignas(8) Tie { + uint32_t idx; + float score; +}; + +struct TieHandleSmem { + alignas(128) uint32_t counter; // output position counter + alignas(128) MatchBin match; + uint32_t histogram[kRadixBins]; // 256-bin radix histogram + uint32_t warp_sum[kNumWarps]; // for 2-pass prefix sum +}; + +template +SGL_DEVICE uint32_t extract_coarse_bin(float x) { + static_assert(0 < kBits && kBits < 15); + const auto hx = cast(x); + const uint16_t bits = *reinterpret_cast(&hx); + const uint16_t key = (bits & 0x8000) ? ~bits : bits | 0x8000; + return key >> (16 - kBits); +} + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +/// Order-preserving float32 -> uint32 for radix select +SGL_DEVICE uint32_t extract_exact_bin(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE void trivial_transform(const TransformParams& params, uint32_t length, uint32_t K) { + if (const auto tx = threadIdx.x; tx < length) { + params.write(tx, tx); + } else if (tx < K) { + params.indices_out[tx] = -1; + } +} + +SGL_DEVICE void tie_handle_transform( + const Tie* __restrict__ ties, // + const uint32_t num_ties, + const uint32_t num_above, + const uint32_t K, + const TransformParams params, + void* _smem) { + auto* smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Each thread loads one element (or becomes inactive) + const bool has_elem = tx < num_ties; + const auto tie = has_elem ? ties[tx] : Tie{0, 0.0f}; + const uint32_t key = extract_exact_bin(tie.score); + const uint32_t idx = tie.idx; + bool active = has_elem; + uint32_t topk_remain = K - num_above; + uint32_t write_pos = K; + + smem->counter = 0; + __syncthreads(); + + // Number of warps covering the 256-bin histogram (256/32 = 8) + constexpr uint32_t kRadixWarps = kRadixBins / kWarpThreads; + +#pragma unroll + for (int round = 0; round < 4; round++) { + const uint32_t shift = 24 - round * 8; + const uint32_t bin = (key >> shift) & 0xFFu; + + // 1. Build histogram + if (tx < kRadixBins) smem->histogram[tx] = 0; + __syncthreads(); + if (active) atomicAdd(&smem->histogram[bin], 1); + __syncthreads(); + + // 2. v2-style 2-pass prefix sum on 256 bins + // Only first 256 threads (8 warps) carry histogram bins. + // Other threads get hist_val=0 and harmless prefix results. + uint32_t hist_val = 0; + uint32_t warp_inc = 0; + if (tx < kRadixBins) { + hist_val = smem->histogram[tx]; + warp_inc = warp_inclusive_sum(lane_id, hist_val); + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + } + __syncthreads(); + if (tx < kRadixBins) { + // Inter-warp prefix (only first kHistWarps warp totals matter) + const auto tmp = (lane_id < kRadixWarps) ? smem->warp_sum[lane_id] : 0; + const auto total = warp::reduce_sum(tmp); + const auto inter = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + const auto prefix = inter + warp_inc; // inclusive prefix through this bin + const auto above = total - prefix; // elements in bins ABOVE this one + // 3. Find threshold bin + if (above < topk_remain && above + hist_val >= topk_remain) { + smem->match = {tx, above, topk_remain - above}; + } + } + __syncthreads(); + + const auto [thr, n_above, _] = smem->match; + + // 4. Scatter + if (active) { + if (bin > thr) { + write_pos = num_above + atomicAdd(&smem->counter, 1); + active = false; + } else if (bin < thr) { + active = false; + } else if (round == 3) { + write_pos = K - atomicAdd(&smem->match.equal_count, -1u); + } + // my_bin == thr && round < 3: stay active for next round + } + + topk_remain -= n_above; + if (topk_remain == 0) break; + } + + if (write_pos < K) params.write(write_pos, idx); +} + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/ptx.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/ptx.cuh new file mode 100644 index 000000000000..73eef555f4db --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/ptx.cuh @@ -0,0 +1,54 @@ +#pragma once +#include + +#include + +#include + +namespace device::top512 { + +namespace ptx { + +SGL_DEVICE void mbarrier_wait(uint64_t* addr, uint32_t phase) { + while (!cuda::ptx::mbarrier_try_wait_parity(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, addr, phase)) + ; +} + +SGL_DEVICE void mbarrier_init(uint64_t* addr, uint32_t arrives) { + cuda::ptx::mbarrier_init(addr, arrives); +} + +SGL_DEVICE void mbarrier_arrive_expect_tx(uint64_t* addr, uint32_t tx) { + cuda::ptx::mbarrier_arrive_expect_tx(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr, tx); +} + +SGL_DEVICE void mbarrier_arrive(uint64_t* addr) { + cuda::ptx::mbarrier_arrive(cuda::ptx::sem_relaxed, cuda::ptx::scope_cta, cuda::ptx::space_shared, addr); +} + +SGL_DEVICE void tma_load(void* dst, const void* src, uint32_t num_bytes, uint64_t* mbar) { + cuda::ptx::cp_async_bulk(cuda::ptx::space_shared, cuda::ptx::space_global, dst, src, num_bytes, mbar); +} + +SGL_DEVICE uint32_t elect_sync() { + uint32_t pred = 0; + asm volatile( + "{\n\t" + ".reg .pred %%px;\n\t" + "elect.sync _|%%px, %1;\n\t" + "@%%px mov.s32 %0, 1;\n\t" + "}" + : "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +} + +SGL_DEVICE bool elect_sync_cta(uint32_t tx) { + const auto warp_id = tx / 32; + const auto uniform_warp_id = __shfl_sync(0xFFFFFFFF, warp_id, 0); + return (uniform_warp_id == 0 && elect_sync()); +} + +} // namespace ptx + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/register.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/register.cuh new file mode 100644 index 000000000000..77d7361ee871 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/register.cuh @@ -0,0 +1,302 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct RegisterTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kVecsPerThread = 4; + static constexpr uint32_t kMaxTolerance = 0; + static constexpr uint32_t kMax1PassLength = kVecsPerThread * 4 * kBlockSize; + static constexpr uint32_t kMaxExtraLength = kMax1PassLength; + static constexpr uint32_t kMax2PassLength = kMax1PassLength + kMaxExtraLength; + + struct Smem { + using HistVec = AlignedVector; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + uint64_t mbarrier; // for cp.async + MatchBin match; + uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + alignas(16) float score_buffer[kMaxExtraLength]; + }; + + template + SGL_DEVICE static void + run(const float* scores, // + int32_t* indices, + const uint32_t length, + void* _smem, + const bool use_pdl = false) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + // Initialize shared memory histogram + { + typename Smem::HistVec hist_vec; + hist_vec.fill(0); + smem->histogram_vec[tx] = hist_vec; + if (tx == 0) { + smem->counter_gt = smem->counter_eq = 0; + if constexpr (kIs2Pass) { + ptx::mbarrier_init(&smem->mbarrier, 1); + } + } + __syncthreads(); + } + + if (use_pdl) device::PDLWaitPrimary(); + + // Load scores into registers + Vec4 local[kVecsPerThread]; +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { + const uint32_t base = (tx + v * kBlockSize) * 4; + if (base >= length) break; + local[v].load(scores, tx + v * kBlockSize); + } + + // Fetch the next chunk of scores + if constexpr (kIs2Pass) { + if (ptx::elect_sync_cta(tx)) { + const auto length_aligned = (length + 3u - kMax1PassLength) & ~3u; + const auto size_bytes = length_aligned * sizeof(float); + ptx::tma_load(smem->score_buffer, scores + kMax1PassLength, size_bytes, &smem->mbarrier); + ptx::mbarrier_arrive_expect_tx(&smem->mbarrier, size_bytes); + } + __syncwarp(); // avoid warp divergence on + } + + // Accumulate histogram via shared-memory atomics +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + if constexpr (!kIs2Pass) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if (idx >= length) goto LABEL_ACC_FINISH; + } + atomicAdd(&smem->histogram[extract_coarse_bin(local[v][e])], 1); + } + } + if constexpr (kIs2Pass) { + // 16K ~ 32K. `i` is a float4 index + if (lane_id == 0) ptx::mbarrier_wait(&smem->mbarrier, 0); + __syncwarp(); + for (uint32_t i = tx; i + kMax1PassLength < length; i += kBlockSize) { + const auto val = smem->score_buffer[i]; + atomicAdd(&smem->histogram[extract_coarse_bin(val)], 1); + } + } + [[maybe_unused]] LABEL_ACC_FINISH: + __syncthreads(); + + // Phase 2: Exclusive prefix scan -> find threshold bin + { + constexpr uint32_t kItems = kHistBins / kBlockSize; + uint32_t orig[kItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t tmp_local_sum = 0; + +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + orig[i] = hist_vec[i]; + tmp_local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, tmp_local_sum); + const auto warp_exc = warp_inc - tmp_local_sum; + if (lane_id == kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inc; + } + + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + // Exactly one bin satisfies: above < K && above + count >= K + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + + // Phase 3: Scatter + // Elements strictly above threshold go directly to output. + // Tied elements: simple path admits first-come; tiebreak path collects into tie_buffer. + const bool need_tiebreak = (num_equal + num_above > K + kMaxTolerance); + const auto topk_indices = indices; + const auto tie_buffer = smem->tie_buffer; + +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if constexpr (!kIs2Pass) { + if (idx >= length) goto LABEL_SCATTER_DONE; + } + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + // prefetch the next scores + if constexpr (kIs2Pass) { + local[v].load(smem->score_buffer, tx + v * kBlockSize); + } + } + + // 16K ~ 32K, already in registers: similar loop as above but read from smem->score_buffer + if constexpr (kIs2Pass) { +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e + kMax1PassLength; + if (idx >= length) goto LABEL_SCATTER_DONE; + const uint32_t bin = extract_coarse_bin(local[v][e]); + if (bin > thr_bin) { + topk_indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kMaxTies) { + tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + topk_indices[which] = idx; + } + } + } + } + } + } + + [[maybe_unused]] LABEL_SCATTER_DONE: + if (!need_tiebreak) return; + + // Phase 4: Tie-breaking within the threshold bin. + // Assume num_ties <= kBlockSize (at most 1 block of ties). + // Each thread takes one tied element, computes its rank (number of + // elements with strictly higher score, breaking exact float ties by + // original index), and writes to output if rank < topk_remain. + __syncthreads(); + static_assert(kMaxTies <= kBlockSize); + + const uint32_t num_ties = min(num_equal, kMaxTies); + const uint32_t topk_remain = K - num_above; + + const auto is_greater = [](const Tie& a, const Tie& b) { + return (a.score > b.score) || (a.score == b.score && a.idx < b.idx); + }; + + if (num_ties <= kWarpThreads) { + static_assert(kWarpThreads <= kNumWarps); + if (lane_id >= num_ties || warp_id >= num_ties) return; // some threads are idle + /// NOTE: use long long to avoid mask overflow when num_ties == 32 + const uint32_t mask = (1ull << num_ties) - 1u; + const auto tie = tie_buffer[lane_id]; + const auto target_tie = tie_buffer[warp_id]; + const bool pred = is_greater(tie, target_tie); + const auto rank = static_cast(__popc(__ballot_sync(mask, pred))); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } else if (num_ties <= kWarpThreads * 2) { + // 64 x 64 topk implementation: each thread takes 2 elements + const auto lane_id_1 = lane_id + kWarpThreads; + const auto warp_id_1 = warp_id + kWarpThreads; + const auto invalid = Tie{.idx = 0xFFFFFFFF, .score = -FLT_MAX}; + const auto tie_0 = tie_buffer[lane_id]; + const auto tie_1 = lane_id_1 < num_ties ? tie_buffer[lane_id_1] : invalid; + if (true) { + const auto target = tie_buffer[warp_id]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + if (warp_id_1 < num_ties) { + const auto target = tie_buffer[warp_id_1]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target.idx; + } + } + } else { + /// NOTE: Based on my observation, this path is very rarely reached + [[unlikely]]; + // Block-level: each thread reads from tie_buffer in shared memory + for (auto i = warp_id; i < num_ties; i += kNumWarps) { + const auto target_tie = tie_buffer[i]; + uint32_t local_rank = 0; + for (auto j = lane_id; j < num_ties; j += kWarpThreads) { + const auto tie = tie_buffer[j]; + if (is_greater(tie, target_tie)) local_rank++; + } + // sum the rank across the warp + const auto rank = warp::reduce_sum(local_rank); + if (lane_id == 0 && rank < topk_remain) { + topk_indices[num_above + rank] = target_tie.idx; + } + } + } + } + + SGL_DEVICE static void transform(const TransformParams params) { + __syncthreads(); + if (const auto tx = threadIdx.x; tx < K) params.transform(tx); + } +}; + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/streaming.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/streaming.cuh new file mode 100644 index 000000000000..4462b89a1930 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/topk/streaming.cuh @@ -0,0 +1,213 @@ +#pragma once + +#include +#include +#include + +#include "common.cuh" +#include "ptx.cuh" +#include +#include + +namespace device::top512 { + +template +struct StreamingTopK { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kRadixBins = 256; + static constexpr uint32_t kElemPerStage = 8; + static constexpr uint32_t kSizePerStage = kElemPerStage * kBlockSize; + static constexpr uint32_t kNumStages = 2; // double buffer + + static constexpr uint32_t kHistItems = kHistBins / kBlockSize; // 4 + static_assert(kHistItems * kBlockSize == kHistBins); + using HistVec = AlignedVector; + + struct Smem { + uint64_t barrier[2][kNumStages]; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kMaxTies]; + }; + union { + float score_buffer[kNumStages][kSizePerStage]; + TieHandleSmem stage2; // reuse smem for tie handling in phase D + }; + }; + + // --------------------------------------------------------------------------- + // Helpers + // --------------------------------------------------------------------------- + + /// NOTE: length must be 4-aligned since we load 4 floats/thread. Caller should round up. + template + SGL_DEVICE static void issue_tma(const float* scores, uint32_t stage, uint32_t length, Smem* smem) { + const auto buf_idx = stage % kNumStages; + const auto offset = stage * kSizePerStage; + const auto size = min(kSizePerStage, length - offset); + const auto size_bytes = size * sizeof(float); + const auto bar = &smem->barrier[kIsScatter][buf_idx]; + ptx::tma_load(smem->score_buffer[buf_idx], scores + offset, size_bytes, bar); + ptx::mbarrier_arrive_expect_tx(bar, size_bytes); + } + + // --------------------------------------------------------------------------- + // Unified streaming pass. Used for both phase A (kIsScatter=false) and + // phase C (kIsScatter=true). Each buffer is reused across iterations via the + // reuse-arrive trick (same pattern as ClusterTopKImpl::stage1). + // --------------------------------------------------------------------------- + + template + SGL_DEVICE static void stream_pass( + const float* scores, + const uint32_t length, + const uint32_t thr_bin, // ignored when !kIsScatter + int32_t* s_topk_indices, // ignored when !kIsScatter + Smem* smem) { + const auto tx = threadIdx.x; + const auto num_iters = (length + kSizePerStage - 1) / kSizePerStage; + const auto lane_id = tx % kWarpThreads; + + // Initial double-buffer TMA prologue. + const auto length_aligned = (length + 3u) & ~3u; + if (tx == 0) { +#pragma unroll + for (uint32_t i = 0; i < kNumStages; i++) { + if (i >= num_iters) break; + issue_tma(scores, i, length_aligned, smem); + } + } + + for (uint32_t iter = 0; iter < num_iters; iter++) { + const auto buf_idx = iter % kNumStages; + const auto offset = iter * kSizePerStage; + const auto this_size = min(kSizePerStage, length - offset); + + if (lane_id == 1) { + const auto phase_bit = (iter / kNumStages) & 1; + ptx::mbarrier_wait(&smem->barrier[kIsScatter][buf_idx], phase_bit); + } + __syncwarp(); + +#pragma unroll + for (uint32_t i = 0; i < kElemPerStage; i++) { + const auto local_idx = tx + i * kBlockSize; + if (local_idx >= this_size) break; + const auto score = smem->score_buffer[buf_idx][local_idx]; + const auto bin = extract_coarse_bin(score); + if constexpr (kIsScatter) { + const auto global_idx = offset + local_idx; + if (bin > thr_bin) { + const auto pos = atomicAdd(&smem->counter_gt, 1); + if (pos < K) s_topk_indices[pos] = global_idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (pos < kMaxTies) smem->tie_buffer[pos] = {global_idx, score}; + } + } else { + atomicAdd(&smem->histogram[bin], 1); + } + } + + __syncthreads(); + if (tx == 0) { + if (const auto next_iter = iter + kNumStages; next_iter < num_iters) { + issue_tma(scores, next_iter, length_aligned, smem); + } + } + } + } + + // --------------------------------------------------------------------------- + // Phase B: find the threshold bin via a warp-level prefix scan. + // Same structure as SmallTopKImpl's phase 2 (4 bins/thread, warp_sum relay). + // --------------------------------------------------------------------------- + + SGL_DEVICE static void find_threshold(uint32_t length, Smem* smem) { + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + uint32_t orig[kHistItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t local_sum = 0; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + orig[i] = hist_vec[i]; + local_sum += orig[i]; + } + + const auto warp_inc = warp_inclusive_sum(lane_id, local_sum); + const auto warp_exc = warp_inc - local_sum; + if (lane_id == kWarpThreads - 1) smem->warp_sum[warp_id] = warp_inc; + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exc; +#pragma unroll + for (uint32_t i = 0; i < kHistItems; ++i) { + prefix_sum += orig[i]; + const auto above = length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kHistItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + SGL_DEVICE static void run(const float* scores, const uint32_t length, int32_t* topk_indices, void* _smem) { + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + __builtin_assume(tx < kBlockSize); + + // Init histogram, barriers, counters. + { + HistVec zero; + zero.fill(0); + smem->histogram_vec[tx] = zero; + if (tx < 2 * kNumStages) { + const auto base_barrier = &smem->barrier[0][0]; + ptx::mbarrier_init(&base_barrier[tx], 1); + } + if (tx == 0) { + smem->counter_gt = 0; + smem->counter_eq = 0; + } + __syncthreads(); + } + + // Phase A: histogram pass (pipelined TMA stream). + stream_pass(scores, length, 0, nullptr, smem); + + // Phase B: locate threshold bin & re-init barriers + find_threshold(length, smem); + + // Phase C: scatter pass. + stream_pass(scores, length, smem->match.bin, topk_indices, smem); + } + + SGL_DEVICE static void transform(const TransformParams params, void* _smem) { + // Phase D: page-translate above entries, then refine ties. + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + const auto num_above = smem->match.above_count; + if (tx < num_above) params.transform(tx); + const auto num_equal = smem->counter_eq; + if (num_above >= K || num_equal == 0) return; + const auto clamped_ties = min(num_equal, kMaxTies); + tie_handle_transform(smem->tie_buffer, clamped_ties, num_above, K, params, &smem->stage2); + } +}; + +} // namespace device::top512 diff --git a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh index 486ea530a17a..a5abcdd4fa55 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/utils.cuh @@ -259,17 +259,27 @@ struct LaunchKernel { m_config.numAttrs = 0; #else if (enabled) { - m_attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - m_attrs[0].val.programmaticStreamSerializationAllowed = true; - m_config.numAttrs = 1; + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = true; m_config.attrs = m_attrs; - } else { - m_config.numAttrs = 0; } #endif return *this; } + auto enable_cluster(dim3 cluster_dim) -> LaunchKernel& { +#ifdef USE_ROCM + (void)cluster_dim; +#else + auto& attr = m_attrs[m_config.numAttrs++]; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {cluster_dim.x, cluster_dim.y, cluster_dim.z}; + m_config.attrs = m_attrs; +#endif + return *this; + } + template auto operator()(T&& kernel, Args&&... args) const -> void { #ifdef USE_ROCM @@ -303,7 +313,7 @@ struct LaunchKernel { cudaLaunchConfig_t m_config; const DebugInfo m_location; - cudaLaunchAttribute m_attrs[1]; + cudaLaunchAttribute m_attrs[2]; }; } // namespace host diff --git a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh index 769c16c08603..975065e035c9 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh @@ -3,6 +3,7 @@ #pragma once #include +#include namespace device::warp { @@ -16,6 +17,7 @@ static constexpr uint32_t kFullMask = 0xffffffffu; * `active_mask` using butterfly (XOR) shuffles. The result is * broadcast to all participating lanes. * + * \tparam kNumThreads Group size for the reduction (defaults to a full warp). * \tparam T Numeric type (e.g. float). * \param value Per-lane input value. * \param active_mask Bitmask of participating lanes (default: all 32). @@ -38,15 +40,18 @@ SGL_DEVICE T reduce_sum(T value, uint32_t active_mask = kFullMask) { * butterfly shuffles. The result is broadcast to all participating * lanes. * + * \tparam kNumThreads Group size for the reduction (defaults to a full warp). * \tparam T Numeric type (must be supported by `math::max`). * \param value Per-lane input value. * \param active_mask Bitmask of participating lanes (default: all 32). * \return The maximum across all active lanes. */ -template +template SGL_DEVICE T reduce_max(T value, uint32_t active_mask = kFullMask) { + static_assert(kNumThreads >= 1 && kNumThreads <= kWarpThreads); + static_assert(std::has_single_bit(kNumThreads), "must be pow of 2"); #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) + for (int mask = kNumThreads / 2; mask > 0; mask >>= 1) value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32)); return value; } diff --git a/python/sglang/jit_kernel/moe_fused_gate.py b/python/sglang/jit_kernel/moe_fused_gate.py new file mode 100644 index 000000000000..d0daad0a3180 --- /dev/null +++ b/python/sglang/jit_kernel/moe_fused_gate.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Tuple + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +_SCORING_FUNC_MAP = { + "sigmoid": 0, + "sqrtsoftplus": 1, +} + + +@cache_once +def _jit_moe_fused_gate_module() -> Module: + return load_jit( + "moe_fused_gate", + cuda_files=["moe/moe_fused_gate.cuh"], + cuda_wrappers=[("moe_fused_gate", "MoEFusedGateKernel::run")], + ) + + +@cache_once +def can_use_moe_fused_gate() -> bool: + logger = logging.getLogger(__name__) + try: + _jit_moe_fused_gate_module() + return True + except Exception as e: + logger.warning(f"Failed to load JIT MoE fused gate kernel: {e}") + return False + + +def moe_fused_gate( + input: torch.Tensor, + bias: torch.Tensor, + topk: int, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + renormalize: bool = True, + routed_scaling_factor: float = 1.0, + apply_routed_scaling_factor_on_output: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + scoring_func_int = _SCORING_FUNC_MAP.get(scoring_func.lower()) + assert ( + scoring_func_int is not None + ), f"Unknown scoring_func '{scoring_func}', must be one of {list(_SCORING_FUNC_MAP.keys())}" + + assert input.dtype == torch.float32, "input must be float32" + assert bias.dtype == torch.float32, "bias must be float32" + assert input.ndim == 2, "input must be 2D" + assert bias.ndim == 1, "bias must be 1D" + assert input.size(1) == bias.size(0), "input and bias must have same num_experts" + assert topk > num_fused_shared_experts, "topk must be > num_fused_shared_experts" + + num_rows, _ = input.shape + device = input.device + + output = torch.empty(num_rows, topk, dtype=torch.float32, device=device) + indices = torch.empty(num_rows, topk, dtype=torch.int32, device=device) + + module = _jit_moe_fused_gate_module() + module.moe_fused_gate( + input, + bias, + output, + indices, + topk, + scoring_func_int, + num_fused_shared_experts, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, + ) + + return output, indices diff --git a/python/sglang/srt/arg_groups/deepseek_v4_hook.py b/python/sglang/srt/arg_groups/deepseek_v4_hook.py new file mode 100644 index 000000000000..b3af8e95f82c --- /dev/null +++ b/python/sglang/srt/arg_groups/deepseek_v4_hook.py @@ -0,0 +1,86 @@ +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + + +def apply_deepseek_v4_defaults(server_args: "ServerArgs", model_arch: str) -> None: + """Apply DeepSeek V4 model-specific server arg defaults and constraints.""" + from sglang.srt.environ import envs + from sglang.srt.server_args import ServerArgs + + server_args.attention_backend = "dsv4" + server_args.page_size = 256 + logger.info( + f"Use dsv4 attention backend for {model_arch}, setting page_size to 256." + ) + + if server_args.max_running_requests is None: + server_args.max_running_requests = 256 + logger.warning( + f"Setting max_running_requests to {server_args.max_running_requests} for {model_arch}." + ) + + if server_args.kv_cache_dtype == "auto": + server_args.kv_cache_dtype = "fp8_e4m3" + logger.warning( + f"Setting KV cache dtype to {server_args.kv_cache_dtype} for {model_arch}." + ) + assert server_args.kv_cache_dtype in [ + "fp8_e4m3" + ], f"{server_args.kv_cache_dtype} is not supported for {model_arch}" + + if server_args.speculative_algorithm is not None: + assert ( + server_args.speculative_algorithm == "EAGLE" + ), f"Only EAGLE speculative algorithm is supported for {model_arch}" + assert ( + server_args.speculative_eagle_topk == 1 + ), f"Only EAGLE speculative algorithm with topk == 1 is supported for {model_arch}" + + if not envs.SGLANG_ENABLE_SPEC_V2.get(): + envs.SGLANG_ENABLE_SPEC_V2.set(True) + logger.warning("Spec v2 is enabled for EAGLE speculative decoding.") + + if server_args.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio: + server_args.swa_full_tokens_ratio = 0.1 + logger.info( + f"Setting swa_full_tokens_ratio to {server_args.swa_full_tokens_ratio} for {model_arch}." + ) + + if server_args.disaggregation_mode != "null" and server_args.pp_size > 1: + # get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized + # flat KV ptrs by PP layer range. + raise ValueError( + f"V4 PD disaggregation requires pp_size=1, got pp_size={server_args.pp_size}." + ) + + +def validate_deepseek_v4_cp(server_args: "ServerArgs") -> None: + """Validate DeepSeek V4 context-parallel configuration.""" + if not server_args.enable_nsa_prefill_context_parallel: + return + + if server_args.nsa_prefill_cp_mode != "round-robin-split": + raise ValueError( + f"DeepSeekV4 only supports round-robin-split CP mode, " + f"got {server_args.nsa_prefill_cp_mode}" + ) + + server_args.enable_dp_attention = True + server_args.moe_dense_tp_size = 1 + server_args.attn_cp_size = server_args.tp_size // server_args.dp_size + assert ( + server_args.dp_size == 1 + ), "For round-robin split mode, dp attention is not supported." + assert ( + server_args.tp_size <= 8 + ), "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." + logger.warning( + f"Enable Context Parallel for DeepSeekV4, " + f"dp_size={server_args.dp_size}, moe_dense_tp_size={server_args.moe_dense_tp_size}, " + f"attn_cp_size={server_args.attn_cp_size}, ep_size={server_args.ep_size}, tp_size={server_args.tp_size}" + ) diff --git a/python/sglang/srt/arg_groups/hisparse_hook.py b/python/sglang/srt/arg_groups/hisparse_hook.py new file mode 100644 index 000000000000..379d76dd79cf --- /dev/null +++ b/python/sglang/srt/arg_groups/hisparse_hook.py @@ -0,0 +1,95 @@ +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + + +# Backend/dtype pairing: flashmla_sparse only takes BF16 KV; +# flashmla_kv only supports FP8 (it always reads KV as FP8 via +# is_fp8_kvcache=True, inline-quantizing BF16 would defeat HiSparse). +_HISPARSE_ALLOWED_BACKENDS_BY_DTYPE = { + "bfloat16": {"flashmla_sparse"}, + "fp8_e4m3": {"flashmla_kv"}, +} + + +def _hisparse_default_backend(kv_cache_dtype: str) -> str: + return "flashmla_kv" if kv_cache_dtype == "fp8_e4m3" else "flashmla_sparse" + + +def apply_hisparse_nsa_backend_defaults( + server_args: "ServerArgs", + user_set_prefill: bool, + user_set_decode: bool, + kv_cache_dtype: str, +) -> bool: + """Pick NSA backends for --enable-hisparse based on KV dtype. + + BF16 KV -> flashmla_sparse, FP8 KV -> flashmla_kv. Returns True if hisparse + handled backend selection (caller should skip its own default logic). + """ + if not server_args.enable_hisparse: + return False + + backend = _hisparse_default_backend(kv_cache_dtype) + if not user_set_prefill: + server_args.nsa_prefill_backend = backend + if not user_set_decode: + server_args.nsa_decode_backend = backend + logger.warning( + f"HiSparse enabled ({kv_cache_dtype}): using NSA backends " + f"prefill={server_args.nsa_prefill_backend}, decode={server_args.nsa_decode_backend}." + ) + return True + + +def validate_hisparse(server_args: "ServerArgs") -> None: + """Validate --enable-hisparse constraints (model class, radix cache, NSA backend).""" + if not server_args.enable_hisparse: + return + + from sglang.srt.configs.model_config import ( + is_deepseek_nsa, + is_deepseek_v4, + ) + + hf_config = server_args.get_model_config().hf_config + is_v4_hisparse = is_deepseek_v4(hf_config) + assert is_deepseek_nsa(hf_config) or is_v4_hisparse, ( + "--enable-hisparse is only supported for DSA (DeepSeek Sparse Attention) " + "models (e.g., DeepSeek V3.2, GLM-5) and DeepSeek V4 now. " + ) + + assert ( + server_args.disable_radix_cache + ), "Hierarchical sparse attention currently requires --disable-radix-cache." + + # DSv4 hisparse handles its own dtype/backend pairing elsewhere; the dtype- + # aware checks below only apply to the DSA hisparse path. + if is_v4_hisparse: + return + + if server_args.kv_cache_dtype not in ("bfloat16", "auto", "fp8_e4m3"): + raise ValueError( + f"HiSparse requires bfloat16 or fp8_e4m3 KV cache, " + f"but got --kv-cache-dtype={server_args.kv_cache_dtype}. " + f"Please use --kv-cache-dtype=bfloat16 or fp8_e4m3." + ) + + allowed_backends = _HISPARSE_ALLOWED_BACKENDS_BY_DTYPE.get( + server_args.kv_cache_dtype, {"flashmla_sparse", "flashmla_kv"} + ) + for attr, label in [ + ("nsa_prefill_backend", "prefill"), + ("nsa_decode_backend", "decode"), + ]: + backend = getattr(server_args, attr) + if backend is not None and backend not in allowed_backends: + raise ValueError( + f"HiSparse with --kv-cache-dtype={server_args.kv_cache_dtype} requires " + f"--nsa-{label}-backend in {sorted(allowed_backends)}, " + f"but got {backend}." + ) diff --git a/python/sglang/srt/configs/deepseek_v4.py b/python/sglang/srt/configs/deepseek_v4.py new file mode 100644 index 000000000000..54c83b7dd184 --- /dev/null +++ b/python/sglang/srt/configs/deepseek_v4.py @@ -0,0 +1,110 @@ +import logging +import os +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +from transformers import PretrainedConfig + +from sglang.srt.layers.quantization.base_config import QuantizationConfig + +logger = logging.getLogger(__name__) + + +def try_detect_fp4_experts(model_path: str) -> Optional[bool]: + """True = mxfp4-packed (U8/I8/F4), False = converted FP8 (F8_E4M3), + None when the header isn't readable (HF slug not cached yet, etc.). + Caller falls back to user default. Pure read; never mutates env. + """ + from sglang.srt.model_loader.weight_utils import ( + probe_routed_expert_weight_dtype, + ) + from sglang.srt.utils import find_local_repo_dir + + if os.path.isdir(model_path): + local_path = model_path + else: + local_path = find_local_repo_dir(model_path) + if not local_path or not os.path.isdir(local_path): + return None + + try: + dtype = probe_routed_expert_weight_dtype(local_path) + except Exception as e: + logger.warning("Failed to probe routed-expert dtype for %s: %s", model_path, e) + return None + if dtype is None: + return None + if dtype in ("U8", "I8", "F4"): + return True + if dtype == "F8_E4M3": + return False + logger.warning( + "Unexpected routed-expert safetensors dtype=%s for DeepSeek V4", dtype + ) + return None + + +@dataclass(kw_only=True) +class DeepSeekV4Config(PretrainedConfig): + architectures: List[str] + attention_bias: bool = False + attention_dropout: float = 0.0 + bos_token_id: int = 0 + eos_token_id: int = 1 + ep_size: int = 1 + first_k_dense_replace: int = 0 + hidden_act: str = "silu" + hidden_size: int = 4096 + index_head_dim: int = 128 + index_n_heads: int = 64 + index_topk: int = 512 + initializer_range: float = 0.02 + intermediate_size: int = 2048 + kv_lora_rank: int = 512 + max_position_embeddings: int = 65536 + model_type: str = "deepseek_v4" + moe_intermediate_size: int = 2048 + moe_layer_freq: int = 1 + n_group: int = 8 + n_routed_experts: int = 256 + n_shared_experts: int = 1 + norm_topk_prob: bool = True + + num_attention_heads: int = 64 + num_experts_per_tok: int = 6 + num_hidden_layers: int = 43 + num_key_value_heads: int = 1 + + q_lora_rank: int = 1024 + qk_nope_head_dim: int = 448 + qk_rope_head_dim: int = 64 + + quantization_config: QuantizationConfig = field(default_factory=QuantizationConfig) + + rms_norm_eps: float = 1e-6 + + rope_scaling: Dict[str, float] = field(default_factory=dict) + rope_theta: int = 10000 + + routed_scaling_factor: float = 1.5 + scoring_func: str = "sqrtsoftplus" + + tie_word_embeddings: bool = False + + topk_group: int = 8 + topk_method: str = "noaux_tc" + + use_cache: bool = True + v_head_dim: int = 512 + vocab_size: int = 129280 + o_lora_rank: int = 1024 + o_groups: int = 8 + window_size: int = 128 + + compress_rope_theta: int = 40000 + compress_ratios: List[int] = field(default_factory=list) + + n_hash_layers: int = 3 + hc_mult: int = 4 + hc_sinkhorn_iters: int = 20 + hc_eps: float = 1e-6 diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ca494730b3d8..4c2a7231e760 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -114,8 +114,15 @@ def is_deepseek_nsa(config) -> bool: ) +def is_deepseek_v4(config) -> bool: + return _hf_arch(config) in ( + "DeepseekV4ForCausalLM", + "DeepseekV4ForCausalLMNextN", + ) + + def get_nsa_index_head_dim(config: PretrainedConfig) -> int: - assert is_deepseek_nsa(config) + assert is_deepseek_nsa(config) or is_deepseek_v4(config) return config.index_head_dim @@ -134,11 +141,15 @@ def get_num_indexer_layers(config) -> int: NSA models (V3.2) instantiate an Indexer on every transformer layer. With index_topk_freq > 1 some layers reuse prev layer's topk; those still - get a slot (mirrored at the MLA call site). Other architectures: set + get a slot (mirrored at the MLA call site). DSv4 has C4 indexers only on + layers whose compress_ratio == 4. Other architectures: set num_indexer_layers on hf_text_config; 0 disables the capturer. """ if is_deepseek_nsa(config): return config.num_hidden_layers + if is_deepseek_v4(config): + compress_ratios = getattr(config, "compress_ratios", None) or [] + return sum(1 for r in compress_ratios if r == 4) return getattr(config, "num_indexer_layers", 0) @@ -221,6 +232,34 @@ def __init__( # Config draft model self._config_draft_model() + # DSV4 expert layout: env (default True = mxfp4) applies only to V4. + # Other FP8 MoE models (for example DeepSeek V3.2) must keep the normal + # FP8 expert tensor layout. + self.is_fp4_experts: bool = False + if is_deepseek_v4(self.hf_config): + self.is_fp4_experts = envs.SGLANG_DSV4_FP4_EXPERTS.get() + if not envs.SGLANG_DSV4_FP4_EXPERTS.is_set(): + from sglang.srt.configs.deepseek_v4 import try_detect_fp4_experts + + detected = try_detect_fp4_experts(self.model_path) + if detected is not None: + self.is_fp4_experts = detected + logger.info( + "Auto-detected DSV4 routed-expert layout: is_fp4_experts=%s", + self.is_fp4_experts, + ) + + # HF config.json inherits topk_group=4 from the V3 template, but + # DSV4 trains with no group limiting (sqrtsoftplus + full-expert + # top-k). Force topk_group == n_group so deepseek_v2.py:531's + # `n_group > topk_group` evaluates False and routes to the + # ungrouped sqrtsoftplus path. The grouped impl only supports + # sigmoid scoring (topk.py:722) and would silently corrupt expert + # weights if hit. + n_group = getattr(self.hf_config, "n_group", None) + if n_group is not None: + self.hf_config.topk_group = n_group + # Check model type self.attention_chunk_size = getattr( self.hf_text_config, "attention_chunk_size", None @@ -367,6 +406,13 @@ def _config_draft_model(self): ]: self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + if ( + is_draft_model + and self.hf_config.architectures[0] == "DeepseekV4ForCausalLM" + ): + self.hf_config.architectures[0] = "DeepseekV4ForCausalLMNextN" + self.hf_config.num_nextn_predict_layers = 1 + if is_draft_model and self.hf_config.architectures[0] in [ "Glm4MoeForCausalLM", "Glm4MoeLiteForCausalLM", @@ -434,13 +480,21 @@ def _derive_hybrid_model(self): ) if self.is_hybrid_swa: - self.swa_attention_layer_ids, self.full_attention_layer_ids = ( - get_hybrid_layer_ids( - self.hf_config.architectures, - self.hf_text_config, - ) + logger.info(f"Hybrid swa model: {self.hf_config.architectures=}") + + self.is_deepseek_v4_arch = any( + arch in ["DeepseekV4ForCausalLM", "DeepseekV4ForCausalLMNextN"] + for arch in self.hf_config.architectures ) + if not self.is_deepseek_v4_arch: + self.swa_attention_layer_ids, self.full_attention_layer_ids = ( + get_hybrid_layer_ids( + self.hf_config.architectures, + self.hf_text_config, + ) + ) + self.has_attention_sinks = self._detect_attention_sinks() self.is_hybrid_swa_compress = self.hf_config.architectures[0] in [ @@ -571,6 +625,23 @@ def _derive_model_shapes(self): self.scaling = compute_mla_mscale_scaling( rope_scaling, self.scaling ) + elif ( + "DeepseekV4ForCausalLM" in self.hf_config.architectures + or "DeepseekV4ForCausalLMNextN" in self.hf_config.architectures + ): + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + self.qk_nope_head_dim = self.hf_config.head_dim - self.qk_rope_head_dim + self.window_size = self.hf_config.sliding_window + self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim = self.head_dim + self.index_head_dim = self.hf_config.index_head_dim + self.compress_ratios = self.hf_config.compress_ratios + self.attention_arch = AttentionArch.MHA + self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + if self.hf_config.rope_scaling: + self.scaling = compute_mla_mscale_scaling( + self.hf_config.rope_scaling, self.scaling + ) elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: self.head_dim = 128 self.attention_arch = AttentionArch.MLA @@ -1466,6 +1537,8 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal piecewise_cuda_graph_disabled_model_archs = [ "DeepseekV32ForCausalLM", + "DeepseekV4ForCausalLM", + "DeepseekV4ForCausalLMNextN", "Qwen3NextForCausalLM", "GlmMoeDsaForCausalLM", "BailingMoeV2_5ForCausalLM", @@ -1568,6 +1641,8 @@ def is_hybrid_swa_model(model_architectures: List[str]): hybrid_swa_archs = { "Llama4ForConditionalGeneration", + "DeepseekV4ForCausalLM", + "DeepseekV4ForCausalLMNextN", "GptOssForCausalLM", *MIMO_V2_MODEL_ARCHS, "MiMoV2MTP", diff --git a/python/sglang/srt/disaggregation/base/conn.py b/python/sglang/srt/disaggregation/base/conn.py index 87244e04451c..8dc32839ac30 100644 --- a/python/sglang/srt/disaggregation/base/conn.py +++ b/python/sglang/srt/disaggregation/base/conn.py @@ -31,7 +31,7 @@ class KVArgs: state_data_ptrs: List[int] state_data_lens: List[int] state_item_lens: List[int] - state_type: str # "none", "mamba", "swa" + state_type: str # "none", "mamba", "swa", "nsa", "dsv4" # for mamba state different tp slice transfer state_dim_per_tensor: List[int] # dimension to slice for each state tensor ib_device: str diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index d2895242afdd..5d3d64f1cff2 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -55,10 +55,8 @@ from sglang.srt.managers.schedule_policy import match_prefix_for_req from sglang.srt.managers.utils import GenerationBatchResult from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator -from sglang.srt.mem_cache.base_prefix_cache import ( - BasePrefixCache, - EvictParams, -) +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, EvictParams +from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool from sglang.srt.mem_cache.common import ( kv_to_page_indices, page_align_floor, @@ -71,7 +69,6 @@ NSATokenToKVPool, ReqToTokenPool, ) -from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.observability.req_time_stats import ( set_schedule_time_batch, set_time_batch, @@ -822,8 +819,7 @@ def pop_preallocated( .cpu() .numpy() ] - elif isinstance(self.token_to_kv_pool, SWAKVPool): - # SWA hybrid model: send decode-side SWA window indices + elif isinstance(self.token_to_kv_pool, BaseSWAKVPool): seq_len = len(decode_req.req.origin_input_ids) window_size = self.scheduler.sliding_window_size diff --git a/python/sglang/srt/disaggregation/nixl/conn.py b/python/sglang/srt/disaggregation/nixl/conn.py index cb21fd5c945b..97520499c310 100644 --- a/python/sglang/srt/disaggregation/nixl/conn.py +++ b/python/sglang/srt/disaggregation/nixl/conn.py @@ -1010,6 +1010,62 @@ def _send_mamba_state_slice( raise Exception("Failed to post Mamba state slice transfer") return xfer_handle + def _send_state_pages_flat( + self, + peer_name: str, + prefill_state_indices: List[int], + dst_state_data_ptrs: list[int], + dst_state_indices: List[int], + dst_state_item_lens: list[int], + dst_gpu_id: int, + notif: str, + ): + """Per-page WRITE transfer of a flat (heterogeneous) state pool. + + Used by V4 whose state pool is a flat list of buffers (SWA + compress + + indexer pools) that does not match the per-layer K/V layout assumed + by ``_send_kvcache_generic``. Both sides must have identical + ``state_item_lens`` (no TP-slicing path). + """ + src_state_ptrs = self.kv_args.state_data_ptrs + src_state_item_lens = self.kv_args.state_item_lens + assert len(src_state_ptrs) == len(dst_state_data_ptrs) + assert len(src_state_item_lens) == len(dst_state_item_lens) + assert len(prefill_state_indices) == len(dst_state_indices), ( + f"State index length mismatch: prefill={len(prefill_state_indices)}, " + f"dst={len(dst_state_indices)}" + ) + for i in range(len(src_state_item_lens)): + assert src_state_item_lens[i] == dst_state_item_lens[i], ( + f"V4 state item length mismatch at index {i}: " + f"{src_state_item_lens[i]} != {dst_state_item_lens[i]}" + ) + + src_addrs = [] + dst_addrs = [] + for i in range(len(src_state_ptrs)): + item_len = src_state_item_lens[i] + for src_idx, dst_idx in zip(prefill_state_indices, dst_state_indices): + src_addr = src_state_ptrs[i] + int(src_idx) * item_len + dst_addr = dst_state_data_ptrs[i] + int(dst_idx) * item_len + src_addrs.append((src_addr, item_len, self.kv_args.gpu_id)) + dst_addrs.append((dst_addr, item_len, dst_gpu_id)) + + if not src_addrs: + return None + + src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM") + dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM") + xfer_handle = self.agent.initialize_xfer( + "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii") + ) + if not xfer_handle: + raise Exception("KVSender failed to create state transfer") + state = self.agent.transfer(xfer_handle) + if state == "ERR": + raise Exception("KVSender failed to post state transfer") + return xfer_handle + def maybe_send_extra( self, peer_name: str, @@ -1068,6 +1124,16 @@ def maybe_send_extra( dst_gpu_id=dst_gpu_id, notif=notif, ) + elif state_type == "dsv4": + return self._send_state_pages_flat( + peer_name, + prefill_state_indices, + dst_state_data_ptrs, + dst_state_indices, + dst_state_item_lens or [], + dst_gpu_id, + notif, + ) else: if state_type != "none": raise RuntimeError( diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 6db94de40055..bcf77768c436 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -48,6 +48,7 @@ Req, ScheduleBatch, ) +from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool from sglang.srt.mem_cache.common import ( kv_to_page_indices, kv_to_page_num, @@ -55,7 +56,6 @@ release_kv_cache, ) from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool -from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.observability.req_time_stats import set_schedule_time_batch if TYPE_CHECKING: @@ -787,7 +787,9 @@ def send_kv_chunk( .cpu() .numpy() ] - elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool): + elif isinstance( + self.token_to_kv_pool_allocator.get_kvcache(), BaseSWAKVPool + ): # SWA hybrid model: send last window KV indices seq_len = len(req.fill_ids) window_size = self.sliding_window_size diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 26817b3de065..37cb474228ba 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -526,9 +526,10 @@ def filter_kv_indices_for_cp_rank( def is_mla_backend(target_kv_pool) -> bool: + from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool - return isinstance(target_kv_pool, MLATokenToKVPool) + return isinstance(target_kv_pool, (MLATokenToKVPool, DeepSeekV4TokenToKVPool)) def setup_state_kv_args( @@ -541,8 +542,9 @@ def setup_state_kv_args( Shared by prefill and decode bootstrap paths so the state_type dispatch lives in one place. """ + from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool + from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool - from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool if not hasattr(token_to_kv_pool, "get_state_buf_infos"): kv_args.state_data_ptrs = [] @@ -558,7 +560,12 @@ def setup_state_kv_args( kv_args.state_data_lens = state_data_lens kv_args.state_item_lens = state_item_lens - if isinstance(token_to_kv_pool, SWAKVPool): + # V4 must be checked before BaseSWAKVPool: V4's state pool is a flat + # heterogeneous list (SWA + compress + indexer), so the per-layer K/V + # transfer path used for "swa"/"nsa" does not apply. + if isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool): + kv_args.state_type = "dsv4" + elif isinstance(token_to_kv_pool, BaseSWAKVPool): kv_args.state_type = "swa" elif isinstance(token_to_kv_pool, HybridLinearKVPool): kv_args.state_type = "mamba" diff --git a/python/sglang/srt/entrypoints/openai/encoding_dsv4.py b/python/sglang/srt/entrypoints/openai/encoding_dsv4.py new file mode 100644 index 000000000000..0822c9ad15e7 --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/encoding_dsv4.py @@ -0,0 +1,850 @@ +# Adapted from the DeepSeek-V4 release reference implementation. +""" +DeepSeek-V4 Encoding + +A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages +with tool calling, thinking mode, and quick instruction task support. +""" + +import copy +import json +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +# ============================================================ +# Special Tokens +# ============================================================ + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" + +USER_SP_TOKEN = "<|User|>" +ASSISTANT_SP_TOKEN = "<|Assistant|>" +LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>" + +# Task special tokens for internal classification tasks +DS_TASK_SP_TOKENS = { + "action": "<|action|>", + "query": "<|query|>", + "authority": "<|authority|>", + "domain": "<|domain|>", + "title": "<|title|>", + "read_url": "<|read_url|>", +} +VALID_TASKS = set(DS_TASK_SP_TOKENS.keys()) + +# ============================================================ +# Templates +# ============================================================ + +system_msg_template: str = "{content}" +user_msg_template: str = "{content}" +latest_reminder_msg_template: str = "{content}" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token +assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}" +thinking_template: str = "{reasoning_content}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = ( + '<{dsml_token}invoke name="{name}">\n{arguments}\n' +) +tool_calls_template = ( + "<{dsml_token}{tc_block_name}>\n{tool_calls}\n" +) +tool_calls_block_name: str = "tool_calls" + +tool_output_template: str = "{content}" + +REASONING_EFFORT_MAX = ( + "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n" + "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n" + "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n" +) + +TOOLS_TEMPLATE = """## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following: + +<{dsml_token}tool_calls> +<{dsml_token}invoke name="$TOOL_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response. + +Otherwise, output directly after {thinking_end_token} with tool calls or final response. + +### Available Tool Schemas + +{tool_schemas} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +""" + +# ============================================================ +# Utility Functions +# ============================================================ + + +def to_json(value: Any) -> str: + """Serialize a value to JSON string.""" + try: + return json.dumps(value, ensure_ascii=False) + except: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + """Extract function definitions from OpenAI-format tool list.""" + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + """Convert OpenAI-format tool calls to internal format.""" + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + """Convert internal tool calls to OpenAI format.""" + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str: + """ + Encode tool call arguments into DSML parameter format. + + Args: + tool_call: Dict with "name" and "arguments" (JSON string) keys. + + Returns: + DSML-formatted parameter string. + """ + p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' + P_dsml_strs = [] + + try: + arguments = json.loads(tool_call["arguments"]) + except Exception as err: + arguments = {"arguments": tool_call["arguments"]} + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments( + tool_name: str, tool_args: Dict[str, Tuple[str, str]] +) -> Dict[str, str]: + """ + Decode DSML parameters back to a tool call dict. + + Args: + tool_name: Name of the tool. + tool_args: Dict mapping param_name -> (value, is_string_flag). + + Returns: + Dict with "name" and "arguments" (JSON string) keys. + """ + + def _decode_value(key: str, value: str, string: str): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = ( + "{" + + ", ".join( + [_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()] + ) + + "}" + ) + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str: + """ + Render tool schemas into the system prompt format. + + Args: + tools: List of tool schema dicts (each with name, description, parameters). + + Returns: + Formatted tools section string. + """ + tools_json = [to_json(t) for t in tools] + + return TOOLS_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages: List[Dict[str, Any]]) -> int: + """Find the index of the last user/developer message.""" + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +def attach_task_to_last_user_message(messages: List[Dict[str, Any]], task: str) -> None: + """Set `task` on the most recent user/developer message; raise if none exists.""" + idx = find_last_user_index(messages) + if idx == -1: + raise ValueError( + "`task` requires at least one message with role='user' or 'developer'." + ) + messages[idx]["task"] = task + + +# ============================================================ +# Message Rendering +# ============================================================ + + +def render_message( + index: int, + messages: List[Dict[str, Any]], + thinking_mode: str, + drop_thinking: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + """ + Render a single message at the given index into its encoded string form. + + This is the core function that converts each message in the conversation + into the DeepSeek-V4 format. + + Args: + index: Index of the message to render. + messages: Full list of messages in the conversation. + thinking_mode: Either "chat" or "thinking". + drop_thinking: Whether to drop reasoning content from earlier turns. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + Encoded string for this message. + """ + assert 0 <= index < len(messages) + assert thinking_mode in [ + "chat", + "thinking", + ], f"Invalid thinking_mode `{thinking_mode}`" + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning_content = msg.get("reasoning_content") + wo_eos = msg.get("wo_eos", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + # Reasoning effort prefix (only at index 0 in thinking mode with max effort) + assert reasoning_effort in [ + "max", + None, + "high", + ], f"Invalid reasoning effort: {reasoning_effort}" + if index == 0 and thinking_mode == "thinking" and reasoning_effort == "max": + prompt += REASONING_EFFORT_MAX + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + if response_format: + prompt += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + elif role == "developer": + assert content, f"Invalid message for role `{role}`: {msg}" + + content_developer = USER_SP_TOKEN + content_developer += content + + if tools: + content_developer += "\n\n" + render_tools(tools) + if response_format: + content_developer += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + prompt += user_msg_template.format(content=content_developer) + + elif role == "user": + prompt += USER_SP_TOKEN + + # Handle content blocks (tool results mixed with text) + content_blocks = msg.get("content_blocks") + if content_blocks: + parts = [] + for block in content_blocks: + block_type = block.get("type") + if block_type == "text": + parts.append(block.get("text", "")) + elif block_type == "tool_result": + tool_content = block.get("content", "") + if isinstance(tool_content, list): + text_parts = [] + for b in tool_content: + if b.get("type") == "text": + text_parts.append(b.get("text", "")) + else: + text_parts.append(f"[Unsupported {b.get('type')}]") + tool_content = "\n\n".join(text_parts) + parts.append(tool_output_template.format(content=tool_content)) + else: + parts.append(f"[Unsupported {block_type}]") + prompt += "\n\n".join(parts) + else: + prompt += content or "" + + elif role == "latest_reminder": + prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format( + content=content + ) + + elif role == "tool": + raise NotImplementedError( + "deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()" + ) + + elif role == "assistant": + thinking_part = "" + tc_content = "" + + if tool_calls: + tc_list = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tc.get("name"), + arguments=encode_arguments_to_dsml(tc), + ) + for tc in tool_calls + ] + tc_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, + tool_calls="\n".join(tc_list), + tc_block_name=tool_calls_block_name, + ) + + summary_content = content or "" + rc = reasoning_content or "" + + # Check if previous message has a task - if so, this is a task output (no thinking) + prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None + + if thinking_mode == "thinking" and not prev_has_task: + if not drop_thinking or index > last_user_idx: + thinking_part = ( + thinking_template.format(reasoning_content=rc) + thinking_end_token + ) + else: + thinking_part = "" + + if wo_eos: + prompt += assistant_msg_wo_eos_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + # Append transition tokens based on what follows + if index + 1 < len(messages) and messages[index + 1].get("role") not in [ + "assistant", + "latest_reminder", + ]: + return prompt + + task = messages[index].get("task") + if task is not None: + # Task special token for internal classification tasks + assert ( + task in VALID_TASKS + ), f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}" + task_sp_token = DS_TASK_SP_TOKENS[task] + + if task != "action": + # Non-action tasks: append task sp token directly after the message + prompt += task_sp_token + else: + # Action task: append Assistant + thinking token + action sp token + prompt += ASSISTANT_SP_TOKEN + prompt += ( + thinking_end_token + if thinking_mode != "thinking" + else thinking_start_token + ) + prompt += task_sp_token + + elif messages[index].get("role") in ["user", "developer"]: + # Normal generation: append Assistant + thinking token + prompt += ASSISTANT_SP_TOKEN + if not drop_thinking and thinking_mode == "thinking": + prompt += thinking_start_token + elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx: + prompt += thinking_start_token + else: + prompt += thinking_end_token + + return prompt + + +# ============================================================ +# Preprocessing +# ============================================================ + + +def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merge tool messages into the preceding user message using content_blocks format. + + DeepSeek-V4 does not have a standalone "tool" role; instead, tool results + are encoded as blocks within user messages. + + This function converts a standard OpenAI-format conversation (with separate + "tool" role messages) into V4 format where tool results are merged into + user messages. + + Args: + messages: List of message dicts in OpenAI format. + + Returns: + Processed message list with tool messages merged into user messages. + """ + merged: List[Dict[str, Any]] = [] + + for msg in messages: + msg = copy.deepcopy(msg) + role = msg.get("role") + + if role == "tool": + # Convert tool message to a user message with tool_result block + tool_block = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": msg.get("content", ""), + } + # Merge into previous message if it's already a user (merged tool) + if ( + merged + and merged[-1].get("role") == "user" + and "content_blocks" in merged[-1] + ): + merged[-1]["content_blocks"].append(tool_block) + else: + merged.append( + { + "role": "user", + "content_blocks": [tool_block], + } + ) + elif role == "user": + text_block = {"type": "text", "text": msg.get("content", "")} + if ( + merged + and merged[-1].get("role") == "user" + and "content_blocks" in merged[-1] + and merged[-1].get("task") is None + ): + merged[-1]["content_blocks"].append(text_block) + else: + new_msg = { + "role": "user", + "content": msg.get("content", ""), + "content_blocks": [text_block], + } + # Preserve extra fields (task, wo_eos, mask, etc.) + for key in ("task", "wo_eos", "mask"): + if key in msg: + new_msg[key] = msg[key] + merged.append(new_msg) + else: + merged.append(msg) + + return merged + + +def sort_tool_results_by_call_order( + messages: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """ + Sort tool_result blocks within user messages by the order of tool_calls + in the preceding assistant message. + + Args: + messages: Preprocessed message list (after merge_tool_messages). + + Returns: + Message list with sorted tool result blocks. + """ + last_tool_call_order: Dict[str, int] = {} + + for msg in messages: + role = msg.get("role") + if role == "assistant" and msg.get("tool_calls"): + last_tool_call_order = {} + for idx, tc in enumerate(msg["tool_calls"]): + tc_id = tc.get("id") or tc.get("function", {}).get("id", "") + if tc_id: + last_tool_call_order[tc_id] = idx + + elif role == "user" and msg.get("content_blocks"): + tool_blocks = [ + b for b in msg["content_blocks"] if b.get("type") == "tool_result" + ] + if len(tool_blocks) > 1 and last_tool_call_order: + sorted_blocks = sorted( + tool_blocks, + key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0), + ) + sorted_idx = 0 + new_blocks = [] + for block in msg["content_blocks"]: + if block.get("type") == "tool_result": + new_blocks.append(sorted_blocks[sorted_idx]) + sorted_idx += 1 + else: + new_blocks.append(block) + msg["content_blocks"] = new_blocks + + return messages + + +# ============================================================ +# Main Encoding Function +# ============================================================ + + +def encode_messages( + messages: List[Dict[str, Any]], + thinking_mode: str, + context: Optional[List[Dict[str, Any]]] = None, + drop_thinking: bool = True, + add_default_bos_token: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + """ + Encode a list of messages into the DeepSeek-V4 prompt format. + + This is the main entry point for encoding conversations. It handles: + - BOS token insertion + - Thinking mode with optional reasoning content dropping + - Tool message merging into user messages + - Multi-turn conversation context + + Args: + messages: List of message dicts to encode. + thinking_mode: Either "chat" or "thinking". + context: Optional preceding context messages (already encoded prefix). + drop_thinking: If True, drop reasoning_content from earlier assistant turns + (only keep reasoning for messages after the last user message). + add_default_bos_token: Whether to prepend BOS token at conversation start. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + The encoded prompt string. + """ + context = context if context else [] + + # Preprocess: merge tool messages and sort tool results + messages = merge_tool_messages(messages) + messages = sort_tool_results_by_call_order(context + messages)[len(context) :] + if context: + context = merge_tool_messages(context) + context = sort_tool_results_by_call_order(context) + + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + # Resolve drop_thinking: if any message has tools defined, don't drop thinking + effective_drop_thinking = drop_thinking + if any(m.get("tools") for m in full_messages): + effective_drop_thinking = False + + if thinking_mode == "thinking" and effective_drop_thinking: + full_messages = _drop_thinking_messages(full_messages) + # After dropping, recalculate how many messages to render + # (context may have shrunk too) + num_to_render = len(full_messages) - len(_drop_thinking_messages(context)) + context_len = len(full_messages) - num_to_render + else: + num_to_render = len(messages) + context_len = len(context) + + for idx in range(num_to_render): + prompt += render_message( + idx + context_len, + full_messages, + thinking_mode=thinking_mode, + drop_thinking=effective_drop_thinking, + reasoning_effort=reasoning_effort, + ) + + return prompt + + +def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Drop reasoning_content and non-essential messages before the last user message. + + Behavior: + - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept. + - Messages at or after the last user index are always kept. + - Assistant messages before the last user get reasoning_content removed. + - Developer messages before the last user are dropped entirely. + """ + last_user_idx = find_last_user_index(messages) + result = [] + keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"} + + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in keep_roles or idx >= last_user_idx: + result.append(msg) + elif role == "assistant": + msg = copy.copy(msg) + msg.pop("reasoning_content", None) + result.append(msg) + # developer and other roles before last_user_idx are dropped + + return result + + +# ============================================================ +# Parsing (Decoding model output) +# ============================================================ + + +def _read_until_stop( + index: int, text: str, stop: List[str] +) -> Tuple[int, str, Optional[str]]: + """ + Read text from index until one of the stop strings is found. + + Returns: + Tuple of (new_index, content_before_stop, matched_stop_string_or_None). + """ + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls( + index: int, text: str +) -> Tuple[int, Optional[str], List[Dict[str, str]]]: + """ + Parse DSML tool calls from text starting at the given index. + + Args: + index: Starting position in text. + text: The full text to parse. + + Returns: + Tuple of (new_index, last_stop_token, list_of_tool_call_dicts). + Each tool call dict has "name" and "arguments" keys. + """ + tool_calls: List[Dict[str, Any]] = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}invoke", tool_calls_end_token] + ) + if _ != ">\n": + raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise ValueError("Missing special token in tool calls") + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL + ) + if len(p_tool_name) != 1: + raise ValueError(f"Tool name format error: '{tool_name_content}'") + tool_name = p_tool_name[0] + + tool_args: Dict[str, Tuple[str, str]] = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop( + index, text, [f"/{dsml_token}parameter"] + ) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + if len(param_kv) != 1: + raise ValueError(f"Parameter format error: '{param_content}'") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise ValueError(f"Duplicate parameter name: '{param_name}'") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n": + raise ValueError( + f"Parameter format error: expected '>\\n' but got '{content}'" + ) + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]: + """ + Parse a model completion text into a structured assistant message. + + This function takes the raw text output from the model (a single assistant turn) + and extracts: + - reasoning_content (thinking block) + - content (summary/response) + - tool_calls (if any) + + NOTE: This function is designed to parse only correctly formatted strings and + will raise ValueError for malformed output. + + Args: + text: The raw completion text (including EOS token). + thinking_mode: Either "chat" or "thinking". + + Returns: + Dict with keys: "role", "content", "reasoning_content", "tool_calls". + tool_calls are in OpenAI format. + """ + summary_content, reasoning_content, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}" + + is_thinking = thinking_mode == "thinking" + is_tool_calling = False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop( + index, text, [thinking_end_token, tool_calls_start_token] + ) + reasoning_content = content_delta + assert ( + stop_token == thinking_end_token + ), "Invalid thinking format: missing " + + index, content_delta, stop_token = _read_until_stop( + index, text, [eos_token, tool_calls_start_token] + ) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + assert stop_token == eos_token, "Invalid format: missing EOS token" + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + assert not tool_ends_text, "Unexpected content after tool calls" + + assert len(text) == index and stop_token in [ + eos_token, + None, + ], "Unexpected content at end" + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + assert ( + sp_token not in summary_content and sp_token not in reasoning_content + ), f"Unexpected special token '{sp_token}' in content" + + return { + "role": "assistant", + "content": summary_content, + "reasoning_content": reasoning_content, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 7bf3581cfdba..ab47fd7f9836 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -17,7 +17,17 @@ import time import uuid from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TypeAlias, Union +from typing import ( + Any, + Dict, + List, + NamedTuple, + Optional, + Tuple, + TypeAlias, + Union, + get_args, +) from openai.types.responses import ( ResponseFunctionToolCall, @@ -500,8 +510,14 @@ class ToolCall(BaseModel): function: FunctionResponse +_GenericMessageRole = Literal[ + "system", "assistant", "tool", "function", "developer", "latest_reminder" +] +_GENERIC_MESSAGE_ROLES: Tuple[str, ...] = get_args(_GenericMessageRole) + + class ChatCompletionMessageGenericParam(BaseModel): - role: Literal["system", "assistant", "tool", "function", "developer"] + role: _GenericMessageRole content: Union[str, List[ChatCompletionMessageContentPart], None] = Field( default=None ) @@ -516,10 +532,9 @@ class ChatCompletionMessageGenericParam(BaseModel): def _normalize_role(cls, v): if isinstance(v, str): v_lower = v.lower() - if v_lower not in {"system", "assistant", "tool", "function", "developer"}: - raise ValueError( - "'role' must be one of 'system', 'developer', 'assistant', 'tool', or 'function' (case-insensitive)." - ) + if v_lower not in _GENERIC_MESSAGE_ROLES: + allowed = ", ".join(repr(r) for r in _GENERIC_MESSAGE_ROLES) + raise ValueError(f"'role' must be one of {allowed} (case-insensitive).") return v_lower raise ValueError("'role' must be a string") @@ -626,6 +641,15 @@ class ChatCompletionRequest(BaseModel): "in a response. 'none' defaults thinking and enable_thinking to false in " "chat_template_kwargs (unless explicitly overridden). Not supported in the harmony path.", ) + task: Optional[ + Literal["action", "query", "authority", "domain", "title", "read_url"] + ] = Field( + default=None, + description="DeepSeek-V4 quick instruction task. When set, the last " + "user/developer message is treated as a single-shot classification prompt " + "and the corresponding task special token (e.g. `<|domain|>`) is appended " + "before generation. Only honored by the dsv4 chat encoder; ignored otherwise.", + ) # Extra parameters for SRT backend only and will be ignored by OpenAI models. top_k: Optional[int] = None diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 7447262bfbe5..7bf381d629e4 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -15,7 +15,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse from jsonschema import Draft202012Validator, SchemaError -from sglang.srt.entrypoints.openai.encoding_dsv32 import encode_messages +from sglang.srt.entrypoints.openai import encoding_dsv4, encoding_dsv32 from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -46,6 +46,7 @@ should_include_usage, to_openai_style_logprobs, ) +from sglang.srt.environ import envs from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.json_array_parser import JsonArrayParser @@ -232,7 +233,9 @@ def __init__( and self.tokenizer_manager.model_config.hf_config.model_type == "gemma4" ) - self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding() + # Which Python-based chat encoder (if any) bypasses apply_chat_template. + # Values: "dsv32", "dsv4", or None. + self.chat_encoding_spec = self._resolve_chat_encoding_spec() def _handle_last_assistant_message( self, @@ -290,14 +293,25 @@ def _append_assistant_prefix_to_prompt_ids( encoded = encoded[1:] return prompt_ids + encoded - def _use_dpsk_v32_encoding(self) -> bool: + def _resolve_chat_encoding_spec(self) -> Optional[str]: + if self.tool_call_parser == "deepseekv4": + return "dsv4" + if self.tool_call_parser == "deepseekv32": + return "dsv32" + + architectures = self.tokenizer_manager.model_config.hf_config.architectures + arch = architectures[0] if architectures else "" + + if "DeepseekV4" in arch: + return "dsv4" + has_chat_template = ( self.tokenizer_manager.tokenizer is not None and self.tokenizer_manager.tokenizer.chat_template is not None ) - architectures = self.tokenizer_manager.model_config.hf_config.architectures - is_dpsk_v32 = "DeepseekV3" in architectures[0] if architectures else False - return not has_chat_template and is_dpsk_v32 + if "DeepseekV3" in arch and not has_chat_template: + return "dsv32" + return None def _request_id_prefix(self) -> str: return "chatcmpl-" @@ -515,14 +529,22 @@ def _apply_jinja_template( template_content_format = self.template_manager.jinja_template_content_format - if self.use_dpsk_v32_encoding: - thinking_mode = ( - "thinking" - if (request.chat_template_kwargs or {}).get("thinking") - else "chat" + if self.chat_encoding_spec is not None: + # Per-request wins; env is fallback default for benchmark + # workflows that can't pass per-request chat_template_kwargs. + thinking_requested = (request.chat_template_kwargs or {}).get( + "thinking", envs.SGLANG_DEFAULT_THINKING.get() ) - messages = request.messages - messages = [msg.model_dump() for msg in messages] + thinking_mode = "thinking" if thinking_requested else "chat" + messages = [msg.model_dump() for msg in request.messages] + + # dsv4/dsv32 are text-only and consume string content; flatten + # OpenAI parts-list content here so the encoder sees a plain string. + for i, msg in enumerate(messages): + if isinstance(msg.get("content"), list): + messages[i] = process_content_for_template_format( + msg, "string", [], [], [], [] + ) for msg in messages: if msg.get("content") is None: @@ -534,7 +556,7 @@ def _apply_jinja_template( video_data, audio_data, modalities, - use_dpsk_v32_encoding=self.use_dpsk_v32_encoding, + use_dpsk_v32_encoding=self.chat_encoding_spec == "dsv32", ) msg.update(processed_msg) @@ -548,7 +570,32 @@ def _apply_jinja_template( messages.insert(0, {"role": "system", "content": ""}) if request.tools: messages[0]["tools"] = [tool.model_dump() for tool in request.tools] - real_input = encode_messages(messages, thinking_mode=thinking_mode) + + if self.chat_encoding_spec == "dsv4": + # V4 encoder only accepts "max" / "high" / None. + # OpenAI protocol defaults to "medium" which V4 rejects; drop it. + # Fallback: if request didn't set it, try env SGLANG_DSV4_REASONING_EFFORT. + effort_source = request.reasoning_effort + if effort_source is None: + env_val = envs.SGLANG_DSV4_REASONING_EFFORT.get() + if env_val: + effort_source = env_val + v4_reasoning_effort = ( + effort_source if effort_source in ("max", "high") else None + ) + if request.task is not None: + encoding_dsv4.attach_task_to_last_user_message( + messages, request.task + ) + real_input = encoding_dsv4.encode_messages( + messages, + thinking_mode=thinking_mode, + reasoning_effort=v4_reasoning_effort, + ) + else: + real_input = encoding_dsv32.encode_messages( + messages, thinking_mode=thinking_mode + ) prompt_ids = self.tokenizer_manager.tokenizer.encode(real_input) # Append assistant prefix if continue_final_message is enabled diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 53bc5927963d..3167b8ee2ca2 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -553,6 +553,64 @@ class Envs: # TokenizerManager SGLANG_REQUEST_STATE_WAIT_TIMEOUT = EnvInt(4) + SGLANG_DEFAULT_THINKING = EnvBool(False) + + # ==================================================================== + # DeepSeek V4 + # ==================================================================== + + # Set False when using FP4-to-FP8 converted DeepSeek V4 checkpoint. + SGLANG_DSV4_FP4_EXPERTS = EnvBool(True) + # Default reasoning_effort for dsv4 chat encoder when request doesn't set it. + # Accepts "", "max", "high" (empty string means unset); other values filtered to None. + SGLANG_DSV4_REASONING_EFFORT = EnvStr("") + + # CUDA kernels + SGLANG_OPT_DEEPGEMM_HC_PRENORM = EnvBool(True) + SGLANG_OPT_USE_TILELANG_MHC_PRE = EnvBool(True) + SGLANG_OPT_USE_TILELANG_MHC_POST = EnvBool(True) + SGLANG_OPT_USE_TILELANG_INDEXER = EnvBool(False) + SGLANG_OPT_USE_JIT_INDEXER_METADATA = EnvBool(False) + SGLANG_OPT_USE_ONLINE_COMPRESS = EnvBool(False) + SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) + SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) + + # SWA radix cache + SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True) + # TODO(DSV4): @ispobock this has bug on main branch when retract + SGLANG_OPT_SWA_RADIX_CACHE_COMPACT = EnvBool(False) + SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT = EnvBool(False) + SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW = EnvBool(False) + + # DeepGemm Mega MoE + SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE = EnvBool(False) + SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK = EnvInt(1024) + SGLANG_OPT_FIX_MEGA_MOE_MEMORY = EnvBool(False) + + # TopK + SGLANG_OPT_USE_FUSED_HASH_TOPK = EnvBool(True) + SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK = EnvBool(True) + SGLANG_OPT_USE_TOPK_V2 = EnvBool(False) + + # GEMM / kernel fusion + SGLANG_OPT_FP8_WO_A_GEMM = EnvBool(False) + SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvStr("cublas") + SGLANG_OPT_USE_JIT_EP_ACTIVATION = EnvBool(True) + SGLANG_OPT_USE_JIT_NORM = EnvBool(False) + SGLANG_OPT_FUSE_WQA_WKV = EnvBool(True) + SGLANG_OPT_SWIGLU_CLAMP_FUSION = EnvBool(True) + + # Cache / overlap + SGLANG_OPT_USE_FUSED_STORE_CACHE = EnvBool(True) + SGLANG_OPT_USE_OVERLAP_STORE_CACHE = EnvBool(True) + SGLANG_OPT_USE_MULTI_STREAM_OVERLAP = EnvBool(True) + + # CUDA graph + SGLANG_PREP_IN_CUDA_GRAPH = EnvBool(True) + + # Distributed + SGLANG_DSV4_FIX_TP_ATTN_A2A_SCATTER = EnvBool(True) + # Symmetric Memory SGLANG_SYMM_MEM_PREALLOC_GB_SIZE = EnvInt(-1) SGLANG_DEBUG_SYMM_MEM = EnvBool(False) @@ -611,6 +669,8 @@ def _convert_SGL_to_SGLANG(): "SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK", ) _print_deprecated_env("SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2") + _print_deprecated_env("SGLANG_ENABLE_THINKING", "SGLANG_DEFAULT_THINKING") + _print_deprecated_env("SGLANG_REASONING_EFFORT", "SGLANG_DSV4_REASONING_EFFORT") _print_deprecated_env( "SGLANG_USE_JIT_ALL_REDUCE", "SGLANG_OPT_USE_CUSTOM_ALL_REDUCE_V2" ) diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index c9524ac5ad41..8123da1972be 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -13,6 +13,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.deepseekv4_detector import DeepSeekV4Detector from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.gemma4_detector import Gemma4Detector @@ -55,6 +56,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, "deepseekv32": DeepSeekV32Detector, + "deepseekv4": DeepSeekV4Detector, "glm": Glm4MoeDetector, "glm45": Glm4MoeDetector, "glm47": Glm47MoeDetector, diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index e870d8f77e5c..d5e3286aa716 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -92,6 +92,15 @@ def create_nsa_backend(runner): return NativeSparseAttnBackend(runner) +@register_attention_backend("dsv4") +def create_dsv4_backend(runner): + from sglang.srt.layers.attention.deepseek_v4_backend import ( + DeepseekV4AttnBackend, + ) + + return DeepseekV4AttnBackend(runner) + + @register_attention_backend("triton") def create_triton_backend(runner): assert not runner.model_config.is_encoder_decoder, ( diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py new file mode 100644 index 000000000000..93e4507656c1 --- /dev/null +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -0,0 +1,1255 @@ +from __future__ import annotations + +import enum +import functools +import logging +from dataclasses import dataclass, field +from typing import ( + TYPE_CHECKING, + Dict, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, +) + +import torch +import torch.nn.functional as F + +from sglang.srt.environ import envs +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.dsv4.compressor import ( + CompressorBackendMixin, + FusedCompressMetadata, + create_paged_compressor_data, +) +from sglang.srt.layers.attention.dsv4.indexer import C4IndexerBackendMixin +from sglang.srt.layers.attention.dsv4.metadata import ( + PagedIndexerMetadata, + copy_metadata, + maybe_copy_inplace, +) +from sglang.srt.layers.attention.dsv4.metadata_kernel import ( + init_compression_metadata as _init_compression_metadata_triton, +) +from sglang.srt.layers.attention.dsv4.quant_k_cache import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.layers.dp_attention import ( + get_attention_cp_rank, + get_attention_cp_size, +) +from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.spec_info import SpecInput +from sglang.srt.utils import ceil_align + +if TYPE_CHECKING: + from flash_mla.flash_mla_interface import FlashMLASchedMeta + + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + +SWA_WINDOW = 128 +C4_TOPK = 512 +PAGE_INDEX_ALIGNED_SIZE = 64 + + +T = TypeVar("T", bound=Optional[torch.Tensor]) + + +def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T: + if x is None: + return None + curr_size = x.shape[-1] + target_size = ceil_align(curr_size, multiples_of) + return F.pad(x, pad=(0, target_size - curr_size), mode="constant", value=-1) + + +def _create_flashmla_metadata(): + import flash_mla + + return flash_mla.get_mla_metadata()[0] + + +def _create_dummy_paged_compress_data(compress_ratio: int): + return None + + +@dataclass +class DSV4AttnMetadata: + page_size: int + page_table: torch.Tensor + raw_out_loc: torch.Tensor + cuda_int32_kwargs: dict + + seq_lens_casual: torch.Tensor + positions_casual: torch.Tensor + + swa_page_indices: torch.Tensor + swa_topk_lengths: torch.Tensor + + c4_sparse_topk: int + c4_out_loc: Optional[torch.Tensor] = None + c4_topk_lengths_raw: Optional[torch.Tensor] = None + c4_topk_lengths_clamp1: Optional[torch.Tensor] = None + c4_sparse_topk_lengths: torch.Tensor = field(init=False) + c4_sparse_page_indices: torch.Tensor = field(init=False) + + c128_out_loc: Optional[torch.Tensor] = None + c128_page_indices: Optional[torch.Tensor] = None + c128_topk_lengths_clamp1: Optional[torch.Tensor] = None + + c1_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c4_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c128_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + + @property + def positions(self) -> torch.Tensor: + return self.positions_casual + + def get_flashmla_metadata(self, compress_ratio: Literal[0, 4, 128]): + if compress_ratio == 0: + return self.c1_flashmla_metadata + elif compress_ratio == 4: + return self.c4_flashmla_metadata + elif compress_ratio == 128: + return self.c128_flashmla_metadata + else: + raise ValueError(f"invalid {compress_ratio=}") + + def copy_(self, other: DSV4AttnMetadata) -> None: + copy_metadata( + src=other, + dst=self, + check_eq_fields=[ + "c4_sparse_topk", + "page_size", + "cuda_int32_kwargs", + ], + copy_fields=[ + "raw_out_loc", + "seq_lens_casual", + "positions_casual", + "c4_out_loc", + "c128_out_loc", + "page_table", + "swa_page_indices", + "swa_topk_lengths", + "c128_page_indices", + "c128_topk_lengths_clamp1", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c4_sparse_topk_lengths", + "c4_sparse_page_indices", + ], + assign_fields=[ + "c1_flashmla_metadata", + "c4_flashmla_metadata", + "c128_flashmla_metadata", + ], + ) + + def init_compression_metadata(self): + assert self.page_table.dim() == 2 + assert ( + self.raw_out_loc.shape == self.seq_lens_casual.shape + ), f"{self.raw_out_loc.shape=}, {self.seq_lens_casual.shape=}" + + ( + self.c4_out_loc, + _, + self.c4_topk_lengths_raw, + self.c4_topk_lengths_clamp1, + self.c128_out_loc, + _, + self.c128_topk_lengths_clamp1, + self.c128_page_indices, + ) = _init_compression_metadata_triton( + self.seq_lens_casual, + self.positions_casual, + self.raw_out_loc, + self.page_table, + self.page_size, + compute_page_indices=True, + ) + + self.c128_page_indices = _pad_last_dim(self.c128_page_indices) + self.swa_page_indices = _pad_last_dim(self.swa_page_indices) + + _CP_REINDEX_FIELDS = [ + "seq_lens_casual", + "positions_casual", + "swa_page_indices", + "swa_topk_lengths", + "page_table", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c128_page_indices", + "c128_topk_lengths_clamp1", + ] + _CP_GLOBAL_FIELDS = [ + "raw_out_loc", + "c4_out_loc", + "c128_out_loc", + ] + + def apply_cp_reindex(self) -> None: + cp_rank = get_attention_cp_rank() + cp_size = get_attention_cp_size() + idx = slice(cp_rank, None, cp_size) + pre_global_len = self.seq_lens_casual.shape[0] + assert pre_global_len % cp_size == 0, ( + f"apply_cp_reindex: global token count {pre_global_len} is not divisible by cp_size={cp_size}. " + "CP round-robin requires padding to ensure divisibility." + ) + expected_local_len = pre_global_len // cp_size + for field_name in self._CP_REINDEX_FIELDS: + val = getattr(self, field_name, None) + assert isinstance( + val, torch.Tensor + ), f"CP reindex: {field_name} is {type(val)}, expected Tensor" + setattr(self, field_name, val[idx].contiguous()) + + for field_name in self._CP_REINDEX_FIELDS: + val = getattr(self, field_name) + assert val.shape[0] == expected_local_len, ( + f"apply_cp_reindex post-condition: {field_name}.shape[0]={val.shape[0]} " + f"!= expected_local_len={expected_local_len} (cp_size={cp_size})" + ) + for field_name in self._CP_GLOBAL_FIELDS: + val = getattr(self, field_name, None) + if val is None: + continue + assert val.shape[0] == pre_global_len, ( + f"apply_cp_reindex post-condition: global field {field_name}.shape[0]={val.shape[0]} " + f"!= pre_global_len={pre_global_len} (must remain global for compressor write path)" + ) + + def init_flashmla_related(self): + # c4_sparse_topk is set from model_config.index_topk per-model + # (small model: 512, large model: 1024). + assert self.c4_sparse_topk in (512, 1024), ( + f"unexpected c4_sparse_topk={self.c4_sparse_topk}; " + "supported: 512 (small) or 1024 (large)" + ) + assert self.c4_topk_lengths_clamp1 is not None + self.c4_sparse_topk_lengths = torch.clamp( + self.c4_topk_lengths_clamp1, max=self.c4_sparse_topk + ) + self.c4_sparse_page_indices = torch.full( + (self.c4_topk_lengths_clamp1.size(0), self.c4_sparse_topk), + -1, + dtype=torch.int32, + device=self.c4_topk_lengths_clamp1.device, + ) + self.c4_sparse_page_indices = _pad_last_dim(self.c4_sparse_page_indices) + self.c1_flashmla_metadata = _create_flashmla_metadata() + self.c4_flashmla_metadata = _create_flashmla_metadata() + self.c128_flashmla_metadata = _create_flashmla_metadata() + + +@dataclass +class DSV4Metadata: + core_attn_metadata: DSV4AttnMetadata + indexer_metadata: Optional[PagedIndexerMetadata] + + c4_compress_metadata: Optional[FusedCompressMetadata] = None + c128_compress_metadata: Optional[FusedCompressMetadata] = None + + @property + def core_metadata(self) -> DSV4AttnMetadata: + return self.core_attn_metadata + + def copy_(self, other: DSV4Metadata): + self.core_attn_metadata.copy_(other.core_attn_metadata) + maybe_copy_inplace(self.indexer_metadata, src=other.indexer_metadata) + maybe_copy_inplace(self.c4_compress_metadata, src=other.c4_compress_metadata) + maybe_copy_inplace( + self.c128_compress_metadata, src=other.c128_compress_metadata + ) + + +@dataclass +class DSV4RawVerifyMetadata: + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + out_cache_loc: torch.Tensor + + extend_seq_lens: Optional[torch.Tensor] = None + + def copy_(self, other: DSV4RawVerifyMetadata): + self.req_pool_indices.copy_(other.req_pool_indices) + self.seq_lens.copy_(other.seq_lens) + self.out_cache_loc.copy_(other.out_cache_loc) + + self.extend_seq_lens = other.extend_seq_lens + + +@dataclass +class DSV4RawDecodeMetadata: + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + out_cache_loc: torch.Tensor + + def copy_(self, other: DSV4RawDecodeMetadata): + self.req_pool_indices.copy_(other.req_pool_indices) + self.seq_lens.copy_(other.seq_lens) + self.out_cache_loc.copy_(other.out_cache_loc) + + +class _GraphBucket(enum.Enum): + DECODE_OR_IDLE = "decode_or_idle" + TARGET_VERIFY = "target_verify" + DRAFT_EXTEND = "draft_extend" + + @classmethod + def of(cls, forward_mode: ForwardMode) -> _GraphBucket: + if forward_mode.is_decode_or_idle(): + return cls.DECODE_OR_IDLE + if forward_mode.is_target_verify(): + return cls.TARGET_VERIFY + if forward_mode.is_draft_extend(include_v2=True): + return cls.DRAFT_EXTEND + raise NotImplementedError(f"unsupported {forward_mode=}") + + +class DeepseekV4AttnBackend( + AttentionBackend, C4IndexerBackendMixin, CompressorBackendMixin +): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + speculative_step_id=0, + topk=0, + speculative_num_steps=0, + ): + super().__init__() + self.device = torch.device(model_runner.device) + head_dim = model_runner.model_config.head_dim + assert ( + head_dim == 512 + ), "DSV4 MQA head_dim = qk_nope_head_dim(448) + qk_rope_head_dim(64) = 512" + self.softmax_scale: float = head_dim**-0.5 + self.head_dim_v: int = model_runner.model_config.v_head_dim + self.cuda_int32_kwargs = {"device": self.device, "dtype": torch.int32} + self.swa_page_size = 128 + assert model_runner.page_size is not None + assert model_runner.req_to_token_pool is not None + self.page_size = model_runner.page_size + assert self.page_size == 256, "the system hardcodes page_size=256" + + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool: DeepSeekV4TokenToKVPool = model_runner.token_to_kv_pool + self.MAX_SEQ_LEN_FOR_CAPTURE = self.req_to_token.shape[1] + + assert isinstance(self.token_to_kv_pool, DeepSeekV4TokenToKVPool) + self.c4_topk = getattr( + model_runner.model_config.hf_text_config, "index_topk", C4_TOPK + ) + + self.topk = model_runner.server_args.speculative_eagle_topk or 0 + assert self.topk in [0, 1], "MTP Topk > 1 not supported for DeepSeek V4" + self.mtp_enabled = self.topk > 0 + self.speculative_num_steps = speculative_num_steps + self.speculative_num_draft_tokens: int = ( + model_runner.server_args.speculative_num_draft_tokens + ) + self.speculative_step_id = speculative_step_id + self.forward_metadata: Union[ + DSV4Metadata, + DSV4RawVerifyMetadata, + DSV4RawDecodeMetadata, + ] = None + self._replay_forward_batch: Optional[ForwardBatch] = None # FIXME: out-of-band + + def _move_to_device(self, x: List[int]) -> torch.Tensor: + pin_tensor = torch.tensor(x, dtype=torch.int32, pin_memory=True) + return pin_tensor.to(self.device, non_blocking=True) + + def init_forward_metadata_indexer(self, core_attn_metadata: DSV4AttnMetadata): + return PagedIndexerMetadata( + page_size=self.page_size, + page_table=core_attn_metadata.page_table, + c4_seq_lens=core_attn_metadata.c4_topk_lengths_raw, + ) + + def init_forward_metadata_decode( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + ) -> Union[DSV4Metadata, DSV4RawDecodeMetadata]: + assert ( + req_pool_indices.shape[0] == seq_lens.shape[0] == out_cache_loc.shape[0] + ), f"{req_pool_indices.shape=} {seq_lens.shape=} {out_cache_loc.shape=}" + + if envs.SGLANG_PREP_IN_CUDA_GRAPH.get(): + return DSV4RawDecodeMetadata( + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + ) + + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices, + seq_lens_casual=seq_lens, + max_seq_len=max_seq_len, + out_loc=out_cache_loc, + need_compress=True, + ) + + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + + create = functools.partial( + create_paged_compressor_data, + is_prefill=False, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) + + return DSV4Metadata( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_prefill( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: List[int], + out_cache_loc: torch.Tensor, + num_tokens: int, + extend_seq_lens: torch.Tensor, + extend_seq_lens_cpu: List[int], + need_compress: bool = True, + use_prefill_cuda_graph: bool = False, + ) -> DSV4Metadata: + seq_lens_casual, req_pool_indices_repeated = self.expand_prefill_casually( + num_tokens=num_tokens, + seq_lens=seq_lens_cpu, + extend_seq_lens=extend_seq_lens_cpu, + req_pool_indices=req_pool_indices, + padded_num_tokens=out_cache_loc.shape[0], + ) + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices_repeated, + seq_lens_casual=seq_lens_casual, + max_seq_len=max_seq_len, + out_loc=out_cache_loc, + need_compress=need_compress, + is_prefill=True, + ) + indexer_metadata = ( + self.init_forward_metadata_indexer(core_attn_metadata) + if need_compress + else None + ) + if not need_compress: + create = _create_dummy_paged_compress_data + else: + create = functools.partial( + create_paged_compressor_data, + is_prefill=True, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + extend_lens=extend_seq_lens, + extend_lens_cpu=extend_seq_lens_cpu, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + return DSV4Metadata( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_target_verify( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> Union[DSV4Metadata, DSV4RawVerifyMetadata]: + if envs.SGLANG_PREP_IN_CUDA_GRAPH.get(): + assert out_cache_loc is not None + if not hasattr(self, "extend_seq_lens_buffer"): + self.extend_seq_lens_buffer = torch.tensor( + [self.speculative_num_draft_tokens] * 1025, device=self.device + ) + extend_seq_lens = self.extend_seq_lens_buffer[: len(seq_lens)] + + return DSV4RawVerifyMetadata( + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + extend_seq_lens=extend_seq_lens, + ) + else: + seq_lens_cpu = seq_lens.tolist() + return self.init_forward_metadata_target_verify_old( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def init_forward_metadata_target_verify_old( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[List[int]] = None, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> DSV4Metadata: + batch_size = len(seq_lens) + seq_lens = seq_lens + self.speculative_num_draft_tokens + seq_lens_cpu = [x + self.speculative_num_draft_tokens for x in seq_lens_cpu] + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size + extend_seq_lens = self._move_to_device(extend_seq_lens_cpu) + num_tokens = self.speculative_num_draft_tokens * batch_size + if out_cache_loc is None: + out_cache_loc = seq_lens.new_zeros(num_tokens) + return self.init_forward_metadata_prefill( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + num_tokens=num_tokens, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=True, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def make_forward_metadata_from_raw_verify( + self, raw_metadata: DSV4RawVerifyMetadata + ) -> DSV4Metadata: + req_pool_indices = raw_metadata.req_pool_indices + seq_lens = raw_metadata.seq_lens + out_cache_loc = raw_metadata.out_cache_loc + + bs, num_draft_tokens = len(seq_lens), self.speculative_num_draft_tokens + seq_lens = seq_lens + self.speculative_num_draft_tokens + extend_seq_lens = raw_metadata.extend_seq_lens + + seq_lens_casual, req_pool_indices_repeated = ( + self.expand_extend_with_same_length( + bs, num_draft_tokens, seq_lens, req_pool_indices + ) + ) + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices_repeated, + seq_lens_casual=seq_lens_casual, + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + out_loc=out_cache_loc, + need_compress=True, + ) + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + create = functools.partial( + create_paged_compressor_data, + is_prefill=True, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_lens=extend_seq_lens, + seq_lens_cpu=None, + extend_lens_cpu=None, + use_prefill_cuda_graph=True, + num_q_tokens=num_draft_tokens * bs, + ) + return DSV4Metadata( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def make_forward_metadata_from_raw_decode( + self, raw_metadata: DSV4RawDecodeMetadata + ) -> DSV4Metadata: + req_pool_indices = raw_metadata.req_pool_indices + seq_lens = raw_metadata.seq_lens + out_cache_loc = raw_metadata.out_cache_loc + + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices, + seq_lens_casual=seq_lens, + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + out_loc=out_cache_loc, + need_compress=True, + ) + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + + create = functools.partial( + create_paged_compressor_data, + is_prefill=False, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) + + return DSV4Metadata( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_draft_extend( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: List[int], + num_tokens_per_bs: int, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> DSV4Metadata: + batch_size = len(seq_lens) + extend_seq_lens_cpu = [num_tokens_per_bs] * batch_size + extend_seq_lens = self._move_to_device(extend_seq_lens_cpu) + num_tokens = num_tokens_per_bs * batch_size + if out_cache_loc is None: + out_cache_loc = seq_lens.new_zeros(num_tokens) + return self.init_forward_metadata_prefill( + seq_lens=seq_lens, + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + num_tokens=num_tokens, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=False, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch) -> None: + if self.mtp_enabled and forward_batch.forward_mode.is_idle(): + return + + req_pool_indices = forward_batch.req_pool_indices + seq_lens = forward_batch.seq_lens.to(torch.int32) + seq_lens_cpu = forward_batch.seq_lens_cpu + assert forward_batch.req_to_token_pool.req_to_token is self.req_to_token + + assert self.swa_page_size % SWA_WINDOW == 0 and self.page_size % 128 == 0 + assert seq_lens_cpu is not None + max_seq_len = int(seq_lens_cpu.max().item()) + + if forward_batch.forward_mode.is_decode_or_idle(): + metadata = self.init_forward_metadata_decode( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=forward_batch.out_cache_loc, + ) + elif forward_batch.forward_mode.is_target_verify(): + metadata = self.init_forward_metadata_target_verify( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=forward_batch.out_cache_loc, + ) + elif forward_batch.forward_mode.is_prefill(include_draft_extend_v2=True): + extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + extend_seq_lens = forward_batch.extend_seq_lens + assert ( + seq_lens is not None + and seq_lens_cpu is not None + and extend_seq_lens is not None + and extend_seq_lens_cpu is not None + ) + is_draft = forward_batch.forward_mode.is_draft_extend(include_v2=True) + metadata = self.init_forward_metadata_prefill( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu.tolist(), + out_cache_loc=forward_batch.out_cache_loc, + num_tokens=sum(extend_seq_lens_cpu), + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=not is_draft, + ) + else: + raise NotImplementedError(f"unsupported mode {forward_batch.forward_mode=}") + + self.forward_metadata = metadata + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + self.cuda_graph_metadata_of_bucket_and_bs: Dict[ + _GraphBucket, + Dict[ + int, + Union[DSV4Metadata, DSV4RawDecodeMetadata, DSV4RawVerifyMetadata], + ], + ] = {bucket: {} for bucket in _GraphBucket} + self.draft_extend_num_tokens_per_bs = ( + max_num_tokens // max_bs if max_bs > 0 else 1 + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ) -> None: + assert req_pool_indices.size(0) == bs + assert seq_lens.size(0) == bs + + bucket = _GraphBucket.of(forward_mode) + raw_type: Optional[type] = None + if bucket == _GraphBucket.DECODE_OR_IDLE: + metadata = self.init_forward_metadata_decode( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=torch.zeros_like(seq_lens), + ) + raw_type = DSV4RawDecodeMetadata + elif bucket == _GraphBucket.TARGET_VERIFY: + out_cache_loc = torch.zeros(num_tokens, **self.cuda_int32_kwargs) + metadata = self.init_forward_metadata_target_verify( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + use_prefill_cuda_graph=True, + ) + raw_type = DSV4RawVerifyMetadata + elif bucket == _GraphBucket.DRAFT_EXTEND: + num_tokens_per_bs = num_tokens // bs + metadata = self.init_forward_metadata_draft_extend( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.tolist(), + num_tokens_per_bs=num_tokens_per_bs, + use_prefill_cuda_graph=True, + ) + else: + raise NotImplementedError(f"{forward_mode=} not supported yet") + + self.cuda_graph_metadata_of_bucket_and_bs[bucket][bs] = metadata + self.forward_metadata = metadata + if raw_type is not None: + self._current_capture_raw = ( + metadata if isinstance(metadata, raw_type) else None + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + ) -> None: + bucket = _GraphBucket.of(forward_mode) + + # FIXME: see cuda_graph_runner — this attribute is set out-of-band. + fb = self._replay_forward_batch + out_cache_loc = fb.out_cache_loc + actual_forward_mode = fb.forward_mode + + if actual_forward_mode == ForwardMode.IDLE: + logger.debug( + f"[IDLE replay] bs={bs}, " + f"local_seq_lens_len={len(seq_lens)}, " + f"has_graph={bs in self.cuda_graph_metadata_of_bucket_and_bs[_GraphBucket.DECODE_OR_IDLE]}" + ) + device = seq_lens.device + seq_lens = torch.ones(bs, dtype=seq_lens.dtype, device=device) + seq_lens_cpu = torch.ones(bs, dtype=torch.int64) + seq_lens_sum = bs + req_pool_indices = torch.zeros( + bs, dtype=req_pool_indices.dtype, device=device + ) + out_cache_loc = torch.zeros(bs, dtype=torch.int64, device=device) + + assert seq_lens_cpu is not None + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + req_pool_indices = req_pool_indices[:bs] + + actual_max_seq_len = seq_lens_cpu.max().item() + chosen_max_seq_len = self.MAX_SEQ_LEN_FOR_CAPTURE + assert actual_max_seq_len <= chosen_max_seq_len + + if bucket == _GraphBucket.DECODE_OR_IDLE: + assert out_cache_loc is not None + assert len(out_cache_loc.shape) == 1, f"{out_cache_loc.shape=}" + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, bs - len(out_cache_loc)), + mode="constant", + value=0, + ) + temp_metadata = self.init_forward_metadata_decode( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc_padded, + ) + elif bucket == _GraphBucket.TARGET_VERIFY: + assert out_cache_loc is not None + num_tokens = self.speculative_num_draft_tokens * bs + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, num_tokens - len(out_cache_loc)), + mode="constant", + value=0, + ) + temp_metadata = self.init_forward_metadata_target_verify( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc_padded, + use_prefill_cuda_graph=True, + ) + elif bucket == _GraphBucket.DRAFT_EXTEND: + num_tokens_per_bs = self.draft_extend_num_tokens_per_bs + temp_metadata = self.init_forward_metadata_draft_extend( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu.tolist(), + num_tokens_per_bs=num_tokens_per_bs, + use_prefill_cuda_graph=True, + ) + else: + raise NotImplementedError + + self.replay_cuda_graph_metadata_from( + bs=bs, temp_metadata=temp_metadata, bucket=bucket + ) + + def replay_cuda_graph_metadata_from( + self, + bs: int, + temp_metadata: Union[ + DSV4Metadata, + DSV4RawVerifyMetadata, + DSV4RawDecodeMetadata, + ], + bucket: _GraphBucket, + ) -> None: + chosen_metadata = self.cuda_graph_metadata_of_bucket_and_bs[bucket][bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + + def get_cuda_graph_seq_len_fill_value(self): + return 1 + + def on_after_cuda_graph_warmup(self): + metadata = self.forward_metadata + if isinstance(metadata, DSV4Metadata) and isinstance( + metadata.core_attn_metadata, DSV4AttnMetadata + ): + core = metadata.core_attn_metadata + core.c1_flashmla_metadata = _create_flashmla_metadata() + core.c4_flashmla_metadata = _create_flashmla_metadata() + core.c128_flashmla_metadata = _create_flashmla_metadata() + + # PREP_IN_CUDA_GRAPH=True: warmup upgraded raw->full on the host; + # restore raw so capture re-runs the upgrade inside the graph. + current_raw = getattr(self, "_current_capture_raw", None) + if current_raw is not None: + self.forward_metadata = current_raw + + def store_cache( + self, layer_id: int, swa_k: torch.Tensor, forward_batch: ForwardBatch + ) -> None: + raw_loc = forward_batch.out_cache_loc + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + self.token_to_kv_pool.set_swa_key_buffer_radix_fused( + layer_id=layer_id, + raw_loc=raw_loc, + cache_k=swa_k, + ) + else: + swa_k_pack = quant_to_nope_fp8_rope_bf16_pack_triton(swa_k) + self.token_to_kv_pool.set_swa_key_buffer_radix( + layer_id=layer_id, + raw_loc=raw_loc, + cache_nope_fp8_rope_bf16_pack=swa_k_pack, + ) + + def _maybe_upgrade_forward_metadata(self) -> None: + # With SGLANG_PREP_IN_CUDA_GRAPH=1, init_forward_metadata_* + # returns a Raw metadata that only carries a few tensors. The + # full DSV4Metadata (including c4/c128 compress + core_attn + + # indexer metadata) must be materialized before any caller that + # touches those fields. For 1.6T the first two layers have + # compress_ratio=128, so forward_core_compressor / forward_c4_indexer + # can fire before attn_backend.forward(), and must trigger the + # upgrade themselves. + if isinstance(self.forward_metadata, DSV4RawVerifyMetadata): + self.forward_metadata = self.make_forward_metadata_from_raw_verify( + raw_metadata=self.forward_metadata, + ) + elif isinstance(self.forward_metadata, DSV4RawDecodeMetadata): + self.forward_metadata = self.make_forward_metadata_from_raw_decode( + raw_metadata=self.forward_metadata, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + compress_ratio: Literal[0, 4, 128], + save_kv_cache: bool = True, + attn_sink: Optional[torch.Tensor] = None, + **_, + ) -> torch.Tensor: + self._maybe_upgrade_forward_metadata() + + if self.mtp_enabled and forward_batch.forward_mode.is_idle(): + return q.new_empty(q.shape[0], q.shape[1], layer.v_head_dim) + + assert k is v, "DeepseekV4 shares k and v" + swa_k = k + + layer_id = layer.layer_id + metadata = self.forward_metadata + core_attn_metadata = metadata.core_attn_metadata + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + if isinstance(core_attn_metadata, DSV4AttnMetadata): + if save_kv_cache: + self.store_cache(layer_id, swa_k, forward_batch) + swa_k_cache = token_to_kv_pool.get_swa_key_buffer_radix(layer_id) + + extra_k_cache, extra_indices, extra_topk_lengths = None, None, None + if compress_ratio == 4: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_attn_metadata.c4_sparse_page_indices + extra_topk_lengths = core_attn_metadata.c4_sparse_topk_lengths + elif compress_ratio == 128: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_attn_metadata.c128_page_indices + extra_topk_lengths = core_attn_metadata.c128_topk_lengths_clamp1 + + swa_window_size = token_to_kv_pool.swa_window_size + assert swa_k_cache.ndim == 2 + k_cache_total_dim = token_to_kv_pool.swa_kv_pool.kv_cache_total_dim + swa_k_cache = swa_k_cache[:, : swa_window_size * k_cache_total_dim].view( + swa_k_cache.shape[0], swa_window_size, 1, k_cache_total_dim + ) + + if extra_k_cache is not None: + page_sizes = { + 4: token_to_kv_pool.page_size // 4, + 128: token_to_kv_pool.page_size // 128, + } + extra_k_cache = extra_k_cache[ + :, : page_sizes[compress_ratio] * k_cache_total_dim + ].view( + extra_k_cache.shape[0], + page_sizes[compress_ratio], + 1, + k_cache_total_dim, + ) + swa_page_indices = core_attn_metadata.swa_page_indices + swa_topk_lengths = core_attn_metadata.swa_topk_lengths + + if self.mtp_enabled: + if swa_page_indices.shape[0] != q.shape[0]: + swa_page_indices = _pad_tensor_to_size( + swa_page_indices, q.shape[0], value=0 + ) + + if swa_topk_lengths.shape[0] != q.shape[0]: + swa_topk_lengths = _pad_tensor_to_size( + swa_topk_lengths, q.shape[0], value=1 + ) + + if q.ndim == 3: + q = q.unsqueeze(1) + if swa_page_indices.ndim == 2: + swa_page_indices = swa_page_indices.unsqueeze(1) + if extra_indices is not None and extra_indices.ndim == 2: + extra_indices = extra_indices.unsqueeze(1) + + assert attn_sink is not None + + flashmla_metadata = core_attn_metadata.get_flashmla_metadata(compress_ratio) + + assert ( + swa_page_indices.shape[-1] % 64 == 0 + ), f"{swa_page_indices.shape=}'s last dimension is not aligned to 64" + if extra_indices is not None: + assert ( + extra_indices.shape[-1] % 64 == 0 + ), f"{extra_indices.shape=}'s last dimension is not aligned to 64" + + import flash_mla + + o = flash_mla.flash_mla_with_kvcache( + q=q, + k_cache=swa_k_cache, + head_dim_v=self.head_dim_v, + block_table=None, + cache_seqlens=None, + tile_scheduler_metadata=flashmla_metadata, + softmax_scale=self.softmax_scale, + is_fp8_kvcache=True, + indices=swa_page_indices, + topk_length=swa_topk_lengths, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_lengths, + )[0] + + o = o.squeeze(1) + return o + + raise NotImplementedError("ragged attention") + + def expand_prefill_casually( + self, + num_tokens: int, + seq_lens: List[int], + extend_seq_lens: List[int], + req_pool_indices: torch.Tensor, + padded_num_tokens: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + seq_lens_casual = torch.empty(num_tokens, **self.cuda_int32_kwargs) + idx_to_req_repeated = torch.empty(num_tokens, **self.cuda_int32_kwargs) + offset = 0 + for i, (kv_len, qo_len) in enumerate(zip(seq_lens, extend_seq_lens)): + out = seq_lens_casual[offset : offset + qo_len] + offset += qo_len + torch.arange(kv_len - qo_len + 1, kv_len + 1, out=out) + idx_to_req_repeated[offset - qo_len : offset].fill_(i) + + assert offset == num_tokens + req_pool_indices_repeated = req_pool_indices[idx_to_req_repeated] + + if padded_num_tokens is not None and padded_num_tokens > num_tokens: + pad_size = padded_num_tokens - num_tokens + seq_lens_casual = torch.nn.functional.pad( + seq_lens_casual, + (0, pad_size), + value=1, + ) + req_pool_indices_repeated = torch.nn.functional.pad( + req_pool_indices_repeated, + (0, pad_size), + value=req_pool_indices_repeated[-1].item(), + ) + + return seq_lens_casual, req_pool_indices_repeated + + def expand_extend_with_same_length( + self, + bs: int, + qo_len: int, + seq_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + ): + seq_lens_casual = seq_lens[:, None] + torch.arange( + -qo_len + 1, 1, **self.cuda_int32_kwargs + ) + seq_lens_casual = seq_lens_casual.flatten() + idx_to_req_repeated = torch.arange( + bs, **self.cuda_int32_kwargs + ).repeat_interleave(qo_len) + req_pool_indices_repeated = req_pool_indices[idx_to_req_repeated] + return seq_lens_casual, req_pool_indices_repeated + + def make_core_attn_metadata( + self, + req_to_token: torch.Tensor, + req_pool_indices_repeated: torch.Tensor, + seq_lens_casual: torch.Tensor, + max_seq_len: int, + out_loc: torch.Tensor, + need_compress: bool = True, + is_prefill: bool = False, + ) -> DSV4AttnMetadata: + assert self.swa_page_size == SWA_WINDOW + + swa_page_indices = self.get_swa_page_indices( + seq_lens_casual=seq_lens_casual, + req_pool_indices_repeated=req_pool_indices_repeated, + ) + + swa_page_indices = _pad_last_dim( + swa_page_indices, multiples_of=PAGE_INDEX_ALIGNED_SIZE + ) + + raw_positions = seq_lens_casual - 1 + swa_topk_lengths = torch.clamp(seq_lens_casual, max=SWA_WINDOW) + + page_table = req_to_token[ + req_pool_indices_repeated, : max_seq_len : self.page_size + ] + page_table = (page_table // self.page_size).to(torch.int32) + + core_attn_metadata = DSV4AttnMetadata( + page_size=self.page_size, + raw_out_loc=out_loc, + seq_lens_casual=seq_lens_casual, + cuda_int32_kwargs=self.cuda_int32_kwargs, + positions_casual=raw_positions, + page_table=page_table, + swa_page_indices=swa_page_indices, + swa_topk_lengths=swa_topk_lengths, + c4_sparse_topk=self.c4_topk, + ) + + if need_compress: + core_attn_metadata.init_compression_metadata() + core_attn_metadata.init_flashmla_related() + else: + core_attn_metadata.c4_sparse_topk_lengths = None + core_attn_metadata.c4_sparse_page_indices = None + core_attn_metadata.c1_flashmla_metadata = _create_flashmla_metadata() + core_attn_metadata.c4_flashmla_metadata = None + core_attn_metadata.c128_flashmla_metadata = None + return core_attn_metadata + + def get_swa_page_indices( + self, + seq_lens_casual: torch.Tensor, + req_pool_indices_repeated: torch.Tensor, + ) -> torch.Tensor: + pos_causal = seq_lens_casual - 1 + num_qo_tokens = seq_lens_casual.size(0) + offsets = pos_causal.unsqueeze(1) - torch.arange( + SWA_WINDOW, **self.cuda_int32_kwargs + ).unsqueeze(0) + invalid_offset_mask = offsets < 0 + offsets.masked_fill_(invalid_offset_mask, 0) + raw_indices = self.req_to_token[req_pool_indices_repeated[:, None], offsets] + assert raw_indices.shape == (num_qo_tokens, SWA_WINDOW) + raw_indices.masked_fill_(invalid_offset_mask, -1) + swa_indices = self.token_to_kv_pool.translate_loc_from_full_to_swa(raw_indices) + return swa_indices + + +class DeepseekV4MultiStepBackend(DeepseekV4AttnBackend): + def __init__( + self, model_runner: ModelRunner, topk: int, speculative_num_steps: int + ): + super().__init__(model_runner) + self.model_runner = model_runner + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.attn_backends: List[DeepseekV4AttnBackend] = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + DeepseekV4AttnBackend( + model_runner, + speculative_step_id=i, + topk=self.topk, + speculative_num_steps=self.speculative_num_steps, + ) + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + def on_after_cuda_graph_warmup(self): + for backend in self.attn_backends: + backend.on_after_cuda_graph_warmup() + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + if self.speculative_num_steps == 1: + return + + self.attn_backends[0]._replay_forward_batch = forward_batch + self.attn_backends[0].init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=forward_batch.seq_lens, + seq_lens_sum=forward_batch.seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, + ) + self.attn_backends[0]._replay_forward_batch = None + temp_metadata = self.attn_backends[0].forward_metadata + + for i in range(1, self.speculative_num_steps - 1): + self.attn_backends[i].replay_cuda_graph_metadata_from( + bs=bs, + temp_metadata=temp_metadata, + bucket=_GraphBucket.DECODE_OR_IDLE, + ) + + +def _pad_tensor_to_size(tensor: torch.Tensor, size: int, *, value: int = 0): + if value == 0: + return torch.cat( + [tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])], + dim=0, + ) + else: + return torch.cat( + [ + tensor, + tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value), + ], + dim=0, + ) diff --git a/python/sglang/srt/layers/attention/dsv4/__init__.py b/python/sglang/srt/layers/attention/dsv4/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/srt/layers/attention/dsv4/compressor.py b/python/sglang/srt/layers/attention/dsv4/compressor.py new file mode 100644 index 000000000000..a09d8cbd713e --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/compressor.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Literal, NamedTuple, Optional, Union + +import torch +import torch.nn as nn + +from sglang.jit_kernel.deepseek_v4 import ( + CompressorDecodePlan, + CompressorPrefillPlan, + compress_forward, + compress_fused_norm_rope_inplace, + linear_bf16_fp32, + triton_create_paged_compress_data, +) +from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config +from sglang.srt.environ import envs +from sglang.srt.layers.attention.dsv4.quant_k_cache import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.layers.attention.nsa.triton_kernel import act_quant +from sglang.srt.layers.attention.nsa.utils import nsa_use_prefill_cp +from sglang.srt.layers.dp_attention import get_attention_cp_size +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.utils.cp_utils import cp_all_gather_rerange_output +from sglang.srt.mem_cache.deepseek_v4_compress_state import CompressStatePool +from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool +from sglang.srt.utils import add_prefix + +if TYPE_CHECKING: + from sglang.srt.layers.attention.deepseek_v4_backend import DeepseekV4AttnBackend + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +class FusedCompressMetadata(NamedTuple): + write_loc: torch.Tensor + extra_data: Optional[torch.Tensor] + plan: Union[CompressorDecodePlan, CompressorPrefillPlan] + + def copy_(self, other: FusedCompressMetadata) -> None: + from .metadata import maybe_copy_inplace + + self.write_loc.copy_(other.write_loc) + maybe_copy_inplace(self.extra_data, src=other.extra_data) + self.plan.copy_(other.plan) + + +class CompressorBackendMixin: + def get_paged_compress_metadata(self, compress_ratio: int) -> FusedCompressMetadata: + attr_name = f"c{compress_ratio}_compress_metadata" + metadata = getattr(self.forward_metadata, attr_name) + assert isinstance(metadata, FusedCompressMetadata) + return metadata + + def forward_compress( + self, + *, + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + ape: torch.Tensor, + head_dim: int, + norm: RMSNorm, + freqs_cis_cache: torch.Tensor, + rotate: bool, + forward_batch: ForwardBatch, + compress_ratio: int, + is_paged: bool = False, + ) -> torch.Tensor: + from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation + + assert compress_ratio in ( + 4, + 128, + ), f"DSV4 supports CSA(4x) and HCA(128x) only, got {compress_ratio=}" + if is_paged: + metadata = self.get_paged_compress_metadata(compress_ratio) + coff = 2 if is_overlap_compress(compress_ratio) else 1 + if compress_ratio == 128 and envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + kv_score_buffer = kv_score_buffer.view(-1, 1, head_dim * 3) + else: + last_dim = 2 * head_dim * coff + assert kv_score_buffer.shape[-1] == last_dim + kv_score_buffer = kv_score_buffer.view(-1, compress_ratio, last_dim) + else: + plan = make_compressor_plan(compress_ratio, forward_batch) + metadata = (forward_batch.req_pool_indices.to(torch.int32), None, plan) + indices, extra_data, plan = metadata + + kv_compressed = compress_forward( + kv_score_buffer=kv_score_buffer, + kv_score_input=kv_score_input, + ape=ape, + indices=indices, + plan=plan, + compress_ratio=compress_ratio, + head_dim=head_dim, + extra_data=extra_data, + ) + compress_fused_norm_rope_inplace( + kv_compressed, + norm.weight, + norm.variance_epsilon, + freqs_cis_cache, + plan, + ) + return rotate_activation(kv_compressed.bfloat16()) if rotate else kv_compressed + + def forward_core_compressor( + self, + x: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + compressor: Compressor, + ) -> None: + if forward_batch.forward_mode.is_idle(): + return + # PREP_IN_CG lazy upgrade: the concrete backend (DeepseekV4AttnBackend) + # owns this helper. MQALayer._forward_prepare calls us before + # attn_backend.forward(), so Raw -> DSV4Metadata must happen here too + # (e.g. 1.6T layer 0 has compress_ratio=128 and needs cX_compress_metadata). + self._maybe_upgrade_forward_metadata() + token_to_kv_pool = forward_batch.token_to_kv_pool + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + new_compressed_kv = compressor(x, forward_batch) + core_metadata = self.forward_metadata.core_metadata + out_loc = ( + core_metadata.c4_out_loc + if compressor.ratio == 4 + else core_metadata.c128_out_loc + ) + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + token_to_kv_pool.set_extra_key_buffer_fused( + layer_id=layer_id, + loc=out_loc, + cache_k=new_compressed_kv, + ) + else: + pack = quant_to_nope_fp8_rope_bf16_pack_triton(new_compressed_kv.bfloat16()) + token_to_kv_pool.set_extra_key_buffer(layer_id, out_loc, pack) + + def forward_indexer_compressor( + self, + x: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + compressor: Compressor, + ) -> None: + assert is_overlap_compress(compressor.ratio) + # PREP_IN_CG lazy upgrade (see forward_core_compressor for rationale). + self._maybe_upgrade_forward_metadata() + token_to_kv_pool = forward_batch.token_to_kv_pool + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + new_compressed_kv = compressor(x, forward_batch) + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + token_to_kv_pool.set_index_k_fused( + layer_id=layer_id, + loc=self.forward_metadata.core_metadata.c4_out_loc, + cache_k=new_compressed_kv, + ) + else: + new_compressed_kv_fp8, new_compressed_kv_scale = act_quant( + new_compressed_kv + ) + token_to_kv_pool.set_index_k_scale_buffer( + layer_id=layer_id, + loc=self.forward_metadata.core_metadata.c4_out_loc, + index_k=new_compressed_kv_fp8, + index_k_scale=new_compressed_kv_scale, + ) + + +def is_overlap_compress(compress_ratio: int) -> bool: + return compress_ratio == 4 + + +def make_compressor_plan( + compress_ratio: Literal[4, 128], + forward_batch: ForwardBatch, +) -> Union[CompressorDecodePlan, CompressorPrefillPlan]: + if forward_batch.forward_mode.is_decode(): + seq_lens_32 = forward_batch.seq_lens.to(torch.int32) + return CompressorDecodePlan(compress_ratio, seq_lens_32) + if forward_batch.forward_mode.is_prefill(): + assert not forward_batch.forward_mode.is_target_verify() + extend_lens_list = forward_batch.extend_seq_lens_cpu + seq_lens_cpu = forward_batch.seq_lens_cpu + assert extend_lens_list is not None and seq_lens_cpu is not None + return CompressorPrefillPlan.generate( + compress_ratio=compress_ratio, + num_q_tokens=sum(extend_lens_list), + seq_lens=seq_lens_cpu, + extend_lens=torch.tensor(extend_lens_list), + device=forward_batch.seq_lens.device, + ) + elif forward_batch.forward_mode.is_target_verify(): + raise NotImplementedError("target verify mode to be implemented") + else: + raise NotImplementedError(f"unsupported mode {forward_batch.forward_mode=}") + + +def create_paged_compressor_data( + compress_ratio: Literal[4, 128], + *, + is_prefill: bool, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_lens: Optional[torch.Tensor] = None, + seq_lens_cpu: Optional[List[int]] = None, + extend_lens_cpu: Optional[List[int]] = None, + use_prefill_cuda_graph: bool = False, + num_q_tokens: Optional[int] = None, +) -> FusedCompressMetadata: + swa_page_size = token_to_kv_pool.swa_page_size + ring_size = token_to_kv_pool.get_ring_size(compress_ratio=compress_ratio) + # assert ring_size % compress_ratio == 0 + + def clip_down(positions: torch.Tensor) -> torch.Tensor: + return positions // compress_ratio * compress_ratio + + def get_raw_loc(positions: torch.Tensor) -> torch.Tensor: + positions = positions.masked_fill(positions < 0, 0) + loc = req_to_token[req_pool_indices, positions] + swa_loc = token_to_kv_pool.translate_loc_from_full_to_swa(loc) + swa_pages = swa_loc // swa_page_size + state_loc = swa_pages * ring_size + swa_loc % ring_size + return (state_loc // compress_ratio).to(torch.int32) + + is_overlap = is_overlap_compress(compress_ratio) + + if is_prefill: + assert extend_lens is not None + write_loc, extra_data = triton_create_paged_compress_data( + compress_ratio=compress_ratio, + is_overlap=is_overlap, + swa_page_size=swa_page_size, + ring_size=ring_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_seq_lens=extend_lens, + req_to_token=req_to_token, + full_to_swa_index_mapping=token_to_kv_pool.full_to_swa_index_mapping, + ) + + plan_kwargs: dict + if seq_lens_cpu is None: + assert num_q_tokens is not None + plan_kwargs = dict( + num_q_tokens=num_q_tokens, + seq_lens=seq_lens, + extend_lens=extend_lens, + ) + else: + assert extend_lens_cpu is not None + plan_kwargs = dict( + num_q_tokens=sum(extend_lens_cpu), + seq_lens=torch.tensor(seq_lens_cpu), + extend_lens=torch.tensor(extend_lens_cpu), + ) + plan = CompressorPrefillPlan.generate( + compress_ratio=compress_ratio, + device=seq_lens.device, + use_cuda_graph=use_prefill_cuda_graph, + **plan_kwargs, + ) + else: + write_positions = clip_down(seq_lens - 1) + write_loc = get_raw_loc(write_positions) + if is_overlap: + write_overlap_loc = get_raw_loc(write_positions - compress_ratio) + extra_data = write_overlap_loc.view(-1, 1) + else: + extra_data = None + plan = CompressorDecodePlan(compress_ratio, seq_lens.to(torch.int32)) + + return FusedCompressMetadata(write_loc=write_loc, extra_data=extra_data, plan=plan) + + +class Compressor(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + is_in_indexer: bool, + freqs_cis: torch.Tensor, + compress_ratio: Literal[0, 4, 128], + head_dim: int, + rotate: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.is_in_indexer = is_in_indexer + self.dim = config.hidden_size + self.head_dim = head_dim + self.rope_head_dim = getattr(config, "qk_rope_head_dim", 64) + assert compress_ratio != 0, "compress_ratio should not be 0" + self.ratio = compress_ratio + self.overlap = self.ratio == 4 + self.rotate = rotate + coff = 1 + self.overlap + + self.ape = nn.Parameter( + torch.empty(self.ratio, coff * self.head_dim, dtype=torch.float32) + ) + wkv_gate_dtype = torch.bfloat16 + self.wkv_gate = ReplicatedLinear( + self.dim, + 2 * coff * self.head_dim, + bias=False, + quant_config=None, + prefix=add_prefix("wkv_gate", prefix), + params_dtype=wkv_gate_dtype, + ) + self.norm = RMSNorm( + self.head_dim, eps=config.rms_norm_eps, weight_dtype=torch.float32 + ) + self.freqs_cis = freqs_cis + + self.ape_converted = False + + def apply_ape_hotfix(self): + assert not self.ape_converted + self.ape_converted = True + + if self.overlap: + ape = torch.chunk(self.ape.data, 2, dim=-1) + ape = torch.cat([ape[0], ape[1]], dim=0) + self.ape.data.copy_(ape.view(self.ratio, -1)) + + def _get_state_pool(self, forward_batch: ForwardBatch) -> CompressStatePool: + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + if self.is_in_indexer: + ret = token_to_kv_pool.get_indexer_compress_states(self.layer_id) + else: + ret = token_to_kv_pool.get_attention_compress_states(self.layer_id) + + assert isinstance(ret, CompressStatePool) + + return ret + + def forward(self, x: torch.Tensor, forward_batch: ForwardBatch) -> torch.Tensor: + if forward_batch.forward_mode.is_idle(): + assert x.shape[0] == 0 + return x.new_empty(0, self.head_dim) + + kv_score = linear_bf16_fp32(x, self.wkv_gate.weight) + if nsa_use_prefill_cp(forward_batch): + kv_score = cp_all_gather_rerange_output( + kv_score, + get_attention_cp_size(), + forward_batch, + torch.cuda.current_stream(), + ) + + backend = forward_batch.attn_backend + if TYPE_CHECKING: + assert isinstance(backend, DeepseekV4AttnBackend) + kv_score_buffer = self._get_state_pool(forward_batch) + kv_score_buffer = kv_score_buffer.kv_score_buffer.kv_score + return backend.forward_compress( + kv_score_buffer=kv_score_buffer, + kv_score_input=kv_score, + ape=self.ape.view(-1, self.head_dim), + head_dim=self.head_dim, + norm=self.norm, + freqs_cis_cache=self.freqs_cis, + rotate=self.rotate, + compress_ratio=self.ratio, + forward_batch=forward_batch, + is_paged=True, + ) diff --git a/python/sglang/srt/layers/attention/dsv4/index_buf_accessor.py b/python/sglang/srt/layers/attention/dsv4/index_buf_accessor.py new file mode 100644 index 000000000000..d9fdbf1aaca9 --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/index_buf_accessor.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +@dataclass +class NopeFp8RopeBf16Pack: + k_nope_fp8: torch.Tensor + k_rope_bf16: torch.Tensor + scale_k_nope_ue8m0: torch.Tensor + + def __post_init__(self): + assert self.k_nope_fp8.shape[-1] == 448 + assert self.k_rope_bf16.shape[-1] == 64 + assert self.scale_k_nope_ue8m0.shape[-1] == 7 + + def slice_pack(self, _slice: Any) -> NopeFp8RopeBf16Pack: + return NopeFp8RopeBf16Pack( + k_nope_fp8=self.k_nope_fp8[_slice], + k_rope_bf16=self.k_rope_bf16[_slice], + scale_k_nope_ue8m0=self.scale_k_nope_ue8m0[_slice], + ) + + +class SetKAndS: + @classmethod + def execute(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + cls.triton(pool, buf, loc, nope_fp8_rope_bf16_pack) + + @classmethod + def torch(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + _set_k_and_s_torch(buf, loc, nope_fp8_rope_bf16_pack, pool.page_size) + + @classmethod + def triton(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + _set_k_and_s_triton(buf, loc, nope_fp8_rope_bf16_pack, pool.page_size) + + +def _set_k_and_s_triton( + buf: torch.Tensor, + loc: torch.Tensor, + nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + page_size: int, +): + num_pages, buf_numel_per_page = buf.shape + (num_tokens_to_write,) = loc.shape + + k_nope, k_rope, scale_k_nope = ( + nope_fp8_rope_bf16_pack.k_nope_fp8, + nope_fp8_rope_bf16_pack.k_rope_bf16, + nope_fp8_rope_bf16_pack.scale_k_nope_ue8m0, + ) + + num_tokens_to_write_nope, nope_dim = k_nope.shape + num_tokens_to_write_rope, rope_dim = k_rope.shape + num_tokens_to_write_scale, scale_dim = scale_k_nope.shape + + assert ( + num_tokens_to_write + == num_tokens_to_write_nope + == num_tokens_to_write_rope + == num_tokens_to_write_scale + ) + + assert buf.dtype == torch.uint8 + assert loc.dtype in [torch.int64, torch.int32], f"{loc.dtype=}" + + assert k_nope.dtype == fp8_dtype + assert k_rope.dtype == torch.bfloat16 + assert scale_k_nope.dtype == torch.uint8, f"{scale_k_nope.dtype=}" + + assert buf.is_contiguous() + assert loc.is_contiguous() + assert k_nope.is_contiguous() + assert k_rope.is_contiguous() + assert scale_k_nope.is_contiguous() + + buf_fp8 = buf.view(fp8_dtype) + buf_bf16 = buf.view(torch.bfloat16) + buf_uint8 = buf.view(torch.uint8) + + nope_rope_bytes = nope_dim + rope_dim * 2 + s_offset_nbytes_in_page = page_size * (nope_dim + rope_dim * 2) + + _set_k_and_s_triton_kernel[(num_tokens_to_write,)]( + buf_fp8, + buf_bf16, + buf_uint8, + loc, + k_nope, + k_rope, + scale_k_nope, + k_nope.stride(0), + k_rope.stride(0), + scale_k_nope.stride(0), + PAGE_SIZE=page_size, + BUF_NUMEL_PER_PAGE=buf_numel_per_page, + NUM_NOPE_ELEMS_PER_TOKEN=nope_dim, + NUM_ROPE_ELEMS_PER_TOKEN=rope_dim, + NUM_SCALE_ELEMS_PER_TOKEN=scale_dim, + NUM_NOPE_ROPE_BYTES_PER_TOKEN=nope_rope_bytes, + PADDED_SCALE_ELEMS_PER_TOKEN=scale_dim + 1, + S_OFFSET_NBYTES_IN_PAGE=s_offset_nbytes_in_page, + BLOCK_NOPE=512, + BLOCK_ROPE=64, + BLOCK_SCALE=8, + ) + + +@triton.jit +def _set_k_and_s_triton_kernel( + buf_fp8_ptr, + buf_bf16_ptr, + buf_uint8_ptr, + loc_ptr, + k_nope_ptr, + k_rope_ptr, + scale_k_nope_ptr, + k_nope_ptr_stride_0, + k_rope_ptr_stride_0, + scale_k_nope_ptr_stride_0, + PAGE_SIZE: tl.constexpr, + BUF_NUMEL_PER_PAGE: tl.constexpr, + NUM_NOPE_ELEMS_PER_TOKEN: tl.constexpr, + NUM_ROPE_ELEMS_PER_TOKEN: tl.constexpr, + NUM_NOPE_ROPE_BYTES_PER_TOKEN: tl.constexpr, + NUM_SCALE_ELEMS_PER_TOKEN: tl.constexpr, + PADDED_SCALE_ELEMS_PER_TOKEN: tl.constexpr, + S_OFFSET_NBYTES_IN_PAGE: tl.constexpr, + BLOCK_NOPE: tl.constexpr, + BLOCK_ROPE: tl.constexpr, + BLOCK_SCALE: tl.constexpr, +): + token_id = tl.program_id(0) + loc = tl.load(loc_ptr + token_id) + + nope_range = tl.arange(0, BLOCK_NOPE) + nope_mask = nope_range < NUM_NOPE_ELEMS_PER_TOKEN + in_k_nope_offsets = token_id * k_nope_ptr_stride_0 + nope_range + k_nope = tl.load(k_nope_ptr + in_k_nope_offsets, mask=nope_mask, other=0.0) + + rope_range = tl.arange(0, BLOCK_ROPE) + in_k_rope_offsets = token_id * k_rope_ptr_stride_0 + rope_range + k_rope = tl.load(k_rope_ptr + in_k_rope_offsets) + + scale_range = tl.arange(0, BLOCK_SCALE) + scale_mask = scale_range < NUM_SCALE_ELEMS_PER_TOKEN + in_scale_k_offsets = token_id * scale_k_nope_ptr_stride_0 + scale_range + k_scale = tl.load(scale_k_nope_ptr + in_scale_k_offsets, mask=scale_mask, other=0) + + loc_page_index = loc // PAGE_SIZE + loc_token_offset_in_page = loc % PAGE_SIZE + + out_k_nope_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE + + loc_token_offset_in_page * NUM_NOPE_ROPE_BYTES_PER_TOKEN + + nope_range + ) + + out_k_rope_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE // 2 + + loc_token_offset_in_page * (NUM_NOPE_ROPE_BYTES_PER_TOKEN // 2) + + NUM_NOPE_ELEMS_PER_TOKEN // 2 + + rope_range + ) + + out_s_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE + + S_OFFSET_NBYTES_IN_PAGE + + loc_token_offset_in_page * PADDED_SCALE_ELEMS_PER_TOKEN + + scale_range + ) + + tl.store(buf_fp8_ptr + out_k_nope_offsets, k_nope, mask=nope_mask) + tl.store(buf_bf16_ptr + out_k_rope_offsets, k_rope) + tl.store(buf_uint8_ptr + out_s_offsets, k_scale, mask=scale_mask) + + +def _set_k_and_s_torch( + buf: torch.Tensor, + loc: torch.Tensor, + nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + page_size: int, +): + num_pages, buf_numel_per_page = buf.shape + (num_tokens_to_write,) = loc.shape + + k_nope, k_rope, scale_k_nope = ( + nope_fp8_rope_bf16_pack.k_nope_fp8, + nope_fp8_rope_bf16_pack.k_rope_bf16, + nope_fp8_rope_bf16_pack.scale_k_nope_ue8m0, + ) + + num_tokens_to_write_nope, nope_dim = k_nope.shape + num_tokens_to_write_rope, rope_dim = k_rope.shape + num_tokens_to_write_scale, scale_dim = scale_k_nope.shape + + assert ( + num_tokens_to_write + == num_tokens_to_write_nope + == num_tokens_to_write_rope + == num_tokens_to_write_scale + ), f"{num_tokens_to_write=} {num_tokens_to_write_nope=} {num_tokens_to_write_rope=} {num_tokens_to_write_scale=}" + + assert buf.dtype == torch.uint8 + assert loc.dtype in [ + torch.int64, + torch.int32, + ], f"{loc.dtype=}" + + assert k_nope.dtype == fp8_dtype + assert k_rope.dtype == torch.bfloat16 + assert scale_k_nope.dtype == torch.uint8 + + assert buf.is_contiguous() + assert loc.is_contiguous() + assert k_nope.is_contiguous() + assert k_rope.is_contiguous() + assert scale_k_nope.is_contiguous() + + buf_fp8 = buf.view(fp8_dtype).flatten() + buf_bf16 = buf.view(torch.bfloat16).flatten() + buf_scale = buf.view(torch.uint8).flatten() + + loc_page_index = loc // page_size + loc_token_offset_in_page = loc % page_size + + s_offset_nbytes_in_page = page_size * (nope_dim + rope_dim * 2) + + nope_offset = loc_page_index * buf_numel_per_page + loc_token_offset_in_page * ( + nope_dim + rope_dim * 2 + ) + + rope_offset = ( + loc_page_index * buf_numel_per_page // 2 + + (loc_token_offset_in_page * (nope_dim + rope_dim * 2) + nope_dim) // 2 + ) + + s_offset = ( + loc_page_index * buf_numel_per_page + + s_offset_nbytes_in_page + + loc_token_offset_in_page * (scale_dim + 1) + ) + + for i in range(num_tokens_to_write): + buf_fp8[nope_offset[i] : nope_offset[i] + nope_dim] = k_nope[i] + buf_bf16[rope_offset[i] : rope_offset[i] + rope_dim] = k_rope[i] + buf_scale[s_offset[i] : s_offset[i] + scale_dim] = scale_k_nope[i] diff --git a/python/sglang/srt/layers/attention/dsv4/indexer.py b/python/sglang/srt/layers/attention/dsv4/indexer.py new file mode 100644 index 000000000000..3bc982446681 --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/indexer.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from sglang.jit_kernel.deepseek_v4 import ( + fused_rope, + topk_transform_512, + topk_transform_512_v2, +) +from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config +from sglang.srt.environ import envs +from sglang.srt.layers.attention.dsv4.compressor import Compressor +from sglang.srt.layers.attention.dsv4.metadata import PagedIndexerMetadata +from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation +from sglang.srt.layers.attention.nsa.triton_kernel import act_quant +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.state_capturer.indexer_topk import get_global_indexer_capturer +from sglang.srt.utils import add_prefix, is_hip + +if TYPE_CHECKING: + from sglang.srt.layers.attention.deepseek_v4_backend import DeepseekV4AttnBackend + from sglang.srt.layers.attention.dsv4.compressor import ( + CompressorBackendMixin, + ) + from sglang.srt.layers.quantization import QuantizationConfig + from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +if is_hip(): + FP8_DTYPE = torch.float8_e4m3fnuz + FP8_MAX = torch.finfo(FP8_DTYPE).max +else: + FP8_DTYPE = torch.float8_e4m3fn + FP8_MAX = torch.finfo(FP8_DTYPE).max + + +def fp8_paged_mqa_logits_torch( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + + assert head_dim == 128, "torch reference impl hardcodes DSV4 indexer head_dim=128" + assert block_size == 64, "torch reference impl hardcodes block_size=64 cache layout" + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + for i in range(batch_size): + q = q_fp8[i, 0] + q = q.to(torch.float32) + q_scale = weight[i] + seq_len = int(seq_lens[i].item()) + assert seq_len <= max_seq_len + num_pages = (seq_len + block_size - 1) // block_size + padded_seq_len = num_pages * block_size + pages = page_table[i, :num_pages] + kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + kvcache = kvcache_fp8[pages] + SCALE_OFFSET = block_size * head_dim + kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE) + kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32) + kvcache_value = kvcache_value.to(torch.float32) + kvcache_scale = kvcache_scale.contiguous() + kvcache_value = kvcache_value.view(padded_seq_len, head_dim) + kvcache_scale = kvcache_scale.view(padded_seq_len) + score = F.linear(kvcache_value, q) + score = F.relu(score) + score *= q_scale[None, :] + score = score.sum(dim=1) + score *= kvcache_scale + logits[i, :seq_len] = score[:seq_len] + + return logits + + +def topk_transform_512_pytorch_vectorized( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, +) -> None: + + TOPK = 512 + batch_size = scores.shape[0] + max_seq_len = scores.shape[1] + device = scores.device + + page_bits = (page_size - 1).bit_length() if page_size > 1 else 0 + page_mask = page_size - 1 + + positions = ( + torch.arange(max_seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + ) + valid_mask = positions < seq_lens.unsqueeze(1) + + masked_scores = scores.clone() + masked_scores[~valid_mask] = float("-inf") + + actual_k = min(TOPK, max_seq_len) + _, raw_indices = torch.topk( + masked_scores, k=actual_k, dim=1, largest=True, sorted=False + ) + raw_indices = raw_indices.to(torch.int32) + + if actual_k < TOPK: + padding = torch.zeros( + (batch_size, TOPK - actual_k), dtype=torch.int32, device=device + ) + raw_indices = torch.cat([raw_indices, padding], dim=1) + + batch_indices = ( + torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, TOPK) + ) + gathered_scores = scores[ + batch_indices.flatten(), raw_indices.clamp(min=0).flatten() + ].view(batch_size, TOPK) + + valid_topk = gathered_scores != float("-inf") + if actual_k < TOPK: + pad_mask = torch.arange(TOPK, device=device).unsqueeze(0) >= actual_k + valid_topk = valid_topk & ~pad_mask + + needs_sequential = seq_lens <= TOPK + if needs_sequential.any(): + sequential_indices = ( + torch.arange(TOPK, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(batch_size, -1) + ) + sequential_valid = sequential_indices < seq_lens.unsqueeze(1) + + raw_indices = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), + torch.where( + sequential_valid, + sequential_indices, + torch.tensor(-1, device=device, dtype=torch.int32), + ), + raw_indices, + ) + valid_topk = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), sequential_valid, valid_topk + ) + + page_idx = raw_indices >> page_bits + offset_in_page = raw_indices & page_mask + + page_idx_clamped = torch.clamp(page_idx, min=0) + physical_pages = torch.gather(page_tables, dim=1, index=page_idx_clamped.long()) + + page_indices = (physical_pages << page_bits) | offset_in_page + page_indices = page_indices.to(torch.int32) + + page_indices = torch.where( + valid_topk, page_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) + + out_page_indices.copy_(page_indices) + + if out_raw_indices is not None: + raw_indices = torch.where( + valid_topk, raw_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) + out_raw_indices.copy_(raw_indices) + + +@triton.jit +def _fused_scale_kernel( + weight_ptr, + q_scale_ptr, + out_ptr, + numel, + out_scale, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < numel + + w = tl.load(weight_ptr + offs, mask=mask) + qs = tl.load(q_scale_ptr + offs, mask=mask) + + acc = w.to(tl.float32) * out_scale * qs.to(tl.float32) + tl.store(out_ptr + offs, acc.to(out_ptr.dtype.element_ty), mask=mask) + + +def fused_scale( + weight: torch.Tensor, + out_scale: float, + q_scale: torch.Tensor, +) -> torch.Tensor: + assert weight.is_contiguous() and q_scale.is_contiguous() + B, H = weight.shape + numel = B * H + out_dtype = torch.promote_types(weight.dtype, q_scale.dtype) + out = torch.empty((B, H, 1), device=weight.device, dtype=out_dtype) + BLOCK = 1024 + grid = (triton.cdiv(numel, BLOCK),) + _fused_scale_kernel[grid]( + weight, + q_scale, + out, + numel, + out_scale, + BLOCK=BLOCK, + ) + return out + + +class C4IndexerBackendMixin: + def __init__(self): + super().__init__() + self.debug_use_external_c4_sparse_indices: bool = False + + def _forward_prepare_multi_stream( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + positions: torch.Tensor, + forward_batch: ForwardBatch, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + alt_streams: Optional[List[torch.cuda.Stream]] = None, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if TYPE_CHECKING: + assert isinstance(self, CompressorBackendMixin) + + assert alt_streams is not None + assert len(alt_streams) >= 2 + current_stream = torch.cuda.current_stream() + stream_q = alt_streams[0] + stream_weights = alt_streams[1] + + stream_q.wait_stream(current_stream) + stream_weights.wait_stream(current_stream) + + self.forward_indexer_compressor( + x=x, + forward_batch=forward_batch, + layer_id=c4_indexer.layer_id, + compressor=c4_indexer.compressor, + ) + c4_indexer_kv_cache = token_to_kv_pool.get_index_k_with_scale_buffer( + layer_id=c4_indexer.layer_id, + ) + + with torch.cuda.stream(stream_q): + if q_lora_ready is not None: + stream_q.wait_event(q_lora_ready) + q = c4_indexer.compute_q(q_lora, positions=positions) + q_fp8, q_scale = act_quant(q) + q_scale_ready = stream_q.record_event() + + with torch.cuda.stream(stream_weights): + weights = c4_indexer.compute_weights(x, skip_scale=True) + stream_weights.wait_event(q_scale_ready) + weights = fused_scale(weights, c4_indexer.weight_scale, q_scale) + + current_stream.wait_stream(stream_q) + current_stream.wait_stream(stream_weights) + + return q_fp8, weights, c4_indexer_kv_cache + + def _forward_prepare_normal( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + positions: torch.Tensor, + forward_batch: ForwardBatch, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if TYPE_CHECKING: + assert isinstance(self, CompressorBackendMixin) + + q = c4_indexer.compute_q(q_lora, positions=positions) + q_fp8, q_scale = act_quant(q) + weights = c4_indexer.compute_weights(x, skip_scale=True) + weights = fused_scale(weights, c4_indexer.weight_scale, q_scale) + self.forward_indexer_compressor( + x=x, + forward_batch=forward_batch, + layer_id=c4_indexer.layer_id, + compressor=c4_indexer.compressor, + ) + c4_indexer_kv_cache = token_to_kv_pool.get_index_k_with_scale_buffer( + layer_id=c4_indexer.layer_id, + ) + return q_fp8, weights, c4_indexer_kv_cache + + def forward_c4_indexer( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + forward_batch: ForwardBatch, + alt_streams: Optional[List[torch.cuda.Stream]] = None, + enable_multi_stream: bool = False, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> None: + if forward_batch.forward_mode.is_idle(): + return + # PREP_IN_CG lazy upgrade: this runs from MQALayer._forward_prepare, + # before attn_backend.forward() would trigger the upgrade. + self._maybe_upgrade_forward_metadata() + token_to_kv_pool = forward_batch.token_to_kv_pool + + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + assert isinstance(self, CompressorBackendMixin) + + metadata = self.forward_metadata + indexer_metadata = metadata.indexer_metadata + core_metadata = metadata.core_metadata + + from sglang.srt.layers.attention.deepseek_v4_backend import ( + DSV4AttnMetadata, + ) + + assert isinstance(core_metadata, DSV4AttnMetadata) + assert isinstance(indexer_metadata, PagedIndexerMetadata) + + if enable_multi_stream: + q_fp8, weights, c4_indexer_kv_cache = self._forward_prepare_multi_stream( + x=x, + q_lora=q_lora, + c4_indexer=c4_indexer, + positions=core_metadata.positions, + forward_batch=forward_batch, + token_to_kv_pool=token_to_kv_pool, + alt_streams=alt_streams, + q_lora_ready=q_lora_ready, + ) + else: + assert q_lora_ready is None + q_fp8, weights, c4_indexer_kv_cache = self._forward_prepare_normal( + x=x, + q_lora=q_lora, + c4_indexer=c4_indexer, + positions=core_metadata.positions, + forward_batch=forward_batch, + token_to_kv_pool=token_to_kv_pool, + ) + + assert len(q_fp8.shape) == 3 + q_fp8 = q_fp8.unsqueeze(1) + assert len(c4_indexer_kv_cache.shape) == 2 + block_kv = 64 + num_heads_kv = 1 + head_dim_with_sf = 132 + + c4_indexer_kv_cache = c4_indexer_kv_cache.view( + c4_indexer_kv_cache.shape[0], block_kv, num_heads_kv, head_dim_with_sf + ) + assert len(weights.shape) == 3 + weights = weights.squeeze(2) + if envs.SGLANG_OPT_USE_TILELANG_INDEXER.get(): + from sglang.srt.layers.attention.dsv4.tilelang_kernel import ( + tilelang_fp8_paged_mqa_logits as fn, + ) + elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + fn = fp8_paged_mqa_logits_torch + else: + from deep_gemm import fp8_paged_mqa_logits as fn + + _c4sl = indexer_metadata.c4_seq_lens + if _c4sl.dim() == 1: + _c4sl = _c4sl.unsqueeze(-1) + logits = fn( + q_fp8, + c4_indexer_kv_cache, + weights, + _c4sl, + indexer_metadata.page_table, + indexer_metadata.deep_gemm_metadata, + indexer_metadata.max_c4_seq_len, + False, + ) + + assert indexer_metadata.page_table is core_metadata.page_table + if self.debug_use_external_c4_sparse_indices: + return + + indexer_capturer = get_global_indexer_capturer() + capture_enabled = indexer_capturer is not None + + hisparse_coordinator = forward_batch.hisparse_coordinator + hisparse_decode = ( + hisparse_coordinator is not None and forward_batch.forward_mode.is_decode() + ) + + raw_indices = None + if capture_enabled: + raw_indices = torch.empty_like(core_metadata.c4_sparse_page_indices) + elif hisparse_decode: + raw_indices = hisparse_coordinator.raw_indices_buffer[ + : core_metadata.c4_sparse_page_indices.size(0) + ] + + if envs.SGLANG_TOPK_TRANSFORM_512_TORCH.get(): + topk_transform_512_pytorch_vectorized( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + raw_indices, + ) + elif envs.SGLANG_OPT_USE_TOPK_V2.get() and raw_indices is None: + topk_transform_512_v2( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + indexer_metadata.topk_metadata, + ) + else: + topk_transform_512( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + raw_indices, + ) + if hisparse_coordinator is not None: + if hisparse_decode: + compress_layer_id = token_to_kv_pool.layer_mapping[ + c4_indexer.layer_id + ].compress_layer_id + core_metadata.c4_sparse_page_indices = ( + hisparse_coordinator.swap_in_selected_pages( + req_pool_indices=forward_batch.req_pool_indices, + compressed_seq_lens=indexer_metadata.c4_seq_lens, + top_k_result=raw_indices, + layer_id=compress_layer_id, + ) + ) + else: + core_metadata.c4_sparse_page_indices = ( + token_to_kv_pool.c4_kv_pool.translate_loc_to_hisparse_device( + core_metadata.c4_sparse_page_indices + ) + ) + + if capture_enabled: + compress_layer_id = token_to_kv_pool.layer_mapping[ + c4_indexer.layer_id + ].compress_layer_id + indexer_capturer.capture(compress_layer_id, raw_indices) + + +class C4Indexer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + freqs_cis: torch.Tensor, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + ): + super().__init__() + self.layer_id = layer_id + self.dim = config.hidden_size + self.n_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.q_lora_rank = config.q_lora_rank + self.softmax_scale = self.head_dim**-0.5 + self.n_local_heads = self.n_heads + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + params_dtype=torch.bfloat16, + prefix=add_prefix("wq_b", prefix), + ) + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=None, + params_dtype=torch.bfloat16, + prefix=add_prefix("weights_proj", prefix), + ) + self.compressor = Compressor( + config, + self.layer_id, + True, + freqs_cis, + compress_ratio=4, + head_dim=self.head_dim, + rotate=True, + prefix=add_prefix("compressor", prefix), + ) + self.freqs_cis = freqs_cis + self.weight_scale: float = self.softmax_scale * self.n_heads**-0.5 + self.alt_streams = alt_streams + + def compute_q(self, q_lora: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + q, _ = self.wq_b(q_lora) + q = q.view(-1, self.n_local_heads, self.head_dim) + fused_rope( + q[..., -self.rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + ) + q = rotate_activation(q) + return q + + def compute_weights(self, x: torch.Tensor, skip_scale=False) -> torch.Tensor: + out, _ = self.weights_proj(x) + if not skip_scale: + out = out * self.weight_scale + return out + + def forward( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + forward_batch: ForwardBatch, + enable_multi_stream: bool = False, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> None: + if TYPE_CHECKING: + assert isinstance(forward_batch.attn_backend, DeepseekV4AttnBackend) + return forward_batch.attn_backend.forward_c4_indexer( + x=x, + q_lora=q_lora, + forward_batch=forward_batch, + c4_indexer=self, + alt_streams=self.alt_streams, + enable_multi_stream=enable_multi_stream, + q_lora_ready=q_lora_ready, + ) diff --git a/python/sglang/srt/layers/attention/dsv4/metadata.py b/python/sglang/srt/layers/attention/dsv4/metadata.py new file mode 100644 index 000000000000..7995dbd959cb --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/metadata.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import warnings +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING, Any, List, Optional + +import torch + +from sglang.srt.environ import envs +from sglang.srt.utils import is_hip + +if TYPE_CHECKING: + pass + + +""" +Some comments on the common terms used in DeepSeekV4Backend: + +topk_lengths: + NOTE: TL;DR: topk_lengths == seq_lens + The FlashMLA sparse decode kernel will attend to `k` tokens for each query. + `topk_lengths` indicates how many tokens each query will attend to. + This should be named as `seq_lens`, but we simply follow the naming convention. + +page_table: + The page table indicates which pages each request is assigned to. + Each value in the page table is the page index in the TokenToKVPool. + This page index is irrelevant to the actual `page_size`. + +page_indices: + The real indices used to index into the KV cache. + This can be computed from the `page_table` and `page_size`. + e.g. page_indices[i, j] = page_table[i, j // page_size] * page_size + (j % page_size) + For sparse C4 top-512 attention, the indices will be selected from the C4 page indices. + In implementation, we don't materialize the full C4 `page_indices`, + but calculate them from `page_table` on-the-fly in the attention kernel. + +positions: + The position of the last token for each request. + For compress token, the positions must be times of compress ratio. + For example, for C4, raw_position=11 will trigger a compression, + But the RoPE's position, during compression, must be 8 instead of 11. + +Some other notes: + c4_ / c128_: means "compressed by 4" / "compressed by 128". + c4_page_size: page_size // 4 + c4_seq_lens: seq_lens // 4, but bounded by at least 1, due to flash_mla requirement. + c4_sparse: means "compressed by 4" but only attend to top-512 tokens. + all related length will be clipped to 512. +""" + + +def copy_metadata( + *, + src, + dst, + check_eq_fields: List[str], + copy_fields: List[str], + assign_fields: Optional[List[str]] = None, +): + assign_fields = assign_fields or [] + + for field_name in check_eq_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + assert src_val == dst_val, f"{field_name=} {src_val=} {dst_val=}" + + for field_name in copy_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + if src_val is None and dst_val is None: + continue + assert dst_val is not None, f"{field_name=} {src_val=} {dst_val=}" + if hasattr(dst_val, "copy_"): + dst_val.copy_(src_val) + else: + warnings.warn( + f"{field_name=} {type(dst_val)=} does not have copy_, use setattr" + ) + setattr(dst, field_name, src_val) + + for field_name in assign_fields: + setattr(dst, field_name, getattr(src, field_name)) + + provided_fields = check_eq_fields + copy_fields + assign_fields + provided_fields_unique = set(provided_fields) + assert len(provided_fields) == len( + provided_fields_unique + ), f"{provided_fields=} has dup" + all_fields = {f.name for f in fields(src)} + provided_fields = set(provided_fields) + assert ( + provided_fields == all_fields + ), f"{provided_fields - all_fields=}, {all_fields - provided_fields=}" + + +@dataclass +class PagedIndexerMetadata: + page_size: int + page_table: torch.Tensor + c4_seq_lens: torch.Tensor + deep_gemm_metadata: Any = field(init=False, repr=False) + topk_metadata: torch.Tensor = field(init=False, repr=False) + + def __post_init__(self): + if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + self.deep_gemm_metadata = None + else: + import deep_gemm + + if envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get(): + from sglang.jit_kernel.deepseek_v4 import get_paged_mqa_logits_metadata + else: + from deep_gemm import get_paged_mqa_logits_metadata + + _c4 = self.c4_seq_lens.to(torch.int32) + if _c4.dim() == 1: + _c4 = _c4.unsqueeze(-1) + self.deep_gemm_metadata = get_paged_mqa_logits_metadata( + _c4, + self.c4_page_size, + deep_gemm.get_num_sms(), + ) + + assert isinstance(self.deep_gemm_metadata, torch.Tensor) + + from sglang.jit_kernel.deepseek_v4 import plan_topk_v2 + + if envs.SGLANG_OPT_USE_TOPK_V2.get(): + self.topk_metadata = plan_topk_v2(self.c4_seq_lens) + else: + self.topk_metadata = torch.empty((0,)) + + assert self.page_size == 256, "the system hardcodes page_size=256" + + @property + def c4_page_size(self) -> int: + return self.page_size // 4 + + @property + def max_seq_len(self) -> int: + return self.page_table.shape[1] * self.page_size + + @property + def max_c4_seq_len(self) -> int: + return self.page_table.shape[1] * self.c4_page_size + + def copy_(self, other: "PagedIndexerMetadata"): + if is_hip(): + copy_fields = ["page_table", "c4_seq_lens"] + else: + copy_fields = ["page_table", "c4_seq_lens", "deep_gemm_metadata"] + copy_fields += ["topk_metadata"] + copy_metadata( + src=other, + dst=self, + check_eq_fields=["page_size"], + copy_fields=copy_fields, + ) + + +def maybe_copy_inplace(dst, *, src) -> None: + assert type(src) == type(dst) + if dst is not None: + dst.copy_(src) diff --git a/python/sglang/srt/layers/attention/dsv4/metadata_kernel.py b/python/sglang/srt/layers/attention/dsv4/metadata_kernel.py new file mode 100644 index 000000000000..ca14cb2624f1 --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/metadata_kernel.py @@ -0,0 +1,200 @@ +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _init_compressed_attn_metadata_kernel( + seq_lens_ptr, + positions_ptr, + raw_out_loc_ptr, + page_table_ptr, + c4_out_loc_ptr, + c4_positions_ptr, + c4_seq_lens_raw_ptr, + c4_seq_lens_clamp1_ptr, + c128_out_loc_ptr, + c128_positions_ptr, + c128_seq_lens_clamp1_ptr, + c128_page_indices_ptr, + bs, + max_pages, + page_size: tl.constexpr, + c128_max_seq_len: tl.constexpr, + c128_page_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + COMPUTE_PAGE_INDICES: tl.constexpr, +): + batch_id = tl.program_id(0) + if batch_id >= bs: + return + + seq_len = tl.load(seq_lens_ptr + batch_id) + position = tl.load(positions_ptr + batch_id) + raw_out_loc = tl.load(raw_out_loc_ptr + batch_id) + + c4_should_compress = (seq_len % 4) == 0 + c4_out_loc = tl.where(c4_should_compress, raw_out_loc // 4, 0) + c4_positions = position & (~3) + c4_seq_lens_raw = seq_len // 4 + c4_seq_lens_clamp1 = tl.maximum(c4_seq_lens_raw, 1) + + tl.store(c4_out_loc_ptr + batch_id, c4_out_loc) + tl.store(c4_positions_ptr + batch_id, c4_positions) + tl.store(c4_seq_lens_raw_ptr + batch_id, c4_seq_lens_raw) + tl.store(c4_seq_lens_clamp1_ptr + batch_id, c4_seq_lens_clamp1) + + c128_should_compress = (seq_len % 128) == 0 + c128_out_loc = tl.where(c128_should_compress, raw_out_loc // 128, 0) + c128_positions = position & (~127) + c128_seq_lens_raw = seq_len // 128 + c128_seq_lens_clamp1 = tl.maximum(c128_seq_lens_raw, 1) + + tl.store(c128_out_loc_ptr + batch_id, c128_out_loc) + tl.store(c128_positions_ptr + batch_id, c128_positions) + tl.store(c128_seq_lens_clamp1_ptr + batch_id, c128_seq_lens_clamp1) + + if COMPUTE_PAGE_INDICES: + page_indices_base = batch_id * c128_max_seq_len + for block_start in range(0, c128_max_seq_len, BLOCK_SIZE): + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < c128_max_seq_len + + page_idx = offsets // c128_page_size + offset_in_page = offsets % c128_page_size + + page_mask = mask & (page_idx < max_pages) + page_table_vals = tl.load( + page_table_ptr + batch_id * max_pages + page_idx, + mask=page_mask, + other=0, + ) + + c_page_indices_vals = page_table_vals * c128_page_size + offset_in_page + + valid_mask = offsets < c128_seq_lens_raw + c_page_indices_vals = tl.where(valid_mask, c_page_indices_vals, -1) + + tl.store( + c128_page_indices_ptr + page_indices_base + offsets, + c_page_indices_vals, + mask=mask, + ) + + +def _init_compressed_attn_metadata_triton( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: Optional[torch.Tensor] = None, + page_size: int = 0, + compute_page_indices: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], +]: + bs = seq_lens.shape[0] + device = seq_lens.device + + c4_out_loc = torch.empty(bs, dtype=torch.int32, device=device) + c4_positions = torch.empty(bs, dtype=torch.int32, device=device) + c4_seq_lens_raw = torch.empty(bs, dtype=torch.int32, device=device) + c4_seq_lens_clamp1 = torch.empty(bs, dtype=torch.int32, device=device) + + c128_out_loc = torch.empty(bs, dtype=torch.int32, device=device) + c128_positions = torch.empty(bs, dtype=torch.int32, device=device) + c128_seq_lens_clamp1 = torch.empty(bs, dtype=torch.int32, device=device) + + if compute_page_indices: + assert ( + page_table is not None + ), "page_table required when compute_page_indices=True" + assert page_size > 0, "page_size required when compute_page_indices=True" + max_pages = page_table.shape[1] + c128_page_size = page_size // 128 + c128_max_seq_len = c128_page_size * max_pages + c128_page_indices = torch.empty( + bs, c128_max_seq_len, dtype=torch.int32, device=device + ) + BLOCK_SIZE = triton.next_power_of_2(max(c128_page_size, 64)) + else: + max_pages = 0 + c128_page_size = 1 + c128_max_seq_len = 0 + c128_page_indices = None + BLOCK_SIZE = 64 + if page_table is None: + page_table = torch.empty(0, dtype=torch.int32, device=device) + + grid = (bs,) + _init_compressed_attn_metadata_kernel[grid]( + seq_lens, + positions, + raw_out_loc, + page_table, + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + ( + c128_page_indices + if c128_page_indices is not None + else torch.empty(0, dtype=torch.int32, device=device) + ), + bs, + max_pages, + page_size if page_size > 0 else 128, + c128_max_seq_len, + c128_page_size, + BLOCK_SIZE, + compute_page_indices, + ) + + return ( + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + c128_page_indices, + ) + + +def init_compression_metadata( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: Optional[torch.Tensor] = None, + page_size: int = 0, + compute_page_indices: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], +]: + return _init_compressed_attn_metadata_triton( + seq_lens, + positions, + raw_out_loc, + page_table, + page_size, + compute_page_indices, + ) diff --git a/python/sglang/srt/layers/attention/dsv4/quant_k_cache.py b/python/sglang/srt/layers/attention/dsv4/quant_k_cache.py new file mode 100644 index 000000000000..6370bcb8d8ce --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/quant_k_cache.py @@ -0,0 +1,120 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.dsv4.index_buf_accessor import NopeFp8RopeBf16Pack +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +@triton.jit +def _quant_k_cache_fused_kernel( + k_bf16_ptr, + k_nope_fp8_ptr, + k_rope_bf16_ptr, + scale_k_nope_uint8_ptr, + k_bf16_stride_0, + k_nope_fp8_stride_0, + k_rope_bf16_stride_0, + scale_stride_0, + DIM_NOPE: tl.constexpr, + DIM_ROPE: tl.constexpr, + TILE_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + EPS: tl.constexpr, +): + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + if tile_id == NUM_TILES: + rope_range = tl.arange(0, TILE_SIZE) + rope_mask = rope_range < DIM_ROPE + + in_rope_offsets = token_id * k_bf16_stride_0 + DIM_NOPE + rope_range + rope_data = tl.load(k_bf16_ptr + in_rope_offsets, mask=rope_mask, other=0.0) + + out_rope_offsets = token_id * k_rope_bf16_stride_0 + rope_range + tl.store(k_rope_bf16_ptr + out_rope_offsets, rope_data, mask=rope_mask) + else: + tile_range = tl.arange(0, TILE_SIZE) + + in_tile_offsets = token_id * k_bf16_stride_0 + tile_id * TILE_SIZE + tile_range + x_bf16 = tl.load(k_bf16_ptr + in_tile_offsets) + x_fp32 = x_bf16.to(tl.float32) + + abs_x = tl.abs(x_fp32) + max_abs = tl.max(abs_x) + max_abs_clamped = tl.maximum(max_abs, EPS) + scale = max_abs_clamped / FP8_MAX + + log2_scale = tl.log2(scale) + ceil_log2 = tl.math.ceil(log2_scale) + scale_pow2_fp32 = tl.exp2(ceil_log2) + scale_inv = 1.0 / scale_pow2_fp32 + x_scaled = x_fp32 * scale_inv + x_fp8 = tl.clamp(x_scaled, FP8_MIN, FP8_MAX).to(k_nope_fp8_ptr.dtype.element_ty) + + out_fp8_offsets = ( + token_id * k_nope_fp8_stride_0 + tile_id * TILE_SIZE + tile_range + ) + tl.store(k_nope_fp8_ptr + out_fp8_offsets, x_fp8) + + exponent = ceil_log2.to(tl.int32) + scale_uint8 = (exponent + 127).to(tl.uint8) + + out_scale_offset = token_id * scale_stride_0 + tile_id + tl.store(scale_k_nope_uint8_ptr + out_scale_offset, scale_uint8) + + +def quant_to_nope_fp8_rope_bf16_pack_triton( + k_bf16: torch.Tensor, +) -> NopeFp8RopeBf16Pack: + assert k_bf16.dtype == torch.bfloat16 + num_tokens, hidden_dim = k_bf16.shape + assert hidden_dim == 512 + dim_nope = 448 + dim_rope = 64 + tile_size = 64 + num_tiles = dim_nope // tile_size + + k_bf16 = k_bf16.contiguous() + + k_nope_fp8 = torch.empty( + (num_tokens, dim_nope), dtype=fp8_dtype, device=k_bf16.device + ) + k_rope_bf16 = torch.empty( + (num_tokens, dim_rope), dtype=torch.bfloat16, device=k_bf16.device + ) + scale_k_nope_ue8m0 = torch.empty( + (num_tokens, num_tiles), dtype=torch.uint8, device=k_bf16.device + ) + + fp8_dtype_info = torch.finfo(fp8_dtype) + + grid = (num_tokens, num_tiles + 1) + _quant_k_cache_fused_kernel[grid]( + k_bf16, + k_nope_fp8, + k_rope_bf16, + scale_k_nope_ue8m0, + k_bf16.stride(0), + k_nope_fp8.stride(0), + k_rope_bf16.stride(0), + scale_k_nope_ue8m0.stride(0), + DIM_NOPE=dim_nope, + DIM_ROPE=dim_rope, + TILE_SIZE=tile_size, + NUM_TILES=num_tiles, + FP8_MIN=fp8_dtype_info.min, + FP8_MAX=fp8_dtype_info.max, + EPS=1e-8, + ) + + return NopeFp8RopeBf16Pack( + k_nope_fp8=k_nope_fp8, + k_rope_bf16=k_rope_bf16, + scale_k_nope_ue8m0=scale_k_nope_ue8m0, + ) diff --git a/python/sglang/srt/layers/attention/dsv4/tilelang_kernel.py b/python/sglang/srt/layers/attention/dsv4/tilelang_kernel.py new file mode 100644 index 000000000000..f94c97146e32 --- /dev/null +++ b/python/sglang/srt/layers/attention/dsv4/tilelang_kernel.py @@ -0,0 +1,123 @@ +import functools +from typing import Any + +import tilelang +import tilelang.language as T +import torch + +from sglang.srt.utils import is_hip + +if is_hip(): + FP8 = "float8_e5m2fnuz" + FP8_ = torch.float8_e5m2 +else: + FP8 = "float8_e4m3" + FP8_ = torch.float8_e4m3fn +FP32 = "float32" +INT32 = "int32" + + +@functools.cache +def fp8_paged_mqa_logits_kernel( + head_dim: int = 128, + num_heads: int = 64, + block_size: int = 64, + clear_accum: bool = True, +) -> Any: + N = T.symbolic("batch_size") + L = T.symbolic("max_table_length") + S = T.symbolic("max_seq_len") + C = T.symbolic("num_blocks") + B = block_size + D = head_dim + H = num_heads + d_0, d_1 = T.dynamic("d_0, d_1") + + assert D % 4 == 0 + assert H % 4 == 0 + assert D == 128 + + @tilelang.jit + def fp8_paged_mqa_logits( + q: T.Tensor[(N, H, D), FP8], + kvcache: T.StridedTensor[(C, B, D), (d_0, D, 1), FP8], + kvcache_scale: T.StridedTensor[(C, B), (d_1, 1), FP32], + weight: T.Tensor[(N, H), FP32], + seq_lens: T.Tensor[(N,), INT32], + page_table: T.Tensor[(N, L), INT32], + o: T.Tensor[(N, S), FP32], + ) -> None: + _ = N, L, S, C, D, H, B, d_0, d_1 + with T.Kernel(N) as bx: + seq_len = seq_lens[bx] + q_smem = T.alloc_shared((H, D), FP8) + q_s_frag = T.alloc_fragment((H,), FP32) + T.copy(q[bx, 0, 0], q_smem) + T.copy(weight[bx, 0], q_s_frag) + + for i in T.Pipelined(T.ceildiv(seq_len, B), num_stages=2): + page = page_table[bx, i] + k_smem = T.alloc_shared((B, D), FP8) + k_s_frag = T.alloc_fragment((B,), FP32) + T.copy(kvcache[page, 0, 0], k_smem) + T.copy(kvcache_scale[page, 0], k_s_frag) + + logits = T.alloc_fragment((B, H), FP32) + if not clear_accum: + T.fill(logits, 0.0) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=clear_accum, + ) + + for h, j in T.Parallel(H, B): + logits[j, h] = T.max(logits[j, h], 0.0) * q_s_frag[h] + logits_sum = T.alloc_fragment((B,), FP32) + T.reduce_sum(logits, logits_sum, dim=1) + for j in T.Parallel(B): + logits_sum[j] *= k_s_frag[j] + T.copy(logits_sum, o[bx, i * B]) + + return fp8_paged_mqa_logits + + +def tilelang_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + assert head_dim == 128, "TODO" + assert block_size == 64, "TODO" + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + kernel = fp8_paged_mqa_logits_kernel( + head_dim=head_dim, + num_heads=num_heads, + block_size=block_size, + clear_accum=clean_logits, + ) + q_fp8 = q_fp8.view(batch_size, num_heads, head_dim) + kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + kvcache = kvcache_fp8[..., : block_size * head_dim].view(dtype=FP8_) + kvcache = kvcache.view(-1, block_size, head_dim) + kvcache_scale = kvcache_fp8[..., block_size * head_dim :].view(dtype=torch.float32) + kernel(q_fp8, kvcache, kvcache_scale, weight, seq_lens, page_table, logits) + return logits diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py index f66d6e9b6b9a..37499c524a99 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py @@ -32,6 +32,8 @@ def grouped_gemm_nt_f8f8bf16_masked( expected_m: int, overlap_args: Optional[Any] = None, max_block_n: int = 256, + recipe_a: Optional[Tuple[int, int]] = None, + recipe_b: Optional[Tuple[int, int]] = None, ): num_groups, _, k = lhs[0].shape _, n, _ = rhs[0].shape @@ -50,12 +52,19 @@ def grouped_gemm_nt_f8f8bf16_masked( overlap_args.num_sms if overlap_args is not None else None ): + fp4_kwargs = {} + if recipe_a is not None: + fp4_kwargs["recipe_a"] = recipe_a + if recipe_b is not None: + fp4_kwargs["recipe_b"] = recipe_b + return deep_gemm.fp8_m_grouped_gemm_nt_masked( lhs, rhs, out, masked_m, expected_m, + **fp4_kwargs, **( dict( enable_overlap=True, @@ -82,6 +91,8 @@ def grouped_gemm_nt_f8f8bf16_contig( rhs: Tuple[torch.Tensor, torch.Tensor], out: torch.Tensor, m_indices: torch.Tensor, + recipe_a: Optional[Tuple[int, int]] = None, + recipe_b: Optional[Tuple[int, int]] = None, ): m, k = lhs[0].shape num_groups, n, _ = rhs[0].shape @@ -93,8 +104,16 @@ def grouped_gemm_nt_f8f8bf16_contig( _sanity_check_input(lhs) _sanity_check_input(rhs) + fp4_kwargs = {} + if recipe_a is not None: + fp4_kwargs["recipe_a"] = recipe_a + if recipe_b is not None: + fp4_kwargs["recipe_b"] = recipe_b + with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): - deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + lhs, rhs, out, m_indices, **fp4_kwargs + ) def gemm_nt_f8f8bf16( diff --git a/python/sglang/srt/layers/deepseek_v4_rope.py b/python/sglang/srt/layers/deepseek_v4_rope.py new file mode 100644 index 000000000000..c717850c63f7 --- /dev/null +++ b/python/sglang/srt/layers/deepseek_v4_rope.py @@ -0,0 +1,179 @@ +import math +from functools import lru_cache +from typing import Optional + +import tilelang +import torch +import triton +import triton.language as tl + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + + +@lru_cache(2) +def precompute_freqs_cis( + dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow +) -> torch.Tensor: + + def find_correction_dim(num_rotations, dim, base, max_seq_len): + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if original_seq_len > 0: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + t = torch.arange(seqlen) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +@triton.jit +def apply_rotary_emb_triton_kernel( + x_ptr, + freqs_ptr, + positions_ptr, + rope_dim, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_freq_pos, + stride_freq_dim, + USE_POS: tl.constexpr, + IS_INVERSE: tl.constexpr, + IS_3D: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_dim = tl.program_id(2) + + if USE_POS: + position = tl.load(positions_ptr + pid_batch) + else: + position = pid_batch + + if IS_3D: + base_offset = pid_batch * stride_x_batch + pid_head * stride_x_head + else: + base_offset = pid_batch * stride_x_batch + + offs_pair = pid_dim * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs_pair < (rope_dim // 2) + + offs_x_real = base_offset + offs_pair * 2 * stride_x_dim + offs_x_imag = base_offset + (offs_pair * 2 + 1) * stride_x_dim + + x_real = tl.load(x_ptr + offs_x_real, mask=mask, other=0.0).to(tl.float32) + x_imag = tl.load(x_ptr + offs_x_imag, mask=mask, other=0.0).to(tl.float32) + + offs_freq_real = position * stride_freq_pos + offs_pair * 2 * stride_freq_dim + offs_freq_imag = position * stride_freq_pos + (offs_pair * 2 + 1) * stride_freq_dim + + freq_real = tl.load(freqs_ptr + offs_freq_real, mask=mask, other=0.0) + freq_imag = tl.load(freqs_ptr + offs_freq_imag, mask=mask, other=0.0) + + if IS_INVERSE: + out_real = x_real * freq_real + x_imag * freq_imag + out_imag = x_imag * freq_real - x_real * freq_imag + else: + out_real = x_real * freq_real - x_imag * freq_imag + out_imag = x_real * freq_imag + x_imag * freq_real + + tl.store(x_ptr + offs_x_real, out_real, mask=mask) + tl.store(x_ptr + offs_x_imag, out_imag, mask=mask) + + +def apply_rotary_emb_triton( + x: torch.Tensor, + freqs_cis: torch.Tensor, + positions: Optional[torch.Tensor] = None, + inverse: bool = False, +) -> torch.Tensor: + is_3d = x.ndim == 3 + + if is_3d: + batch_size, n_heads, rope_dim = x.shape + else: + batch_size, rope_dim = x.shape + n_heads = 1 + + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + + BLOCK_SIZE = 128 + + num_blocks_dim = triton.cdiv(rope_dim // 2, BLOCK_SIZE) + grid = (batch_size, n_heads if is_3d else 1, num_blocks_dim) + + if positions is not None: + assert positions.shape == ( + batch_size, + ), f"positions shape {positions.shape} != ({batch_size},)" + + apply_rotary_emb_triton_kernel[grid]( + x, + freqs_real, + positions, + rope_dim, + x.stride(0), + x.stride(1) if is_3d else 0, + x.stride(-1), + freqs_real.stride(0), + freqs_real.stride(1), + USE_POS=True, + IS_INVERSE=inverse, + IS_3D=is_3d, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + assert ( + freqs_real.shape[0] == batch_size + ), f"freqs_cis batch size {freqs_real.shape[0]} != x batch size {batch_size}" + + apply_rotary_emb_triton_kernel[grid]( + x, + freqs_real, + None, + rope_dim, + x.stride(0), + x.stride(1) if is_3d else 0, + x.stride(-1), + freqs_real.stride(0), + freqs_real.stride(1), + USE_POS=False, + IS_INVERSE=inverse, + IS_3D=is_3d, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return x diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py new file mode 100644 index 000000000000..1c27636efb5c --- /dev/null +++ b/python/sglang/srt/layers/mhc.py @@ -0,0 +1,643 @@ +import functools +import math +from typing import Tuple + +import tilelang +import tilelang.language as T +import torch + +from sglang.jit_kernel.utils import is_arch_support_pdl +from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_round_robin_split +from sglang.srt.layers.utils.common import strict_contiguous + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + + +@tilelang.jit(pass_configs=pass_configs) +def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float): + n = T.symbolic("n") + mix_hc = (2 + hc) * hc + threads = 64 + + ENABLE_PDL = is_arch_support_pdl() + + @T.prim_func + def hc_split_sinkhorn_kernel_( + mixes: T.Tensor[(n, mix_hc), FP32], + hc_scale: T.Tensor[(3,), T.float32], + hc_base: T.Tensor[(mix_hc,), T.float32], + pre: T.Tensor[(n, hc), FP32], + post: T.Tensor[(n, hc), FP32], + comb: T.Tensor[(n, hc, hc), FP32], + ): + with T.Kernel(n, threads=threads) as i: + if ENABLE_PDL: + T.pdl_sync() + + mixes_shared = T.alloc_shared(mix_hc, FP32) + comb_frag = T.alloc_fragment((hc, hc), FP32) + T.copy(mixes[i, :], mixes_shared) + + for j in T.Parallel(hc): + pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps + for j in T.Parallel(hc): + post[i, j] = 2 * T.sigmoid( + mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc] + ) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = ( + mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + + hc_base[j * hc + k + hc * 2] + ) + + row_sum = T.alloc_fragment(hc, FP32) + col_sum = T.alloc_fragment(hc, FP32) + + row_max = T.alloc_fragment(hc, FP32) + T.reduce_max(comb_frag, row_max, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j]) + T.reduce_sum(comb_frag, row_sum, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps + + T.reduce_sum(comb_frag, col_sum, dim=0) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) + + for _ in T.serial(sinkhorn_iters - 1): + T.reduce_sum(comb_frag, row_sum, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps) + T.reduce_sum(comb_frag, col_sum, dim=0) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) + + T.copy(comb_frag, comb[i, :, :]) + if ENABLE_PDL: + T.pdl_trigger() + + return hc_split_sinkhorn_kernel_ + + +def hc_split_sinkhorn( + mixes: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + hc_mult: int = 4, + sinkhorn_iters: int = 20, + eps: float = 1e-6, +): + b, s, _ = mixes.size() + pre = mixes.new_empty(b, s, hc_mult) + post = mixes.new_empty(b, s, hc_mult) + comb = mixes.new_empty(b, s, hc_mult, hc_mult) + kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps) + kernel( + mixes.view(-1, (2 + hc_mult) * hc_mult), + hc_scale, + hc_base, + pre.view(-1, hc_mult), + post.view(-1, hc_mult), + comb.view(-1, hc_mult, hc_mult), + ) + return pre, post, comb + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 16, + hc_mult: int = 4, +): + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + hidden_block = math.gcd(512, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] + hc_scale: T.Tensor[[3], T.float32] + hc_base: T.Tensor[[hc_mult3], T.float32] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(num_tokens, threads=96) as i: + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + + if ENABLE_PDL: + T.pdl_sync() + + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = ( + T.sigmoid( + mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult] + ) + * hc_post_mult_value + ) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = ( + mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + + hc_base[j * hc_mult + k + hc_mult * 2] + ) + + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + xs = T.alloc_shared((hc_mult, hidden_block), T.float32) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + + if ENABLE_PDL: + T.pdl_trigger() + + +@tilelang.jit +def mhc_pre_gemm_sqrsum_tilelang( + x, + fn, + out, + sqrsum, + hc_mult3: int, + hc_hidden_size: int, + token_block: int = 32, + hidden_block: int = 256, +) -> tilelang.JITKernel: + assert hc_mult3 <= 32 + num_tokens = T.dynamic("num_tokens") + assert hc_hidden_size % hidden_block == 0 + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) + fn: T.Tensor((hc_mult3, hc_hidden_size), T.float32) + out: T.Tensor((num_tokens, hc_mult3), T.float32) + sqrsum: T.Tensor((num_tokens), T.float32) + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(T.ceildiv(num_tokens, token_block)) as px: + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sqrsum_part = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sqrsum_part) + if ENABLE_PDL: + T.pdl_sync() + for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2): + x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout( + {x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)} + ) + + T.copy(x[px * token_block, pz * hidden_block], x_smem_16) + T.copy(fn[0, pz * hidden_block], fn_smem) + + x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem_16, x_frag_16) + x_frag = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_frag_16, x_frag) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j] + + T.gemm( + x_frag, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + sqrsum_l = T.alloc_fragment(token_block, T.float32) + T.reduce_sum(sqrsum_part, sqrsum_l) + for i in T.Parallel(token_block): + sqrsum[px * token_block + i] = sqrsum_l[i] + for i, j in T.Parallel(token_block, 32): + if j < hc_mult3: + out[px * token_block + i, j] = out_frag[i, j] + if ENABLE_PDL: + T.pdl_trigger() + + +@functools.cache +def mhc_pre_gemm_sqrsum_splitk_kernel( + hc_mult3: int, + hc_hidden_size: int, + split_k: int, + token_block: int = 32, + hidden_block: int = 256, + threads: int = 128, +) -> Tuple[tilelang.JITKernel, tilelang.JITKernel]: + assert hc_mult3 <= 32 + assert hc_hidden_size % hidden_block == 0 + assert hc_hidden_size % split_k == 0 + split_size = hc_hidden_size // split_k + assert split_size % hidden_block == 0 + + num_tokens = T.dynamic("num_tokens") + + ENABLE_PDL = is_arch_support_pdl() + + @tilelang.jit + def mhc_pre_gemm_sqrsum_splitk_stage_0( + x: T.Tensor[(num_tokens, hc_hidden_size), T.bfloat16], + fn: T.Tensor[(hc_mult3, hc_hidden_size), T.float32], + out_partial: T.Tensor[(split_k, num_tokens, 32), T.float32], + sqrsum_partial: T.Tensor[(split_k, num_tokens), T.float32], + ): + with T.Kernel(T.ceildiv(num_tokens, token_block), split_k, threads=threads) as ( + px, + bz, + ): + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sq_part4 = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sq_part4) + + k_base = bz * split_size + + if ENABLE_PDL: + T.pdl_sync() + + for pz in T.Pipelined(split_size // hidden_block, num_stages=2): + x_smem = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout( + {x_smem: tilelang.layout.make_swizzled_layout(x_smem)} + ) + + T.copy(x[px * token_block, k_base + pz * hidden_block], x_smem) + T.copy(fn[0, k_base + pz * hidden_block], fn_smem) + + x_f16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem, x_f16) + x_f = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_f16, x_f) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + v = x_f[i, jj * 4 + j] + sq_part4[i, j] += v * v + + T.gemm( + x_f, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + + sq_l = T.alloc_fragment((token_block,), T.float32) + T.reduce_sum(sq_part4, sq_l) + + for i in T.Parallel(token_block): + t = px * token_block + i + if t < num_tokens: + sqrsum_partial[bz, t] = sq_l[i] + + for i, j in T.Parallel(token_block, 32): + t = px * token_block + i + if t < num_tokens: + out_partial[bz, t, j] = out_frag[i, j] + + if ENABLE_PDL: + T.pdl_trigger() + + @tilelang.jit + def mhc_pre_gemm_sqrsum_splitk_stage_1( + out_partial: T.Tensor[(split_k, num_tokens, 32), T.float32], + sqrsum_partial: T.Tensor[(split_k, num_tokens), T.float32], + out: T.Tensor[(num_tokens, hc_mult3), T.float32], + sqrsum: T.Tensor[(num_tokens,), T.float32], + ): + warps_per_cta = threads // 32 + num_reduce = T.ceildiv(split_k, 32) + with T.Kernel(T.ceildiv(num_tokens, warps_per_cta), threads=threads) as (px,): + tx = T.get_thread_binding() + warp = tx // 32 + lane = tx % 32 + t = px * warps_per_cta + warp + s = T.alloc_local((1,), T.float32) + acc = T.alloc_local((1,), T.float32) + s[0] = 0 + acc[0] = 0 + if ENABLE_PDL: + T.pdl_sync() + + if t < num_tokens: + for r in T.serial(num_reduce): + bz = r * 32 + lane + s[0] += T.if_then_else(bz < split_k, sqrsum_partial[bz, t], 0.0) + sqrsum[t] = T.warp_reduce_sum(s[0]) + if lane < hc_mult3: + for bz in T.serial(split_k): + acc[0] += out_partial[bz, t, lane] + out[t, lane] = acc[0] + + if ENABLE_PDL: + T.pdl_trigger() + + return ( + mhc_pre_gemm_sqrsum_splitk_stage_0, + mhc_pre_gemm_sqrsum_splitk_stage_1, + ) + + +def mhc_pre( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + n_splits_pre: int = 32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + assert residual.dtype == torch.bfloat16 + assert fn.dtype == torch.float32 + assert hc_scale.dtype == torch.float32 + assert hc_base.dtype == torch.float32 + + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + hc_hidden_size = hc_mult * hidden_size + assert fn.shape[0] == hc_mult3 + assert fn.shape[1] == hc_hidden_size + assert hc_scale.shape == (3,) + assert hc_base.shape == (hc_mult3,) + + outer_shape = residual.shape[:-2] + + residual_flat = residual.view(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + fn_flat = fn + + post_mix = torch.empty( + num_tokens, hc_mult, dtype=torch.float32, device=residual.device + ) + comb_mix = torch.empty( + num_tokens, hc_mult2, dtype=torch.float32, device=residual.device + ) + layer_input = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device + ) + + gemm_out_mul = torch.empty( + n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device + ) + gemm_out_sqrsum = torch.empty( + n_splits, num_tokens, dtype=torch.float32, device=residual.device + ) + + if num_tokens <= 2048: + assert n_splits == 1 + if hc_hidden_size == 16384: + hidden_block = 256 + elif hc_hidden_size == 28672: + hidden_block = 128 + else: + raise NotImplementedError( + f"mhc_pre splitk kernel only supports hc_hidden_size in {{16384, 28672}}, " + f"got {hc_hidden_size}" + ) + kernel_0, kernel_1 = mhc_pre_gemm_sqrsum_splitk_kernel( + hc_mult3, + hc_hidden_size, + split_k=n_splits_pre, + token_block=32, + hidden_block=hidden_block, + ) + partial_out = gemm_out_mul.new_empty(n_splits_pre, num_tokens, 32) + partial_sqrsum = gemm_out_sqrsum.new_empty(n_splits_pre, num_tokens) + kernel_0( + residual_flat.view(num_tokens, hc_hidden_size), + fn_flat, + partial_out, + partial_sqrsum, + ) + kernel_1( + partial_out, + partial_sqrsum, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + ) + del partial_out, partial_sqrsum + else: + assert ( + n_splits == 1 + ), "The simple TileLang version gemm_sqrsum doesn't support split-k" + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + + return post_mix, comb_mix, layer_input + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_post_tilelang( + a, b, c, d, x, hc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024 +) -> tilelang.JITKernel: + n = T.dynamic("num_tokens") + h = hidden + + h_blk = math.gcd(hidden, h_blk) + a: T.Tensor((n, hc, hc), T.float32) + b: T.Tensor((n, hc, h), T.bfloat16) + c: T.Tensor((n, hc), T.float32) + d: T.Tensor((n, h), T.bfloat16) + x: T.Tensor((n, hc, h), T.bfloat16) + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(n, threads=n_thr) as i_n: + if ENABLE_PDL: + T.pdl_sync() + + x_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + b_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + d_shared = T.alloc_shared(h_blk, T.bfloat16) + + x_local = T.alloc_fragment((hc, h_blk), T.float32) + b_local = T.alloc_fragment((hc, h_blk), T.float32) + d_local = T.alloc_fragment(h_blk, T.float32) + + a_local = T.alloc_fragment((hc, hc), T.float32) + c_local = T.alloc_fragment(hc, T.float32) + T.copy(a[i_n, 0, 0], a_local) + T.copy(c[i_n, 0], c_local) + + for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2): + T.copy(b[i_n, 0, i0_h * h_blk], b_shared) + T.copy(d[i_n, i0_h * h_blk], d_shared) + + T.copy(b_shared, b_local) + T.copy(d_shared, d_local) + for i_hco, i1_h in T.Parallel(hc, h_blk): + x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h] + for i_hci in T.serial(hc): + x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] + T.copy(x_local, x_shared) + + T.copy(x_shared, x[i_n, 0, i0_h * h_blk]) + + if ENABLE_PDL: + T.pdl_trigger() + + +def mhc_post( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + if is_nsa_prefill_cp_round_robin_split(): + x = strict_contiguous(x) + residual = strict_contiguous(residual) + post_layer_mix = strict_contiguous(post_layer_mix) + comb_res_mix = strict_contiguous(comb_res_mix) + out = torch.empty_like(residual) + mhc_post_tilelang( + comb_res_mix, + residual, + post_layer_mix.squeeze(-1), + x, + out, + residual.shape[-2], + residual.shape[-1], + ) + return out diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py index 23184e94b483..a2f3f845eab2 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_marlin_moe.py @@ -1,6 +1,7 @@ from typing import Optional import torch +import torch.nn.functional as F from sglang.srt.utils import is_cuda from sglang.srt.utils.custom_op import register_custom_op @@ -14,9 +15,16 @@ from sglang.jit_kernel.moe_wna16_marlin import moe_wna16_marlin_gemm -def get_scalar_type(num_bits: int, has_zp: bool): +def get_scalar_type(num_bits: int, has_zp: bool, scales: Optional[torch.Tensor] = None): from sgl_kernel.scalar_type import scalar_types + if ( + not has_zp + and num_bits == 4 + and scales is not None + and scales.dtype == torch.float8_e8m0fnu + ): + return scalar_types.float4_e2m1f if has_zp: assert num_bits == 4 return scalar_types.uint4 @@ -24,6 +32,22 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 +def swiglu_limit_func( + output: torch.Tensor, + input: torch.Tensor, # first half is gate, second half is up + swiglu_limit: float = 0.0, +) -> None: + d = input.shape[1] // 2 + gate = input[:, :d] + up = input[:, d:] + + if swiglu_limit > 0: + gate = torch.clamp(gate, max=swiglu_limit) + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + + output.copy_(F.silu(gate) * up) + + @register_custom_op(out_shape="hidden_states") def fused_marlin_moe( hidden_states: torch.Tensor, @@ -47,6 +71,7 @@ def fused_marlin_moe( is_k_full: bool = True, inplace: bool = False, routed_scaling_factor: Optional[float] = None, + clamp_limit: Optional[float] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -86,12 +111,29 @@ def fused_marlin_moe( assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] - assert ( - hidden_states.dtype == w1_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" - assert ( - hidden_states.dtype == w2_scale.dtype - ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" + is_mxfp4_marlin = ( + num_bits == 4 + and w1_zeros is None + and w2_zeros is None + and w1_scale.dtype == torch.float8_e8m0fnu + and w2_scale.dtype == torch.float8_e8m0fnu + ) + if is_mxfp4_marlin: + assert w1_scale.dtype == torch.float8_e8m0fnu, ( + "MXFP4 Marlin expects w1_scale to be torch.float8_e8m0fnu, " + f"got {w1_scale.dtype}" + ) + assert w2_scale.dtype == torch.float8_e8m0fnu, ( + "MXFP4 Marlin expects w2_scale to be torch.float8_e8m0fnu, " + f"got {w2_scale.dtype}" + ) + else: + assert ( + hidden_states.dtype == w1_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w1_scale.dtype ({w1_scale.dtype})" + assert ( + hidden_states.dtype == w2_scale.dtype + ), f"moe_wna16_marlin_gemm assumes hidden_states.dtype ({hidden_states.dtype}) == w2_scale.dtype ({w2_scale.dtype})" assert num_bits in [4, 8] M, K = hidden_states.shape @@ -122,8 +164,8 @@ def fused_marlin_moe( max_workspace_size, dtype=torch.int, device=device, requires_grad=False ) - scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) - scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) + scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None, w1_scale) + scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None, w2_scale) intermediate_cache2 = torch.empty( (M * topk_ids.shape[1], N), @@ -143,7 +185,7 @@ def fused_marlin_moe( use_atomic_add = ( hidden_states.dtype == torch.half or torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 - ) + ) and (not is_mxfp4_marlin) intermediate_cache1 = moe_wna16_marlin_gemm( hidden_states, @@ -174,7 +216,14 @@ def fused_marlin_moe( is_zp_float=False, ) - silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) + if clamp_limit is not None: + swiglu_limit_func( + intermediate_cache2, + intermediate_cache1.view(-1, 2 * N), + clamp_limit, + ) + else: + silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2) if expert_map is not None: intermediate_cache3.zero_() @@ -210,12 +259,15 @@ def fused_marlin_moe( output = hidden_states if inplace else torch.empty_like(hidden_states) - if routed_scaling_factor is None: - routed_scaling_factor = 1.0 + if is_mxfp4_marlin: + return torch.sum(intermediate_cache3, dim=1, out=output) + else: + if routed_scaling_factor is None: + routed_scaling_factor = 1.0 - moe_sum_reduce( - intermediate_cache3, - output, - routed_scaling_factor, - ) - return output + moe_sum_reduce( + intermediate_cache3, + output, + routed_scaling_factor, + ) + return output diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index f34ffc10fea7..82543626af35 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -177,6 +177,7 @@ def __init__( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None, + swiglu_limit: Optional[float] = None, use_weight_loader_fused: bool = False, with_bias=False, routing_method_type: Optional[RoutingMethodType] = None, @@ -262,6 +263,7 @@ def __init__( routed_scaling_factor=routed_scaling_factor, gemm1_alpha=gemm1_alpha, gemm1_clamp_limit=gemm1_clamp_limit, + swiglu_limit=swiglu_limit, is_gated=is_gated, routing_method_type=routing_method_type, ) diff --git a/python/sglang/srt/layers/moe/hash_topk.py b/python/sglang/srt/layers/moe/hash_topk.py new file mode 100644 index 000000000000..6b63b286ae62 --- /dev/null +++ b/python/sglang/srt/layers/moe/hash_topk.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import logging +from typing import Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.environ import envs +from sglang.srt.eplb.expert_location_dispatch import ( + ExpertLocationDispatchInfo, + topk_ids_logical_to_physical, +) +from sglang.srt.layers.moe.topk import ( + StandardTopKOutput, + _mask_topk_ids_padded_region, +) +from sglang.srt.utils import is_hip + +logger = logging.getLogger(__name__) + + +class HashTopK(nn.Module): + def __init__( + self, + topk, + num_experts, + num_fused_shared_experts, + vocab_size, + scoring_func="sqrtsoftplus", + routed_scaling_factor=1.5, + apply_routed_scaling_factor_on_output=False, + ): + super().__init__() + self.num_experts = num_experts + self.topk = topk + self.routed_scaling_factor = routed_scaling_factor + self.num_fused_shared_experts = num_fused_shared_experts + self.score_func = scoring_func + self.tid2eid = nn.Parameter( + torch.empty(vocab_size, topk - num_fused_shared_experts, dtype=torch.int32), + requires_grad=False, + ) + + assert not apply_routed_scaling_factor_on_output, "not implemented" + + def empty_topk_output(self, device: torch.device): + topk = self.topk - self.num_fused_shared_experts + topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) + topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) + router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) + + def _forward_torch( + self, router_logits: torch.Tensor, input_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.score_func == "softmax": + scores = router_logits.softmax(dim=-1) + elif self.score_func == "sigmoid": + scores = router_logits.sigmoid() + else: + scores = torch.nn.functional.softplus(router_logits).sqrt() + + num_token = scores.shape[0] + + topk_ids = torch.zeros( + (num_token, self.topk), dtype=torch.int32, device=scores.device + ) + topk_weights = torch.zeros( + (num_token, self.topk), dtype=scores.dtype, device=scores.device + ) + + if self.num_fused_shared_experts == 1: + topk_ids[:, :-1] = self.tid2eid[input_ids] + topk_weights[:, :-1] = scores.gather(1, topk_ids[:, :-1]) + + if self.score_func != "softmax": + topk_weights[:, :-1] /= topk_weights[:, :-1].sum(dim=-1, keepdim=True) + + topk_ids[:, -1] = torch.randint( + low=self.num_experts, + high=self.num_experts + self.num_fused_shared_experts, + size=(num_token,), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / self.routed_scaling_factor + ) + else: + topk_ids[:, :] = self.tid2eid[input_ids] + topk_weights[:, :] = scores.gather(1, topk_ids[:, :]) + if self.score_func != "softmax": + topk_weights[:, :] /= topk_weights[:, :].sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + input_ids: torch.Tensor, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ): + assert ( + input_ids.shape[0] == hidden_states.shape[0] == router_logits.shape[0] + ), f"{input_ids.shape=} {hidden_states.shape=} {router_logits.shape=}" + + if envs.SGLANG_OPT_USE_FUSED_HASH_TOPK.get(): + from sglang.jit_kernel.deepseek_v4 import hash_topk + + topk_weights, topk_ids = hash_topk( + router_logits=router_logits, + input_ids=input_ids, + tid2eid=self.tid2eid, + num_fused_shared_experts=self.num_fused_shared_experts, + routed_scaling_factor=self.routed_scaling_factor, + scoring_func=self.score_func, + ) + else: + topk_weights, topk_ids = self._forward_torch(router_logits, input_ids) + + if is_hip(): + topk_weights = topk_weights.to(torch.float32) + + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + topk_output = StandardTopKOutput( + topk_weights=topk_weights, topk_ids=topk_ids, router_logits=router_logits + ) + return topk_output diff --git a/python/sglang/srt/layers/moe/mega_moe.py b/python/sglang/srt/layers/moe/mega_moe.py new file mode 100644 index 000000000000..f94930b8a4cb --- /dev/null +++ b/python/sglang/srt/layers/moe/mega_moe.py @@ -0,0 +1,289 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mega-MoE forward path and expert-weight prep shared by Deepseek V2/V4.""" + +from __future__ import annotations + +from contextlib import nullcontext +from typing import TYPE_CHECKING, Optional + +import torch + +from sglang.jit_kernel.deepseek_v4 import mega_moe_pre_dispatch +from sglang.srt.environ import envs +from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo +from sglang.srt.layers.dp_attention import get_dp_global_num_tokens +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + +if TYPE_CHECKING: + from deep_gemm import SymmBuffer + + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.models.deepseek_v2 import DeepseekV2MoE + + +_MEGA_MOE_SYMM_BUFFER: dict = {} + + +def _get_mega_moe_symm_buffer( + group, + num_experts: int, + num_max_tokens_per_rank: int, + num_topk: int, + hidden: int, + intermediate_hidden: int, +) -> SymmBuffer: + import deep_gemm + + key = ( + id(group), + num_max_tokens_per_rank, + num_experts, + num_topk, + hidden, + intermediate_hidden, + ) + buf = _MEGA_MOE_SYMM_BUFFER.get(key) + if buf is None: + buf = deep_gemm.get_symm_buffer_for_mega_moe( + group, + num_experts, + num_max_tokens_per_rank, + num_topk, + hidden, + intermediate_hidden, + use_fp8_dispatch=True, + activation="swiglu", + ) + _MEGA_MOE_SYMM_BUFFER[key] = buf + return buf + + +def should_use_mega_moe(moe: "DeepseekV2MoE", hidden_states: torch.Tensor) -> bool: + if not envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): + return False + if not getattr(moe.experts, "_mega_moe_weights_built", False): + return False + if get_is_capture_mode(): + return True + + global_num_tokens = get_dp_global_num_tokens() + if global_num_tokens: + max_tokens_per_rank = max(global_num_tokens) + else: + max_tokens_per_rank = hidden_states.shape[0] + cap = envs.SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK.get() + return max_tokens_per_rank <= cap + + +def forward_mega_moe( + moe: "DeepseekV2MoE", + hidden_states: torch.Tensor, + forward_batch: Optional["ForwardBatch"] = None, + input_ids_global: Optional[torch.Tensor] = None, +) -> torch.Tensor: + num_tokens = hidden_states.shape[0] + + sbo_overlap_flag = ( + moe.alt_stream is not None + and moe.num_fused_shared_experts == 0 + and num_tokens > 0 + and get_is_capture_mode() + ) + + if sbo_overlap_flag: + current_stream = torch.cuda.current_stream() + moe.alt_stream.wait_stream(current_stream) + shared_output = moe._forward_shared_experts(hidden_states) + mega_stream_ctx = torch.cuda.stream(moe.alt_stream) + else: + shared_output = moe._forward_shared_experts(hidden_states) + mega_stream_ctx = nullcontext() + + with mega_stream_ctx: + y = _run_mega_routed( + moe, hidden_states, forward_batch, input_ids_global, num_tokens + ) + + if sbo_overlap_flag: + current_stream.wait_stream(moe.alt_stream) + + if shared_output is not None: + y.add_(shared_output) + return y + + +def _run_mega_routed( + moe: "DeepseekV2MoE", + hidden_states: torch.Tensor, + forward_batch: Optional["ForwardBatch"], + input_ids_global: Optional[torch.Tensor], + num_tokens: int, +) -> torch.Tensor: + import deep_gemm + + from sglang.srt.distributed.parallel_state import get_moe_ep_group + + hidden_size = moe.config.hidden_size + + if num_tokens > 0: + router_logits = moe.gate(hidden_states, forward_batch=forward_batch) + topk_kwargs = {"input_ids": input_ids_global} if moe.is_hash else {} + topk_output = moe.topk( + hidden_states, + router_logits, + num_token_non_padded=( + forward_batch.num_token_non_padded + if forward_batch is not None + else None + ), + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=moe.layer_id, + ), + **topk_kwargs, + ) + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + else: + topk_ids = None + topk_weights = None + + ep_group = get_moe_ep_group().device_group + num_experts = moe.experts.num_experts + top_k = moe.config.num_experts_per_tok + moe.num_fused_shared_experts + intermediate_size = moe.config.moe_intermediate_size + num_max_tokens_per_rank = ( + envs.SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK.get() + ) + assert num_tokens <= num_max_tokens_per_rank, ( + f"mega MoE: num_tokens={num_tokens} exceeds cap " + f"SGLANG_OPT_DEEPGEMM_MEGA_MOE_NUM_MAX_TOKENS_PER_RANK=" + f"{num_max_tokens_per_rank}; raise the env var or shrink " + f"cuda_graph_max_bs / chunked_prefill_size accordingly" + ) + + buf = _get_mega_moe_symm_buffer( + ep_group, + num_experts=num_experts, + num_max_tokens_per_rank=num_max_tokens_per_rank, + num_topk=top_k, + hidden=hidden_size, + intermediate_hidden=intermediate_size, + ) + + if num_tokens > 0: + topk_ids_in = topk_ids + topk_weights_in = topk_weights + else: + topk_ids_in = hidden_states.new_empty((0, top_k), dtype=torch.int32) + topk_weights_in = hidden_states.new_empty((0, top_k), dtype=torch.float32) + mega_moe_pre_dispatch( + hidden_states, + topk_ids_in, + topk_weights_in, + buf.x, + buf.x_sf, + buf.topk_idx, + buf.topk_weights, + quant_group_size=32, + ) + + # Allocate at least one row so y has a non-null CUDA data_ptr; + # the DeepGEMM tvm-ffi binding rejects nullptr in convert_to_torch_tensor(). + y = torch.empty( + (max(num_tokens, 1), hidden_size), + dtype=torch.bfloat16, + device=hidden_states.device, + ) + swiglu_limit = getattr(moe.config, "swiglu_limit", None) + deep_gemm.fp8_fp4_mega_moe( + y, + moe.experts.mega_l1_weights, + moe.experts.mega_l2_weights, + buf, + recipe=(1, 1, 32), + activation="swiglu", + activation_clamp=swiglu_limit, + fast_math=True, + ) + y = y[:num_tokens] + + if not moe.experts.should_fuse_routed_scaling_factor_in_topk: + y.mul_(moe.routed_scaling_factor) + return y + + +def build_mega_moe_experts_weights(experts) -> None: + from deep_gemm import ( + transform_sf_into_required_layout, + transform_weights_for_mega_moe, + ) + from deep_gemm.mega import _interleave_l1_weights, _transpose_sf_for_utccp + + if getattr(experts, "_mega_moe_weights_built", False): + return + + w13 = experts.w13_weight.data + w13_sf_fp32 = experts.w13_weight_scale_inv.data + w2 = experts.w2_weight.data + w2_sf_fp32 = experts.w2_weight_scale_inv.data + + num_groups, n1, half_k1 = w13.shape + k1 = half_k1 * 2 + _, n2, half_k2 = w2.shape + k2 = half_k2 * 2 + + w13_sf = transform_sf_into_required_layout( + w13_sf_fp32, + mn=n1, + k=k1, + recipe=(1, 32), + num_groups=num_groups, + disable_ue8m0_cast=False, + ) + w2_sf = transform_sf_into_required_layout( + w2_sf_fp32, + mn=n2, + k=k2, + recipe=(1, 32), + num_groups=num_groups, + disable_ue8m0_cast=False, + ) + + if envs.SGLANG_OPT_FIX_MEGA_MOE_MEMORY.get(): + # Build the interleaved L1 weight + scale once; share the weight buffer + # between `w13_weight.data` (normal deep-ep path) and `mega_l1_weights[0]` + # (mega moe path). Mega moe additionally needs a UTCCP-transposed scale; + # the deep-ep path consumes the non-transposed interleaved scale and a + # swizzle-aware activation kernel. L2 weight is untouched by the mega + # transform, so the existing `w2_weight.data` is shared directly. + w13_interleaved, w13_sf_interleaved = _interleave_l1_weights((w13, w13_sf)) + w13_sf_utccp = _transpose_sf_for_utccp(w13_sf_interleaved) + w2_sf_utccp = _transpose_sf_for_utccp(w2_sf) + + experts.w13_weight.data = w13_interleaved + experts.w13_weight_scale_inv.data = w13_sf_interleaved + experts.w2_weight_scale_inv.data = w2_sf + experts.w13_weight_scale_inv.format_ue8m0 = True + experts.w2_weight_scale_inv.format_ue8m0 = True + + experts.mega_l1_weights = (experts.w13_weight.data, w13_sf_utccp) + experts.mega_l2_weights = (experts.w2_weight.data, w2_sf_utccp) + else: + l1_pair, l2_pair = transform_weights_for_mega_moe((w13, w13_sf), (w2, w2_sf)) + + experts.mega_l1_weights = l1_pair + experts.mega_l2_weights = l2_pair + + experts._mega_moe_weights_built = True diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index 9bfe4cc46e37..8412e9fba58c 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -48,6 +48,7 @@ class MoeRunnerConfig: routed_scaling_factor: Optional[float] = None gemm1_alpha: Optional[float] = None gemm1_clamp_limit: Optional[float] = None + swiglu_limit: Optional[float] = None @dataclass diff --git a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py index 3521a44bbe94..cfdee4757150 100644 --- a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py +++ b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py @@ -1,10 +1,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Tuple +import einops import torch +from sglang.jit_kernel.deepseek_v4 import silu_and_mul_masked_post_quant +from sglang.srt.environ import envs from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.moe_runner.base import ( MoeQuantInfo, @@ -45,8 +48,11 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_musa = is_musa() +# Imported only for the SGLANG_OPT_FIX_MEGA_MOE_MEMORY=False fallback path. if not (_is_npu or _is_hip) and _is_cuda: - from sglang.jit_kernel.activation import silu_and_mul + from sglang.jit_kernel.activation import silu_and_mul as _legacy_silu_and_mul +else: + _legacy_silu_and_mul = None _MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT") @@ -109,6 +115,8 @@ class DeepGemmMoeQuantInfo(MoeQuantInfo): w13_scale: Optional[torch.Tensor] = None w2_scale: Optional[torch.Tensor] = None block_shape: Optional[List[int]] = None + # DSV4 mxfp4 layout flag; selects recipe_a=(1,128)/recipe_b=(1,32) downstream. + is_fp4_experts: bool = False class DeepGemmRunnerCore(MoeRunnerCore): @@ -116,6 +124,13 @@ def __init__(self, config: MoeRunnerConfig): super().__init__(config) assert self.config.activation == "silu" assert self.config.is_gated + self.swiglu_limit = self.config.swiglu_limit + self.use_swizzle = False + if envs.SGLANG_OPT_FIX_MEGA_MOE_MEMORY.get(): + assert envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get() + assert envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get() + assert envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get() + self.use_swizzle = True def run( self, @@ -140,9 +155,10 @@ def _run_contiguous_gemm( quant_info: DeepGemmMoeQuantInfo, running_state: dict, ) -> torch.Tensor: + from sglang.jit_kernel.deepseek_v4 import silu_and_mul_contig_post_quant from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_fp8, + create_per_token_group_quant_fp8_output_scale, ) hidden_states = runner_input.hidden_states @@ -157,6 +173,10 @@ def _run_contiguous_gemm( K = hidden_states_shape[1] scale_block_size = 128 + recipe_a, recipe_b = ( + ((1, 128), (1, 32)) if quant_info.is_fp4_experts else (None, None) + ) + w13_weight_fp8 = ( quant_info.w13_weight, quant_info.w13_scale, @@ -176,30 +196,69 @@ def _run_contiguous_gemm( w13_weight_fp8, gateup_output, m_indices, + recipe_a=recipe_a, + recipe_b=recipe_b, ) dispose_tensor(hidden_states) dispose_tensor(hidden_states_scale) - down_input = torch.empty( - ( - all_tokens, - N // 2, - ), - device=gateup_output.device, - dtype=torch.bfloat16, - ) - silu_and_mul(gateup_output.view(-1, N), down_input) - del gateup_output + if envs.SGLANG_OPT_FIX_MEGA_MOE_MEMORY.get(): + swiglu_limit_arg: Optional[float] = self.swiglu_limit - down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( - down_input, - scale_block_size, - column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - ) - del down_input + down_input_fp8 = torch.empty( + (all_tokens, N // 2), + device=gateup_output.device, + dtype=torch.float8_e4m3fn, + ) + down_input_scale = create_per_token_group_quant_fp8_output_scale( + x_shape=(all_tokens, N // 2), + device=gateup_output.device, + group_size=scale_block_size, + column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + silu_and_mul_contig_post_quant( + input=gateup_output, + output=down_input_fp8, + output_scale=down_input_scale, + quant_group_size=scale_block_size, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + transposed=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + swiglu_limit=swiglu_limit_arg, + swizzle=self.use_swizzle, + ) + del gateup_output + else: + # Hacky byte-equal fallback that reproduces the optimize-branch + # code path exactly: bf16 silu_and_mul then a separate per-token + # group fp8 quant. Kept behind the mega-moe-memory flag. + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_fp8, + ) + + if self.swiglu_limit is not None: + gateup_output = _apply_swiglu_limit( + gateup_output, swiglu_limit=self.swiglu_limit + ) + + down_input = torch.empty( + (all_tokens, N // 2), + device=gateup_output.device, + dtype=torch.bfloat16, + ) + _legacy_silu_and_mul(gateup_output.view(-1, N), down_input) + del gateup_output + + down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8( + down_input, + scale_block_size, + column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + del down_input down_output = torch.empty( (all_tokens, K), @@ -214,6 +273,8 @@ def _run_contiguous_gemm( w2_weight_fp8, down_output, m_indices, + recipe_a=recipe_a, + recipe_b=recipe_b, ) return down_output @@ -225,12 +286,6 @@ def _run_masked_gemm( running_state: dict, ) -> torch.Tensor: from sglang.srt.layers import deep_gemm_wrapper - from sglang.srt.layers.moe.ep_moe.kernels import ( - silu_and_mul_masked_post_quant_fwd, - ) - from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_8bit, - ) hidden_states = runner_input.hidden_states hidden_states_scale = runner_input.hidden_states_scale @@ -242,6 +297,10 @@ def _run_masked_gemm( w13_scale = quant_info.w13_scale w2_scale = quant_info.w2_scale + recipe_a, recipe_b = ( + ((1, 128), (1, 32)) if quant_info.is_fp4_experts else (None, None) + ) + hidden_states_device = running_state["hidden_states_device"] # GroupGemm-0 @@ -270,51 +329,45 @@ def _run_masked_gemm( gateup_output, masked_m, expected_m, + recipe_a=recipe_a, + recipe_b=recipe_b, ) dispose_tensor(hidden_states) dispose_tensor(hidden_states_scale) + swiglu_limit_arg: Optional[float] = None + if self.swiglu_limit is not None: + # DeepSeek V4: clamped swiglu requires JIT EP activation; the + # FAST_ACT fused-quant path doesn't carry a swiglu_limit arg. + assert ( + not _MASKED_GEMM_FAST_ACT + ), "DeepSeek V4 does not support SGLANG_MASKED_GEMM_FAST_ACT" + assert ( + envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get() + ), "DeepSeek V4 requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + + if envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get(): + swiglu_limit_arg = self.swiglu_limit + else: + gateup_output = einops.rearrange( + gateup_output, "grp tok hidden -> (grp tok) hidden" + ) + gateup_output = _apply_swiglu_limit( + gateup_output, swiglu_limit=self.swiglu_limit + ) + gateup_output = einops.rearrange( + gateup_output, "(grp tok) hidden -> grp tok hidden", grp=num_groups + ) + # Act - scale_block_size = 128 - if _MASKED_GEMM_FAST_ACT: - down_input, down_input_scale = sglang_per_token_group_quant_8bit( - x=gateup_output, - dst_dtype=torch.float8_e4m3fn, - group_size=scale_block_size, - masked_m=masked_m, - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - fuse_silu_and_mul=True, - enable_v2=True, - ) - else: - down_input = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2, - ), - device=hidden_states_device, - dtype=torch.float8_e4m3fn, - ) - down_input_scale = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2 // scale_block_size, - ), - device=hidden_states_device, - dtype=torch.float32, - ) - silu_and_mul_masked_post_quant_fwd( - gateup_output, - down_input, - down_input_scale, - scale_block_size, - masked_m, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - ) + down_input, down_input_scale = _varlen_deep_gemm_silu_mul_quant( + gateup_output, + masked_m, + group_size=128, + topk=self.config.top_k, + swiglu_limit=swiglu_limit_arg, + swizzle=self.use_swizzle, + ) del gateup_output # GroupGemm-1 @@ -348,6 +401,8 @@ def _run_masked_gemm( down_output, masked_m, expected_m, + recipe_a=recipe_a, + recipe_b=recipe_b, **gemm_overlap_args_dict, ) meta_overlap_args = running_state.get("meta_overlap_args", None) @@ -616,3 +671,113 @@ def post_permute_deep_gemm_to_deepep_normal( topk_ids=running_state["topk_ids"], topk_weights=running_state["topk_weights"], ) + + +def _varlen_deep_gemm_silu_mul_quant( + gateup_output: torch.Tensor, + masked_m: Optional[torch.Tensor], + group_size: int, + topk: int, + swiglu_limit: Optional[float] = None, + swizzle: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_8bit, + ) + + if _MASKED_GEMM_FAST_ACT: + assert not swizzle, ( + "SGLANG_OPT_FIX_MEGA_MOE_MEMORY is incompatible with " + "SGLANG_MASKED_GEMM_FAST_ACT (swizzled layout only supported by JIT act)" + ) + assert ( + swiglu_limit is None + ), "swiglu_limit (DeepSeek V4) is not supported together with SGLANG_MASKED_GEMM_FAST_ACT" + return sglang_per_token_group_quant_8bit( + x=gateup_output, + dst_dtype=torch.float8_e4m3fn, + group_size=group_size, + masked_m=masked_m, + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + fuse_silu_and_mul=True, + enable_v2=True, + ) + + assert masked_m is not None + hidden_states_device = gateup_output.device + E, N, D_2 = gateup_output.shape + D = D_2 // 2 + del D_2 + G = D // group_size + down_input = torch.empty( + (E, N, D), + device=hidden_states_device, + dtype=torch.float8_e4m3fn, + ) + + if envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get(): + assert N % 4 == 0 and G % 4 == 0 + packed_ue8m0 = deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + down_input_scale = torch.empty( + (E, G // 4, N) if packed_ue8m0 else (E, N, G), + device=hidden_states_device, + dtype=torch.int32 if packed_ue8m0 else torch.float32, + ) + silu_and_mul_masked_post_quant( + gateup_output, + down_input, + down_input_scale, + group_size, + masked_m, + scale_ue8m0=packed_ue8m0, + topk=topk, + transposed=packed_ue8m0, + swiglu_limit=swiglu_limit, + swizzle=swizzle, + ) + if packed_ue8m0: + down_input_scale = down_input_scale.transpose(-1, -2) + else: + assert ( + swiglu_limit is None + ), "swiglu_limit (DeepSeek V4) requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + assert ( + not swizzle + ), "SGLANG_OPT_FIX_MEGA_MOE_MEMORY requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + down_input_scale = torch.empty( + (E, N, G), + device=hidden_states_device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + gateup_output, + down_input, + down_input_scale, + group_size, + masked_m, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + return down_input, down_input_scale + + +def _apply_swiglu_limit( + gateup_output: torch.Tensor, swiglu_limit: float +) -> torch.Tensor: + assert swiglu_limit == 10 + + num_tokens, hidden_size_x2 = gateup_output.shape + assert gateup_output.dtype == torch.bfloat16 + + gate, up = torch.chunk(gateup_output, chunks=2, dim=-1) + assert gate.shape == (num_tokens, hidden_size_x2 // 2) + assert up.shape == (num_tokens, hidden_size_x2 // 2) + + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + gate = torch.clamp(gate, max=swiglu_limit) + + out = torch.cat([gate, up], dim=-1) + assert out.shape == (num_tokens, hidden_size_x2) + return out diff --git a/python/sglang/srt/layers/moe/moe_runner/marlin.py b/python/sglang/srt/layers/moe/moe_runner/marlin.py index 45104dd27805..4e335f330694 100644 --- a/python/sglang/srt/layers/moe/moe_runner/marlin.py +++ b/python/sglang/srt/layers/moe/moe_runner/marlin.py @@ -97,8 +97,26 @@ def fused_experts_none_to_marlin( hidden_states.device, max_blocks_per_sm=4 ) + marlin_hidden_states = hidden_states + # Avoid aliasing the MoE input buffer until Marlin output semantics are + # fully validated across shared-expert and overlap paths. + marlin_inplace = False + if ( + quant_info.weight_bits == 4 + and quant_info.w13_qzeros is None + and quant_info.w2_qzeros is None + and quant_info.w13_scales.dtype == torch.float8_e8m0fnu + and quant_info.w2_scales.dtype == torch.float8_e8m0fnu + and hidden_states.dtype == torch.float16 + ): + # MXFP4(E8M0) Marlin kernels are only numerically valid on the bf16 + # activation path. The fp16 + E8M0 path is intentionally not generated + # in sgl-kernel, so upcast activations here and cast the result back. + marlin_hidden_states = hidden_states.to(torch.bfloat16) + marlin_inplace = False + output = fused_marlin_moe( - hidden_states=hidden_states, + hidden_states=marlin_hidden_states, w1=quant_info.w13_qweight, w2=quant_info.w2_qweight, w1_scale=quant_info.w13_scales, @@ -116,8 +134,9 @@ def fused_experts_none_to_marlin( workspace=MARLIN_MOE_WORKSPACE, num_bits=quant_info.weight_bits, is_k_full=quant_info.is_k_full, - inplace=runner_config.inplace, + inplace=marlin_inplace, routed_scaling_factor=runner_config.routed_scaling_factor, + clamp_limit=runner_config.swiglu_limit, ).to(hidden_states.dtype) return StandardCombineInput( diff --git a/python/sglang/srt/layers/moe/moe_runner/triton.py b/python/sglang/srt/layers/moe/moe_runner/triton.py index 840d70b0ba3d..96e431d4e385 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton.py @@ -126,6 +126,7 @@ def run( gemm1_limit=self.config.gemm1_clamp_limit, filter_expert=filter_expert, hooks=hooks, + swiglu_limit=self.config.swiglu_limit, ) return TritonRunnerOutput(hidden_states=out) diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..eacde3f6b8fb --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..b60f7dc039df --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py index eded81834bf5..e953615bb132 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe.py @@ -11,6 +11,7 @@ import torch.nn.functional as F import triton.language as tl +from sglang.srt.environ import envs from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.utils import get_moe_padding_size from sglang.srt.server_args import get_global_server_args @@ -28,6 +29,7 @@ from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config from .fused_moe_triton_kernels import ( + act_and_mul_triton, invoke_fused_moe_kernel, moe_sum_reduce_triton, support_tensor_descriptor, @@ -112,6 +114,7 @@ def inplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ) -> None: fused_experts_impl( hidden_states, @@ -142,6 +145,7 @@ def inplace_fused_experts( gemm1_alpha, gemm1_limit, filter_expert, + swiglu_limit=swiglu_limit, ) @@ -174,6 +178,7 @@ def outplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -204,6 +209,7 @@ def outplace_fused_experts( gemm1_alpha=gemm1_alpha, gemm1_limit=gemm1_limit, filter_expert=filter_expert, + swiglu_limit=swiglu_limit, ) @@ -262,6 +268,7 @@ def fused_experts( moe_runner_config.gemm1_alpha, moe_runner_config.gemm1_clamp_limit, filter_expert, + swiglu_limit=moe_runner_config.swiglu_limit, ) return hidden_states else: @@ -293,6 +300,7 @@ def fused_experts( gemm1_alpha=moe_runner_config.gemm1_alpha, gemm1_limit=moe_runner_config.gemm1_clamp_limit, filter_expert=filter_expert, + swiglu_limit=moe_runner_config.swiglu_limit, ) @@ -425,6 +433,7 @@ def _fused_moe_kernel_sequence( gemm1_limit: Optional[float], filter_expert: bool, hooks: Optional[Any] = None, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: """Run the MoE kernel/activation/kernel/combine sequence in a single shot. @@ -519,6 +528,7 @@ def _fused_moe_kernel_sequence( if activation == "silu" and is_gated: # - gemm1_alpha != None: GPT-OSS-style swiglu(alpha, limit) # - gemm1_alpha == None and gemm1_limit != None: silu+clamp+mul(limit-only) + # - swiglu_limit != None: DeepSeek V4 swiglu clamp + silu_and_mul (CUDA/HIP only) if gemm1_alpha is not None: assert gemm1_limit is not None intermediate_cache2 = swiglu_gpt_oss_sigmoid_alpha( @@ -528,6 +538,55 @@ def _fused_moe_kernel_sequence( intermediate_cache2 = _swiglu_silu_clamp_mul( intermediate_cache1.view(-1, N), gemm1_limit ) + elif swiglu_limit is not None: + # DeepSeek V4: swiglu clamp before silu_and_mul. + # Two paths gated by SGLANG_OPT_SWIGLU_CLAMP_FUSION: + # fusion=True: clamp fused into act_and_mul_triton or silu_and_mul_clamp + # fusion=False: explicit clamp_ on intermediate_cache1 (path checker) + assert swiglu_limit == 10 + assert intermediate_cache1.shape == (total_tokens, N) + assert _is_cuda or _is_hip, "DeepSeek V4 only supports CUDA/HIP downstream" + + swiglu_limit_for_triton: Optional[float] = None + swiglu_limit_for_silu_and_mul_clamp: Optional[float] = None + + if envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get(): + if filter_expert: + swiglu_limit_for_triton = swiglu_limit + else: + assert ( + _is_cuda + ), "fused silu_and_mul_clamp kernel is CUDA-only; HIP must disable SWIGLU_CLAMP_FUSION" + swiglu_limit_for_silu_and_mul_clamp = swiglu_limit + else: + half = N // 2 + intermediate_cache1[:, :half].clamp_(max=swiglu_limit) + intermediate_cache1[:, half:].clamp_( + min=-swiglu_limit, max=swiglu_limit + ) + + if not filter_expert: + if swiglu_limit_for_silu_and_mul_clamp is not None: + from sglang.jit_kernel.deepseek_v4 import silu_and_mul_clamp + + silu_and_mul_clamp( + intermediate_cache1.view(-1, N), + intermediate_cache2, + swiglu_limit_for_silu_and_mul_clamp, + ) + else: + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + act_and_mul_triton( + intermediate_cache1.view(-1, N), + intermediate_cache2, + config, + topk_ids, + expert_ids, + down_moe_use_tma, + activation, + swiglu_limit=swiglu_limit_for_triton, + ) elif _is_cuda or _is_hip or _is_xpu: if filter_expert and _is_cuda: # HIP/XPU fall through to the unfiltered path: the down kernel @@ -753,6 +812,7 @@ def fused_experts_impl( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ): padded_size = padding_size if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: @@ -827,6 +887,7 @@ def fused_experts_impl( gemm1_limit=gemm1_limit, filter_expert=filter_expert, hooks=None, + swiglu_limit=swiglu_limit, ) diff --git a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py index d02a29762ea9..e50b9ed3cf1e 100644 --- a/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/moe_runner/triton_utils/fused_moe_triton_kernels.py @@ -930,6 +930,124 @@ def invoke_fused_moe_kernel( ) +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _apply_activation(x, ACTIVATION_TYPE: tl.constexpr): + """ + Apply activation function based on compile-time constant. + + Args: + x: Input tensor (converted to float32 inside) + ACTIVATION_TYPE: Compile-time constant string ("silu" or "gelu") + + Returns: + Activated output in the same dtype as input + """ + x = x.to(tl.float32) + if ACTIVATION_TYPE == "silu": + return x * tl.sigmoid(x) + elif ACTIVATION_TYPE == "gelu": + kAlpha = 0.7978845608028654 + return 0.5 * x * (1 + tanh(kAlpha * (x + 0.044715 * x * x * x))) + else: + raise ValueError(f"Unsupported activation: {ACTIVATION_TYPE}") + + +@triton.jit +def act_and_mul_kernel( + gateup_output, + down_input, + hidden_size, + expert_ids_ptr, + expert_step: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + ACTIVATION_TYPE: tl.constexpr, + SWIGLU_LIMIT: tl.constexpr = 0.0, + HAS_SWIGLU_LIMIT: tl.constexpr = False, +): + """ + Unified activation and multiply kernel that handles both sorted and unsorted routing, + and both SiLU and GELU activations using compile-time constants. + """ + InDtype = gateup_output.dtype.element_ty + OutDtype = down_input.dtype.element_ty + + half_hidden_size = hidden_size // 2 + pid = tl.program_id(0) + + expert_id = tl.load(expert_ids_ptr + pid // expert_step) + + if expert_id == -1: + return + + gateup_output_ptr = gateup_output + pid * hidden_size + down_input_ptr = down_input + pid * half_hidden_size + gate_output_ptr = gateup_output_ptr + up_output_ptr = gateup_output_ptr + half_hidden_size + + for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < half_hidden_size + + gate_output = tl.load(gate_output_ptr + offset, mask=mask) + up_output = tl.load(up_output_ptr + offset, mask=mask) + + if HAS_SWIGLU_LIMIT: + gate_output = tl.minimum(gate_output, SWIGLU_LIMIT) + up_output = tl.maximum(tl.minimum(up_output, SWIGLU_LIMIT), -SWIGLU_LIMIT) + + gate_output_activated = _apply_activation(gate_output, ACTIVATION_TYPE) + gate_output_activated = gate_output_activated.to(InDtype) + + act_mul_output = gate_output_activated * up_output + act_mul_output = act_mul_output.to(OutDtype) + tl.store(down_input_ptr + offset, act_mul_output, mask=mask) + + +def act_and_mul_triton( + gateup_output: torch.Tensor, + down_input: torch.Tensor, + config: Dict[str, Any], + topk_ids: Optional[torch.Tensor] = None, + expert_ids: Optional[torch.Tensor] = None, + down_moe_use_tma: bool = False, + activation: str = "silu", + swiglu_limit: Optional[float] = None, +) -> None: + """ + Args: + gateup_output: Input tensor containing gate and up outputs concatenated + down_input: Output tensor for the result + config: Configuration dictionary with BLOCK_SIZE_M and BLOCK_SIZE_N + topk_ids: Expert IDs for unsorted routing (used when down_moe_use_tma=False) + expert_ids: Expert IDs for sorted routing (used when down_moe_use_tma=True) + down_moe_use_tma: Whether to use sorted routing layout + activation: Activation type ("silu" or "gelu") + swiglu_limit: if not None, clamp gate to [-inf, L] and up to [-L, L] before activation + (compiles a separate kernel variant via tl.constexpr). + """ + grid = (down_input.shape[0],) + hidden_size = gateup_output.shape[1] + expert_ids_row = topk_ids.view(-1) if not down_moe_use_tma else expert_ids + expert_step = 1 if not down_moe_use_tma else config["BLOCK_SIZE_M"] + has_swiglu_limit = swiglu_limit is not None + act_and_mul_kernel[grid]( + gateup_output, + down_input, + hidden_size, + expert_ids_row, + expert_step, + BLOCK_SIZE=512, + ACTIVATION_TYPE=activation, + SWIGLU_LIMIT=float(swiglu_limit) if has_swiglu_limit else 0.0, + HAS_SWIGLU_LIMIT=has_swiglu_limit, + ) + + # _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py @triton.jit def _moe_sum_reduce_kernel( diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index 721cbe629296..35ee82fed85c 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -88,15 +88,16 @@ def __init__(self, moe_runner_config: MoeRunnerConfig): self.moe_ep_size = get_moe_expert_parallel_world_size() backend = get_moe_runner_backend() self.enable_flashinfer_cutlass_moe = backend.is_flashinfer_cutlass() - # FlashInfer CUTLASS and CuteDSL handle EP internally with global expert IDs. - # Skip local expert mapping so topk_ids stay in global space. + self.enable_flashinfer_mxfp4_moe = backend.is_flashinfer_mxfp4() + self.enable_flashinfer_trtllm_routed_moe = backend.is_flashinfer_trtllm_routed() + # Skip local expert mapping when the backend handles EP with global expert IDs: + # - cutlass / cutedsl / trtllm_routed handle EP internally + # - mxfp4 dispatcher mapping is already global self.skip_local_expert_mapping = ( backend.is_flashinfer_cutlass() or backend.is_flashinfer_cutedsl() or backend.is_flashinfer_trtllm_routed() - ) - self.enable_flashinfer_trtllm_routed_moe = ( - get_moe_runner_backend().is_flashinfer_trtllm_routed() + or self.enable_flashinfer_mxfp4_moe ) self.num_experts = moe_runner_config.num_experts self.num_local_experts = moe_runner_config.num_local_experts @@ -189,7 +190,7 @@ def dispatch( ) ) - if self.local_expert_mapping is not None: + if self.local_expert_mapping is not None and not self.skip_local_expert_mapping: if _use_aiter: self.expert_mask_gpu = ( ( diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 2829a9df5eea..b5663e44be18 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -36,6 +36,7 @@ except ImportError: pass +from sglang.jit_kernel.deepseek_v4 import mask_topk_ids from sglang.srt.distributed import ( get_moe_expert_parallel_rank, get_moe_expert_parallel_world_size, @@ -44,6 +45,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.environ import envs from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( @@ -272,6 +274,7 @@ def __init__( apply_routed_scaling_factor_on_output: Optional[bool] = False, output_format: Optional[TopKOutputFormat] = None, fused_shared_experts_scaling_factor: Optional[float] = None, + is_fp4_experts: bool = False, ): # NOTE: scoring_func is not used for now, but we keep it for future use # see https://github.com/sgl-project/sglang/pull/4505 for more details @@ -281,6 +284,9 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.layer_id = layer_id + # flashinfer_mxfp4 backend only: True -> STANDARD (Mxfp4FlashinferTrtllmMoEMethod + # consumes), False -> BYPASSED (flashinfer's own mxfp4 kernel). No-op otherwise. + self.is_fp4_experts = is_fp4_experts self.topk_config = TopKConfig( top_k=top_k, use_grouped_topk=use_grouped_topk, @@ -327,9 +333,8 @@ def forward_cuda( output_format = self.topk_config.output_format elif get_moe_runner_backend().is_triton_kernels(): output_format = TopKOutputFormat.TRITON_KERNEL - elif ( - get_moe_runner_backend().is_flashinfer_trtllm() - or get_moe_runner_backend().is_flashinfer_mxfp4() + elif get_moe_runner_backend().is_flashinfer_trtllm() or ( + get_moe_runner_backend().is_flashinfer_mxfp4() and not self.is_fp4_experts ): output_format = TopKOutputFormat.BYPASSED else: @@ -699,6 +704,101 @@ def kimi_k2_biased_topk_impl( return topk_weights, topk_ids +@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu) +def biased_topk_impl( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + if scoring_func == "sigmoid": + scores = gating_output.sigmoid() + elif scoring_func == "sqrtsoftplus": + scores = torch.nn.functional.softplus(gating_output).sqrt() + + num_token = scores.shape[0] + num_experts = scores.shape[1] + + scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) + _, topk_ids = torch.topk( + scores_for_choice, + k=topk, + dim=-1, + sorted=(True if num_fused_shared_experts > 0 else False), + ) + topk_weights = scores.gather(1, topk_ids) + + if num_fused_shared_experts: + topk_ids[:, -1] = torch.randint( + low=num_experts, + high=num_experts + num_fused_shared_experts, + size=(topk_ids.size(0),), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + if routed_scaling_factor is not None: + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor + ) + + if renormalize: + topk_weights_sum = ( + topk_weights.sum(dim=-1, keepdim=True) + if num_fused_shared_experts == 0 + else topk_weights[:, :-1].sum(dim=-1, keepdim=True) + ) + topk_weights = topk_weights / topk_weights_sum + if apply_routed_scaling_factor_on_output: + topk_weights *= routed_scaling_factor + + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids + + +def biased_topk_jit_kernel_impl( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + from sglang.jit_kernel.moe_fused_gate import moe_fused_gate + + topk_weights, topk_ids = moe_fused_gate( + gating_output, + correction_bias, + topk=topk, + scoring_func=scoring_func, + num_fused_shared_experts=num_fused_shared_experts, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids + + @torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu) def biased_grouped_topk_impl( hidden_states: torch.Tensor, @@ -779,11 +879,15 @@ def is_power_of_two(n): def _mask_topk_ids_padded_region( topk_ids: torch.Tensor, num_token_non_padded: Optional[torch.Tensor] = None, -): +) -> None: if num_token_non_padded is None: return - indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device) - topk_ids[indices >= num_token_non_padded, :] = -1 + # TODO: let the kernel support other dtypes + if _is_cuda and topk_ids.dtype == torch.int32: + mask_topk_ids(topk_ids, num_token_non_padded) + else: + indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device) + topk_ids[indices >= num_token_non_padded, :] = -1 @torch.compile(dynamic=True, backend=get_compiler_backend()) @@ -1209,7 +1313,27 @@ def select_experts( ) elif custom_routing_function is None: assert not apply_routed_scaling_factor_on_output, "Not implemented" - if ( + if scoring_func == "sqrtsoftplus": + _biased_topk = ( + biased_topk_jit_kernel_impl + if envs.SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK.get() + else biased_topk_impl + ) + + topk_weights, topk_ids = _biased_topk( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=num_routed_topk if _use_aiter else top_k, + renormalize=renormalize, + scoring_func=scoring_func, + num_fused_shared_experts=num_fused_shared_experts, + routed_scaling_factor=routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + elif ( get_moe_runner_backend().is_flashinfer_trtllm_routed() and scoring_func == "softmax" and correction_bias is None diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index ff0deb03e5cf..7f766dcda3c4 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -34,6 +34,7 @@ def _dtype_rank(dtype: torch.dtype) -> Optional[int]: torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, ): return 0 if dtype in (torch.float16, torch.bfloat16): @@ -70,6 +71,8 @@ def copy_with_check(target: torch.Tensor, loaded_weight: torch.Tensor): raise ValueError( f"Downcasting not allowed: {target.dtype=}, {loaded_weight.dtype=}" ) + if loaded_rank == torch.float8_e8m0fnu: + assert target_rank in {torch.float8_e8m0fnu, torch.float32} target.copy_(loaded_weight) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 8056802b1969..291e2c2de3bd 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -136,8 +136,12 @@ def __init__( weight_block_size: List[int] = None, packed_modules_mapping: Optional[Dict[str, List[str]]] = None, use_mxfp8: bool = False, + is_fp4_experts: bool = False, ) -> None: super().__init__() + # DSV4 mxfp4-packed (True) vs converted FP8 (False); injected by + # model_loader from ModelConfig. Default False off the DSV4 path. + self.is_fp4_experts = is_fp4_experts self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: log_info_on_rank0(logger, "Detected fp8 checkpoint.") @@ -247,7 +251,23 @@ def get_quant_method( return UnquantizedFusedMoEMethod( layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe ) - return Fp8MoEMethod(self) + + fp8_method = Fp8MoEMethod(self) + + if self.is_fp4_experts and get_moe_runner_backend().is_marlin(): + from sglang.srt.layers.quantization.mxfp4_marlin_moe import ( + Mxfp4MarlinMoEMethod, + ) + + return Mxfp4MarlinMoEMethod(fp8_method, prefix=prefix) + + if self.is_fp4_experts and get_moe_runner_backend().is_flashinfer_mxfp4(): + from sglang.srt.layers.quantization.mxfp4_flashinfer_trtllm_moe import ( + Mxfp4FlashinferTrtllmMoEMethod, + ) + + return Mxfp4FlashinferTrtllmMoEMethod(fp8_method, prefix=prefix) + return fp8_method elif isinstance(layer, RadixAttention): return Fp8KVCacheMethod(self) return None @@ -796,6 +816,7 @@ def __init__(self, quant_config: Fp8Config): self.block_quant = ( self.use_mxfp8 or self.quant_config.weight_block_size is not None ) + self.is_fp4_expert = self.quant_config.is_fp4_experts self.with_bias = False if get_moe_runner_backend().is_cutlass(): assert ( @@ -873,7 +894,26 @@ def create_weights( ) # WEIGHTS - if _is_hip and _use_hip_int4: + if self.is_fp4_expert: + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + elif _is_hip and _use_hip_int4: # INT4 MoE weight - INT32 packed w13_weight = torch.nn.Parameter( torch.empty( @@ -945,7 +985,29 @@ def create_weights( set_weight_attrs(w2_weight_bias, extra_weight_attrs) # WEIGHT_SCALES - if self.block_quant: + if self.is_fp4_expert: + fp4_block_k = 32 + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + elif self.block_quant: scale_dtype = torch.uint8 if self.use_mxfp8 else torch.float32 scale_init = torch.zeros if scale_dtype == torch.uint8 else torch.ones w13_weight_scale = torch.nn.Parameter( @@ -1102,6 +1164,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: ) else: # For fp8 moe run with deepgemm, the expert weights and scales need be requantized to ue8m0 + from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.model_loader.utils import ( should_deepgemm_weight_requant_ue8m0, @@ -1110,8 +1173,46 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: # Check if MoE will actually use DeepGEMM runner will_use_deepgemm = self.is_deepgemm_moe_runner_backend_enabled() + if self.is_fp4_expert: + if get_moe_runner_backend().is_marlin(): + layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) + layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) + return + + layer.w13_weight.data = layer.w13_weight.data.view(torch.int8) + layer.w2_weight.data = layer.w2_weight.data.view(torch.int8) + + if envs.SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE.get(): + from sglang.srt.layers.moe.mega_moe import ( + build_mega_moe_experts_weights, + ) + + build_mega_moe_experts_weights(layer) + return + + if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 and will_use_deepgemm: + from deep_gemm import transform_sf_into_required_layout + + for scale_param, weight_param in [ + (layer.w13_weight_scale_inv, layer.w13_weight), + (layer.w2_weight_scale_inv, layer.w2_weight), + ]: + num_experts, n, _ = scale_param.data.shape + k = weight_param.shape[2] * 2 + scale_param.data = transform_sf_into_required_layout( + scale_param.data, + mn=n, + k=k, + recipe=(1, 32), + num_groups=num_experts, + disable_ue8m0_cast=False, + ) + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True + if ( - should_deepgemm_weight_requant_ue8m0( + not self.is_fp4_expert + and should_deepgemm_weight_requant_ue8m0( weight_block_size=getattr( self.quant_config, "weight_block_size", None ), @@ -1690,6 +1791,7 @@ def apply( w13_scale=w13_scale, w2_scale=w2_scale, block_shape=block_shape, + is_fp4_experts=self.is_fp4_expert, ) elif ( self.runner.runner_backend.is_flashinfer_trtllm() diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index ac3498657a58..2a36ea8e6b41 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -547,6 +547,51 @@ def sglang_per_token_group_quant_fp8( return x_q, x_s +def sglang_per_token_group_quant_fp8_ue8m0( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert ( + x.shape[-1] % group_size == 0 + ), f"hidden ({x.shape[-1]}) must be divisible by group_size ({group_size})" + assert x.is_contiguous(), "x must be contiguous" + assert enable_sgl_per_token_group_quant_8bit, ( + "sgl_per_token_group_quant_8bit is required (v2 kernel supports " + "group_size in {16, 32, 64, 128})" + ) + + *x_batch, x_q_mn, x_q_k = x.shape + x_q = torch.empty(x.shape, device=x.device, dtype=fp8_dtype) + + x_s_mn = x_q_mn + x_s_k = x_q_k // group_size + aligned_mn = ceil_align(x_s_mn, 4) + aligned_k = ceil_align(x_s_k, 4) + x_s = torch.empty( + (*x_batch, aligned_k // 4, aligned_mn), + device=x.device, + dtype=torch.int, + ).transpose(-1, -2)[..., :x_s_mn, :] + + if x.shape[0] > 0: + sgl_per_token_group_quant_8bit( + x, + x_q, + x_s, + group_size, + eps, + fp8_min, + fp8_max, + True, # scale_ue8m0 + False, # fuse_silu_and_mul + None, # masked_m + enable_v2=True, + ) + + return x_q, x_s + + # TODO maybe unify int8 and fp8 code later def sglang_per_token_group_quant_8bit( x: torch.Tensor, @@ -1015,8 +1060,25 @@ def get_w8a8_block_fp8_configs( logger, f"Using configuration from {config_file_path} for W8A8 Block FP8 kernel.", ) - # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + raw = {int(key): val for key, val in json.load(f).items()} + + sanitized = {} + clamped_ms = [] + for m_key, cfg in raw.items(): + if cfg["BLOCK_SIZE_K"] < block_k: + clamped_ms.append((m_key, cfg["BLOCK_SIZE_K"])) + cfg = {**cfg, "BLOCK_SIZE_K": block_k} + sanitized[m_key] = cfg + if clamped_ms: + logger.warning( + "Clamped BLOCK_SIZE_K up to %d in tuned config %s for entries %s " + "(scale stepping requires BLOCK_SIZE_K >= block_k).", + block_k, + json_file_name, + clamped_ms, + ) + + return sanitized # If no optimized configuration is available, we will use the default # configuration diff --git a/python/sglang/srt/layers/quantization/kv_cache.py b/python/sglang/srt/layers/quantization/kv_cache.py index ef7cbe74efb0..9866530a109c 100644 --- a/python/sglang/srt/layers/quantization/kv_cache.py +++ b/python/sglang/srt/layers/quantization/kv_cache.py @@ -40,6 +40,8 @@ def create_weights(self, layer: torch.nn.Module): layer.v_scale = torch.nn.Parameter( torch.tensor(-1.0, dtype=torch.float32), requires_grad=False ) + layer.k_scale._skip_weight_check = True + layer.v_scale._skip_weight_check = True def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp4.py b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py new file mode 100644 index 000000000000..11a664c88be9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp4.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, +) +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() + +if _is_cuda: + from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack + + +def mxfp4_marlin_process_scales( + marlin_scales: torch.Tensor, + input_dtype: torch.dtype | None = None, +) -> torch.Tensor: + if input_dtype is None or input_dtype.itemsize == 2: + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1 + ) + marlin_scales = marlin_scales.to(torch.float8_e8m0fnu) + if input_dtype == torch.float8_e4m3fn: + marlin_scales = marlin_scales.view(torch.uint8) + assert marlin_scales.max() <= 249 + # exponent_bias (fp4->fp8) = 2 ** 3 - 2 ** 1 = 6 + marlin_scales = marlin_scales + 6 + marlin_scales = marlin_scales.view(torch.float8_e8m0fnu) + return marlin_scales + + +def _normalize_scale_tensor( + scales: torch.Tensor, target_dtype: torch.dtype +) -> torch.Tensor: + # The kernel consumes E8M0 exponents. Regardless of the placeholder dtype + # the loader used, we want the *numerical* value 2**e in ``target_dtype``. + # float32/bfloat16/float16 containers hold the numerical 2**e directly + # (they were filled via a dtype-promoting copy from uint8/e8m0). + # uint8/int8 containers hold the raw E8M0 byte and must be reinterpreted. + if scales.dtype == torch.float8_e8m0fnu: + return scales.to(target_dtype) + if scales.dtype == torch.uint8: + return scales.view(torch.float8_e8m0fnu).to(target_dtype) + if scales.dtype == torch.int8: + return scales.view(torch.uint8).view(torch.float8_e8m0fnu).to(target_dtype) + if scales.dtype in (torch.float32, torch.bfloat16, torch.float16): + return scales.to(target_dtype) + raise TypeError(f"Unsupported MXFP4 scale dtype for Marlin: {scales.dtype}") + + +def prepare_moe_mxfp4_layer_for_marlin(layer: torch.nn.Module) -> None: + group_size = 32 + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + w13_bias = getattr(layer, "w13_bias", None) + w2_bias = getattr(layer, "w2_bias", None) + + num_experts = w13.shape[0] + intermediate_size = w13.shape[1] // 2 + hidden_size = w13.shape[2] * 2 + param_dtype = getattr( + layer, + "orig_dtype", + w13_bias.dtype if w13_bias is not None else torch.bfloat16, + ) + + device = w13.device + layer.workspace = marlin_make_workspace(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + def _repack_weight(weight: torch.Tensor, is_w13: bool) -> torch.Tensor: + if is_w13: + size_n, size_k = intermediate_size * 2, hidden_size + else: + size_n, size_k = hidden_size, intermediate_size + assert weight.shape == (num_experts, size_n, size_k // 2) + + tensor_list = [] + for i in range(num_experts): + qweight = weight[i].view(torch.int32).T.contiguous() + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + tensor_list.append(marlin_qweight) + return torch.stack(tensor_list) + + def _permute_scales(scales: torch.Tensor, is_w13: bool) -> torch.Tensor: + scales = _normalize_scale_tensor(scales, param_dtype) + + if is_w13: + size_n, size_k = intermediate_size * 2, hidden_size + else: + size_n, size_k = hidden_size, intermediate_size + + tensor_list = [] + for i in range(num_experts): + scale = scales[i].T.contiguous() + marlin_scales = marlin_permute_scales( + s=scale, + size_k=size_k, + size_n=size_n, + group_size=group_size, + ) + tensor_list.append( + mxfp4_marlin_process_scales( + marlin_scales, + input_dtype=param_dtype, + ) + ) + return torch.stack(tensor_list) + + def _permute_bias(bias: torch.Tensor | None) -> torch.Tensor | None: + if bias is None: + return None + tensor_list = [] + for i in range(num_experts): + tensor_list.append(marlin_permute_bias(bias[i].to(param_dtype))) + return torch.stack(tensor_list) + + w13_marlin = _repack_weight(w13, True) + w2_marlin = _repack_weight(w2, False) + w13_scale_marlin = _permute_scales(w13_scale, True) + w2_scale_marlin = _permute_scales(w2_scale, False) + + layer.w13_weight = torch.nn.Parameter(w13_marlin, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_marlin, requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_scale_marlin, requires_grad=False + ) + layer.w2_weight_scale_inv = torch.nn.Parameter(w2_scale_marlin, requires_grad=False) + + if w13_bias is not None: + layer.w13_bias = torch.nn.Parameter( + _permute_bias(w13_bias), requires_grad=False + ) + if w2_bias is not None: + layer.w2_bias = torch.nn.Parameter(_permute_bias(w2_bias), requires_grad=False) diff --git a/python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py b/python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py new file mode 100644 index 000000000000..dc398f491905 --- /dev/null +++ b/python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py @@ -0,0 +1,461 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +import triton +import triton.language as tl +from torch.nn import Module +from torch.nn.parameter import Parameter + +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.layers.moe.utils import RoutingMethodType +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ( + is_flashinfer_available, + log_info_on_rank0, + set_weight_attrs, +) +from sglang.srt.utils.common import next_power_of_2 + +if is_flashinfer_available(): + from flashinfer import mxfp8_quantize, shuffle_matrix_a, shuffle_matrix_sf_a + from flashinfer.fp4_quantization import block_scale_interleave + from flashinfer.fused_moe import trtllm_fp4_block_scale_routed_moe + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput + +from sglang.srt.utils.common import get_bool_env_var + +_USE_OFFICIAL_SHUFFLE = get_bool_env_var( + "SGLANG_MXFP4_USE_OFFICIAL_SHUFFLE", default="true" +) + + +class PackTopkIds: + + @classmethod + def execute( + cls, topk_ids: torch.Tensor, topk_weights: torch.Tensor + ) -> torch.Tensor: + return cls.triton(topk_ids, topk_weights) + + @classmethod + def vanilla( + cls, topk_ids: torch.Tensor, topk_weights: torch.Tensor + ) -> torch.Tensor: + weight_bits = ( + topk_weights.to(torch.bfloat16).view(torch.int16).to(torch.int32) & 0xFFFF + ) + return (topk_ids.to(torch.int32) << 16) | weight_bits + + @classmethod + def triton(cls, topk_ids: torch.Tensor, topk_weights: torch.Tensor) -> torch.Tensor: + assert ( + topk_ids.shape == topk_weights.shape + ), f"shape mismatch: {topk_ids.shape=} vs {topk_weights.shape=}" + assert topk_ids.ndim >= 1, f"expected >=1D, got {topk_ids.shape=}" + + assert ( + topk_ids.dtype == torch.int32 + ), f"topk_ids must be int32, got {topk_ids.dtype}" + assert ( + topk_weights.dtype == torch.float32 + ), f"topk_weights must be float32, got {topk_weights.dtype}" + + assert topk_ids.is_contiguous(), "topk_ids must be contiguous" + assert topk_weights.is_contiguous(), "topk_weights must be contiguous" + + out = torch.empty_like(topk_ids, dtype=torch.int32) + numel = out.numel() + if numel == 0: + return out + + BLOCK_SIZE = 1024 + grid = (triton.cdiv(numel, BLOCK_SIZE),) + _pack_topk_ids_triton_kernel[grid]( + topk_ids, + topk_weights, + out, + numel, + BLOCK_SIZE=BLOCK_SIZE, + ) + return out + + +@triton.jit +def _pack_topk_ids_triton_kernel( + topk_ids_ptr, + topk_weights_ptr, + out_ptr, + numel, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < numel + + ids = tl.load(topk_ids_ptr + offsets, mask=mask, other=0) + w = tl.load(topk_weights_ptr + offsets, mask=mask, other=0.0) + + w_bf16 = w.to(tl.bfloat16) + w_i16 = w_bf16.to(tl.int16, bitcast=True) + w_i32 = w_i16.to(tl.int32) & 0xFFFF + + ids_i32 = ids.to(tl.int32) + packed = (ids_i32 << 16) | w_i32 + + tl.store(out_ptr + offsets, packed, mask=mask) + + +class Mxfp4FlashinferTrtllmMoEMethod: + + def __init__(self, fp8_method, prefix: str): + self._fp8 = fp8_method + self.prefix = prefix + self.flashinfer_mxfp4_moe_precision = ( + get_global_server_args().flashinfer_mxfp4_moe_precision + ) + + def create_moe_runner(self, layer, moe_runner_config): + self.moe_runner_config = moe_runner_config + + swiglu_limit = moe_runner_config.swiglu_limit + assert ( + swiglu_limit is not None + ), f"swiglu_limit must be non-None for DeepSeek V4 (got {swiglu_limit!r})" + self._gemm1_clamp_limit_tensor = ( + torch.full( + (layer.num_local_experts,), + swiglu_limit, + dtype=torch.float32, + device=layer.w13_weight.device, + ) + if swiglu_limit is not None + else None + ) + + def create_weights( + self, + layer, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + fp4_block_k = 32 + + w13_weight = Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + w2_weight = Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_weight_scale = Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w13_weight_scale.format_ue8m0 = False + w2_weight_scale.format_ue8m0 = False + scale_attrs = dict(extra_weight_attrs) + scale_attrs["quant_method"] = FusedMoeWeightScaleSupported.BLOCK.value + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + set_weight_attrs(w13_weight_scale, scale_attrs) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + set_weight_attrs(w2_weight_scale, scale_attrs) + + def process_weights_after_loading(self, layer: Module) -> None: + from sglang.srt.layers.quantization.utils import reorder_w1w3_to_w3w1 + + self._fp8.process_weights_after_loading(layer) + + if getattr(layer, "_mega_moe_weights_built", False): + return + + w13_w, w13_s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale_inv.data + ) + layer.w13_weight = Parameter(w13_w, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_s, requires_grad=False) + + log_info_on_rank0( + logger, + f"Shuffling FP4 expert weights for TRT-LLM MxFP4 kernel " + f"(layer: {self.prefix})...", + ) + + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + num_experts = w13.shape[0] + + if w13_scale.dtype == torch.float32: + w13_scale = w13_scale.to(torch.float8_e8m0fnu) + w2_scale = w2_scale.to(torch.float8_e8m0fnu) + + epilogue_tile_m = 128 + g1_w, g1_s, g2_w, g2_s = [], [], [], [] + if _USE_OFFICIAL_SHUFFLE: + cache: dict = {} + for i in range(num_experts): + w13_u8 = w13[i].view(torch.uint8) + w13_s_u8 = w13_scale[i].view(torch.uint8) + w2_u8 = w2[i].view(torch.uint8) + w2_s_u8 = w2_scale[i].view(torch.uint8) + + perm = _maybe_get_cached_w3_w1_permute_indices( + cache, + w13_u8, + epilogue_tile_m, + ) + g1_w.append(w13_u8[perm.to(w13_u8.device)].contiguous()) + perm_sf = _maybe_get_cached_w3_w1_permute_indices( + cache, + w13_s_u8, + epilogue_tile_m, + num_elts_per_sf=16, + ) + g1_s.append( + block_scale_interleave( + w13_s_u8[perm_sf.to(w13_s_u8.device)].contiguous() + ) + ) + + perm = get_w2_permute_indices_with_cache( + cache, + w2_u8, + epilogue_tile_m, + ) + g2_w.append(w2_u8[perm.to(w2_u8.device)].contiguous()) + perm_sf = get_w2_permute_indices_with_cache( + cache, + w2_s_u8, + epilogue_tile_m, + num_elts_per_sf=16, + ) + g2_s.append( + block_scale_interleave( + w2_s_u8[perm_sf.to(w2_s_u8.device)].contiguous() + ) + ) + else: + for i in range(num_experts): + g1_w.append(shuffle_matrix_a(w13[i].view(torch.uint8), epilogue_tile_m)) + g1_s.append( + shuffle_matrix_sf_a(w13_scale[i].view(torch.uint8), epilogue_tile_m) + ) + g2_w.append(shuffle_matrix_a(w2[i].view(torch.uint8), epilogue_tile_m)) + g2_s.append( + shuffle_matrix_sf_a(w2_scale[i].view(torch.uint8), epilogue_tile_m) + ) + + layer.w13_weight = Parameter(torch.stack(g1_w), requires_grad=False) + layer.w13_weight_scale_inv = Parameter( + torch.stack(g1_s) + .view(torch.float8_e4m3fn) + .reshape(num_experts, w13.shape[1], -1), + requires_grad=False, + ) + layer.w2_weight = Parameter(torch.stack(g2_w), requires_grad=False) + layer.w2_weight_scale_inv = Parameter( + torch.stack(g2_s) + .view(torch.float8_e4m3fn) + .reshape(num_experts, w2.shape[1], -1), + requires_grad=False, + ) + + self._register_static_scale_ones(layer) + torch.cuda.empty_cache() + + def _register_static_scale_ones(self, layer: Module) -> None: + device = layer.w13_weight.device + for name in ( + "output1_scale_scalar", + "output1_scale_gate_scalar", + "output2_scale_scalar", + ): + layer.register_buffer( + name, + torch.ones(layer.num_local_experts, device=device, dtype=torch.float32), + persistent=False, + ) + + def apply( + self, + layer: Module, + dispatch_output: DispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + from sglang.srt.layers.moe.topk import TopKOutputChecker + + hidden_states = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + w13 = layer.w13_weight + w2 = layer.w2_weight + w13_scale = layer.w13_weight_scale_inv + w2_scale = layer.w2_weight_scale_inv + + intermediate_size = w2.shape[2] * 2 if w2.dtype == torch.uint8 else w2.shape[2] + hidden_size = w13.shape[2] * 2 if w13.dtype == torch.uint8 else w13.shape[2] + + num_local_experts = layer.num_local_experts + if w13_scale.dim() == 2: + w13_scale = w13_scale.reshape(num_local_experts, 2 * intermediate_size, -1) + if w2_scale.dim() == 2: + w2_scale = w2_scale.reshape(num_local_experts, hidden_size, -1) + + if TopKOutputChecker.format_is_standard(topk_output): + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + elif TopKOutputChecker.format_is_bypassed(topk_output): + raise NotImplementedError( + "the old code in this branch is WRONG. e.g. it does not consider HashTopK, and may miss args" + ) + else: + raise ValueError(f"Unsupported topk output format: {topk_output.format}") + + packed_topk = PackTopkIds.execute(topk_ids, topk_weights) + + precision = self.flashinfer_mxfp4_moe_precision + if precision == "bf16": + assert hidden_states.dtype == torch.bfloat16 + x_quant = hidden_states + x_scale = None + origin_dim = x_quant.shape[-1] + if hidden_size != origin_dim: + x_quant = torch.nn.functional.pad( + x_quant, + (0, hidden_size - origin_dim), + mode="constant", + value=0.0, + ) + elif precision == "default": + x_quant, x_scale = mxfp8_quantize( + hidden_states, False, alignment=hidden_size + ) + x_scale = x_scale.view(torch.float8_e4m3fn).reshape( + *hidden_states.shape[:-1], -1 + ) + else: + raise NotImplementedError(f"Unsupported mxfp4 moe precision: {precision}") + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = x_quant.shape[0] + out_hidden_size = ( + x_quant.shape[-1] * 2 + if x_quant.dtype == torch.uint8 + else x_quant.shape[-1] + ) + symm_output = torch.empty( + num_tokens, out_hidden_size, dtype=torch.bfloat16, device=x_quant.device + ) + + output = trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_topk, + routing_bias=None, + hidden_states=x_quant, + hidden_states_scale=x_scale, + gemm1_weights=w13, + gemm1_weights_scale=w13_scale, + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=self._gemm1_clamp_limit_tensor, + gemm2_weights=w2, + gemm2_weights_scale=w2_scale, + gemm2_bias=None, + output1_scale_scalar=layer.output1_scale_scalar, + output1_scale_gate_scalar=layer.output1_scale_gate_scalar, + output2_scale_scalar=layer.output2_scale_scalar, + num_experts=layer.num_experts, + top_k=packed_topk.shape[1], + n_group=1, + topk_group=1, + intermediate_size=intermediate_size, + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=num_local_experts, + routed_scaling_factor=1.0, + routing_method_type=int(RoutingMethodType.TopK), + do_finalize=True, + tune_max_num_tokens=next_power_of_2(x_quant.shape[0]), + output=symm_output, + )[0] + + return StandardCombineInput(hidden_states=output) + + +def maybe_fuse_routed_scale_and_shared_add( + experts, + routed: torch.Tensor, + shared: torch.Tensor | None, + routed_scaling_factor: float, +) -> torch.Tensor: + # When MxFP4 fusion is on, the upstream `routed *= scale` is skipped and + # the scaling is folded into the shared-add via `shared.add_(routed, + # alpha=scale)`. With no shared output, the missing scale is applied + # in-place. Otherwise `routed` is already scale-final and we just add + # `shared` (or pass through if there is none). + from sglang.srt.layers.quantization.mxfp4_marlin_moe import ( + Mxfp4MarlinMoEMethod, + ) + + fused = isinstance( + experts.quant_method, (Mxfp4FlashinferTrtllmMoEMethod, Mxfp4MarlinMoEMethod) + ) + if fused: + if shared is not None: + return shared.add_(routed, alpha=routed_scaling_factor) + return routed.mul_(routed_scaling_factor) + if shared is not None: + routed += shared + return routed diff --git a/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py new file mode 100644 index 000000000000..90a3de66f4aa --- /dev/null +++ b/python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from torch.nn import Module + +from sglang.srt.layers.moe.moe_runner.marlin import MarlinMoeQuantInfo +from sglang.srt.layers.moe.utils import MoeRunnerBackend +from sglang.srt.utils import log_info_on_rank0 +from sglang.srt.utils.common import is_sm90_supported + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput + +logger = logging.getLogger(__name__) + + +class Mxfp4MarlinMoEMethod: + """MXFP4 (E8M0 scales) MoE quantization method using the Marlin backend.""" + + def __init__(self, fp8_method, prefix: str): + self._fp8 = fp8_method + self.prefix = prefix + + def create_moe_runner(self, layer, moe_runner_config): + from sglang.srt.layers.moe.moe_runner import MoeRunner + + self.runner = MoeRunner(MoeRunnerBackend.MARLIN, moe_runner_config) + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Delegate to the underlying FP8 method for weight creation — + # the raw weight shapes are the same; only post-loading processing differs. + self._fp8.create_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + + def process_weights_after_loading(self, layer: Module) -> None: + from sglang.srt.layers.quantization.marlin_utils import ( + check_moe_marlin_supports_layer, + ) + from sglang.srt.layers.quantization.marlin_utils_fp4 import ( + prepare_moe_mxfp4_layer_for_marlin, + ) + + # Let the FP8 base method handle ROCm normalization, etc. + self._fp8.process_weights_after_loading(layer) + + if getattr(layer, "_mega_moe_weights_built", False): + return + + if not is_sm90_supported(): + raise RuntimeError( + "DeepSeekV4 MXFP4 Marlin fallback requires Hopper/SM90 or above." + ) + if not check_moe_marlin_supports_layer(layer, 32): + raise RuntimeError( + "Current DeepSeekV4 MoE layer does not satisfy Marlin constraints." + ) + + # NOTE: the Marlin MoE runner consumes w13 in the checkpoint's + # native ``[w1; w3]`` order -- see ``silu_and_mul`` in + # fused_marlin_moe.py which expects ``gate = intermediate[:, :N]`` + # (first half) and ``up = intermediate[:, N:]`` (second half). + # Unlike the flashinfer trtllm_fp4 kernel (which wants [w3, w1]), + # we must *not* call ``reorder_w1w3_to_w3w1`` here. + + log_info_on_rank0( + logger, + f"Preparing DeepSeekV4 MXFP4 experts for Marlin backend " + f"(layer: {self.prefix})...", + ) + prepare_moe_mxfp4_layer_for_marlin(layer) + layer._dsv4_mxfp4_backend = "marlin" + + def apply( + self, + layer: Module, + dispatch_output: DispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput + from sglang.srt.layers.moe.topk import TopKOutputChecker + + topk_output = dispatch_output.topk_output + if not TopKOutputChecker.format_is_standard(topk_output): + raise ValueError(f"Unsupported topk output format: {topk_output.format}") + + quant_info = MarlinMoeQuantInfo( + w13_qweight=layer.w13_weight, + w2_qweight=layer.w2_weight, + w13_scales=layer.w13_weight_scale_inv, + w2_scales=layer.w2_weight_scale_inv, + w13_g_idx_sort_indices=None, + w2_g_idx_sort_indices=None, + weight_bits=4, + is_k_full=True, + ) + runner_output = self.runner.run(dispatch_output, quant_info=quant_info) + + return StandardCombineInput(hidden_states=runner_output.hidden_states) diff --git a/python/sglang/srt/layers/utils/common.py b/python/sglang/srt/layers/utils/common.py index b982b247e8a1..4c8080b0de80 100644 --- a/python/sglang/srt/layers/utils/common.py +++ b/python/sglang/srt/layers/utils/common.py @@ -38,6 +38,21 @@ def pad_or_narrow_weight( ) +def is_strict_contiguous(x: torch.Tensor) -> bool: + expected_stride = 1 + for size, stride in zip(reversed(x.shape), reversed(x.stride())): + if stride != expected_stride: + return False + expected_stride *= size + return True + + +def strict_contiguous(x: torch.Tensor) -> torch.Tensor: + if is_strict_contiguous(x): + return x + return x.clone(memory_format=torch.contiguous_format) + + def copy_or_rebind_param( module: torch.nn.Module, name: str, new_value: torch.Tensor ) -> None: diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py index 07c3aa7aea24..97c87f92699f 100644 --- a/python/sglang/srt/managers/hisparse_coordinator.py +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -1,12 +1,14 @@ # to be combined with the sparse coordinator class and sparse algorithm family import logging -from typing import List, NamedTuple +from typing import List, NamedTuple, Union import torch from sglang.srt.managers.schedule_batch import Req from sglang.srt.mem_cache.hisparse_memory_pool import ( + DeepSeekV4HiSparseTokenToKVPoolAllocator, + DeepSeekV4SingleKVPoolHost, HiSparseNSATokenToKVPool, HiSparseTokenToKVPoolAllocator, ) @@ -15,7 +17,10 @@ device_module = get_device_module() -from sglang.jit_kernel.hisparse import load_cache_to_device_buffer_mla +from sglang.jit_kernel.hisparse import ( + load_cache_to_device_buffer_dsv4_mla, + load_cache_to_device_buffer_mla, +) from sglang.srt.mem_cache.memory_pool import ReqToTokenPool logger = logging.getLogger(__name__) @@ -38,11 +43,14 @@ class HiSparseCoordinator: def __init__( self, req_to_token_pool: ReqToTokenPool, - token_to_kv_pool_allocator: HiSparseTokenToKVPoolAllocator, + token_to_kv_pool_allocator: Union[ + HiSparseTokenToKVPoolAllocator, + DeepSeekV4HiSparseTokenToKVPoolAllocator, + ], top_k: int, device_buffer_size: int, device: str, - tp_group: torch.distributed.ProcessGroup, + tp_group, host_to_device_ratio: int = 2, ): self.req_to_token_pool = req_to_token_pool @@ -50,21 +58,43 @@ def __init__( self.top_k = top_k self.device_buffer_size = device_buffer_size self.device = device + self.compress_ratio = self.token_to_kv_pool_allocator.compress_ratio - self.mem_pool_device: HiSparseNSATokenToKVPool = ( - self.token_to_kv_pool_allocator.get_kvcache() - ) - self.mem_pool_host = MLATokenToKVPoolHost( - device_pool=self.mem_pool_device, - host_to_device_ratio=host_to_device_ratio, - host_size=0, - page_size=1, # for simplicity, we set page size to 1 to enable backup one token at a time - layout="layer_first", - override_kv_cache_dim=self.mem_pool_device.kv_cache_dim, + self.is_dsv4_hisparse = isinstance( + self.token_to_kv_pool_allocator, DeepSeekV4HiSparseTokenToKVPoolAllocator ) + if self.is_dsv4_hisparse: + self.mem_pool_device = self.token_to_kv_pool_allocator.hisparse_kvcache + host_size = self.token_to_kv_pool_allocator.size_full // self.compress_ratio + self.mem_pool_host = DeepSeekV4SingleKVPoolHost( + self.mem_pool_device, host_size, 1 + ) + self.item_size_bytes = ( + self.mem_pool_host.kv_cache_total_dim + * self.mem_pool_host.dtype.itemsize + ) + else: + assert isinstance( + self.token_to_kv_pool_allocator, HiSparseTokenToKVPoolAllocator + ) + self.mem_pool_device: HiSparseNSATokenToKVPool = ( + self.token_to_kv_pool_allocator.get_kvcache() + ) + self.mem_pool_host = MLATokenToKVPoolHost( + device_pool=self.mem_pool_device, + host_to_device_ratio=host_to_device_ratio, + host_size=0, + page_size=1, + layout="layer_first", + override_kv_cache_dim=self.mem_pool_device.kv_cache_dim, + ) + self.item_size_bytes = self.mem_pool_host.token_stride_size - max_num_reqs = req_to_token_pool.req_to_token.shape[0] + max_num_req_slots = req_to_token_pool.req_to_token.shape[0] max_context_len = req_to_token_pool.max_context_len + max_compressed_context_len = ( + max_context_len + self.compress_ratio - 1 + ) // self.compress_ratio # to have an extra page for new tokens self.padded_buffer_size = ( @@ -72,13 +102,15 @@ def __init__( ) self.req_to_device_buffer = torch.zeros( - (max_num_reqs, self.padded_buffer_size), dtype=torch.int64, device=device + (max_num_req_slots, self.padded_buffer_size), + dtype=torch.int64, + device=device, ) self.req_device_buffer_size = torch.zeros( - max_num_reqs, dtype=torch.int64, device="cpu" + max_num_req_slots, dtype=torch.int64, device="cpu" ) self.req_to_host_pool = torch.full( - (max_num_reqs, max_context_len), + (max_num_req_slots, max_compressed_context_len), -1, dtype=torch.int64, device=device, @@ -97,13 +129,13 @@ def __init__( # initialize data structures for swap-in kernel layer_num = self.mem_pool_device.layer_num self.req_device_buffer_tokens = torch.full( - (layer_num, max_num_reqs, self.padded_buffer_size), + (layer_num, max_num_req_slots, self.padded_buffer_size), -1, dtype=torch.int32, device=device, ) self.req_device_buffer_token_locs = torch.full( - (layer_num, max_num_reqs, self.padded_buffer_size), + (layer_num, max_num_req_slots, self.padded_buffer_size), -1, dtype=torch.int32, device=device, @@ -113,13 +145,19 @@ def __init__( ) self.lru_slots = ( self._lru_init.view(1, 1, -1) - .repeat(layer_num, max_num_reqs, 1) + .repeat(layer_num, max_num_req_slots, 1) .contiguous() ) + self._device_buffer_arange_i32 = torch.arange( + self.device_buffer_size, dtype=torch.int32, device=device + ) # Pre-allocated output buffer for swap_in_selected_pages (CUDA-graph safe) self.top_k_device_locs_buffer = torch.full( - (max_num_reqs, self.top_k), -1, dtype=torch.int32, device=device + (max_num_req_slots, self.top_k), -1, dtype=torch.int32, device=device + ) + self.raw_indices_buffer = torch.full( + (max_num_req_slots, self.top_k), -1, dtype=torch.int32, device=device ) # Scalar tensor: number of real (non-padded) requests in the batch. # Updated before each graph replay so padded blocks early-return. @@ -127,7 +165,7 @@ def __init__( # CPU flag: True means "skip backup on the next decode step" because # staging already backed up all prefill tokens. Cleared after one step. - self._skip_first_backup = [False] * max_num_reqs + self._skip_first_backup = [False] * max_num_req_slots def set_decode_producer_stream(self, stream) -> None: self.decode_producer_stream = stream @@ -151,11 +189,14 @@ def get_token_stats(self) -> HiSparseTokenStats: def admit_request_into_staging(self, req: Req) -> None: req.hisparse_staging = True - logical_indices = self.req_to_token_pool.req_to_token[ + + full_kv_indices = self.req_to_token_pool.req_to_token[ req.req_pool_idx, : len(req.fill_ids) - ] - device_indices = self.mem_pool_device._translate_loc_to_hisparse_device( - logical_indices + ].to(dtype=torch.int64, copy=True) + device_indices = ( + self.mem_pool_device.translate_loc_from_full_to_hisparse_device( + full_kv_indices + ) ) prefill_len = len(device_indices) @@ -178,7 +219,10 @@ def admit_request_into_staging(self, req: Req) -> None: with device_module.stream(self.write_staging_stream): start_event.wait(self.write_staging_stream) self.mem_pool_host.backup_from_device_all_layer( - self.mem_pool_device, host_indices, device_indices, io_backend="kernel" + self.mem_pool_device, + host_indices, + device_indices, + io_backend="kernel", ) finish_event.record() if host_indices.is_cuda: @@ -201,6 +245,12 @@ def admit_request_direct(self, req: Req) -> None: buffer. In the staging path this is correct (prefill filled the buffer), but here the buffer is empty. """ + if self.is_dsv4_hisparse: + # TODO(dsv4): wire PD direct-to-host. Needs (a) load_to_device_per_layer + raise NotImplementedError( + "PD direct-to-host admission is not supported for dsv4 hisparse yet." + ) + self.alloc_device_buffer(req) if req.kv_allocated_len <= self.device_buffer_size: @@ -211,12 +261,12 @@ def admit_request_direct(self, req: Req) -> None: self._preload_to_device_buffer(req) else: # Long sequence: reset device_buffer_tokens to -1 so the kernel - # sees all slots as empty → every top-k lookup is a miss → host load. + # sees all slots as empty -> every top-k lookup is a miss -> host load. self.req_device_buffer_tokens[ :, req.req_pool_idx, : self.device_buffer_size ] = -1 - req.staging = False + req.hisparse_staging = False self._skip_first_backup[req.req_pool_idx] = True logger.debug("HiSparse: admitting request %s directly", req.rid) @@ -236,74 +286,52 @@ def _preload_to_device_buffer(self, req: Req) -> None: ) def alloc_device_buffer(self, req: Req) -> None: - allocated_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : req.kv_allocated_len - ] - page_size = self.mem_pool_device.page_size - # Allocate only enough for current tokens (page-aligned). - # When prefill already fills device_buffer_size, include the reserved page. - alloc_size = min( - ((req.kv_allocated_len + page_size - 1) // page_size) * page_size, - self.device_buffer_size, - ) - if alloc_size == self.device_buffer_size: + if self.is_dsv4_hisparse: + allocated_len = len(req.fill_ids) alloc_size = self.padded_buffer_size + else: + allocated_len = req.kv_allocated_len + page_size = self.mem_pool_device.page_size + # Allocate only enough for current tokens (page-aligned). + # When prefill already fills device_buffer_size, include the reserved page. + alloc_size = min( + ((allocated_len + page_size - 1) // page_size) * page_size, + self.device_buffer_size, + ) + if alloc_size == self.device_buffer_size: + alloc_size = self.padded_buffer_size + + compressed_logical_indices = ( + self.mem_pool_device.translate_loc_from_full_to_compressed( + self.req_to_token_pool.req_to_token[req.req_pool_idx, :allocated_len] + ) + ) + compressed_len = len(compressed_logical_indices) + buffer_indices = self.token_to_kv_pool_allocator.alloc_device_buffer( - allocated_indices, - alloc_size, + compressed_logical_indices, alloc_size ) if buffer_indices is None: logger.error( "HiSparse: alloc_device_buffer failed for req %s " - "(kv_allocated_len=%d, alloc_size=%d)", + "(compressed_len=%d, alloc_size=%d)", req.rid, - req.kv_allocated_len, + compressed_len, alloc_size, ) raise RuntimeError("HiSparse alloc_device_buffer returned None") + buffer_indices = buffer_indices.to(torch.int32) self.req_to_device_buffer[req.req_pool_idx, :alloc_size] = buffer_indices self.req_device_buffer_size[req.req_pool_idx] = alloc_size self.req_device_buffer_tokens[ :, req.req_pool_idx, : self.device_buffer_size - ] = torch.arange(self.device_buffer_size, device=self.device) + ] = self._device_buffer_arange_i32 self.req_device_buffer_token_locs[:, req.req_pool_idx, :alloc_size] = ( buffer_indices[:alloc_size] ) - def has_ongoing_staging(self) -> bool: - return len(self.ack_staging_queue) > 0 - - def collect_ready_reqs(self) -> List[Req]: - ready_reqs = [] - if len(self.ack_staging_queue) == 0: - return ready_reqs - - finish_count = 0 - for _, finish_event, _ in self.ack_staging_queue: - if not finish_event.query(): - break - finish_count += 1 - queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu") - if self.tp_world_size > 1: - # synchronize TP workers to make sure the same update to scheduler - torch.distributed.all_reduce( - queue_size, - op=torch.distributed.ReduceOp.MIN, - group=self.tp_group, - ) - finish_count = int(queue_size.item()) - while finish_count > 0: - _, _, req = self.ack_staging_queue.pop(0) - # prepare device buffer and update req - self.alloc_device_buffer(req) - req.hisparse_staging = False - self._skip_first_backup[req.req_pool_idx] = True - finish_count -= 1 - ready_reqs.append(req) - return ready_reqs - def _grow_device_buffers( self, seq_lens: torch.Tensor, @@ -377,6 +405,38 @@ def _grow_device_buffers( reserved_positions = (seq_lens - 1).clamp(max=self.device_buffer_size) return self.req_to_device_buffer[req_pool_indices, reserved_positions] + def has_ongoing_staging(self) -> bool: + return len(self.ack_staging_queue) > 0 + + def collect_ready_reqs(self) -> List[Req]: + ready_reqs: List[Req] = [] + if len(self.ack_staging_queue) == 0: + return ready_reqs + + finish_count = 0 + for _, finish_event, _ in self.ack_staging_queue: + if not finish_event.query(): + break + finish_count += 1 + queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu") + if self.tp_world_size > 1: + # synchronize TP workers to make sure the same update to scheduler + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + finish_count = int(queue_size.item()) + while finish_count > 0: + _, _, req = self.ack_staging_queue.pop(0) + # prepare device buffer and update req + self.alloc_device_buffer(req) + self._skip_first_backup[req.req_pool_idx] = True + req.hisparse_staging = False + finish_count -= 1 + ready_reqs.append(req) + return ready_reqs + def map_last_loc_to_buffer( self, seq_lens: torch.Tensor, @@ -389,17 +449,52 @@ def map_last_loc_to_buffer( self._eager_backup_previous_token( seq_lens, req_pool_indices, seq_lens_cpu, req_pool_indices_cpu ) - # Grow device buffers if needed and resolve the latest-token slot. - reserved_buffer_loc = self._grow_device_buffers( - seq_lens, req_pool_indices, seq_lens_cpu, req_pool_indices_cpu + + if not self.is_dsv4_hisparse: + # Grow device buffers if needed and resolve the latest-token slot. + reserved_buffer_loc = self._grow_device_buffers( + seq_lens, req_pool_indices, seq_lens_cpu, req_pool_indices_cpu + ) + self.req_device_buffer_token_locs[ + :, req_pool_indices, self.device_buffer_size + ] = reserved_buffer_loc.to(torch.int32) + + # No need to clear prior mappings: the only consumer of the mapping + # for past tokens is the swap-in kernel, and it goes through + # top_k_device_locs returned by swap_in_selected_pages -- not via + # mapping[old_out_cache_loc] -- so stale entries are harmless. + compressed_locs = self.token_to_kv_pool_allocator.get_last_loc_compressed( + out_cache_loc + ) + self.mem_pool_device.full_to_hisparse_device_index_mapping[ + compressed_locs + ] = reserved_buffer_loc + return + + active_reqs = seq_lens % self.compress_ratio == 0 + if not torch.any(active_reqs): + return + + active_seq_lens = seq_lens[active_reqs] + active_out_cache_loc = out_cache_loc[active_reqs] + active_req_pool_indices = req_pool_indices[active_reqs] + + compressed_seq_lens = active_seq_lens // self.compress_ratio + reserved_positions = (compressed_seq_lens - 1).clamp( + max=self.device_buffer_size ) + reserved_buffer_loc = self.req_to_device_buffer[ + active_req_pool_indices, reserved_positions + ] self.req_device_buffer_token_locs[ - :, req_pool_indices, self.device_buffer_size + :, active_req_pool_indices, self.device_buffer_size ] = reserved_buffer_loc.to(torch.int32) - # todo, clear the prior mapping as well - self.mem_pool_device.full_to_hisparse_device_index_mapping[out_cache_loc] = ( + compressed_locs = self.token_to_kv_pool_allocator.get_last_loc_compressed( + active_out_cache_loc + ) + self.mem_pool_device.full_to_hisparse_device_index_mapping[compressed_locs] = ( reserved_buffer_loc ) @@ -410,23 +505,29 @@ def _eager_backup_previous_token( seq_lens_cpu: torch.Tensor, req_pool_indices_cpu: torch.Tensor, ) -> None: - """Back up the previous decode token to host memory. + """Back up the previous compressed token to host memory. - Every decode step, the token written in the *previous* step must be - backed up to host so the swap-in kernel can later recover it. + Each newly produced compressed token (one per `compress_ratio` decode + steps) must be backed up to host so the swap-in kernel can later + recover it. - The only exception is the first decode step right after staging: all - prefill tokens were already backed up during staging, so there is nothing new to save yet. + Two cases are skipped: + - The first decode step right after staging: all prefill tokens were + already backed up during staging, so there is nothing new to save. + - Steps where `(seq_len - 1) % compress_ratio != 0`: no new compressed + token was produced this step. """ # Build the list of batch positions that need a host backup. - # Skip the first decode step after staging (prefill already backed up). + # Skip the first decode step after staging (prefill already backed up), + # and skip non-aligned steps that did not produce a new compressed token. backup_indices = [] for i in range(len(seq_lens_cpu)): req_idx = int(req_pool_indices_cpu[i]) if self._skip_first_backup[req_idx]: self._skip_first_backup[req_idx] = False continue - backup_indices.append(i) + if (int(seq_lens_cpu[i]) - 1) % self.compress_ratio == 0: + backup_indices.append(i) if not backup_indices: return @@ -434,13 +535,18 @@ def _eager_backup_previous_token( backup_indices_gpu = torch.tensor( backup_indices, dtype=torch.int64, device=self.device ) - # The previous token's position and its device buffer slot: - # - short seq: slot = seq_len - 2 (within the regular buffer) - # - long seq: slot = device_buffer_size (the reserved slot) - actual_token_pos = seq_lens[backup_indices_gpu] - 2 - buffer_slot = actual_token_pos.clamp(max=self.device_buffer_size) - backup_req_indices = req_pool_indices[backup_indices_gpu] + + # The previous compressed token's position and its device buffer slot: + # compressed_pos = (seq_len - 1) // compress_ratio - 1 + # - short: slot = compressed_pos (within the regular buffer) + # - long: slot = device_buffer_size (the reserved slot) + prev_seq_lens = seq_lens[backup_indices_gpu] - 1 + compressed_prev_seq_lens = prev_seq_lens // self.compress_ratio + actual_compressed_pos = compressed_prev_seq_lens - 1 + + buffer_slot = actual_compressed_pos.clamp(max=self.device_buffer_size) + device_locs = self.req_to_device_buffer[backup_req_indices, buffer_slot] host_locs = self.mem_pool_host.alloc(len(device_locs)) @@ -453,11 +559,9 @@ def _eager_backup_previous_token( f"HiSparse host mem pool alloc failed for {len(device_locs)} decode backup tokens" ) host_locs = host_locs.to(device=self.device) - self.req_to_host_pool[backup_req_indices, actual_token_pos] = host_locs + self.req_to_host_pool[backup_req_indices, actual_compressed_pos] = host_locs - if self._has_pending_backup: - self._backup_done_event.wait(device_module.current_stream()) - self._has_pending_backup = False + self.wait_for_pending_backup() schedule_stream = device_module.current_stream() with device_module.stream(self.decode_backup_stream): self.decode_backup_stream.wait_stream(schedule_stream) @@ -474,8 +578,8 @@ def _eager_backup_previous_token( host_locs.record_stream(self.decode_backup_stream) if backup_req_indices.is_cuda: backup_req_indices.record_stream(self.decode_backup_stream) - if actual_token_pos.is_cuda: - actual_token_pos.record_stream(self.decode_backup_stream) + if actual_compressed_pos.is_cuda: + actual_compressed_pos.record_stream(self.decode_backup_stream) if device_locs.is_cuda: device_locs.record_stream(self.decode_backup_stream) self._has_pending_backup = True @@ -486,20 +590,6 @@ def wait_for_pending_backup(self) -> None: self._backup_done_event.wait(device_module.current_stream()) self._has_pending_backup = False - def get_front_topk_tokens( - self, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - ) -> torch.Tensor: - top_k_indices = self.req_to_device_buffer[req_pool_indices, : self.top_k].to( - torch.int32 - ) - topk_col_indices = torch.arange(self.top_k, device=self.device).unsqueeze(0) - # Mask out positions beyond each request's seq_len - mask = topk_col_indices >= seq_lens.unsqueeze(1) - top_k_indices[mask] = -1 - return top_k_indices - def naive_load_topk( self, req_pool_indices: torch.Tensor, @@ -512,6 +602,10 @@ def naive_load_topk( This is a naive per-request loop implementation for debugging/validation. Production code uses swap_in_selected_pages (JIT CUDA kernel) instead. + Note: dsv4 hisparse is not supported — DeepSeekV4SingleKVPoolHost has no + load_to_device_per_layer and indices live in compressed space. Currently + only used as a kernel oracle in test_hisparse_unit.py (non-dsv4 path). + Args: req_pool_indices: Pool indices for each request. Shape: (num_reqs,) seq_lens: Sequence lengths for each request. Shape: (num_reqs,) @@ -521,6 +615,9 @@ def naive_load_topk( Returns: Device KV cache indices for the selected tokens. Shape: (num_reqs, top_k) """ + assert ( + not self.is_dsv4_hisparse + ), "naive_load_topk is not implemented for dsv4 hisparse" num_reqs = req_pool_indices.size(0) top_k_indices = torch.full( (num_reqs, self.top_k), -1, dtype=torch.int32, device=self.device @@ -586,7 +683,7 @@ def naive_load_topk( return top_k_indices def abort_staging_request(self, req: Req) -> None: - """Remove a request from the staging queue and free its host resources. + """Remove a request from the staging queue and free its host + device resources. Must be called when aborting a request that has been admitted into staging but has not yet completed (i.e. req.hisparse_staging is True). @@ -598,8 +695,15 @@ def abort_staging_request(self, req: Req) -> None: # Wait for any in-flight staging DMA to complete before freeing self.write_staging_stream.synchronize() + prefill_len = len(req.fill_ids) + allocated_locs = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, :prefill_len + ] + self.token_to_kv_pool_allocator.free_hisparse(allocated_locs) + # Free host memory that was allocated during admit_request_into_staging - host_indices = self.req_to_host_pool[req.req_pool_idx, : req.kv_allocated_len] + compressed_len = prefill_len // self.compress_ratio + host_indices = self.req_to_host_pool[req.req_pool_idx, :compressed_len] host_indices = host_indices[host_indices >= 0] if host_indices.numel() > 0: self.mem_pool_host.free(host_indices) @@ -617,26 +721,38 @@ def request_finished(self, req: Req): # release resources only after the execution of a potential overlapped batch if self.decode_producer_stream is not None: device_module.current_stream().wait_stream(self.decode_producer_stream) - if self._has_pending_backup: - self._backup_done_event.wait(device_module.current_stream()) - self._has_pending_backup = False - - # release memory — only free actually-allocated buffer indices + self.wait_for_pending_backup() + + # Use kv_allocated_len (not seqlen): under speculative decoding the + # allocator can over-allocate beyond the committed seqlen, and those + # extra slots may carry stale mapping entries pointing at buffer slots + # we just freed via free_hisparse_indices(all_hi). If left set, the + # subsequent release_kv_cache -> allocator.free -> free_hisparse path + # re-frees them (double-free into the page allocator's free list). + allocated_len = req.kv_allocated_len + compressed_len = allocated_len // self.compress_ratio + + # release memory -- only free actually-allocated buffer indices current_cap = int(self.req_device_buffer_size[req.req_pool_idx]) - buffer_indices = self.req_to_device_buffer[req.req_pool_idx, :current_cap] - self.token_to_kv_pool_allocator.free_hisparse_indices(buffer_indices) + if current_cap > 0: + side_buf_hi = self.req_to_device_buffer[req.req_pool_idx, :current_cap] + all_hi = torch.unique(side_buf_hi[side_buf_hi > 0]) + if all_hi.numel() > 0: + self.token_to_kv_pool_allocator.free_hisparse_indices(all_hi) allocated_locs = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : req.kv_allocated_len + req.req_pool_idx, :allocated_len ] - self.token_to_kv_pool_allocator.full_to_hisparse_device_index_mapping[ + compressed_locs = self.mem_pool_device.translate_loc_from_full_to_compressed( allocated_locs - ] = 0 + ) + self.mem_pool_device.full_to_hisparse_device_index_mapping[compressed_locs] = 0 - host_indices = self.req_to_host_pool[req.req_pool_idx, : req.kv_allocated_len] + host_indices = self.req_to_host_pool[req.req_pool_idx, :compressed_len] host_indices = host_indices[host_indices >= 0] if host_indices.numel() > 0: self.mem_pool_host.free(host_indices) + # clear req info self.req_device_buffer_tokens[:, req.req_pool_idx, :] = -1 self.req_device_buffer_token_locs[:, req.req_pool_idx, :] = -1 @@ -649,31 +765,24 @@ def request_finished(self, req: Req): def swap_in_selected_pages( self, req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, + compressed_seq_lens: torch.Tensor, top_k_result: torch.Tensor, layer_id: int, ) -> torch.Tensor: """Swap selected top-k tokens into device memory and return their indices.""" - # The CUDA kernel expects req_pool_indices as int64 and seq_lens as int32 or int64. - if req_pool_indices.dtype != torch.int64: - raise ValueError( - f"req_pool_indices dtype {req_pool_indices.dtype} is not int64 as expected" - ) - if seq_lens.dtype not in (torch.int32, torch.int64): - raise ValueError( - f"seq_lens dtype {seq_lens.dtype} is not int32 or int64 as expected" - ) - if top_k_result.dtype != torch.int32: - raise ValueError( - f"top_k_result dtype {top_k_result.dtype} is not int32 as expected" - ) - num_reqs = req_pool_indices.size(0) + top_k_indices = self.top_k_device_locs_buffer[:num_reqs] top_k_indices.fill_(-1) + # todo, adjustable for performance block_size = 1024 - load_cache_to_device_buffer_mla( + swap_in_fn = ( + load_cache_to_device_buffer_dsv4_mla + if self.is_dsv4_hisparse + else load_cache_to_device_buffer_mla + ) + swap_in_fn( top_k_tokens=top_k_result, device_buffer_tokens=self.req_device_buffer_tokens[layer_id], host_cache_locs=self.req_to_host_pool, @@ -682,9 +791,9 @@ def swap_in_selected_pages( device_buffer=self.mem_pool_device.kv_buffer[layer_id], top_k_device_locs=top_k_indices, req_pool_indices=req_pool_indices, - seq_lens=seq_lens, + seq_lens=compressed_seq_lens, lru_slots=self.lru_slots[layer_id], - item_size_bytes=self.mem_pool_host.token_stride_size, + item_size_bytes=self.item_size_bytes, num_top_k=self.top_k, hot_buffer_size=self.device_buffer_size, page_size=1, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 40e2e0dd7f36..07a1c867dd5e 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -742,6 +742,8 @@ def __init__( self.storage_hit_length = 0 # The node to lock until for swa radix tree lock ref self.swa_uuid_for_lock: Optional[int] = None + # Whether the prefill-time SWA tree lock has been released early + self.swa_prefix_lock_released: bool = False # The prefix length that is inserted into the tree cache self.cache_protected_len: int = 0 @@ -1239,6 +1241,7 @@ def reset_for_retract(self): self.last_node = None self.cache_protected_len = 0 self.swa_uuid_for_lock = None + self.swa_prefix_lock_released = False self.extend_input_len = 0 self.is_retracted = True self.retracted_stain = True @@ -1524,7 +1527,7 @@ def init_new( if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): is_hybrid_swa = True - return cls( + batch = cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool_allocator=token_to_kv_pool_allocator, @@ -1544,6 +1547,7 @@ def init_new( chunked_req=chunked_req, dllm_config=dllm_config, ) + return batch def batch_size(self): return len(self.reqs) @@ -2223,7 +2227,7 @@ def retract_decode( def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): req = self.reqs[idx] - if self.hisparse_coordinator is not None: + if self.hisparse_coordinator is not None and not req.finished(): self.hisparse_coordinator.retract_req(req) if server_args.disaggregation_mode == "decode": @@ -2634,6 +2638,11 @@ def maybe_evict_swa(self): sliding_window_size = self.tree_cache.sliding_window_size server_args = get_global_server_args() + release_leaf_lock = ( + envs.SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW.get() + and hasattr(self.tree_cache, "dec_swa_lock_only") + ) + # Eviction_interval: trade-off between SWA token waste and eviction overhead page_size = self.tree_cache.page_size eviction_interval = max( @@ -2651,6 +2660,22 @@ def maybe_evict_swa(self): # 2. Evict swa every eviction_interval tokens to reduce the overhead. if req.decode_batch_idx % eviction_interval == 1: self._evict_swa(req, req.seqlen - 1) + + # Once the decode position has moved past the sliding window, + # the SWA portion of the prefill-time tree lock is no longer + # needed by this request. Convert it from protected to + # evictable so SWA LRU can reclaim it under pressure. + if ( + release_leaf_lock + and not req.swa_prefix_lock_released + and req.swa_uuid_for_lock is not None + and req.last_node is not None + and req.decode_batch_idx >= sliding_window_size + ): + self.tree_cache.dec_swa_lock_only( + req.last_node, req.swa_uuid_for_lock + ) + req.swa_prefix_lock_released = True elif self.forward_mode.is_extend() and self.tree_cache.is_chunk_cache(): pre_len = self.prefix_lens[idx] if self.enable_overlap: @@ -2680,7 +2705,8 @@ def _evict_swa(self, req: Req, pre_len: int): # Subtract an extra page_size so the eviction frontier never reaches the # radix tree insert boundary (page_floor(seq_len)). This keeps at least one # page of non-evicted SWA KV for the tree to store as a non-tombstone node, - # preserving cache reuse in multi-turn scenarios. + # preserving cache reuse in multi-turn scenarios. Without this, leaf nodes + # may become tombstoned, causing SWA memory leak. # See also: _insert_helper case 3 in swa_radix_cache.py (defensive counterpart). new_swa_evicted_seqlen = max( req.swa_evicted_seqlen, diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index fea14cc98f46..764d086d1c6f 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -43,6 +43,9 @@ InsertParams, MatchPrefixParams, ) +from sglang.srt.mem_cache.hisparse_memory_pool import ( + DeepSeekV4HiSparseTokenToKVPoolAllocator, +) from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator from sglang.srt.server_args import ServerArgs @@ -444,8 +447,11 @@ def __init__( ] ) + # DeepSeek V4 HiSparse wraps an SWATokenToKVPoolAllocator internally and + # exposes the full SWA allocator interface. self.is_hybrid_swa = isinstance( - self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator + self.token_to_kv_pool_allocator, + (SWATokenToKVPoolAllocator, DeepSeekV4HiSparseTokenToKVPoolAllocator), ) self.is_hybrid_ssm_cache = self.tree_cache.supports_mamba() @@ -753,6 +759,13 @@ def add_req_state(r, insert_sort=False): return AddReqResult.NO_TOKEN tokens_freed += tokens_occupied + if (self.prefill_delayer_single_pass is not None) and ( + not self.prefill_delayer_single_pass.negotiate_should_allow_prefill( + local_prefillable=True + ) + ): + return AddReqResult.OTHER + if self.dllm_config is not None: if self.rem_dllm_tokens <= 0: return AddReqResult.OTHER @@ -906,6 +919,13 @@ def add_one_req( trunc_len // truncation_align_size ) + now_input_len = trunc_len + len(req.prefix_indices) + now_input_len = now_input_len // self.page_size * self.page_size + trunc_len = now_input_len - len(req.prefix_indices) + + if trunc_len <= 0: + return AddReqResult.OTHER + # Chunked prefill req.set_extend_input_len(trunc_len) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9c694321cbef..5c04dc4ee5dc 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1192,6 +1192,11 @@ def init_disaggregation(self): # todo: should we fix this when enabling mtp or it doesn't matter since we only enable mtp in decode node thus we don't transfer draft kvs between P and D? draft_token_to_kv_pool, model_config = self._get_draft_kv_pool() + # Default to the target model_config so the MetadataBuffers branches + # below can always access it; overridden by the draft model_config + # when this node runs a spec module. + if model_config is None: + model_config = self.model_config if ( self.disaggregation_mode == DisaggregationMode.DECODE @@ -2419,17 +2424,14 @@ def _build_hisparse_decode_batch(self, reqs): batch.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) batch.orig_seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device) batch.seq_lens_sum = sum(seq_lens) - # output_ids = last generated token, used as input_ids by prepare_for_decode batch.output_ids = torch.tensor( [r.output_ids[-1] for r in reqs], dtype=torch.int64, device=device ) - # Set logprob fields if any request needs them if batch.return_logprob: batch.top_logprobs_nums = [r.top_logprobs_num for r in reqs] batch.token_ids_logprobs = [list(r.origin_input_ids) for r in reqs] - # Build sampling info from scratch for these requests batch.sampling_info = SamplingBatchInfo.from_schedule_batch( batch, self.model_config.vocab_size ) @@ -3512,8 +3514,6 @@ def abort_request(self, recv_req: AbortReq): self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req) # For disaggregation decode mode, the request in the waiting queue has KV cache allocated. if self.disaggregation_mode == DisaggregationMode.DECODE: - if self.enable_hisparse: - self.hisparse_coordinator.request_finished(req) release_kv_cache(req, self.tree_cache) # For disaggregation prefill mode, free the metadata buffer index if self.disaggregation_mode == DisaggregationMode.PREFILL: diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index c06103969a3f..ebf929f71251 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -179,11 +179,12 @@ def get_pool_stats(self: Scheduler) -> PoolStats: if self.is_hybrid_swa: pool_stats = self._get_swa_token_info() elif self.is_hybrid_ssm: - return self._get_mamba_token_info() - elif self.enable_hisparse: - return self._get_hisparse_token_info() + pool_stats = self._get_mamba_token_info() else: - return self._get_token_info() + pool_stats = self._get_token_info() + + if self.enable_hisparse: + pool_stats = self._get_hisparse_token_info(pool_stats) # swa + ssm can coexist: overlay mamba fields onto swa stats if self.is_hybrid_ssm: @@ -208,8 +209,7 @@ def _get_token_info(self: Scheduler) -> PoolStats: full_evictable_size=evictable_size, ) - def _get_hisparse_token_info(self: Scheduler) -> PoolStats: - pool_stats = self._get_token_info() + def _get_hisparse_token_info(self: Scheduler, pool_stats: PoolStats) -> PoolStats: if self.enable_hisparse and self.hisparse_coordinator is not None: h = self.hisparse_coordinator.get_token_stats() return dataclasses.replace( @@ -266,6 +266,13 @@ def _get_swa_token_info(self: Scheduler) -> PoolStats: swa_num_used = self.swa_tokens_per_layer - ( swa_available_size + swa_evictable_size ) + # FIXME(hisparse): host-backup transiently over-releases the device pool + # counter, producing negative full_num_used / swa_num_used. We clamp to 0 + # to keep token_usage / leak checks sane, but the underlying accounting + # bug should be fixed so the clamp can go away. + if self.enable_hisparse: + full_num_used = max(0, full_num_used) + swa_num_used = max(0, swa_num_used) full_token_usage = full_num_used / self.full_tokens_per_layer swa_token_usage = swa_num_used / self.swa_tokens_per_layer @@ -546,11 +553,13 @@ def on_idle(self: Scheduler): if not self.is_fully_idle(): return - # memory leak check - has_leak, messages = self._check_all_pools(self.get_pool_stats()) - if has_leak: - self._report_leak("pool", "\n".join(messages)) - self._check_req_pool() + # memory leak check (skipped for hisparse — pool counters intentionally + # diverge during host-backup, see _get_swa_token_info clamp). + if not self.enable_hisparse: + has_leak, messages = self._check_all_pools(self.get_pool_stats()) + if has_leak: + self._report_leak("pool", "\n".join(messages)) + self._check_req_pool() # tree cache sanity check self._check_tree_cache() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index ace1f19504be..60e105d93963 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig from sglang.srt.server_args import ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed from sglang.srt.utils.hf_transformers_utils import ( @@ -249,9 +250,10 @@ def __init__( self.is_multi_layer_eagle = is_multi_layer_eagle self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator - self.memory_pool_config = memory_pool_config self.attn_cp_rank = attn_cp_rank self.moe_dp_rank = moe_dp_rank + # Draft worker: target's resolved MemoryPoolConfig (forwarded to ModelRunner). + self.memory_pool_config = memory_pool_config # MTP model runners self.model_runner_list: List[ModelRunner] = [] diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index fe93fa1d3fd3..7ec8bd2264eb 100755 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -55,6 +55,10 @@ def __init__( self.is_not_in_free_group = True self.free_group = [] + @property + def size_full(self): + return self.size + def debug_print(self) -> str: return "" diff --git a/python/sglang/srt/mem_cache/base_swa_memory_pool.py b/python/sglang/srt/mem_cache/base_swa_memory_pool.py new file mode 100644 index 000000000000..0fe0db219d2f --- /dev/null +++ b/python/sglang/srt/mem_cache/base_swa_memory_pool.py @@ -0,0 +1,33 @@ +import abc +from typing import List, Tuple + +import torch + +from sglang.srt.mem_cache.memory_pool import KVCache + + +class BaseSWAKVPool(KVCache): + """ABC for SWA-like KV pools. + + Subclasses expose a `swa_kv_pool` sub-pool plus a full -> swa index + mapping. Used by `SWATokenToKVPoolAllocator` and the disagg paths to + handle SWA state separately from the full KV state. + """ + + swa_kv_pool: KVCache + + @abc.abstractmethod + def register_mapping(self, full_to_swa_index_mapping: torch.Tensor) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + @abc.abstractmethod + def set_swa_loc(self, loc: torch.Tensor) -> None: + raise NotImplementedError() + + @abc.abstractmethod + def get_state_buf_infos(self) -> Tuple[List[int], List[int], List[int]]: + raise NotImplementedError() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 6641bcfa0a56..6d34a3aa1fc2 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -19,6 +19,9 @@ MatchPrefixParams, MatchResult, ) +from sglang.srt.mem_cache.hisparse_memory_pool import ( + DeepSeekV4HiSparseTokenToKVPoolAllocator, +) from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator if TYPE_CHECKING: @@ -110,7 +113,14 @@ class SWAChunkCache(ChunkCache): """ChunkCache with support for sliding window attention.""" def __init__(self, params: CacheInitParams): - assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + # DeepSeek V4 HiSparse wraps SWATokenToKVPoolAllocator and exposes the same API. + assert isinstance( + params.token_to_kv_pool_allocator, + ( + SWATokenToKVPoolAllocator, + DeepSeekV4HiSparseTokenToKVPoolAllocator, + ), + ) super().__init__(params) self.sliding_window_size = params.sliding_window_size diff --git a/python/sglang/srt/mem_cache/deepseek_v4_compress_state.py b/python/sglang/srt/mem_cache/deepseek_v4_compress_state.py new file mode 100644 index 000000000000..3c865fe84d58 --- /dev/null +++ b/python/sglang/srt/mem_cache/deepseek_v4_compress_state.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import dataclasses +from contextlib import nullcontext + +import torch + +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE +from sglang.srt.mem_cache.utils import maybe_init_custom_mem_pool +from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter + + +@dataclasses.dataclass +class KVAndScore: + kv_score: torch.Tensor + + @property + def kv(self) -> torch.Tensor: + return self.kv_score[..., : self._item_size] + + @property + def score(self) -> torch.Tensor: + return self.kv_score[..., self._item_size :] + + def __post_init__(self): + self._item_size = self.kv_score.shape[-1] // 2 + + def __getitem__(self, index) -> KVAndScore: + return KVAndScore(self.kv_score[index]) + + def clear(self): + self.kv.zero_() + self.score.fill_(float("-inf")) + + +class CompressStatePool: + def __init__( + self, + size: int, + ring_size: int, + overlap: bool, + head_dim: int, + dtype: torch.dtype, + device: str, + enable_memory_saver: bool, + ratio: int, + online: bool = False, + ): + self.ring_size = ring_size + + if online: + assert ring_size == 1, "online compress requires ring_size=1" + self._size = size + self.ring_size + 1 + last_dim = 3 * head_dim + else: + self._size = size + self.ring_size + 1 + self._size = (self._size + ratio - 1) // ratio * ratio + last_dim = 2 * (1 + overlap) * head_dim + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + self.enable_custom_mem_pool, self.custom_mem_pool, _ = ( + maybe_init_custom_mem_pool(device=device) + ) + + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.kv_score_buffer = KVAndScore( + torch.empty( + (self._size, last_dim), + dtype=dtype, + device=device, + ) + ) + if not online: + self.kv_score_buffer[-1].clear() diff --git a/python/sglang/srt/mem_cache/deepseek_v4_memory_pool.py b/python/sglang/srt/mem_cache/deepseek_v4_memory_pool.py new file mode 100644 index 000000000000..e0e7dd56fe04 --- /dev/null +++ b/python/sglang/srt/mem_cache/deepseek_v4_memory_pool.py @@ -0,0 +1,738 @@ +from __future__ import annotations + +import logging +from contextlib import nullcontext +from typing import List, Literal, NamedTuple, Optional, Tuple + +import torch + +from sglang.jit_kernel.deepseek_v4 import fused_store_cache +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE +from sglang.srt.environ import envs +from sglang.srt.layers.attention.dsv4 import ( + index_buf_accessor as dsv4_index_buf_accessor, +) +from sglang.srt.layers.attention.dsv4.index_buf_accessor import NopeFp8RopeBf16Pack +from sglang.srt.layers.attention.nsa import index_buf_accessor +from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool +from sglang.srt.mem_cache.deepseek_v4_compress_state import CompressStatePool +from sglang.srt.mem_cache.memory_pool import KVCache +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ceil_div + +logger = logging.getLogger(__name__) + +ONLINE_C128 = envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get() + + +def get_compress_state_ring_size( + compress_ratio: int, is_speculative: bool = False +) -> int: + assert compress_ratio in [4, 128], f"Unsupported {compress_ratio = }" + # Online c128 keeps a single (max, sum, kv) state per index instead of a + # 128-slot ring buffer of raw tokens, so ring_size collapses to 1. Online + # is incompatible with speculative decode for now. + if compress_ratio == 128 and ONLINE_C128: + assert not is_speculative, "online c128 does not support MTP" + return 1 + if is_speculative: + return 16 if compress_ratio == 4 else 256 + else: + return 8 if compress_ratio == 4 else 128 + + +class DeepSeekV4SingleKVPool(KVCache): + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + + self.scale_pad = 1 + self.quantize_block_size = 64 + self.rope_storage_dtype = torch.bfloat16 + self.k_with_scale_buffer_dtype = torch.int8 + self._create_buffers() + + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.kv_buffer = [ + self.create_buffer( + num_pages=(self.size + self.page_size + 1) // self.page_size, + ) + for _ in range(self.layer_num) + ] + + def get_bytes_per_token(self) -> int: + dim_per_token = ( + self.qk_nope_head_dim + + self.qk_rope_head_dim * self.rope_storage_dtype.itemsize + + self.qk_nope_head_dim // self.quantize_block_size + + self.scale_pad + ) + return dim_per_token + + def create_buffer(self, *, num_pages: int): + bytes_per_token = self.get_bytes_per_token() + self.kv_cache_total_dim = bytes_per_token + bytes_per_page_non_padded = self.page_size * bytes_per_token + self.bytes_per_page_padded = ceil_div(bytes_per_page_non_padded, 576) * 576 + + assert bytes_per_token == 448 + 64 * 2 + 8, ( + "DSV4 KV layout: qk_nope_head_dim FP8 (448) + qk_rope_head_dim BF16 " + "(64*2) + nope FP8 scales + scale_pad = 584 bytes/token" + ) + assert self.store_dtype == torch.uint8 + + return torch.zeros( + num_pages, + self.bytes_per_page_padded, + dtype=self.store_dtype, + device=self.device, + ) + + def set_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ): + dsv4_index_buf_accessor.SetKAndS.execute( + pool=self, + buf=self.kv_buffer[layer_id], + loc=loc, + nope_fp8_rope_bf16_pack=cache_nope_fp8_rope_bf16_pack, + ) + + def set_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + return fused_store_cache( + input=cache_k, + cache=self.kv_buffer[layer_id], + indices=loc, + page_size=self.page_size, + type="flashmla", + ) + + def get_key_buffer(self, layer_id: int): + return self.kv_buffer[layer_id] + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError("Use get_key_buffer instead.") + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError("Use get_key_buffer instead.") + + +class HiSparseC4DevicePool(DeepSeekV4SingleKVPool): + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: int | None = None, + end_layer: int | None = None, + ): + super().__init__( + size, + page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.kv_buffer], + dtype=torch.uint64, + device=self.device, + ) + self.compress_ratio = 4 + + def register_mapping(self, full_to_hisparse_device_index_mapping: torch.Tensor): + self.full_to_hisparse_device_index_mapping = ( + full_to_hisparse_device_index_mapping + ) + + def translate_loc_from_full_to_compressed(self, full_indices: torch.Tensor): + mask = (full_indices + 1) % self.compress_ratio == 0 + compressed_indices = full_indices[mask] // self.compress_ratio + return compressed_indices + + def translate_loc_to_hisparse_device(self, compressed_indices: torch.Tensor): + return self.full_to_hisparse_device_index_mapping[compressed_indices].to( + torch.int32 + ) + + def _translate_loc_to_hisparse_device(self, compressed_indices: torch.Tensor): + return self.full_to_hisparse_device_index_mapping[compressed_indices] + + def translate_loc_from_full_to_hisparse_device(self, full_indices: torch.Tensor): + return self._translate_loc_to_hisparse_device( + self.translate_loc_from_full_to_compressed(full_indices) + ) + + def set_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack, + ): + loc = self.translate_loc_to_hisparse_device(loc) + super().set_key_buffer(layer_id, loc, cache_nope_fp8_rope_bf16_pack) + + def set_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + loc = self.translate_loc_to_hisparse_device(loc) + return super().set_key_buffer_fused(layer_id, loc, cache_k) + + def get_cpu_copy(self, indices, mamba_indices=None): + raise NotImplementedError("HiSparseC4DevicePool does not support get_cpu_copy") + + def load_cpu_copy(self, kv_cache_cpu, indices, mamba_indices=None): + raise NotImplementedError("HiSparseC4DevicePool does not support load_cpu_copy") + + +class DeepSeekV4IndexerPool(KVCache): + quant_block_size = 128 + index_k_with_scale_buffer_dtype = torch.uint8 + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + index_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + self.index_head_dim = index_head_dim + + self._create_buffer() + + def _create_buffer(self): + num_scales_per_token = self.index_head_dim // self.quant_block_size + page_bytes = self.page_size * self.index_head_dim + page_bytes += self.page_size * num_scales_per_token * 4 + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.index_k_with_scale_buffer = [ + torch.zeros( + (self.size + self.page_size + 1) // self.page_size, + page_bytes, + dtype=self.index_k_with_scale_buffer_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: + return self.index_k_with_scale_buffer[layer_id] + + def get_index_k_scale_buffer( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + buf = self.index_k_with_scale_buffer[layer_id] + return index_buf_accessor.GetKAndS.execute( + self, buf, seq_len=seq_len, page_indices=page_indices + ) + + def set_index_k_scale_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + ) -> None: + buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] + index_buf_accessor.SetKAndS.execute( + pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale + ) + + def set_index_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + return fused_store_cache( + input=cache_k, + cache=self.index_k_with_scale_buffer[layer_id - self.start_layer], + indices=loc, + page_size=self.page_size, + type="indexer", + ) + + +class DeepSeekV4LayerItem(NamedTuple): + compress_ratio: Literal[0, 4, 128] + compress_layer_id: int + compress_kv_pool: Optional[DeepSeekV4SingleKVPool] = None + + +class DeepSeekV4TokenToKVPool(BaseSWAKVPool): + + def __init__( + self, + max_num_reqs: int, + swa_size: int, + c4_size: int, + c128_size: int, + c4_state_pool_size: int, + c128_state_pool_size: int, + page_size: int, + swa_page_size: int, + dtype: torch.dtype, + state_dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + indexer_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + compression_ratios: List[int], + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + enable_hisparse: bool = False, + ): + super().__init__( + swa_size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + c4_logical_size = c128_size * 32 + + logger.info( + "Initialize DeepSeekV4TokenToKVPool with " + f"{max_num_reqs=} {swa_size=} {c4_size=} " + f"{c4_logical_size=} {c128_size=} " + f"{c4_state_pool_size=} {c128_state_pool_size=}" + ) + + self.max_num_reqs = max_num_reqs + self.c4_size = c4_size + self.c4_logical_size = c4_logical_size + self.c128_size = c128_size + self.c4_state_pool_size = c4_state_pool_size + self.c128_state_pool_size = c128_state_pool_size + self.state_dtype = state_dtype + self.compression_ratios = compression_ratios + + assert page_size % swa_page_size == 0 + + self.swa_size = swa_size + self.swa_window_size = swa_page_size + self.swa_page_size = swa_page_size + self.scale_pad = 1 + + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.indexer_head_dim = indexer_head_dim + + c4_layer_num = sum(1 for r in compression_ratios if r == 4) + c128_layer_num = sum(1 for r in compression_ratios if r == 128) + c4_page_size = page_size // 4 + c128_page_size = page_size // 128 + self.swa_kv_pool = DeepSeekV4SingleKVPool( + swa_size, + swa_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + layer_num, + device, + enable_memory_saver, + ) + + c4_kv_pool_type = DeepSeekV4SingleKVPool + if enable_hisparse: + c4_kv_pool_type = HiSparseC4DevicePool + self.c4_kv_pool = c4_kv_pool_type( + c4_size, + c4_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + c4_layer_num, + device, + enable_memory_saver, + ) + + self.c128_kv_pool = DeepSeekV4SingleKVPool( + c128_size, + c128_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + c128_layer_num, + device, + enable_memory_saver, + ) + + self.c4_indexer_kv_pool = DeepSeekV4IndexerPool( + self.c4_logical_size, + c4_page_size, + dtype, + indexer_head_dim, + c4_layer_num, + device, + enable_memory_saver, + ) + + self._init_compressed_layer_mapping() + + self._init_paged_compress_states(enable_memory_saver) + + self._should_cache_swa = envs.SGLANG_OPT_CACHE_SWA_TRANSLATION.get() + + def register_mapping(self, full_to_swa_index_mapping: torch.Tensor): + self.full_to_swa_index_mapping = full_to_swa_index_mapping + + def get_ring_size(self, compress_ratio: int) -> int: + server_args = get_global_server_args() + is_speculative = server_args.speculative_algorithm is not None + return get_compress_state_ring_size(compress_ratio, is_speculative) + + def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): + assert self.full_to_swa_index_mapping is not None + + return self.full_to_swa_index_mapping[kv_indices].to(torch.int32) + + def set_swa_loc(self, loc: torch.Tensor) -> None: + # No-op: SWAKVPool's set_swa_loc precomputes SWA-translated loc once per + # forward batch for set_kv_buffer to read via self.swa_loc. DSV4 has its + # own equivalent cache via `_should_cache_swa + cached_loc` (in + # set_swa_key_buffer_radix_fused), so we ignore main's precomputed loc. + pass + + def get_contiguous_buf_infos(self) -> Tuple[List[int], List[int], List[int]]: + data_ptrs: List[int] = [] + data_lens: List[int] = [] + item_lens: List[int] = [] + + for bufs in [ + self.c4_kv_pool.kv_buffer, + self.c4_indexer_kv_pool.index_k_with_scale_buffer, + self.c128_kv_pool.kv_buffer, + ]: + for buf in bufs: + assert buf.ndim == 2, f"expected 2D buffer, got {buf.ndim}D" + data_ptrs.append(buf.data_ptr()) + data_lens.append(buf.nbytes) + item_lens.append(buf[0].nbytes) + + return data_ptrs, data_lens, item_lens + + def get_state_buf_infos(self) -> Tuple[List[int], List[int], List[int]]: + data_ptrs: List[int] = [] + data_lens: List[int] = [] + item_lens: List[int] = [] + + for buf in self.swa_kv_pool.kv_buffer: + assert buf.ndim == 2, f"expected 2D buffer, got {buf.ndim}D" + data_ptrs.append(buf.data_ptr()) + data_lens.append(buf.nbytes) + item_lens.append(buf[0].nbytes) + + for pools in [ + self.compress_state_pools, + self.indexer_compress_state_pools, + ]: + for pool in pools: + if pool is None: + continue + t = pool.kv_score_buffer.kv_score + assert t.ndim == 2, f"expected 2D buffer, got {t.ndim}D" + data_ptrs.append(t.data_ptr()) + data_lens.append(t.nbytes) + item_lens.append(t[0].nbytes * pool.ring_size) + + return data_ptrs, data_lens, item_lens + + def _init_paged_compress_states(self, enable_memory_saver: bool): + c4_state_pool_size = self.c4_state_pool_size + c128_state_pool_size = self.c128_state_pool_size + self.compress_state_pools: List[CompressStatePool] = [] + self.indexer_compress_state_pools: List[CompressStatePool] = [] + + for ratio in self.compression_ratios: + overlap = ratio == 4 + compress_state_pool = indexer_compress_state_pool = None + size = c4_state_pool_size if ratio == 4 else c128_state_pool_size + ring_size = self.get_ring_size(ratio) if ratio != 0 else 0 + if ratio != 0: + compress_state_pool = CompressStatePool( + size=size, + ring_size=ring_size, + overlap=overlap, + head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, + dtype=self.state_dtype, + device=self.device, + enable_memory_saver=enable_memory_saver, + ratio=ratio, + online=(ratio == 128 and ONLINE_C128), + ) + + if ratio == 4: + indexer_compress_state_pool = CompressStatePool( + size=size, + ring_size=ring_size, + overlap=overlap, + head_dim=self.indexer_head_dim, + device=self.device, + dtype=self.state_dtype, + enable_memory_saver=enable_memory_saver, + ratio=ratio, + ) + + self.compress_state_pools.append(compress_state_pool) + self.indexer_compress_state_pools.append(indexer_compress_state_pool) + + def _init_compressed_layer_mapping(self): + c1_cnt, c4_cnt, c128_cnt = 0, 0, 0 + self.layer_mapping: List[DeepSeekV4LayerItem] = [] + + for ratio in self.compression_ratios: + if ratio == 0: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=0, + compress_layer_id=c1_cnt, + ) + ) + c1_cnt += 1 + elif ratio == 4: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=4, + compress_layer_id=c4_cnt, + compress_kv_pool=self.c4_kv_pool, + ) + ) + c4_cnt += 1 + elif ratio == 128: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=128, + compress_layer_id=c128_cnt, + compress_kv_pool=self.c128_kv_pool, + ) + ) + c128_cnt += 1 + else: + raise ValueError(f"Unsupported compression ratio: {ratio}") + + def get_attention_compress_states(self, layer_id: int) -> CompressStatePool: + compress_state_pool = self.compress_state_pools[layer_id] + assert ( + compress_state_pool is not None + ), "Only c4/c128 layers have attention states." + return compress_state_pool + + def get_indexer_compress_states(self, layer_id: int) -> CompressStatePool: + indexer_compress_state_pool = self.indexer_compress_state_pools[layer_id] + assert ( + indexer_compress_state_pool is not None + ), "Only c4 layers have indexer states." + return indexer_compress_state_pool + + def get_swa_key_buffer(self, layer_id: int) -> torch.Tensor: + return self.swa_kv_pool.get_key_buffer(layer_id) + + def set_swa_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + self.swa_kv_pool.set_key_buffer(layer_id, loc, cache_nope_fp8_rope_bf16_pack) + + def get_extra_key_buffer(self, layer_id: int) -> torch.Tensor | None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + return compress_kv_pool.get_key_buffer(compress_layer_id) + + def set_extra_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + compress_kv_pool.set_key_buffer( + compress_layer_id, loc, cache_nope_fp8_rope_bf16_pack + ) + + def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.get_index_k_with_scale_buffer(compress_layer_id) + + def get_index_k_scale_buffer( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.get_index_k_scale_buffer( + compress_layer_id, seq_len, page_indices + ) + + def set_index_k_scale_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + ) -> None: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + self.c4_indexer_kv_pool.set_index_k_scale_buffer( + compress_layer_id, loc, index_k, index_k_scale + ) + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def set_swa_key_buffer_radix( + self, + layer_id: int, + raw_loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + swa_loc = self.translate_loc_from_full_to_swa(raw_loc) + self.swa_kv_pool.set_key_buffer( + layer_id, swa_loc, cache_nope_fp8_rope_bf16_pack + ) + + def get_swa_key_buffer_radix(self, layer_id: int) -> torch.Tensor: + return self.swa_kv_pool.get_key_buffer(layer_id) + + def set_swa_key_buffer_radix_fused( + self, + layer_id: int, + raw_loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + if self._should_cache_swa: + if layer_id == 0: + self.cached_loc = self.translate_loc_from_full_to_swa(raw_loc) + swa_loc = self.cached_loc + else: + swa_loc = self.translate_loc_from_full_to_swa(raw_loc) + return self.swa_kv_pool.set_key_buffer_fused(layer_id, swa_loc, cache_k) + + def set_extra_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + return compress_kv_pool.set_key_buffer_fused(compress_layer_id, loc, cache_k) + + def set_index_k_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.set_index_fused(compress_layer_id, loc, cache_k) diff --git a/python/sglang/srt/mem_cache/hisparse_memory_pool.py b/python/sglang/srt/mem_cache/hisparse_memory_pool.py index 4ecc08a27b4f..15b315837010 100644 --- a/python/sglang/srt/mem_cache/hisparse_memory_pool.py +++ b/python/sglang/srt/mem_cache/hisparse_memory_pool.py @@ -1,8 +1,10 @@ # mapping on device memory, host memory and memory allocator +import logging import weakref from typing import Optional +import psutil import torch from sglang.srt.layers.radix_attention import RadixAttention @@ -10,10 +12,16 @@ BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.deepseek_v4_memory_pool import ( + DeepSeekV4TokenToKVPool, + HiSparseC4DevicePool, +) from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils.common import get_num_new_pages +logger = logging.getLogger(__name__) + # sgl_kernel.kvcacheio is only available in CUDA/ROCm sgl-kernel builds (not XPU/MPS/NPU/CPU). _is_cuda = is_cuda() _is_hip = is_hip() @@ -75,6 +83,12 @@ def translate_loc_to_hisparse_device(self, compressed_indices: torch.Tensor): def _translate_loc_to_hisparse_device(self, compressed_indices: torch.Tensor): return self.full_to_hisparse_device_index_mapping[compressed_indices] + def translate_loc_from_full_to_hisparse_device(self, full_indices: torch.Tensor): + return self._translate_loc_to_hisparse_device(full_indices) + + def translate_loc_from_full_to_compressed(self, full_indices: torch.Tensor): + return full_indices + def set_kv_buffer( self, layer: RadixAttention, @@ -128,13 +142,14 @@ def __init__( page_size: int, dtype: torch.dtype, device: torch.device, - kvcache: NSATokenToKVPool, + kvcache: HiSparseNSATokenToKVPool, need_sort: bool, host_to_device_ratio: int = 2, ): self._kvcache = kvcache self._size_full = size * host_to_device_ratio self._size_hisparse = size + self.compress_ratio = 1 self.dtype = dtype self.device = device self.page_size = page_size @@ -148,7 +163,6 @@ def __init__( kvcache, need_sort, ) - self.hisparse_attn_allocator = PagedTokenToKVPoolAllocator( self._size_hisparse, self.page_size, @@ -157,7 +171,6 @@ def __init__( kvcache, need_sort, ) - self.full_to_hisparse_device_index_mapping = torch.cat( [ torch.zeros( @@ -174,7 +187,6 @@ def __init__( self.is_not_in_free_group = True self.free_group = [] self.clear() - self._kvcache.register_mapping( weakref.proxy(self.full_to_hisparse_device_index_mapping) ) @@ -183,15 +195,23 @@ def __init__( def size_full(self) -> int: return self._size_full + @property + def size(self) -> int: + return self._size_full + def available_size(self) -> int: return min( self.logical_attn_allocator.available_size(), self.hisparse_attn_allocator.available_size(), ) + def get_kvcache(self): + return self._kvcache + def alloc(self, need_size: int): raise NotImplementedError( - "Page size = 1 is not supported in HiSparse allocator" + "HiSparse allocator does not support direct token allocation; " + "use alloc_extend or alloc_decode instead." ) def alloc_logical_only( @@ -260,9 +280,11 @@ def free_hisparse_indices(self, buffer_indices: torch.Tensor): self.hisparse_attn_allocator.is_not_in_free_group = True self.hisparse_attn_allocator.free(buffer_indices[buffer_indices > 0]) + def get_last_loc_compressed(self, last_locs: torch.Tensor): + return last_locs + def get_last_loc_hisparse_device(self, last_locs: torch.Tensor): - hisparse_last_locs = self._kvcache._translate_loc_to_hisparse_device(last_locs) - return hisparse_last_locs + return self._kvcache._translate_loc_to_hisparse_device(last_locs) def alloc_extend( self, @@ -312,9 +334,7 @@ def alloc_extend( assert ( hisparse_indices is not None ), "Hisparse allocation failed in alloc_extend" - self.full_to_hisparse_device_index_mapping[logical_indices] = hisparse_indices - return logical_indices def alloc_decode( @@ -323,64 +343,429 @@ def alloc_decode( seq_lens_cpu: torch.Tensor, last_loc: torch.Tensor, # last_loc for full layers ): - logical_indices = self.logical_attn_allocator.alloc_decode( + return self.logical_attn_allocator.alloc_decode( seq_lens, seq_lens_cpu, last_loc ) - return logical_indices + def free_hisparse(self, free_indices: torch.Tensor): + hisparse_indices = self._kvcache._translate_loc_to_hisparse_device(free_indices) + hisparse_indices = hisparse_indices[hisparse_indices > 0] + self.free_hisparse_indices(hisparse_indices) + self.full_to_hisparse_device_index_mapping[free_indices] = 0 + + def clear(self): + self.logical_attn_allocator.clear() + self.hisparse_attn_allocator.clear() + # Note: the last item is -1, we don't clear it, see the comment in __init__ + self.full_to_hisparse_device_index_mapping[:-1].fill_(0) + self.is_not_in_free_group = True + self.free_group = [] + + def free_group_begin(self): + return + + def free_group_end(self): + return + + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + if self.is_not_in_free_group: + self.logical_attn_allocator.free(free_index) + self.free_hisparse(free_index) + else: + self.free_group.append(free_index) + assert ( + self.logical_attn_allocator.available_size() + <= self.logical_attn_allocator.size + ) + assert ( + self.hisparse_attn_allocator.available_size() + <= self.hisparse_attn_allocator.size + ) + + +class DeepSeekV4SingleKVPoolHost: + + def __init__( + self, + device_pool: HiSparseC4DevicePool, + host_size: int, + page_size: int, + pin_memory: bool = True, + device: str = "cpu", + ): + + assert host_size > 0, "Host size must be specified and greater than 0" + assert page_size == 1, "Host page size must be 1 for DeepSeekV4SingleKVPoolHost" + + self.device_pool = device_pool + self.size = host_size + self.page_size = page_size + self.num_pages = (self.size + self.page_size - 1) // self.page_size + self.pin_memory = pin_memory + self.device = device + + self.dtype = device_pool.store_dtype + self.layer_num = device_pool.layer_num + self.kv_cache_total_dim = device_pool.kv_cache_total_dim + + self.kv_buffer = self.init_kv_buffer() + self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)] + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.data_refs], + dtype=torch.uint64, + device=self.device_pool.device, + ) + self.clear() + + def clear(self): + self.free_slots = torch.arange( + 1, self.num_pages + 1, dtype=torch.int64, device="cpu" + ) + + def init_kv_buffer(self): + dims = (self.layer_num, self.size + self.page_size, self.kv_cache_total_dim) + requested_bytes = ( + self.layer_num + * (self.size + self.page_size) + * self.kv_cache_total_dim + * self.dtype.itemsize + ) + host_mem = psutil.virtual_memory() + # preserve at least 10GB for other usage + ten_gb = 10 * (1024**3) + available_bytes = host_mem.available - ten_gb + if requested_bytes > available_bytes: + raise ValueError( + f"Not enough host memory available. Requesting " + f"{requested_bytes / 1e9:.2f} GB but only have " + f"{available_bytes / 1e9:.2f} GB free. Please reduce the " + f"size of the hierarchical cache." + ) + else: + logger.info( + f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." + ) + + host_pool = torch.empty(dims, dtype=self.dtype, device=self.device) + assert self.pin_memory, "DeepSeekV4SingleKVPoolHost requires pin_memory=True" + if self.pin_memory: + torch.cuda.cudart().cudaHostRegister( + host_pool.data_ptr(), host_pool.numel() * host_pool.element_size(), 0 + ) + return host_pool + + def backup_from_device_all_layer( + self, device_pool, host_indices, device_indices, io_backend="kernel" + ): + if io_backend != "kernel": + raise ValueError(f"Unsupported IO backend: {io_backend}") + + from sglang.jit_kernel.deepseek_v4 import hisparse_offload_to_host + + if host_indices.device != device_indices.device: + host_indices = host_indices.to(device=device_indices.device) + host_indices_i64 = ( + host_indices.to(torch.int64) + if host_indices.dtype != torch.int64 + else host_indices + ) + device_indices_i64 = ( + device_indices.to(torch.int64) + if device_indices.dtype != torch.int64 + else device_indices + ) + hisparse_offload_to_host( + gpu_ptrs=device_pool.data_ptrs, + cpu_ptrs=self.data_ptrs, + gpu_indices=device_indices_i64, + cpu_indices=host_indices_i64, + ) + + def available_size(self): + return len(self.free_slots) + + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + if need_size > self.available_size(): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + return select_index - def alloc_decode_debug( + def free(self, indices: torch.Tensor) -> int: + self.free_slots = torch.cat([self.free_slots, indices.cpu()]) + return len(indices) + + +class DeepSeekV4HiSparseTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): + + def __init__( self, + logical_attn_allocator: BaseTokenToKVPoolAllocator, + ): + assert isinstance(logical_attn_allocator._kvcache, DeepSeekV4TokenToKVPool) + assert isinstance( + logical_attn_allocator._kvcache.c4_kv_pool, HiSparseC4DevicePool + ) + self.compress_ratio = 4 + + self.hisparse_kvcache = logical_attn_allocator._kvcache.c4_kv_pool + self._size_full = logical_attn_allocator.size_full + self._size_hisparse = self.hisparse_kvcache.size + + self.dtype = self.hisparse_kvcache.dtype + self.device = self.hisparse_kvcache.device + self.page_size = self.hisparse_kvcache.page_size + + self.logical_attn_allocator = logical_attn_allocator + self._kvcache = logical_attn_allocator._kvcache + self.hisparse_attn_allocator = PagedTokenToKVPoolAllocator( + self._size_hisparse, + self.page_size, + self.dtype, + self.device, + self.hisparse_kvcache, + logical_attn_allocator.need_sort, + ) + + self.full_to_hisparse_device_index_mapping = torch.cat( + [ + torch.zeros( + self._kvcache.c4_logical_size + self.page_size, + dtype=torch.int64, + device=self.device, + ), + torch.tensor([-1], dtype=torch.int64, device=self.device), + ] + ) + + self.need_sort = logical_attn_allocator.need_sort + self.free_pages = None + self.release_pages = None + self.is_not_in_free_group = True + self.free_group = [] + self.clear() + + self.hisparse_kvcache.register_mapping( + weakref.proxy(self.full_to_hisparse_device_index_mapping) + ) + + @property + def size_full(self) -> int: + return self._size_full + + @property + def size(self) -> int: + return self.logical_attn_allocator.size + + @property + def size_swa(self) -> int: + return self.logical_attn_allocator.size_swa + + @property + def full_to_swa_index_mapping(self): + return self.logical_attn_allocator.full_to_swa_index_mapping + + def debug_print(self) -> str: + msg = self.logical_attn_allocator.debug_print() + msg += ( + f"#hisparse-available-size: " + f"{self.hisparse_attn_allocator.available_size()}, " + ) + return msg + + def get_kvcache(self): + return self._kvcache + + def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): + return self.logical_attn_allocator.translate_loc_from_full_to_swa(kv_indices) + + def full_available_size(self): + return min( + self.logical_attn_allocator.full_available_size(), + self.hisparse_attn_allocator.available_size() * self.compress_ratio, + ) + + def swa_available_size(self): + return self.logical_attn_allocator.swa_available_size() + + def free_swa(self, free_indices: torch.Tensor): + self.logical_attn_allocator.free_swa(free_indices) + + def available_size(self) -> int: + return min( + self.logical_attn_allocator.available_size(), + self.hisparse_attn_allocator.available_size() * self.compress_ratio, + ) + + def alloc(self, need_size: int): + raise NotImplementedError( + "DeepSeek V4 HiSparse allocator does not support direct token allocation; " + "use alloc_extend or alloc_decode instead." + ) + + def alloc_device_buffer(self, allocated_indices, need_size: int): + assert need_size % self.page_size == 0 + hisparse_indices = self.full_to_hisparse_device_index_mapping[allocated_indices] + self.full_to_hisparse_device_index_mapping[allocated_indices] = 0 + + device_buffer_size = need_size - self.page_size + P = len(hisparse_indices) + if P > device_buffer_size + 1: + newest_src = hisparse_indices[P - 1].clone() + old_at_dbs = hisparse_indices[device_buffer_size].clone() + hisparse_indices[device_buffer_size] = newest_src + hisparse_indices[P - 1] = old_at_dbs + + if len(hisparse_indices) >= need_size: + buffer_indices = hisparse_indices[:need_size] + surplus = hisparse_indices[need_size:] + if surplus.numel() > 0: + buffer_pages = torch.unique(buffer_indices // self.page_size) + surplus_pages = torch.unique(surplus // self.page_size) + pure_surplus = surplus_pages[~torch.isin(surplus_pages, buffer_pages)] + if pure_surplus.numel() > 0: + self.hisparse_attn_allocator.is_not_in_free_group = True + self.hisparse_attn_allocator.free(pure_surplus * self.page_size) + else: + page_residual_length = len(hisparse_indices) % self.page_size + if page_residual_length != 0: + hisparse_indices = torch.cat( + [ + hisparse_indices, + torch.arange( + hisparse_indices[-1] + 1, + hisparse_indices[-1] + + self.page_size + - page_residual_length + + 1, + device=self.device, + ), + ] + ) + extra_indices = self.hisparse_attn_allocator.alloc( + need_size - len(hisparse_indices) + ) + assert ( + extra_indices is not None + ), "Hisparse allocation failed in alloc_device_buffer" + buffer_indices = torch.cat([hisparse_indices, extra_indices]) + return buffer_indices + + def free_hisparse_indices(self, buffer_indices: torch.Tensor): + self.hisparse_attn_allocator.is_not_in_free_group = True + self.hisparse_attn_allocator.free(buffer_indices[buffer_indices > 0]) + + def get_last_loc_compressed(self, last_locs: torch.Tensor): + return (last_locs - 3) // self.compress_ratio + + def get_last_loc_hisparse_device(self, last_locs: torch.Tensor): + return self.hisparse_kvcache._translate_loc_to_hisparse_device( + self.get_last_loc_compressed(last_locs) + ) + + def alloc_extend( + self, + prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, seq_lens: torch.Tensor, seq_lens_cpu: torch.Tensor, - last_loc: torch.Tensor, # last_loc for full layers + last_loc: torch.Tensor, + extend_num_tokens: int, ): - logical_indices = self.logical_attn_allocator.alloc_decode( - seq_lens, seq_lens_cpu, last_loc + assert self.page_size > 1 + + num_new_pages_logical = get_num_new_pages( + seq_lens=seq_lens_cpu, page_size=self.page_size, prefix_lens=prefix_lens_cpu ) + num_new_pages_hisparse = get_num_new_pages( + seq_lens=seq_lens_cpu // self.compress_ratio, + page_size=self.page_size, + prefix_lens=prefix_lens_cpu // self.compress_ratio, + ) + if ( + num_new_pages_logical + > self.logical_attn_allocator.available_size() // self.page_size + ): + return None + if ( + num_new_pages_hisparse + > self.hisparse_attn_allocator.available_size() // self.page_size + ): + return None - hisparse_last_loc = self.get_last_loc_hisparse_device(last_loc) - hisparse_indices = self.hisparse_attn_allocator.alloc_decode( + logical_indices = self.logical_attn_allocator.alloc_extend( + prefix_lens, + prefix_lens_cpu, seq_lens, seq_lens_cpu, - hisparse_last_loc, + last_loc, + extend_num_tokens, ) + assert logical_indices is not None, "Logical allocation failed in alloc_extend" - if logical_indices is None or hisparse_indices is None: - return None - - self.full_to_hisparse_device_index_mapping[logical_indices] = hisparse_indices + compressed_logical_indices = ( + self.hisparse_kvcache.translate_loc_from_full_to_compressed(logical_indices) + ) + hisparse_last_loc = self.get_last_loc_hisparse_device(last_loc) + hisparse_indices = self.hisparse_attn_allocator.alloc_extend( + prefix_lens // self.compress_ratio, + prefix_lens_cpu // self.compress_ratio, + seq_lens // self.compress_ratio, + seq_lens_cpu // self.compress_ratio, + hisparse_last_loc, + len(compressed_logical_indices), + ) + assert ( + hisparse_indices is not None + ), "Hisparse allocation failed in alloc_extend" + self.full_to_hisparse_device_index_mapping[compressed_logical_indices] = ( + hisparse_indices.to(torch.int64) + ) return logical_indices - def free_hisparse(self, free_indices: torch.Tensor): - hisparse_indices = self._kvcache._translate_loc_to_hisparse_device(free_indices) + def alloc_decode( + self, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, + ): + return self.logical_attn_allocator.alloc_decode( + seq_lens, seq_lens_cpu, last_loc + ) + + def free_compressed(self, compressed_indices: torch.Tensor): + hisparse_indices = self.hisparse_kvcache.translate_loc_to_hisparse_device( + compressed_indices + ) hisparse_indices = hisparse_indices[hisparse_indices > 0] self.free_hisparse_indices(hisparse_indices) - self.full_to_hisparse_device_index_mapping[free_indices] = 0 + self.full_to_hisparse_device_index_mapping[compressed_indices] = 0 + + def free_hisparse(self, free_indices: torch.Tensor): + compressed_indices = ( + self.hisparse_kvcache.translate_loc_from_full_to_compressed(free_indices) + ) + self.free_compressed(compressed_indices) def clear(self): self.logical_attn_allocator.clear() self.hisparse_attn_allocator.clear() - # Note: the last item is -1, we don't clear it, see the comment in __init__ self.full_to_hisparse_device_index_mapping[:-1].fill_(0) self.is_not_in_free_group = True self.free_group = [] - def free_group_begin(self): - return - - def free_group_end(self): - return - def free(self, free_index: torch.Tensor): if free_index.numel() == 0: return if self.is_not_in_free_group: self.logical_attn_allocator.free(free_index) - self.free_hisparse(free_index) else: self.free_group.append(free_index) assert ( diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 67d30295ce7a..393738ec31e7 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -9,6 +9,7 @@ PagedTokenToKVPoolAllocator, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.base_swa_memory_pool import BaseSWAKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool from sglang.srt.mem_cache.utils import maybe_init_custom_mem_pool from sglang.srt.utils import is_npu @@ -25,7 +26,7 @@ GB = 1024 * 1024 * 1024 -class SWAKVPool(KVCache): +class SWAKVPool(BaseSWAKVPool): """KV cache with separate pools for full and SWA attention layers.""" def __init__( @@ -253,29 +254,32 @@ def __init__( page_size: int, dtype: torch.dtype, device: str, - kvcache: SWAKVPool, + kvcache: BaseSWAKVPool, need_sort: bool, ): - assert isinstance(kvcache, SWAKVPool) + assert isinstance(kvcache, BaseSWAKVPool) self._size_full = size self._size_swa = size_swa self.dtype = dtype self.device = device self.page_size = page_size + full_kv_pool = getattr(kvcache, "full_kv_pool", None) + swa_kv_pool = getattr(kvcache, "swa_kv_pool", None) + if page_size == 1: self.full_attn_allocator = TokenToKVPoolAllocator( size, dtype, device, - kvcache.full_kv_pool, + full_kv_pool, need_sort, ) self.swa_attn_allocator = TokenToKVPoolAllocator( size_swa, dtype, device, - kvcache.swa_kv_pool, + swa_kv_pool, need_sort, ) else: @@ -288,7 +292,7 @@ def __init__( page_size, dtype, device, - kvcache.full_kv_pool, + full_kv_pool, need_sort, ) self.swa_attn_allocator = PagedTokenToKVPoolAllocatorClass( @@ -296,7 +300,7 @@ def __init__( page_size, dtype, device, - kvcache.swa_kv_pool, + swa_kv_pool, need_sort, ) # Note: append one more item of value -1 in the end so -1 maps to -1. diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index c650588b6394..d226314791f8 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -27,6 +27,7 @@ import torch from numpy import float64 +from sglang.srt.environ import envs from sglang.srt.mem_cache.base_prefix_cache import ( BasePrefixCache, DecLockRefParams, @@ -471,8 +472,11 @@ def cache_finished_req(self, req: Req, is_insert: bool = True) -> None: # Remove req slot release the cache lock self.dec_lock_ref( - req.last_node, DecLockRefParams(swa_uuid_for_lock=req.swa_uuid_for_lock) + req.last_node, + DecLockRefParams(swa_uuid_for_lock=req.swa_uuid_for_lock), + skip_swa=req.swa_prefix_lock_released, ) + req.swa_prefix_lock_released = False def cache_unfinished_req(self, req: Req, chunked=False) -> None: """Cache request when it is unfinished.""" @@ -524,8 +528,11 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: req.cache_protected_len = len(new_indices) self.dec_lock_ref( - req.last_node, DecLockRefParams(swa_uuid_for_lock=req.swa_uuid_for_lock) + req.last_node, + DecLockRefParams(swa_uuid_for_lock=req.swa_uuid_for_lock), + skip_swa=req.swa_prefix_lock_released, ) + req.swa_prefix_lock_released = False result = self.inc_lock_ref(new_last_node) swa_uuid_for_lock = result.swa_uuid_for_lock @@ -568,12 +575,15 @@ def evict(self, params: EvictParams) -> EvictResult: # 1. free node kv indices, evict full and swa tokens self.token_to_kv_pool_allocator.free(x.value) full_num_evicted += len(x.value) - swa_num_evicted += len(x.value) + # Tombstoned leaves had their SWA freed earlier in `dec_swa_lock_only` + if not x.swa_tombstone: + swa_num_evicted += len(x.value) # 2. get the next leaf, update the lru lists x_next = self.full_lru_list.get_prev_leaf_no_lock(x) self.full_lru_list.remove_node(x) - self.swa_lru_list.remove_node(x) + if not x.swa_tombstone: + self.swa_lru_list.remove_node(x) # 3. delete the leaf node self._delete_leaf(x) @@ -610,6 +620,18 @@ def evict(self, params: EvictParams) -> EvictResult: # 3. tombstone the node self._tombstone_internal_node(x) + elif x.full_lock_ref > 0: + # Leaf still holds a full-side lock (can happen when the + # SWA leaf-lock early-release optimization revived a + # tombstoned leaf. Treat it like an internal tombstone. + self.token_to_kv_pool_allocator.free_swa(x.value) + swa_num_evicted += len(x.value) + + x_next = self.swa_lru_list.get_prev_no_lock(x) + self.swa_lru_list.remove_node(x) + + self.swa_evictable_size_ -= len(x.value) + x.swa_tombstone = True else: assert ( x.full_lock_ref == 0 @@ -679,20 +701,26 @@ def inc_lock_ref(self, node: TreeNode) -> IncLockRefResult: return IncLockRefResult(swa_uuid_for_lock=swa_uuid_for_lock) def dec_lock_ref( - self, node: TreeNode, params: Optional[DecLockRefParams] = None + self, + node: TreeNode, + params: Optional[DecLockRefParams] = None, + skip_swa: bool = False, ) -> DecLockRefResult: """ Decrement the lock reference count for the node. It unlocks the full_lock_ref for nodes between the [last node, root), exclusive. It unlocks the swa_lock_ref for nodes between the [last node, swa_uuid_for_lock], inclusive. If swa_uuid_for_lock is None, it unlocks to the root, exclusive. + + If skip_swa is True, only the full_lock_ref is decremented; the SWA lock is + assumed to have been released already (e.g. via `dec_swa_lock_only`). """ swa_uuid_for_lock = params.swa_uuid_for_lock if params is not None else None if self.disable: return DecLockRefResult() - dec_lock_swa = True + dec_lock_swa = not skip_swa while node != self.root_node: assert ( node.full_lock_ref > 0 @@ -721,6 +749,61 @@ def dec_lock_ref( return DecLockRefResult() + def dec_swa_lock_only( + self, node: TreeNode, swa_uuid_for_lock: Optional[int] = None + ): + """ + Decrement only the swa_lock_ref (and swa_protected_size_) along the chain + [node, swa_uuid_for_lock], inclusive. The full_lock_ref is left untouched + so the caller's full-cache protection is preserved. + + Used to early-release the SWA portion of a request's tree lock once the + request's decode position has advanced past the sliding window, so the + protected window can be reclaimed. + + For internal nodes, the standard protected -> evictable transition is + applied (node stays in swa_lru_list and may be evicted by SWA LRU later). + For leaf nodes, since `swa_lru_list` cannot contain a leaf with + `full_lock_ref > 0` (SWA-eviction would also delete the still-referenced + leaf), we instead free the SWA pool slots immediately and mark the leaf + as `swa_tombstone=True`. The full kv stays alive until the full-side + lock drops; future prefix-matches stop before this tombstoned leaf. + + Caller must ensure this is invoked at most once per (node, swa_uuid_for_lock) + pair (track via e.g. `Req.swa_prefix_lock_released`). When the request + finally releases its full lock via `dec_lock_ref`, pass `skip_swa=True` + to avoid touching SWA state again. + """ + if self.disable: + return + + while node != self.root_node: + assert ( + not node.swa_tombstone + ), f"dec_swa_lock_only on swa_tombstone node, {node.id=}" + assert ( + node.swa_lock_ref > 0 + ), f"dec_swa_lock_only on node with {node.swa_lock_ref=}, {node.id=}" + + if node.swa_lock_ref == 1: + self.swa_protected_size_ -= len(node.value) + if len(node.children) == 0: + # Leaf: free SWA pool slots and tombstone, and remove from + # swa_lru_list so SWA-eviction won't pick this tombstoned + # leaf (which still holds full_lock_ref > 0). The full kv + # stays alive until the request releases its full lock. + self.token_to_kv_pool_allocator.free_swa(node.value) + self.swa_lru_list.remove_node(node) + node.swa_tombstone = True + else: + # Internal: standard protected -> evictable. + self.swa_evictable_size_ += len(node.value) + node.swa_lock_ref -= 1 + + if swa_uuid_for_lock and node.swa_uuid == swa_uuid_for_lock: + break + node = node.parent + def sanity_check(self): self.full_lru_list.sanity_check(self) self.swa_lru_list.sanity_check(self) @@ -789,9 +872,13 @@ def _match_prefix_helper( match_len_since_tombstone = float("inf") best_value_len = 0 best_last_node = node + enable_compact = envs.SGLANG_OPT_SWA_RADIX_CACHE_COMPACT.get() while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] + if enable_compact: + self._compact_single_child_chain(child) + if child.swa_tombstone: # update best_value_len and best_last_node if needed if match_len_since_tombstone >= self.sliding_window_size: @@ -871,6 +958,84 @@ def _match_post_processor( last_host_node=last_node, ) + def _compact_single_child_chain(self, node: TreeNode) -> None: + # FIXME(ispobock): drifts retract pool accounting (commit 6348cb506); + # also overwrites active swa_uuid when window > page_size. Off by + # default via SGLANG_OPT_SWA_RADIX_CACHE_COMPACT. + while len(node.children) == 1: + child = next(iter(node.children.values())) + if len(child.children) == 0: + break + sum_gc_full_lock_ref = sum( + gc.full_lock_ref for gc in child.children.values() + ) + if child.full_lock_ref > sum_gc_full_lock_ref: + break + if ( + child.swa_tombstone != node.swa_tombstone + or child.full_lock_ref != node.full_lock_ref + or child.swa_lock_ref != node.swa_lock_ref + ): + break + + # Preserve is_bigram: main #23106 made bigram an O(1) flag on RadixKey; + # the constructor defaults to False, so concat without explicit flag + # silently demotes EAGLE/MTP bigram keys → match() returns 0 → + # _split_node assert. + node.key = RadixKey( + node.key.token_ids + child.key.token_ids, + node.key.extra_key, + is_bigram=node.key.is_bigram, + ) + node.value = torch.cat([node.value, child.value]) + node.children = child.children + for grandchild in node.children.values(): + grandchild.parent = node + + if child.swa_uuid is not None: + node.swa_uuid = child.swa_uuid + + self.full_lru_list.remove_node(child) + if not child.swa_tombstone: + self.swa_lru_list.remove_node(child) + + def _maybe_split_leaf_for_swa_lock(self, leaf: TreeNode) -> TreeNode: + """``inc_lock_ref`` protects ``len(leaf.value)`` SWA tokens for the + leaf even though SWA only actually needs the last + ``sliding_window_size`` tokens. With chunked prefill, leaves can be + thousands of tokens long, which inflates ``swa_protected_size_`` by + ~``chunked_prefill_size / sliding_window_size`` and causes premature + SWA pool exhaustion / retract thrashing. + """ + if ( + leaf is self.root_node + or leaf.swa_lock_ref > 0 + or leaf.swa_tombstone + or len(leaf.value) == 0 + ): + return leaf + + # Smallest page-aligned size that still covers the sliding window. + tail_size = ( + (self.sliding_window_size + self.page_size - 1) + // self.page_size + * self.page_size + ) + if len(leaf.value) <= tail_size: + return leaf + + split_at = len(leaf.value) - tail_size + + if split_at <= 0 or split_at >= len(leaf.value): + return leaf + if self.page_size > 1 and ( + split_at % self.page_size != 0 or len(leaf.value) % self.page_size != 0 + ): + return leaf + + self._split_node(leaf.key, leaf, split_at) + return leaf + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: # new_node -> child new_node = TreeNode() @@ -1026,7 +1191,13 @@ def _insert_helper( key = key[swa_tombstone_len:] value = value[swa_tombstone_len:] - self._add_new_node(node, key, value, swa_tombstone=False) + new_leaf = self._add_new_node(node, key, value, swa_tombstone=False) + + if envs.SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT.get(): + # Cap the leaf at one (page-aligned) sliding window so a future + # inc_lock_ref only protects `sliding_window_size` tokens of SWA pool. + self._maybe_split_leaf_for_swa_lock(new_leaf) + return total_prefix_length def _add_new_node( @@ -1074,15 +1245,15 @@ def _iteratively_delete_tombstone_leaf( return node, full_num_evicted def _delete_leaf(self, node: TreeNode) -> None: - assert ( - not node.swa_tombstone - ), f"Invariant violated: leaf node is a tombstone, {node.id=}" assert len(node.children) == 0, f"leaf node has children, {node.id=}" key = node.key.child_key(self.page_size) v = node.parent.children.pop(key, None) assert v == node, f"parent does not have child key, {key}" self.full_evictable_size_ -= len(node.key) - self.swa_evictable_size_ -= len(node.key) + # Tombstoned leaves were never (re-)added to swa_lru_list and were + # already removed from swa_evictable_size_ when they were tombstoned. + if not node.swa_tombstone: + self.swa_evictable_size_ -= len(node.key) def _tombstone_internal_node(self, node: TreeNode) -> None: assert len(node.children) != 0, f"Cannot tombstone a leaf node, {node.id=}" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 05b64778ad39..00ca12ffff93 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -394,6 +394,12 @@ def get_is_capture_mode(): return is_capture_mode +def compile_in_capture_mode(func): + if get_is_capture_mode(): + return torch.compile(func) + return func + + def get_capture_lora_variant() -> Optional[str]: """Return the lora variant being captured, or None if not in dual capture.""" return _capture_lora_variant @@ -1163,11 +1169,13 @@ def run_once(): self.device_module.synchronize() self.model_runner.tp_group.barrier() run_once() + attn_backend.on_after_cuda_graph_warmup() if get_global_graph_memory_pool() is None: set_global_graph_memory_pool(self.device_module.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) + out = self._capture_graph( graph, get_global_graph_memory_pool(), stream, run_once ) @@ -1269,6 +1277,9 @@ def replay_prepare( attn_backend = self.model_runner.decode_attn_backend_group[stream_idx] else: attn_backend = self.attn_backend + # FIXME: implicit channel for backends (dsv4) that need forward_batch + # in replay metadata prep. Should become a real param on the interface. + attn_backend._replay_forward_batch = forward_batch attn_backend.init_forward_metadata_replay_cuda_graph( bs, buffers.req_pool_indices[:bs], @@ -1279,6 +1290,7 @@ def replay_prepare( forward_batch.spec_info, seq_lens_cpu=buffers.seq_lens_cpu[:bs], ) + attn_backend._replay_forward_batch = None # Store fields self.raw_bs = raw_bs @@ -1326,6 +1338,7 @@ def replay( ) with ctx: self.graphs[graph_key].replay() + output = self.output_buffers[graph_key] if isinstance(output, LogitsProcessorOutput): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 97ef0fdf2cad..9832eb615522 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -346,6 +346,10 @@ def __init__( ): # Parse args self.mem_fraction_static = mem_fraction_static + # Set on target by `_resolve_memory_pool_config`; passed in for draft + # workers so they reuse target's resolved sizes (replaces legacy + # `server_args._draft_pool_config` mutation hack). + self.memory_pool_config = memory_pool_config self.device = server_args.device self.gpu_id = gpu_id self.tp_rank = tp_rank @@ -364,7 +368,6 @@ def __init__( self.dist_port = nccl_port self.server_args = server_args self.is_draft_worker = is_draft_worker - self.memory_pool_config = memory_pool_config self.is_generation = model_config.is_generation self.device_timer = None self.is_multimodal = model_config.is_multimodal @@ -378,7 +381,9 @@ def __init__( self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.is_hybrid_swa = model_config.is_hybrid_swa - self.is_hybrid_swa_compress = model_config.is_hybrid_swa_compress + self.is_hybrid_swa_compress = getattr( + model_config, "is_hybrid_swa_compress", False + ) self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size rope_scaling = getattr( @@ -555,6 +560,9 @@ def __init__( self._model_update_group = {} self._weights_send_group = {} + if not hasattr(self, "hisparse_coordinator"): + self.hisparse_coordinator = None + def _build_model_config( self, server_args, model_path=None, model_revision=None, is_draft_model=False ): @@ -738,26 +746,6 @@ def initialize(self, pre_model_load_memory: float): # Init ngram embedding token table self.maybe_init_ngram_embedding() - # Init hisparse coordinator (must happen before CUDA graph capture) - if self.enable_hisparse: - from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator - from sglang.srt.mem_cache.sparsity import parse_hisparse_config - - hisparse_cfg = parse_hisparse_config(self.server_args) - self.hisparse_coordinator = HiSparseCoordinator( - req_to_token_pool=self.req_to_token_pool, - token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, - top_k=hisparse_cfg.top_k, - device_buffer_size=hisparse_cfg.device_buffer_size, - device=self.device, - tp_group=( - self.attention_tp_group.cpu_group - if self.server_args.enable_dp_attention - else self.tp_group.cpu_group - ), - host_to_device_ratio=hisparse_cfg.host_to_device_ratio, - ) - # Init routed experts capturer self.init_routed_experts_capturer() @@ -772,6 +760,28 @@ def initialize(self, pre_model_load_memory: float): self.init_cublas() self.init_attention_backend() self.kernel_warmup() + # Init hisparse coordinator (must happen before CUDA graph capture) + if self.enable_hisparse: + from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator + from sglang.srt.mem_cache.sparsity import parse_hisparse_config + + hisparse_cfg = parse_hisparse_config(self.server_args) + hisparse_top_k = getattr( + self.model_config.hf_text_config, "index_topk", hisparse_cfg.top_k + ) + self.hisparse_coordinator = HiSparseCoordinator( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + top_k=hisparse_top_k, + device_buffer_size=hisparse_cfg.device_buffer_size, + device=self.device, + tp_group=( + self.attention_tp_group.cpu_group + if self.server_args.enable_dp_attention + else self.tp_group.cpu_group + ), + host_to_device_ratio=hisparse_cfg.host_to_device_ratio, + ) self._pre_initialize_flashinfer_allreduce_workspace() self.init_device_graphs() elif self.device in ["npu", "cpu"]: @@ -801,6 +811,9 @@ def adjust_hybrid_swa_layers_for_pp(self): if not self.is_hybrid_swa: return + if self.model_config.is_deepseek_v4_arch: + return + full_attention_layer_ids = [ layer_idx for layer_idx in range(self.start_layer, self.end_layer + 1) @@ -3313,10 +3326,12 @@ def _forward_raw( # Hisparse coordinator if ( - self.hisparse_coordinator is not None - and forward_batch.forward_mode.is_decode() + forward_batch.forward_mode.is_decode() + and self.hisparse_coordinator is not None ): + forward_batch.hisparse_coordinator = self.hisparse_coordinator self.hisparse_coordinator.wait_for_pending_backup() + self.hisparse_coordinator.num_real_reqs.fill_(forward_batch.batch_size) # Replay cuda graph if applicable if can_run_graph: diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 9bc7e123ca7c..dd612d0206b4 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -5,7 +5,11 @@ import torch -from sglang.srt.configs.model_config import get_nsa_index_head_dim, is_deepseek_nsa +from sglang.srt.configs.model_config import ( + get_nsa_index_head_dim, + is_deepseek_nsa, + is_deepseek_v4, +) from sglang.srt.distributed.parallel_state import get_world_group from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import get_attention_tp_size @@ -13,7 +17,9 @@ PagedTokenToKVPoolAllocator, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.deepseek_v4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.hisparse_memory_pool import ( + DeepSeekV4HiSparseTokenToKVPoolAllocator, HiSparseNSATokenToKVPool, HiSparseTokenToKVPoolAllocator, ) @@ -52,7 +58,6 @@ class ModelRunnerKVCacheMixin: - def _profile_available_bytes(self: ModelRunner, pre_model_load_memory: int) -> int: post_model_load_memory = get_available_gpu_memory( self.device, @@ -282,11 +287,48 @@ def _init_pools(self: ModelRunner): # Initialize token_to_kv_pool is_nsa_model = is_deepseek_nsa(self.model_config.hf_config) + is_dsv4_model = is_deepseek_v4(self.model_config.hf_config) - # Check out-of-tree platform (plugin system) first + # Out-of-tree platform plugin system — used by elif below from sglang.srt.platforms import current_platform - if current_platform.is_out_of_tree() and not self.mambaish_config: + if is_dsv4_model: + swa_page_size = self.page_size + assert swa_page_size == 256, "In paged swa mode, page_size must be 256." + + if self.is_draft_worker: + from sglang.srt.models.deepseek_v4_nextn import ( + COMPRESS_RATIO_NEXTN_LAYER, + ) + + compression_ratios = [ + COMPRESS_RATIO_NEXTN_LAYER + ] * self.num_effective_layers + else: + compression_ratios = self.model_config.compress_ratios + self.token_to_kv_pool = DeepSeekV4TokenToKVPool( + max_num_reqs=self.max_running_requests, + swa_size=self.swa_max_total_num_tokens, + c4_size=self.c4_max_total_num_tokens, + c128_size=self.c128_max_total_num_tokens, + c4_state_pool_size=self.c4_state_pool_size, + c128_state_pool_size=self.c128_state_pool_size, + page_size=self.page_size, + swa_page_size=swa_page_size, + dtype=self.kv_cache_dtype, + state_dtype=self.state_dtype, + qk_nope_head_dim=self.model_config.qk_nope_head_dim, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + indexer_head_dim=self.model_config.index_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + compression_ratios=compression_ratios, + start_layer=self.start_layer, + end_layer=self.end_layer, + enable_hisparse=self.enable_hisparse, + ) + elif current_platform.is_out_of_tree() and not self.mambaish_config: if self.use_mla_backend and is_nsa_model: PoolCls = current_platform.get_nsa_kv_pool_cls() self.token_to_kv_pool = PoolCls( @@ -656,15 +698,25 @@ def _init_pools(self: ModelRunner): need_sort=need_sort, ) + if self.enable_hisparse and is_dsv4_model: + assert self.is_hybrid_swa, "DeepSeek V4 HiSparse requires SWA mode." + self.token_to_kv_pool_allocator = ( + DeepSeekV4HiSparseTokenToKVPoolAllocator( + self.token_to_kv_pool_allocator + ) + ) + else: assert self.is_draft_worker if self.is_hybrid_swa: - assert ( - self.token_to_kv_pool_allocator.__class__ - == SWATokenToKVPoolAllocator + swa_allocator = getattr( + self.token_to_kv_pool_allocator, + "logical_attn_allocator", + self.token_to_kv_pool_allocator, ) + assert swa_allocator.__class__ == SWATokenToKVPoolAllocator self.token_to_kv_pool.full_to_swa_index_mapping = ( - self.token_to_kv_pool_allocator.full_to_swa_index_mapping + swa_allocator.full_to_swa_index_mapping ) def _apply_token_constraints(self: ModelRunner, token_capacity: int) -> int: @@ -725,6 +777,27 @@ def _apply_memory_pool_config(self: ModelRunner, config: MemoryPoolConfig): self.full_max_total_num_tokens = config.full_max_total_num_tokens self.swa_max_total_num_tokens = config.swa_max_total_num_tokens + # DSV4 compressed-attention pool sizes. Draft worker reuses target's + # full/swa sizes but does NOT own c4/c128/state pools (those live on + # the target rank only); zero them out regardless of what config holds. + if self.is_draft_worker: + self.c4_max_total_num_tokens = 0 + self.c128_max_total_num_tokens = 0 + self.c4_state_pool_size = 0 + self.c128_state_pool_size = 0 + else: + self.c4_max_total_num_tokens = config.c4_max_total_num_tokens + self.c128_max_total_num_tokens = config.c128_max_total_num_tokens + self.c4_state_pool_size = config.c4_state_pool_size + self.c128_state_pool_size = config.c128_state_pool_size + + # state_dtype is a DSV4 architectural constant (fp32 for c4/c128 + # state buffers); set unconditionally so draft workers have it before + # _init_pools reads it (target path also overwrites this in the + # configurator's resolve() for parity, harmless here). + if is_deepseek_v4(self.model_config.hf_config): + self.state_dtype = torch.float32 + self._init_pools() def _resolve_memory_pool_config( diff --git a/python/sglang/srt/model_executor/pool_configurator.py b/python/sglang/srt/model_executor/pool_configurator.py index 659b65fd34f7..5212f2c0f910 100644 --- a/python/sglang/srt/model_executor/pool_configurator.py +++ b/python/sglang/srt/model_executor/pool_configurator.py @@ -19,8 +19,14 @@ import torch -from sglang.srt.configs.model_config import get_nsa_index_head_dim, is_deepseek_nsa +from sglang.srt.configs.model_config import ( + get_nsa_index_head_dim, + is_deepseek_nsa, + is_deepseek_v4, +) +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.mem_cache.deepseek_v4_memory_pool import get_compress_state_ring_size from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool from sglang.srt.utils.common import is_float4_e2m1fn_x2 @@ -34,6 +40,12 @@ class MemoryPoolConfig: full_max_total_num_tokens: Optional[int] = None swa_max_total_num_tokens: Optional[int] = None + # DSV4 compressed-attention pool sizes (target only; draft workers leave at 0). + c4_max_total_num_tokens: int = 0 + c128_max_total_num_tokens: int = 0 + c4_state_pool_size: int = 0 + c128_state_pool_size: int = 0 + mem_fraction_static: Optional[float] = None def __post_init__(self): @@ -284,10 +296,177 @@ def calculate_pool_sizes_from_max_tokens( return self._solve_pool_sizes(max_total_num_tokens, page_size) +@dataclass +class _DSV4PoolSizes: + full_max_total_num_tokens: int + swa_max_total_num_tokens: int + c4_max_total_num_tokens: int + c128_max_total_num_tokens: int + c4_state_pool_size: int + c128_state_pool_size: int + + +class DSV4PoolConfigurator(MemoryPoolConfigurator): + """Configurator for DSV4 compressed-attention models. + + Splits available memory across full / swa / c4 / c128 + c4_state / c128_state + pools. coeff is bytes_per_full_token (inflated by (T+D)/T when speculative + decode reserves a draft worker, mirroring dflash's cell_size scaling); bias = 0. + """ + + def __init__(self, mr: ModelRunner): + cfg = mr.model_config + self.qk_nope_head_dim = cfg.qk_nope_head_dim + self.qk_rope_head_dim = cfg.qk_rope_head_dim + self.indexer_head_dim = cfg.index_head_dim + self.compression_ratios = cfg.compress_ratios + self.swa_page_size = cfg.window_size + self.swa_ratio = mr.server_args.swa_full_tokens_ratio + self.is_speculative = mr.server_args.speculative_algorithm is not None + if mr.enable_hisparse: + from sglang.srt.mem_cache.sparsity import parse_hisparse_config + + self.c4_shrink_factor = parse_hisparse_config( + mr.server_args + ).host_to_device_ratio + else: + self.c4_shrink_factor = 1 + assert self.c4_shrink_factor >= 1 + if self.c4_shrink_factor > 1: + logger.info(f"HiSparse c4 host-to-device ratio = {self.c4_shrink_factor}") + + self.c4_ring_size = get_compress_state_ring_size(4, self.is_speculative) + self.c128_ring_size = get_compress_state_ring_size(128, self.is_speculative) + + self.num_layers_total = len(self.compression_ratios) + self.num_layers_ca4 = sum(1 for r in self.compression_ratios if r == 4) + self.num_layers_ca128 = sum(1 for r in self.compression_ratios if r == 128) + + self.bytes_per_full_token = self._get_bytes_per_full_token() + if self.is_speculative: + # Reserve memory for the speculative draft worker by inflating + # per-token bytes by (target+draft)/target. Equivalent to dflash's + # scale_kv_cell_size_per_token_for_dflash but applied to + # bytes_per_full_token: tokens = avail / (bpft * (T+D)/T). + draft_layers = 1 + target_layers = self.num_layers_total + self.bytes_per_full_token *= (target_layers + draft_layers) / target_layers + + # Online c128 keeps a single in-progress (max, sum, kv) state per index + # and assumes a strict forward-only schedule. Speculative decode (MTP) + # would need rollback / replay across draft and verify, which the + # online path doesn't support yet. + if envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get(): + assert ( + mr.spec_algorithm.is_none() + ), "SGLANG_OPT_USE_ONLINE_COMPRESS does not support speculative decode (MTP) yet" + logger.info("DSV4 compressed attention: online c128 enabled (ring_size=1)") + + def _get_bytes_per_full_token(self) -> float: + kv_bytes = self.qk_nope_head_dim + self.qk_rope_head_dim * 2 + 8 + + quant_block_size = 128 + indexer_bytes = ( + self.indexer_head_dim + self.indexer_head_dim // quant_block_size * 4 + ) + + attn_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + state_dtype_size = 4 + c4_state_bytes = 2 * 2 * attn_head_dim * state_dtype_size + # Online c128 stores (max, sum, kv) per slot (3*head_dim) instead of + # raw (kv, score) (2*head_dim). Combined with ring_size=1 this still + # nets a large reduction (~3/256x) but the per-slot bytes go up. + c128_online = envs.SGLANG_OPT_USE_ONLINE_COMPRESS.get() + c128_state_bytes = ( + (3 if c128_online else 2 * 1) * attn_head_dim * state_dtype_size + ) + c4_indexer_state_bytes = 2 * 2 * self.indexer_head_dim * state_dtype_size + + c4_state_ratio = self.c4_ring_size / self.swa_page_size + c128_state_ratio = self.c128_ring_size / self.swa_page_size + + c4_frac = 1 / (4 * self.c4_shrink_factor) + return ( + self.swa_ratio * kv_bytes * self.num_layers_total + + c4_frac * kv_bytes * self.num_layers_ca4 + + 1 / 128 * kv_bytes * self.num_layers_ca128 + + 1 / 4 * indexer_bytes * self.num_layers_ca4 + + self.swa_ratio * c4_state_ratio * c4_state_bytes * self.num_layers_ca4 + + self.swa_ratio + * c128_state_ratio + * c128_state_bytes + * self.num_layers_ca128 + + self.swa_ratio + * c4_state_ratio + * c4_indexer_state_bytes + * self.num_layers_ca4 + ) + + def _compute_dsv4_sizes(self, full_token: int, page_size: int) -> _DSV4PoolSizes: + full_token = full_token // page_size * page_size + swa_tokens = int(full_token * self.swa_ratio) // page_size * page_size + return _DSV4PoolSizes( + full_max_total_num_tokens=full_token, + swa_max_total_num_tokens=swa_tokens, + c4_max_total_num_tokens=full_token // (4 * self.c4_shrink_factor), + c128_max_total_num_tokens=full_token // 128, + c4_state_pool_size=swa_tokens // self.swa_page_size * self.c4_ring_size, + c128_state_pool_size=swa_tokens // self.swa_page_size * self.c128_ring_size, + ) + + def _to_config(self, sizes: _DSV4PoolSizes) -> MemoryPoolConfig: + full = sizes.full_max_total_num_tokens + swa = sizes.swa_max_total_num_tokens + logger.info( + f"DSV4 pool sizes: full={full}, swa={swa}, " + f"c4={sizes.c4_max_total_num_tokens}, " + f"c128={sizes.c128_max_total_num_tokens}, " + f"c4_state={sizes.c4_state_pool_size}, " + f"c128_state={sizes.c128_state_pool_size}" + ) + return MemoryPoolConfig( + max_total_num_tokens=full, + full_max_total_num_tokens=full, + swa_max_total_num_tokens=swa, + c4_max_total_num_tokens=sizes.c4_max_total_num_tokens, + c128_max_total_num_tokens=sizes.c128_max_total_num_tokens, + c4_state_pool_size=sizes.c4_state_pool_size, + c128_state_pool_size=sizes.c128_state_pool_size, + ) + + def calculate_pool_sizes( + self, available_bytes: int, page_size: int + ) -> MemoryPoolConfig: + assert ( + page_size % 128 == 0 + ), "page_size must be multiple of 128 for compressed attention" + + full_token = int(available_bytes / self.bytes_per_full_token) + sizes = self._compute_dsv4_sizes(full_token, page_size) + logger.info( + f"DSV4 memory calculation: " + f"bytes_per_full_token={self.bytes_per_full_token:.2f}, " + f"available_bytes={available_bytes / (1 << 30):.2f} GB, " + f"full_token={sizes.full_max_total_num_tokens}" + ) + return self._to_config(sizes) + + def calculate_pool_sizes_from_max_tokens( + self, max_total_num_tokens: int, page_size: int + ) -> MemoryPoolConfig: + assert ( + page_size % 128 == 0 + ), "page_size must be multiple of 128 for compressed attention" + sizes = self._compute_dsv4_sizes(max_total_num_tokens, page_size) + return self._to_config(sizes) + + def create_memory_pool_configurator( mr: ModelRunner, ) -> MemoryPoolConfigurator: """Factory: select the right configurator for the model architecture.""" + if is_deepseek_v4(mr.model_config.hf_config) and mr.is_hybrid_swa: + return DSV4PoolConfigurator(mr) if mr.is_hybrid_swa: return HybridSWAPoolConfigurator(mr) # Future: MambaPoolConfigurator diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 8bfd57d33cf9..28996e48955e 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -234,6 +234,11 @@ def _get_quantization_config( # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3 if quant_config is None: return None + # Carry DSV4 expert layout into Fp8Config so downstream readers don't read env. + from sglang.srt.layers.quantization.fp8 import Fp8Config + + if isinstance(quant_config, Fp8Config): + quant_config.is_fp4_experts = model_config.is_fp4_experts if not _is_npu: major, minor = get_device_capability() diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 084fa02ec4d6..1b43ccc0538d 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -11,8 +11,11 @@ import json import logging import os +import re +import struct import tempfile from collections import defaultdict +from pathlib import Path from typing import ( Any, Callable, @@ -71,6 +74,60 @@ RUNAI_STREAMER_TENSOR_ATTR = "_sglang_runai_streamer_tensor" +# Matches routed-expert weight keys in both HF-style layouts +# (``...mlp.experts..{gate,up,down}_proj.weight``) and DeepSeek V4 +# layouts (``...ffn.experts..w{1,2,3}.weight``). ``shared_experts`` is +# excluded because the index segment requires a digit after ``.experts.``. +_ROUTED_EXPERT_KEY_RE = re.compile( + r"\.experts\.\d+\.(?:w[123]|down_proj|up_proj|gate_proj)\.weight$" +) + + +def probe_routed_expert_weight_dtype(model_path: str) -> Optional[str]: + """Return the safetensors dtype string (e.g. ``F8_E4M3``, ``U8``) of one + routed-expert weight tensor, or ``None`` if the checkpoint is remote or has + no matching key. Reads only the safetensors header of the relevant shard. + """ + if not os.path.isdir(model_path): + return None + + index_file = os.path.join(model_path, "model.safetensors.index.json") + target_key = None + target_shard_path = None + + if os.path.exists(index_file): + with open(index_file) as f: + index = json.load(f) + weight_map = index.get("weight_map", {}) or {} + for k, shard in weight_map.items(): + if _ROUTED_EXPERT_KEY_RE.search(k): + target_key = k + target_shard_path = os.path.join(model_path, shard) + break + if target_key is None: + return None + else: + shards = sorted(Path(model_path).glob("*.safetensors")) + if not shards: + return None + target_shard_path = str(shards[0]) + + with open(target_shard_path, "rb") as f: + (header_len,) = struct.unpack(" None: super().__init__() self.tp_size = tp_size + self.swiglu_limit = swiglu_limit self.gate_up_proj = MergedColumnParallelLinear( hidden_size, @@ -263,6 +269,12 @@ def forward( x = (x, None, y) gate_up, _ = self.gate_up_proj(x) + if self.swiglu_limit is not None: + _g, _u = gate_up.chunk(2, dim=-1) + _lim = float(self.swiglu_limit) + gate_up = torch.cat( + [_g.clamp(max=_lim), _u.clamp(min=-_lim, max=_lim)], dim=-1 + ) x = self.act_fn(gate_up) x, _ = self.down_proj( x, @@ -278,13 +290,17 @@ def __init__( quant_config, prefix: str = "", is_nextn: bool = False, + is_hash_moe: bool = False, + is_deepseek_v4: bool = False, ): super().__init__() self.is_nextn = is_nextn + self.is_deepseek_v4 = is_deepseek_v4 self.weight = nn.Parameter( torch.empty((config.n_routed_experts, config.hidden_size)) ) - if config.topk_method == "noaux_tc": + + if config.topk_method == "noaux_tc" and not is_hash_moe: correction_bias_dtype = torch.float32 if quant_config is not None: if ( @@ -324,7 +340,11 @@ def forward( if get_global_server_args().enable_deterministic_inference: return F.linear(hidden_states, self.weight, None) - if forward_batch is not None and nsa_use_prefill_cp(forward_batch): + if ( + not self.is_deepseek_v4 + and forward_batch is not None + and nsa_use_prefill_cp(forward_batch) + ): logits = F.linear(hidden_states, self.weight, None) else: # NOTE: For some unknown reason, router_gemm seems degrade accept length. @@ -352,7 +372,13 @@ def forward( elif _use_aiter: logits = aiter_dsv3_router_gemm(hidden_states, self.weight) else: - logits = F.linear(hidden_states, self.weight, None) + if self.is_deepseek_v4: + from sglang.jit_kernel.deepseek_v4 import linear_bf16_fp32 + + logits = linear_bf16_fp32(hidden_states, self.weight) + else: + # After testing, we may use the faster code in `if deepseek v4` branch + logits = F.linear(hidden_states, self.weight, None) return logits @@ -367,6 +393,7 @@ def __init__( prefix: str = "", alt_stream: Optional[torch.cuda.Stream] = None, is_nextn: bool = False, + is_deepseek_v4: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -406,6 +433,9 @@ def __init__( self.alt_stream = alt_stream self.is_nextn = is_nextn + n_hash_layers = getattr(config, "num_hash_layers", 0) + self.is_hash = layer_id < n_hash_layers and not (is_deepseek_v4 and is_nextn) + if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -423,6 +453,8 @@ def __init__( quant_config=quant_config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn, + is_hash_moe=self.is_hash, + is_deepseek_v4=is_deepseek_v4, ) # scaling factor for fused shared experts on AMD-platform. @@ -452,31 +484,52 @@ def __init__( routing_method_type=getattr( config, "routing_method_type", RoutingMethodType.DeepSeekV3 ), + swiglu_limit=getattr(config, "swiglu_limit", None), prefix=add_prefix("experts", prefix), ) - self.topk = TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, - layer_id=self.layer_id, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - quant_config=quant_config, - routed_scaling_factor=self.routed_scaling_factor, - apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, - fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor, - # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized - # and requires the output format to be standard (except trtllm). We use quant_config to determine the output format. - output_format=( - TopKOutputFormat.STANDARD - if (quant_config is None) - and (not get_moe_runner_backend().is_flashinfer_trtllm()) - else None - ), - ) + if self.is_hash and not (is_nextn and is_deepseek_v4): + self.topk = HashTopK( + topk=config.num_experts_per_tok + self.num_fused_shared_experts, + num_experts=config.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, + vocab_size=config.vocab_size, + scoring_func=config.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, + ) + else: + # Default: grouped noaux_tc top-k. Covers V3/V3.2/GLM-5/Glm4MoeLite. + topk_kwargs = dict( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + layer_id=self.layer_id, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + quant_config=quant_config, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, + fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor, + # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized + # and requires the output format to be standard (except trtllm). We use quant_config to determine the output format. + output_format=( + TopKOutputFormat.STANDARD + if (quant_config is None) + and (not get_moe_runner_backend().is_flashinfer_trtllm()) + else None + ), + ) + # DSV4 override: ungrouped sqrtsoftplus + fp4 expert layout flag. + if is_deepseek_v4: + topk_kwargs.update( + use_grouped_topk=False, + scoring_func=config.scoring_func, + is_fp4_experts=getattr(quant_config, "is_fp4_experts", False), + ) + self.topk = TopK(**topk_kwargs) self.shared_experts_is_int8 = False self.shared_experts_is_fp8 = False @@ -497,6 +550,7 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + swiglu_limit=getattr(config, "swiglu_limit", None), prefix=add_prefix("shared_experts", prefix), **( dict(tp_rank=0, tp_size=1) @@ -593,7 +647,19 @@ def forward( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: + from sglang.srt.layers.moe.mega_moe import forward_mega_moe, should_use_mega_moe + + if should_use_mega_moe(self, hidden_states): + return forward_mega_moe( + self, + hidden_states, + forward_batch, + input_ids_global=input_ids_global, + ) + if not self._enable_a2a_moe: if ( self.alt_stream is not None @@ -612,6 +678,8 @@ def forward( should_allreduce_fusion, use_reduce_scatter, gemm_output_zero_allocator, + input_ids, + input_ids_global=input_ids_global, ) else: return self.forward_normal( @@ -619,9 +687,13 @@ def forward( should_allreduce_fusion, use_reduce_scatter, gemm_output_zero_allocator, + input_ids, + input_ids_global=input_ids_global, ) else: - return self.forward_deepep(hidden_states, forward_batch) + return self.forward_deepep( + hidden_states, forward_batch, input_ids_global=input_ids_global + ) def forward_normal_dual_stream( self, @@ -629,6 +701,8 @@ def forward_normal_dual_stream( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) @@ -644,10 +718,16 @@ def forward_normal_dual_stream( with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) + topk_kwargs = ( + {"input_ids": input_ids_global} + if getattr(self, "is_hash", False) + else {} + ) topk_output = self.topk( hidden_states, router_logits, expert_location_dispatch_info=dispatch_info, + **topk_kwargs, ) final_hidden_states = self.experts(hidden_states, topk_output) if not (_is_cuda or _is_musa) or isinstance( @@ -656,7 +736,14 @@ def forward_normal_dual_stream( final_hidden_states *= self.routed_scaling_factor current_stream.wait_stream(self.alt_stream) - final_hidden_states += shared_output + + final_hidden_states = maybe_fuse_routed_scale_and_shared_add( + self.experts, + final_hidden_states, + shared_output, + self.routed_scaling_factor, + ) + if self.tp_size > 1 and not should_skip_post_experts_all_reduce( is_tp_path=True, use_reduce_scatter=use_reduce_scatter, @@ -671,6 +758,8 @@ def forward_normal( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -691,10 +780,16 @@ def forward_normal( ) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) + topk_kwargs = ( + {"input_ids": input_ids_global} + if getattr(self, "is_hash", False) + else {} + ) topk_output = self.topk( hidden_states, router_logits, expert_location_dispatch_info=dispatch_info, + **topk_kwargs, ) else: shared_output = None @@ -743,8 +838,14 @@ def _post_combine_hook( ): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor - if shared_output is not None: - final_hidden_states += shared_output + + final_hidden_states = maybe_fuse_routed_scale_and_shared_add( + self.experts, + final_hidden_states, + shared_output, + self.routed_scaling_factor, + ) + if self.tp_size > 1 and not should_skip_post_experts_all_reduce( is_tp_path=True, use_reduce_scatter=use_reduce_scatter, @@ -813,6 +914,7 @@ def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: shared_output = None sbo_enabled_flag = self._fuse_shared_experts_inside_sbo and not self.is_nextn @@ -835,6 +937,11 @@ def forward_deepep( shared_event = self.alt_stream.record_event() else: shared_output = self._forward_shared_experts(hidden_states) + topk_kwargs = ( + {"input_ids": input_ids_global} + if getattr(self, "is_hash", False) + else {} + ) topk_output = self.topk( hidden_states, router_logits, @@ -842,6 +949,7 @@ def forward_deepep( expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ), + **topk_kwargs, ) else: topk_output = self.topk.empty_topk_output(hidden_states.device) @@ -1696,6 +1804,7 @@ def __init__( prefix=add_prefix("mlp", prefix), tp_rank=mlp_tp_rank, tp_size=mlp_tp_size, + swiglu_limit=getattr(config, "swiglu_limit", None), ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py new file mode 100644 index 000000000000..b1c225051967 --- /dev/null +++ b/python/sglang/srt/models/deepseek_v4.py @@ -0,0 +1,1528 @@ +from __future__ import annotations + +import concurrent.futures +import logging +from typing import TYPE_CHECKING, Iterable, List, Literal, Optional, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +import sglang.srt.models.deepseek_v2 as deepseek_v2 +from sglang.jit_kernel.deepseek_v4 import fused_rope, rmsnorm_self +from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.environ import envs +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.attention.dsv4.compressor import Compressor +from sglang.srt.layers.attention.dsv4.indexer import C4Indexer +from sglang.srt.layers.attention.nsa.utils import ( + can_nsa_cp_split, + is_nsa_enable_prefill_cp, + is_nsa_prefill_cp_round_robin_split, + nsa_use_prefill_cp, +) +from sglang.srt.layers.communicator import get_attn_tp_context +from sglang.srt.layers.deepseek_v4_rope import apply_rotary_emb_triton +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + attn_tp_all_gather, + dp_gather_partial, + dp_scatter, + get_attention_cp_rank, + get_attention_cp_size, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, + get_global_dp_buffer, + get_local_dp_buffer, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.utils.cp_utils import ( + cp_all_gather_rerange_output, + cp_split_and_rebuild_data, + cp_split_and_rebuild_position, + prepare_context_parallel_metadata, +) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.mem_cache.memory_pool import RadixAttention +from sglang.srt.model_executor.cuda_graph_runner import ( + compile_in_capture_mode, + get_is_capture_mode, +) +from sglang.srt.model_loader.utils import maybe_executor_submit, should_async_load +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.dbrx import ReplicatedLinear +from sglang.srt.models.deepseek_v2 import ParallelLMHead, _is_cuda, _is_hip, _is_npu +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ( + LazyValue, + add_prefix, + log_info_on_rank0, + make_layers, +) +from sglang.srt.utils.hf_transformers_utils import get_rope_config + +logger = logging.getLogger(__name__) + +_FP8_WO_A_GEMM = envs.SGLANG_OPT_FP8_WO_A_GEMM.get() + + +if TYPE_CHECKING: + from sglang.srt.layers.attention.deepseek_v4_backend import ( + DeepseekV4AttnBackend, + ) + from sglang.srt.layers.quantization import QuantizationConfig + from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + PPProxyTensors, + ) + + +@triton.jit +def _rms_normalize_kernel( + x_ptr, + weight_ptr, + eps, + stride_row, + dim, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, +): + pid = tl.program_id(0) + + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < dim + + base = pid * stride_row + x = tl.load(x_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + + mean_sq = tl.sum(x * x, axis=0) / dim + rms_inv = tl.rsqrt(mean_sq + eps) + out = x * rms_inv + + if HAS_WEIGHT: + weight = tl.load(weight_ptr + offs, mask=mask, other=0.0) + out = out * weight + + tl.store(x_ptr + base + offs, out, mask=mask) + + +def rms_normalize_triton( + x: torch.Tensor, eps: float, weight: torch.Tensor = None +) -> torch.Tensor: + dim = x.shape[-1] + x_flat = x.view(-1, dim) + num_rows = x_flat.shape[0] + + BLOCK_SIZE = triton.next_power_of_2(dim) + grid = (num_rows,) + + _rms_normalize_kernel[grid]( + x_flat, + weight, + eps, + x_flat.stride(0), + dim, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=(weight is not None), + ) + return x + + +class MQALayer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + compress_ratio_override: Optional[int] = None, + ) -> None: + super().__init__() + self.tp_rank = attn_tp_rank = get_attention_tp_rank() + self.tp_size = attn_tp_size = get_attention_tp_size() + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_size = get_attention_cp_size() + self.tp_rank = attn_tp_rank = 0 + self.tp_size = attn_tp_size = 1 + self.layer_id = layer_id + self.dim = config.hidden_size + self.qk_rope_head_dim = config.qk_rope_head_dim + self.qk_nope_head_dim = config.head_dim - config.qk_rope_head_dim + self.head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim + self.n_heads = config.num_attention_heads + self.n_local_heads = self.n_heads // attn_tp_size + self.n_groups = config.o_groups + self.n_local_groups = self.n_groups // attn_tp_size + self.rope_head_dim = config.qk_rope_head_dim + self.softmax_scale = self.head_dim**-0.5 + self.hidden_size = config.hidden_size + self.q_lora_rank = config.q_lora_rank + self.o_lora_rank = config.o_lora_rank + self.eps = config.rms_norm_eps + compress_ratio = ( + compress_ratio_override + if compress_ratio_override is not None + else config.compress_ratios[layer_id] + ) + assert compress_ratio in [0, 4, 128] + self.compress_ratio: Literal[0, 4, 128] = compress_ratio + + assert self.head_dim == config.head_dim + assert config.num_key_value_heads == 1 + + rope_theta, rope_scaling = get_rope_config(config) + if rope_scaling: + rope_scaling["rope_type"] = "deepseek_yarn" + + rope_base = config.compress_rope_theta if self.compress_ratio else rope_theta + + from sglang.srt.layers.deepseek_v4_rope import precompute_freqs_cis + + assert self.compress_ratio in {0, 4, 128} + if self.compress_ratio: + original_seq_len = rope_scaling["original_max_position_embeddings"] + else: + original_seq_len = 0 + + freqs_cis = precompute_freqs_cis( + dim=self.qk_rope_head_dim, + seqlen=config.max_position_embeddings, + original_seq_len=original_seq_len, + base=rope_base, + factor=rope_scaling["factor"], + beta_fast=rope_scaling["beta_fast"], + beta_slow=rope_scaling["beta_slow"], + ) + self.register_buffer("freqs_cis", freqs_cis, persistent=False) + self.freqs_cis: torch.Tensor + + if envs.SGLANG_OPT_USE_MULTI_STREAM_OVERLAP.get() and alt_streams is not None: + self.alt_streams = alt_streams[:3] + self.alt_streams_indexer = alt_streams[-2:] + else: + self.alt_streams = None + self.alt_streams_indexer = None + + from sglang.srt.utils import is_blackwell_supported + + self._multi_stream_bs_limit = 128 if is_blackwell_supported() else 64 + + self.compressor = None + self.indexer = None + if self.compress_ratio: + self.compressor = Compressor( + config, + layer_id=self.layer_id, + is_in_indexer=False, + freqs_cis=freqs_cis, + compress_ratio=self.compress_ratio, + head_dim=self.head_dim, + rotate=False, + prefix=add_prefix("compressor", prefix), + ) + if self.compress_ratio == 4: + self.indexer = C4Indexer( + config, + freqs_cis=freqs_cis, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("indexer", prefix), + alt_streams=self.alt_streams_indexer, + ) + + self.attn_sink = nn.Parameter(torch.empty(self.n_heads, dtype=torch.float32)) + self.fuse_wqa_wkv = envs.SGLANG_OPT_FUSE_WQA_WKV.get() + if self.fuse_wqa_wkv: + self.wqkv_a = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wqkv_a", prefix), + ) + else: + self.wq_a = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wq_a", prefix), + ) + self.wkv = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wkv", prefix), + ) + self.q_norm = RMSNorm(self.q_lora_rank, eps=self.eps) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wq_b", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + self.kv_norm = RMSNorm(self.head_dim, eps=self.eps) + self.wo_a = ColumnParallelLinear( + self.n_heads * self.head_dim // self.n_groups, + self.n_groups * self.o_lora_rank, + bias=False, + quant_config=quant_config if _FP8_WO_A_GEMM else None, + prefix=add_prefix("wo_a", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + **({} if _FP8_WO_A_GEMM else {"params_dtype": torch.bfloat16}), + ) + if _FP8_WO_A_GEMM: + assert hasattr( + self.wo_a, "weight_scale_inv" + ), "FP8 quant_config must create weight_scale_inv" + self.wo_a.weight_scale_inv.format_ue8m0 = True + self.wo_b = RowParallelLinear( + self.n_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=attn_tp_size > 1, + prefix=add_prefix("wo_b", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + + self.attn_mqa = RadixAttention( + self.n_local_heads, + self.head_dim, + self.softmax_scale, + num_kv_heads=1, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn_mqa", prefix), + ) + + self.overlap_store_cache = envs.SGLANG_OPT_USE_OVERLAP_STORE_CACHE.get() + self.use_jit_norm = envs.SGLANG_OPT_USE_JIT_NORM.get() + + def _compute_q_a( + self, + x: torch.Tensor, + qkv_a: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if qkv_a is not None: + q = qkv_a[..., : self.q_lora_rank] + else: + q, _ = self.wq_a(x) + q = self.q_norm(q) + q_lora = q + return q_lora + + def _compute_q_b( + self, + q: torch.Tensor, + positions: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + q, _ = self.wq_b(q) + q = q.view(-1, self.n_local_heads, self.head_dim) + if self.use_jit_norm: + q = rmsnorm_self(q, self.eps) + else: + q = rms_normalize_triton(q, self.eps) + if positions is not None: + fused_rope( + q[..., -self.qk_rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + ) + else: + apply_rotary_emb_triton(q[..., -self.qk_rope_head_dim :], self.freqs_cis) + return q + + def _compute_kv( + self, + x: torch.Tensor, + positions: Optional[torch.Tensor] = None, + qkv_a: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if qkv_a is not None: + kv = qkv_a[..., self.q_lora_rank :] + else: + kv, _ = self.wkv(x) + kv = self.kv_norm(kv) + if positions is not None: + fused_rope( + kv[..., -self.qk_rope_head_dim :].unsqueeze(1), + None, + self.freqs_cis, + positions=positions, + ) + else: + apply_rotary_emb_triton(kv[..., -self.qk_rope_head_dim :], self.freqs_cis) + return kv + + def _forward_prepare_multi_stream( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + attn_backend: DeepseekV4AttnBackend, + q_out: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.alt_streams is not None + assert len(self.alt_streams) >= 3 + + current_stream = torch.cuda.current_stream() + stream_kv = self.alt_streams[0] + stream_compressor = self.alt_streams[1] + stream_indexer = self.alt_streams[2] + + stream_kv.wait_stream(current_stream) + stream_compressor.wait_stream(current_stream) + stream_indexer.wait_stream(current_stream) + + qkv_a: Optional[torch.Tensor] = None + qkv_a_ready: Optional[torch.cuda.Event] = None + if self.fuse_wqa_wkv: + qkv_a, _ = self.wqkv_a(x) + qkv_a_ready = current_stream.record_event() + + q_lora = self._compute_q_a(x, qkv_a=qkv_a) + q_lora_ready = current_stream.record_event() + + if self.indexer is not None: + with torch.cuda.stream(stream_indexer): + self.indexer( + x=x, + q_lora=q_lora, + forward_batch=forward_batch, + enable_multi_stream=True, + q_lora_ready=q_lora_ready, + ) + + with torch.cuda.stream(stream_kv): + if qkv_a_ready is not None: + stream_kv.wait_event(qkv_a_ready) + kv = self._compute_kv(x, positions, qkv_a=qkv_a) + if self.overlap_store_cache: + attn_backend.store_cache( + layer_id=self.layer_id, + swa_k=kv, + forward_batch=forward_batch, + ) + + del qkv_a + + if self.compressor is not None: + with torch.cuda.stream(stream_compressor): + attn_backend.forward_core_compressor( + x, forward_batch, self.layer_id, self.compressor + ) + + q = self._compute_q_b(q_lora, positions) + if q_out is not None: + q_out.copy_(q) + + current_stream.wait_stream(stream_kv) + current_stream.wait_stream(stream_compressor) + current_stream.wait_stream(stream_indexer) + + return q, kv + + def _forward_prepare( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + attn_backend: DeepseekV4AttnBackend, + q_out: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.fuse_wqa_wkv: + qkv_a, _ = self.wqkv_a(x) + q = qkv_a[..., : self.q_lora_rank] + kv = qkv_a[..., self.q_lora_rank :] + del qkv_a + else: + kv, _ = self.wkv(x) + q, _ = self.wq_a(x) + q = self.q_norm(q) + q_lora = q + q, _ = self.wq_b(q) + q = q.view(-1, self.n_local_heads, self.head_dim) + if self.use_jit_norm: + q = rmsnorm_self(q, self.eps) + else: + q = rms_normalize_triton(q, self.eps) + + kv = self.kv_norm(kv) + + fused_rope( + q[..., -self.qk_rope_head_dim :], + kv[..., -self.qk_rope_head_dim :].unsqueeze(1), + self.freqs_cis, + positions=positions, + ) + + if self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch): + kv = cp_all_gather_rerange_output( + kv.contiguous(), + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + + if self.overlap_store_cache: + attn_backend.store_cache( + layer_id=self.layer_id, + swa_k=kv, + forward_batch=forward_batch, + ) + + if self.indexer is not None: + self.indexer(x=x, q_lora=q_lora, forward_batch=forward_batch) + if self.compressor is not None: + attn_backend.forward_core_compressor( + x, + forward_batch, + self.layer_id, + self.compressor, + ) + + if q_out is not None: + q_out.copy_(q) + return q, kv + + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + if not get_attn_tp_context().input_scattered and x.shape[0] == 0: + assert ( + not self.wo_b.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return x + + attn_backend = forward_batch.attn_backend + if TYPE_CHECKING: + assert isinstance(attn_backend, DeepseekV4AttnBackend) + + enable_multi_stream = ( + envs.SGLANG_OPT_USE_MULTI_STREAM_OVERLAP.get() + and self.alt_streams is not None + and get_is_capture_mode() + and x.shape[0] <= self._multi_stream_bs_limit + and not (self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch)) + ) + + tp_slice, q_padded, q_out = slice(None), None, None + if self.tp_size > 1: + q_padded = x.new_empty(x.shape[0], self.n_heads, self.head_dim) + rank = self.tp_rank + tp_slice = slice(rank * self.n_local_heads, (rank + 1) * self.n_local_heads) + q_out = q_padded[:, tp_slice, :] + + if enable_multi_stream: + q, kv = self._forward_prepare_multi_stream( + x, positions, forward_batch, attn_backend, q_out + ) + else: + q, kv = self._forward_prepare( + x, positions, forward_batch, attn_backend, q_out + ) + + o = attn_backend.forward( + q=q_padded if q_padded is not None else q, + k=kv, + v=kv, + layer=self.attn_mqa, + forward_batch=forward_batch, + compress_ratio=self.compress_ratio, + attn_sink=self.attn_sink, + save_kv_cache=not self.overlap_store_cache, + ) + o = o[:, tp_slice, :] + fused_rope( + o[..., -self.qk_rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + inverse=True, + ) + + o = o.view(o.shape[0], self.n_local_groups, -1) + + if _FP8_WO_A_GEMM: + import deep_gemm + + T, G, D = o.shape + R = self.o_lora_rank + o_fp8, o_s = sglang_per_token_group_quant_fp8( + o.reshape(T * G, D).contiguous(), + group_size=128, + ) + output = torch.empty(T, G, R, device=o.device, dtype=torch.bfloat16) + deep_gemm.fp8_einsum( + "bhr,hdr->bhd", + (o_fp8.view(T, G, D), o_s.view(T, G, -1)), + (self.wo_a.weight.view(G, R, D), self.wo_a.weight_scale_inv.data), + output, + recipe=(1, 1, 128), + ) + o = output + else: + wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) + o = torch.einsum("tgd,grd->tgr", o, wo_a) + + o, _ = self.wo_b(o.flatten(1)) + + return o + + +class DeepseekV4DecoderLayer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + moe_quant_config_override: Optional[QuantizationConfig] = None, + is_nextn: bool = False, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + compress_ratio_override: Optional[int] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_id = layer_id + self.self_attn = MQALayer( + config=config, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + alt_streams=alt_streams, + compress_ratio_override=compress_ratio_override, + ) + self.mlp = deepseek_v2.DeepseekV2MoE( + config=config, + quant_config=moe_quant_config_override or quant_config, + prefix=add_prefix("mlp", prefix), + layer_id=self.layer_id, + alt_stream=alt_streams[0] if alt_streams is not None else None, + is_nextn=is_nextn, + is_deepseek_v4=True, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.hc_mult = hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + mix_hc = (2 + hc_mult) * hc_mult + hc_dim = hc_mult * config.hidden_size + self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim, dtype=torch.float32)) + self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim, dtype=torch.float32)) + self.hc_attn_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float32)) + self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float32)) + self.hc_attn_scale = nn.Parameter(torch.empty(3, dtype=torch.float32)) + self.hc_ffn_scale = nn.Parameter(torch.empty(3, dtype=torch.float32)) + self.rms_norm_eps = config.rms_norm_eps + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + + def hc_pre( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + @compile_in_capture_mode + def hc_pre_torch_impl(x, hc_fn): + x_flat = x.flatten(1).float() + rsqrt = torch.rsqrt( + x_flat.square().mean(-1, keepdim=True) + self.rms_norm_eps + ) + mixes = (F.linear(x_flat, hc_fn) * rsqrt).unsqueeze(1) + return x_flat, mixes + + shape, dtype = x.size(), x.dtype + + if x.shape[0] == 0: + y = torch.empty((0, shape[-1]), dtype=dtype, device=x.device) + post = torch.empty((0, self.hc_mult), dtype=dtype, device=x.device) + comb = torch.empty( + (0, self.hc_mult, self.hc_mult), dtype=dtype, device=x.device + ) + return y, post, comb + + if envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.get(): + from sglang.srt.layers.mhc import mhc_pre + + post, comb, y = mhc_pre( + residual=x, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=self.rms_norm_eps, + hc_pre_eps=self.hc_eps, + hc_sinkhorn_eps=self.hc_eps, + hc_post_mult_value=2.0, + sinkhorn_repeat=self.hc_sinkhorn_iters, + ) + return y, post.squeeze(-1), comb + + if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get(): + import deep_gemm + + x_flat = x.flatten(1).bfloat16() + + m, k = x_flat.shape + mix_hc = hc_fn.size(0) + d_out = torch.empty((m, mix_hc), dtype=torch.float, device=x.device) + s_out = torch.empty((m,), dtype=torch.float, device=x.device) + deep_gemm.tf32_hc_prenorm_gemm( + x_flat, hc_fn.float().contiguous(), d_out, s_out, num_splits=None + ) + rsqrt = torch.rsqrt(s_out / k + self.rms_norm_eps) + mixes = (d_out * rsqrt.unsqueeze(1)).unsqueeze(1) + else: + x_flat, mixes = hc_pre_torch_impl(x, hc_fn) + + from sglang.srt.layers.mhc import hc_split_sinkhorn + + pre, post, comb = hc_split_sinkhorn( + mixes, + hc_scale, + hc_base, + self.hc_mult, + self.hc_sinkhorn_iters, + self.hc_eps, + ) + y = (pre.squeeze(1).unsqueeze(-1) * x_flat.view(shape)).sum(dim=1) + return y.to(dtype), post.squeeze(1), comb.squeeze(1) + + def hc_post( + self, + x: torch.Tensor, + residual: torch.Tensor, + post: torch.Tensor, + comb: torch.Tensor, + ): + + if x.shape[0] == 0: + return torch.empty( + (0, self.hc_mult, x.shape[-1]), dtype=x.dtype, device=x.device + ) + + if envs.SGLANG_OPT_USE_TILELANG_MHC_POST.get(): + from sglang.srt.layers.mhc import mhc_post + + return mhc_post(x, residual, post, comb) + + assert residual.shape == (x.shape[0], self.hc_mult, x.shape[-1]) + assert post.shape == (x.shape[0], self.hc_mult) + assert comb.shape == (x.shape[0], self.hc_mult, self.hc_mult) + + @compile_in_capture_mode + def hc_post_torch_impl(x, residual, post, comb): + return ( + post.unsqueeze(-1) * x.unsqueeze(1) + + (comb.unsqueeze(-1) * residual.unsqueeze(2)).sum(dim=1) + ).type_as(x) + + return hc_post_torch_impl(x, residual, post, comb) + + def forward( + self, + positions: torch.tensor, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + forward_batch: ForwardBatch, + input_ids_global: torch.Tensor, + ) -> torch.Tensor: + residual = hidden_states + hidden_states, post, comb = self.hc_pre( + hidden_states, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + x=hidden_states, + positions=positions, + forward_batch=forward_batch, + ) + + hidden_states = self.hc_post(hidden_states, residual, post, comb) + residual = hidden_states + hidden_states, post, comb = self.hc_pre( + hidden_states, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + _use_cp = self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch) + _use_tp_moe_gather = ( + not _use_cp + and get_attention_dp_size() > 1 + and get_moe_a2a_backend().is_none() + ) + _use_tp_attn_a2a_scatter = ( + not _use_cp + and envs.SGLANG_DSV4_FIX_TP_ATTN_A2A_SCATTER.get() + and get_attention_tp_size() > 1 + and not get_moe_a2a_backend().is_none() + ) + if _use_cp: + assert get_moe_a2a_backend().is_deepep(), ( + "CP requires DeepEP (moe_a2a_backend == deepep). " + "Only DeepEP is tested with CP's per-rank token split." + ) + cp_rank = get_attention_cp_rank() + cp_size = get_attention_cp_size() + input_ids = input_ids[cp_rank::cp_size].contiguous() + input_ids_global = input_ids + elif _use_tp_moe_gather: + hidden_states, local_hidden_states = get_global_dp_buffer(), hidden_states + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + _a2a_scatter_chunks: Optional[List[torch.Tensor]] = None + if _use_tp_attn_a2a_scatter: + s, r = get_attention_tp_size(), get_attention_tp_rank() + _a2a_scatter_chunks = list(hidden_states.tensor_split(s)) + hidden_states = _a2a_scatter_chunks[r].contiguous() + input_ids = input_ids.tensor_split(s)[r].contiguous() + input_ids_global = input_ids_global.tensor_split(s)[r].contiguous() + hidden_states = self.mlp( + hidden_states, + forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + if _use_tp_moe_gather: + hidden_states, global_hidden_states = get_local_dp_buffer(), hidden_states + dp_scatter(hidden_states, global_hidden_states, forward_batch) + if _use_tp_attn_a2a_scatter: + assert _a2a_scatter_chunks is not None + gathered = [torch.empty_like(t) for t in _a2a_scatter_chunks] + attn_tp_all_gather(gathered, hidden_states.contiguous()) + hidden_states = torch.cat(gathered) + + hidden_states = self.hc_post(hidden_states, residual, post, comb) + + return hidden_states + + +class DeepseekV4Model(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: DeepSeekV4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.pp_group = get_pp_group() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + ) + self.rms_norm_eps = config.rms_norm_eps + self.alt_streams = ( + [torch.cuda.Stream() for _ in range(5)] if (_is_cuda or _is_hip) else None + ) + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: DeepseekV4DecoderLayer( + config=config, + layer_id=idx, + quant_config=quant_config, + prefix=prefix, + alt_streams=self.alt_streams, + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix=add_prefix("layers", prefix), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gemm_output_zero_allocator_size = 0 + self.hc_eps = config.hc_eps + self.hc_mult = hc_mult = config.hc_mult + self.norm_eps = config.rms_norm_eps + hc_dim = hc_mult * config.hidden_size + self.hc_head_fn = nn.Parameter( + torch.empty(hc_mult, hc_dim, dtype=torch.float32) + ) + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) + + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_size = get_attention_cp_size() + + def hc_head( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + + if get_attention_dp_size() > 1 and get_moe_a2a_backend().is_none(): + input_ids_global = torch.empty( + (_DpGatheredBufferWrapper._global_dp_buffer_len, 1), + dtype=input_ids.dtype, + device=input_ids.device, + ) + dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch) + input_ids_global = input_ids_global.squeeze(-1) + else: + input_ids_global = input_ids + + if nsa_use_prefill_cp(forward_batch): + hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + positions = cp_split_and_rebuild_position(forward_batch, positions) + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + + if nsa_use_prefill_cp(forward_batch): + hidden_states = cp_all_gather_rerange_output( + hidden_states, + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + + pre_hc_head = hidden_states.flatten(1) + + hidden_states = self.hc_head( + hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base + ) + hidden_states = self.norm(hidden_states) + + return hidden_states, pre_hc_head + + +class DeepseekV4ForCausalLM(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + self.determine_num_fused_shared_experts() + self.model = DeepseekV4Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.pp_group = get_pp_group() + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = False + get_attn_tp_context().init_context(config.q_lora_rank, is_nsa=True) + + self._routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: layer.mlp.get_moe_weights() + for layer_id, layer in enumerate(self.model.layers) + if isinstance(layer.mlp, deepseek_v2.DeepseekV2MoE) + } + ) + + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_rank = get_attention_cp_rank() + self.cp_size = get_attention_cp_size() + + @property + def routed_experts_weights_of_layer(self): + return self._routed_experts_weights_of_layer.value + + def determine_num_fused_shared_experts(self): + self.num_fused_shared_experts = 0 + if get_global_server_args().disable_shared_experts_fusion: + return + + get_global_server_args().disable_shared_experts_fusion = True + log_info_on_rank0( + logger, + "DeepSeek V4 requires different clamping for shared and routed experts. " + "Shared experts fusion optimization is disabled.", + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + if self.nsa_enable_prefill_cp: + if can_nsa_cp_split(len(input_ids), self.cp_size, True, forward_batch): + forward_batch.attn_cp_metadata = prepare_context_parallel_metadata( + len(input_ids), + self.cp_rank, + self.cp_size, + forward_batch.seq_lens_cpu.tolist(), + ) + if is_nsa_prefill_cp_round_robin_split(): + metadata = forward_batch.attn_backend.forward_metadata + core_meta = metadata.core_attn_metadata + core_meta.apply_cp_reindex() + core_meta.init_flashmla_related() + if metadata.indexer_metadata is not None: + metadata.indexer_metadata = ( + forward_batch.attn_backend.init_forward_metadata_indexer( + core_meta + ) + ) + + with get_attn_tp_context().maybe_input_scattered(forward_batch): + hidden_states = self.model.forward( + input_ids, positions, forward_batch, input_embeds + ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + hidden_states, pre_hc_head = hidden_states + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + hidden_states_before_norm=pre_hc_head, + ) + + def _setup_fp8_wo_a_scales(self, is_nextn: bool) -> None: + from deep_gemm import transform_sf_into_required_layout + + layers = self.model.layers + for layer in layers: + attn = layer.self_attn + G = attn.n_local_groups + R = attn.o_lora_rank + D = attn.wo_a.weight.shape[1] + + raw_scale = attn.wo_a.weight_scale_inv.data.view(G, R // 128, D // 128) + attn.wo_a.weight_scale_inv.data = transform_sf_into_required_layout( + raw_scale, + mn=R, + k=D, + recipe=(1, 128, 128), + num_groups=G, + is_sfa=False, + ) + + def post_load_weights(self, is_nextn=False, weight_names=None): + if _FP8_WO_A_GEMM: + self._setup_fp8_wo_a_scales(is_nextn) + + if is_nextn: + return + for layer in self.model.layers: + self_attn = layer.self_attn + if self_attn.compress_ratio != 0 and not self_attn.compressor.ape_converted: + self_attn.compressor.apply_ape_hotfix() + if ( + self_attn.compress_ratio == 4 + and not self_attn.indexer.compressor.ape_converted + ): + self_attn.indexer.compressor.apply_ape_hotfix() + + @staticmethod + def remap_weight_name_to_dpsk_hf_format( + name: str, is_nextn: bool = False, num_hidden_layers: Optional[int] = None + ) -> str: + if name == "embed.weight": + return "model.embed_tokens.weight" + if name == "head.weight": + return "lm_head.weight" + if name == "norm.weight": + return "model.norm.weight" + if name.startswith("hc_head_"): + return "model." + name + + if is_nextn and name.startswith("mtp."): + parts = name.split(".", 2) + if len(parts) >= 3: + rest = parts[2] + nextn_spec_prefixes = [ + "e_proj", + "h_proj", + "emb", + "enorm", + "hnorm", + "norm", + "head", + "hc_head", + ] + is_nextn_spec = any(rest.startswith(p) for p in nextn_spec_prefixes) + if is_nextn_spec: + if rest.startswith("emb.tok_emb"): + rest = rest.replace("emb.tok_emb", "embed_tokens") + elif rest == "norm.weight": + rest = "shared_head.norm.weight" + elif rest.startswith("head."): + rest = "shared_head.head.weight" + elif rest == "e_proj.scale": + rest = "e_proj.weight_scale_inv" + elif rest == "h_proj.scale": + rest = "h_proj.weight_scale_inv" + name = f"model.layers.{num_hidden_layers}." + rest + + if name.startswith("layers."): + name = "model." + name + name = name.replace(".attn.", ".self_attn.") + name = name.replace(".ffn.", ".mlp.") + name = name.replace(".attn_norm.", ".input_layernorm.") + name = name.replace(".ffn_norm.", ".post_attention_layernorm.") + + if "self_attn" in name: + name = name.replace(".scale", ".weight_scale_inv") + + name = name.replace(".gate.tid2eid", ".topk.tid2eid") + name = name.replace(".gate.bias", ".gate.e_score_correction_bias") + name = name.replace(".w1.", ".gate_proj.") + name = name.replace(".w2.", ".down_proj.") + name = name.replace(".w3.", ".up_proj.") + if "mlp" in name: + name = name.replace(".scale", ".weight_scale_inv") + + return name + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + nextn_layer_id = ( + 0 + if self.config.num_hidden_layers == 1 + else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") + + if not envs.SGLANG_OPT_FP8_WO_A_GEMM.get(): + weights = list(weights) + exists_wo_a_scale = any(n.endswith(".wo_a.scale") for n, t in weights) + if exists_wo_a_scale: + logger.info("Execute dequant fp8 wo_a") + weights = _dequant_fp8_wo_a(weights) + else: + logger.info("Skip dequant fp8 wo_a") + + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + + if self.quant_config and self.quant_config.get_name() == "w4afp8": + expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts + ) + + cache_compressor_weight = {} + COMPRESSOR_PART = ".compressor.w" + + fuse_wqa_wkv = envs.SGLANG_OPT_FUSE_WQA_WKV.get() + cache_wqkv_a_weight: dict[str, dict[str, torch.Tensor]] = {} + + def auto_weight_loader(module): + return getattr(module, "weight_loader", default_weight_loader) + + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names_out_of_layer = [ + "shared_head.norm", + "shared_head.head", + "embed_tokens", + ".e_proj", + "h_proj", + "enorm", + "hnorm", + "hc_head_base", + "hc_head_fn", + "hc_head_scale", + ] + + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + weight_names = [] + for name, loaded_weight in weights: + try: + use_async_loading = should_async_load(loaded_weight) + + name = self.remap_weight_name_to_dpsk_hf_format( + name, + is_nextn=is_nextn, + num_hidden_layers=self.config.num_hidden_layers, + ) + + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if ( + self.num_fused_shared_experts > 0 + and "mlp.shared_experts" in name + ): + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + + weight_names.append(name) + + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) + >= self.config.num_hidden_layers + ): + continue + + if name.startswith("mtp"): + continue + else: + if "shared_head.head" in name or "embed_tokens" in name: + continue + + if not name.startswith(nextn_layer_prefix): + continue + + in_decoder = True + for weight_name in nextn_spec_weight_names_out_of_layer: + if weight_name in name: + in_decoder = False + name = name.replace(nextn_layer_prefix, "model") + break + + if in_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict and name.startswith("mtp"): + break + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight, shard_id), + ) + loaded_params.add(name) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=( + param, + loaded_weight, + name, + ), + func_kwargs={ + "shard_id": shard_id, + "expert_id": expert_id, + }, + ) + loaded_params.add(name) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if ( + ".embed_tokens." in name + and not self.pp_group.is_first_rank + ): + continue + if ".norm." in name and not self.pp_group.is_last_rank: + continue + elif COMPRESSOR_PART in name: + is_kv = name.endswith(".wkv.weight") + is_wgate = name.endswith(".wgate.weight") + assert is_kv != is_wgate + key = name.rsplit(".", 2)[0] + assert key.endswith(".compressor") + if key not in cache_compressor_weight: + cache_compressor_weight[key] = ( + is_kv, + loaded_weight, + ) + else: + assert key in cache_compressor_weight + cached_is_kv, cached_weight = ( + cache_compressor_weight[key] + ) + assert cached_is_kv != is_kv + kv = loaded_weight if is_kv else cached_weight + wgate = loaded_weight if is_wgate else cached_weight + fused_weight = torch.cat([kv, wgate], dim=0) + param_name = key + ".wkv_gate.weight" + param = params_dict[param_name] + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, fused_weight), + ) + loaded_params.add(param_name) + cache_compressor_weight.pop(key) + elif fuse_wqa_wkv and ( + name.endswith(".wq_a.weight") + or name.endswith(".wq_a.weight_scale_inv") + or name.endswith(".wkv.weight") + or name.endswith(".wkv.weight_scale_inv") + ): + is_q = ".wq_a." in name + param_name = name.replace( + ".wq_a." if is_q else ".wkv.", ".wqkv_a." + ) + bucket = cache_wqkv_a_weight.setdefault(param_name, {}) + shard_key = "q" if is_q else "kv" + assert ( + shard_key not in bucket + ), f"duplicate shard {shard_key} for {param_name}" + bucket[shard_key] = loaded_weight + if len(bucket) == 2: + fused_weight = torch.cat( + [bucket["q"], bucket["kv"]], dim=0 + ) + param = params_dict[param_name] + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, fused_weight), + ) + loaded_params.add(param_name) + cache_wqkv_a_weight.pop(param_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace( + f"{scale[0]}_proj", "attn_mqa" + ) + break + if name not in params_dict: + if not name.startswith("mtp"): + logger.warning( + f"{name} not found in params_dict." + ) + continue + param = params_dict[name] + + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight), + ) + loaded_params.add(name) + except Exception as e: + e.add_note(f"{name=} {loaded_weight.shape=}") + raise + + for future in concurrent.futures.as_completed(futures): + future.result() + + assert len(cache_compressor_weight) == 0 + assert len(cache_wqkv_a_weight) == 0, cache_wqkv_a_weight.keys() + unloaded_params = params_dict.keys() - loaded_params + + skipped_checking_patterns = ["attn_mqa.k_scale", "attn_mqa.v_scale"] + if is_nextn: + skipped_checking_patterns.extend(["lm_head", "embed_tokens"]) + unloaded_params = { + p + for p in unloaded_params + if all( + skipped_checking_pattern not in p + for skipped_checking_pattern in skipped_checking_patterns + ) + } + if unloaded_params: + logger.warning( + f"Some weights are not initialized from checkpoints: {unloaded_params}" + ) + + self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.n_routed_experts, + num_groups=None, + ) + + +EntryClass = [DeepseekV4ForCausalLM] + + +def _dequant_fp8(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + from einops import rearrange + + assert ( + weight.dtype == torch.float8_e4m3fn + ), f"expected fp8_e4m3fn, got {weight.dtype}" + assert scale.dtype in ( + torch.float8_e8m0fnu, + torch.float32, + ), f"expected fp8_e8m0fnu or float32, got {scale.dtype}" + + weight_f32 = rearrange( + weight.float(), "(sn bn) (sk bk) -> sn bn sk bk", bn=128, bk=128 + ) + result = rearrange( + weight_f32 * scale.float()[:, None, :, None], "sn bn sk bk -> (sn bn) (sk bk)" + ) + + return result.to(torch.bfloat16) + + +def _dequant_fp8_wo_a( + weights: Iterable[Tuple[str, torch.Tensor]], +) -> Iterable[Tuple[str, torch.Tensor]]: + weights_dict = dict(weights) + + for name in list(weights_dict.keys()): + if name not in weights_dict: + continue + if not name.endswith(".wo_a.weight"): + continue + scale_name = name.replace(".wo_a.weight", ".wo_a.scale") + assert scale_name in weights_dict + weight = weights_dict.pop(name) + scale = weights_dict.pop(scale_name) + yield name, _dequant_fp8(weight, scale) + + yield from weights_dict.items() diff --git a/python/sglang/srt/models/deepseek_v4_nextn.py b/python/sglang/srt/models/deepseek_v4_nextn.py new file mode 100644 index 000000000000..9b220b184a21 --- /dev/null +++ b/python/sglang/srt/models/deepseek_v4_nextn.py @@ -0,0 +1,216 @@ +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + dp_gather_partial, + get_attention_dp_size, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.utils import get_moe_a2a_backend +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.deepseek_v4 import DeepseekV4DecoderLayer, DeepseekV4ForCausalLM +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + +COMPRESS_RATIO_NEXTN_LAYER = 0 + + +class DeepseekV4ModelNextN(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rms_norm_eps = config.rms_norm_eps + + self.hc_eps = config.hc_eps + self.hc_mult = hc_mult = config.hc_mult + hc_dim = hc_mult * config.hidden_size + self.hc_head_fn = nn.Parameter( + torch.empty(hc_mult, hc_dim, dtype=torch.float32) + ) + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) + + self.e_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("e_proj", prefix), + ) + self.h_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("h_proj", prefix), + ) + + layer_name = "decoder" + + self.decoder = DeepseekV4DecoderLayer( + config, + layer_id=0, + quant_config=quant_config, + is_nextn=True, + prefix=add_prefix(layer_name, prefix), + alt_streams=None, + compress_ratio_override=COMPRESS_RATIO_NEXTN_LAYER, + ) + + self.shared_head = nn.Module() + self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def hc_head( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.rms_norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if hidden_states.shape[0] > 0: + n_tokens = hidden_states.shape[0] + d = self.config.hidden_size + hc_flat = forward_batch.spec_info.hidden_states.view( + n_tokens * self.hc_mult, d + ) + h_proj_out, _ = self.h_proj(self.hnorm(hc_flat)) + h_proj_hidden_states = h_proj_out.view(n_tokens, self.hc_mult, d) + + e_proj_hidden_states, _ = self.e_proj(self.enorm(hidden_states)) + hidden_states = e_proj_hidden_states[:, None, :] + h_proj_hidden_states + else: + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + + if get_attention_dp_size() > 1 and get_moe_a2a_backend().is_none(): + input_ids_global = torch.empty( + (_DpGatheredBufferWrapper._global_dp_buffer_len, 1), + dtype=input_ids.dtype, + device=input_ids.device, + ) + dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch) + input_ids_global = input_ids_global.squeeze(-1) + else: + input_ids_global = input_ids + + hidden_states = self.decoder( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + + pre_hc_head = hidden_states.flatten(1) + + hidden_states = self.hc_head( + hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base + ) + hidden_states = self.shared_head.norm(hidden_states) + + return hidden_states, pre_hc_head + + +class DeepseekV4ForCausalLMNextN(DeepseekV4ForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.pp_group = get_pp_group() + self.quant_config = quant_config + self.determine_num_fused_shared_experts() + + self.model = DeepseekV4ModelNextN( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states, pre_hc_head = self.model(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + hidden_states_before_norm=pre_hc_head, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights, is_nextn=True) + + def post_load_weights(self, is_nextn=False, weight_names=None): + super().post_load_weights(is_nextn=True, weight_names=weight_names) + + +EntryClass = [DeepseekV4ForCausalLMNextN] diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 9ad15c8d187a..b91427eeff39 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -601,6 +601,7 @@ class ReasoningParser: DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = { "deepseek-r1": DeepSeekR1Detector, "deepseek-v3": _DeepSeekV3Detector, + "deepseek-v4": _DeepSeekV3Detector, "glm45": Glm45Detector, "hunyuan": HunyuanDetector, "gpt-oss": GptOssDetector, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index de7f06468d9c..a99c4d07535f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -144,6 +144,8 @@ "torch_native", "flex_attention", "nsa", + "dsv4", + "compressed", # Deprecated alias for "dsv4" # NVIDIA specific "cutlass_mla", "fa3", @@ -1076,6 +1078,20 @@ def _handle_deprecated_args(self): envs.SGLANG_SPEC_NAN_DETECTION.set(True) envs.SGLANG_SPEC_OOB_DETECTION.set(True) + # Deprecated attention-backend alias: "compressed" -> "dsv4". + for attr in ( + "attention_backend", + "decode_attention_backend", + "prefill_attention_backend", + "speculative_draft_attention_backend", + ): + if getattr(self, attr, None) == "compressed": + logger.warning( + "--%s=compressed is deprecated; use 'dsv4' instead.", + attr.replace("_", "-"), + ) + setattr(self, attr, "dsv4") + # Native gRPC flags — env-only for now, not exposed as CLI args. # Set as instance attributes (not dataclass fields) to avoid # argparse namespace lookup in from_cli_args. @@ -1638,25 +1654,16 @@ def _set_default_nsa_kv_cache_dtype(self, major: int, quantization: str) -> str: ], "DeepSeek DSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype" def _set_default_nsa_backends(self, kv_cache_dtype: str, major: int) -> str: + from sglang.srt.arg_groups.hisparse_hook import ( + apply_hisparse_nsa_backend_defaults, + ) + user_set_prefill = self.nsa_prefill_backend is not None user_set_decode = self.nsa_decode_backend is not None - # HiSparse: BF16 KV -> flashmla_sparse (native BF16 sparse). - # FP8 KV -> flashmla_kv (native FP8 + sparse via is_fp8_kvcache=True + indices=...). - # flashmla_sparse does not accept FP8, and flashmla_kv does not accept BF16 sparse, - # so the KV dtype determines the backend when the user does not override. - if self.enable_hisparse: - hisparse_default_backend = ( - "flashmla_kv" if kv_cache_dtype == "fp8_e4m3" else "flashmla_sparse" - ) - if not user_set_prefill: - self.nsa_prefill_backend = hisparse_default_backend - if not user_set_decode: - self.nsa_decode_backend = hisparse_default_backend - logger.warning( - f"HiSparse enabled ({kv_cache_dtype}): using NSA backends " - f"prefill={self.nsa_prefill_backend}, decode={self.nsa_decode_backend}." - ) + if apply_hisparse_nsa_backend_defaults( + self, user_set_prefill, user_set_decode, kv_cache_dtype + ): return if not user_set_prefill and not user_set_decode and is_hip(): @@ -1718,6 +1725,15 @@ def _handle_model_specific_adjustments(self): ]: self.dtype = "bfloat16" + if model_arch in [ + "DeepseekV4ForCausalLM", + ]: + from sglang.srt.arg_groups.deepseek_v4_hook import ( + apply_deepseek_v4_defaults, + ) + + apply_deepseek_v4_defaults(self, model_arch) + if model_arch in [ "DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM", @@ -1773,8 +1789,8 @@ def _handle_model_specific_adjustments(self): self.dp_size == 1 ), "For round-robin split mode, dp attention is not supported." assert ( - self.tp_size == 8 - ), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)" + self.tp_size <= 8 + ), "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." self.attn_cp_size = self.tp_size // self.dp_size logger.warning( @@ -1921,6 +1937,13 @@ def _handle_model_specific_adjustments(self): "Use triton fused moe by default for bf16 nextn layer in deepseek fp4 checkpoint." ) + elif model_arch in [ + "DeepseekV4ForCausalLM", + ]: + from sglang.srt.arg_groups.deepseek_v4_hook import validate_deepseek_v4_cp + + validate_deepseek_v4_cp(self) + elif model_arch in ["GptOssForCausalLM"]: # Set attention backend for GPT-OSS if self.is_attention_backend_not_set(): @@ -3529,6 +3552,7 @@ def _handle_speculative_decoding(self): if model_arch in [ "DeepseekV32ForCausalLM", "DeepseekV3ForCausalLM", + "DeepseekV4ForCausalLM", "Glm4MoeForCausalLM", "Glm4MoeLiteForCausalLM", "GlmMoeDsaForCausalLM", @@ -5926,13 +5950,14 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable hierarchical sparse attention", ) - parser.add_argument( "--hisparse-config", + "--hierarchical-sparse-attention-extra-config", + dest="hisparse_config", type=str, default=ServerArgs.hisparse_config, help="A dictionary in JSON string format for hierarchical sparse attention configuration. " - 'Example: \'{"top_k": 2048, "device_buffer_size": 4096}\'', + 'Example: \'{"top_k": 2048, "device_buffer_size": 4096, "host_to_device_ratio": 2}\'', ) # LMCache @@ -6943,43 +6968,9 @@ def check_server_args(self): ) # Check hisparse - if self.enable_hisparse: - from sglang.srt.configs.model_config import is_deepseek_nsa - - hf_config = self.get_model_config().hf_config - assert is_deepseek_nsa(hf_config), ( - "--enable-hisparse is only supported for DSA (DeepSeek Sparse Attention) models now" - "(e.g., DeepSeek V3.2, GLM-5). " - ) + from sglang.srt.arg_groups.hisparse_hook import validate_hisparse - assert ( - self.disable_radix_cache - ), "Hierarchical sparse attention currently requires --disable-radix-cache." - if self.kv_cache_dtype not in ("bfloat16", "auto", "fp8_e4m3"): - raise ValueError( - f"HiSparse requires bfloat16 or fp8_e4m3 KV cache, " - f"but got --kv-cache-dtype={self.kv_cache_dtype}. " - f"Please use --kv-cache-dtype=bfloat16 or fp8_e4m3." - ) - - # Backend/dtype pairing: flashmla_sparse only takes BF16 KV; - # flashmla_kv only supports FP8 (it always reads KV as FP8 via - # is_fp8_kvcache=True, inline-quantizing BF16 would defeat HiSparse). - allowed_backends_for_dtype = { - "bfloat16": {"flashmla_sparse"}, - "fp8_e4m3": {"flashmla_kv"}, - }.get(self.kv_cache_dtype, {"flashmla_sparse", "flashmla_kv"}) - for attr, label in [ - ("nsa_prefill_backend", "prefill"), - ("nsa_decode_backend", "decode"), - ]: - backend = getattr(self, attr) - if backend is not None and backend not in allowed_backends_for_dtype: - raise ValueError( - f"HiSparse with --kv-cache-dtype={self.kv_cache_dtype} requires " - f"--nsa-{label}-backend in {sorted(allowed_backends_for_dtype)}, " - f"but got {backend}." - ) + validate_hisparse(self) assert ( self.schedule_conservativeness >= 0 diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py index 94c85ac903fd..4da59b72a933 100644 --- a/python/sglang/srt/speculative/draft_utils.py +++ b/python/sglang/srt/speculative/draft_utils.py @@ -56,6 +56,7 @@ def create_decode_backend(self): "nsa": self._create_nsa_decode_backend, "ascend": self._create_ascend_decode_backend, "fa4": self._create_fa4_decode_backend, + "dsv4": self._create_dsv4_decode_backend, } return self._create_backend( @@ -81,6 +82,7 @@ def create_draft_extend_backend(self): "nsa": self._create_nsa_prefill_backend, "ascend": self._create_ascend_prefill_backend, "fa4": self._create_fa4_prefill_backend, + "dsv4": self._create_dsv4_prefill_backend, } backend_name = ( "decode_attention_backend" @@ -205,6 +207,15 @@ def _create_ascend_decode_backend(self): self.draft_model_runner, self.topk, self.speculative_num_steps ) + def _create_dsv4_decode_backend(self): + from sglang.srt.layers.attention.deepseek_v4_backend import ( + DeepseekV4MultiStepBackend, + ) + + return DeepseekV4MultiStepBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + def _create_flashinfer_prefill_backend(self): if not get_global_server_args().use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( @@ -275,3 +286,10 @@ def _create_flashmla_prefill_backend(self): "flashmla prefill backend is not yet supported for draft extend." ) return None + + def _create_dsv4_prefill_backend(self): + from sglang.srt.layers.attention.deepseek_v4_backend import ( + DeepseekV4AttnBackend, + ) + + return DeepseekV4AttnBackend(self.draft_model_runner, skip_prefill=False) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index a5fc8523b7ff..804d421b1db2 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -211,6 +211,13 @@ def _capture_init(self, run_once_fn): torch.cuda.synchronize() self.model_runner.tp_group.barrier() run_once_fn() + hook = getattr( + self.model_runner.draft_attn_backend, + "on_after_cuda_graph_warmup", + None, + ) + if hook is not None: + hook() def _capture_graph(self, graph, pool, stream, run_once_fn): with torch.cuda.graph(graph, pool=pool, stream=stream): diff --git a/python/sglang/srt/utils/hf_transformers/common.py b/python/sglang/srt/utils/hf_transformers/common.py index ecd024f4c2b3..9067e6fa6746 100644 --- a/python/sglang/srt/utils/hf_transformers/common.py +++ b/python/sglang/srt/utils/hf_transformers/common.py @@ -101,16 +101,22 @@ ] } -# DeepSeek V3.2 reuses the V3 config schema. Subclass the upstream transformers -# class with the V3.2 model_type so AutoConfig.register passes its consistency -# check (which requires class.model_type == registered key). +# DeepSeek V3.2 / V4 reuse the V3 config schema. Subclass the upstream +# transformers class with each model_type so AutoConfig.register passes its +# consistency check (which requires class.model_type == registered key). +# Default-value divergences (e.g. V4's topk_group) are handled in +# model_config.py post-load. try: from transformers import DeepseekV3Config as _HFDeepseekV3Config class _DeepseekV32ConfigAlias(_HFDeepseekV3Config): model_type = "deepseek_v32" + class _DeepseekV4ConfigAlias(_HFDeepseekV3Config): + model_type = "deepseek_v4" + _CONFIG_REGISTRY["deepseek_v32"] = _DeepseekV32ConfigAlias + _CONFIG_REGISTRY["deepseek_v4"] = _DeepseekV4ConfigAlias except ImportError: pass diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index 6fd74a4e05f2..55566973af2c 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -1,6 +1,6 @@ import logging import time -from typing import Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Optional, Set, Tuple import torch import torch.distributed as dist @@ -83,22 +83,37 @@ def _reset_tensors(self): def _compare(self): assert self._snapshot_tensors is not None + skip_compare_names = { + name + for name, param in self._model_state() + if getattr(param, "_skip_weight_check", False) + } _check_tensors( - expect_tensors=_postprocess_tensors(self._snapshot_tensors), - actual_tensors=_postprocess_tensors(dict(self._model_state())), + expect_tensors=_postprocess_tensors( + self._snapshot_tensors, skip_compare_names + ), + actual_tensors=_postprocess_tensors( + dict(self._model_state()), skip_compare_names + ), ) def _compute_checksum(self) -> Dict: torch.cuda.synchronize() start = time.perf_counter() + skip_compare_names = { + name + for name, param in self._model_state() + if getattr(param, "_skip_weight_check", False) + } + # Reuse the snapshot/compare postprocess pipeline so fp8 weights are # dequantized to bf16 before hashing — two (qweight, scale) pairs that # produce the same bf16 must produce the same checksum. checksums = { name: _hash_tensor(tensor.data) for name, should_compare, tensor in _postprocess_tensors( - dict(self._model_state()) + dict(self._model_state()), skip_compare_names ) if should_compare } @@ -202,16 +217,17 @@ def _random_like(t: torch.Tensor): def _postprocess_tensors( raw: Dict[str, torch.Tensor], + skip_compare_names: Set[str], ) -> Iterable[Tuple[str, bool, torch.Tensor]]: from sglang.srt.debug_utils.dumper import get_tensor_info - skip_compare_names = [] + skip_compare_names = set(skip_compare_names) # Skip non-persistent buffers (registered with persistent=False; recomputed # after weight load and not part of the synced payload). for name in raw: if _is_non_persistent_buffer_name(name): - skip_compare_names.append(name) + skip_compare_names.add(name) logger.info(f"[check_tensors] Skipping non-persistent buffer: {name}") # dequant fp8 @@ -221,10 +237,11 @@ def _postprocess_tensors( # Match: `something.weight`, `something.experts.w2_weight` if name.endswith("weight") and name.replace("weight", "weight_scale_inv") in raw ] - skip_compare_names += quant_names - skip_compare_names += [ + quant_scale_names = [ name.replace("weight", "weight_scale_inv") for name in quant_names ] + skip_compare_names.update(quant_names) + skip_compare_names.update(quant_scale_names) for name in quant_names: w_q = raw[name] w_s = raw[name.replace("weight", "weight_scale_inv")] @@ -232,10 +249,13 @@ def _postprocess_tensors( try: if w_s.dtype == torch.int32: # UE8M0 packed format (Blackwell DeepGEMM) - w_s = inverse_transform_scale_ue8m0(w_s, mn=w_q.shape[-2]) + w_s_for_dequant = inverse_transform_scale_ue8m0(w_s, mn=w_q.shape[-2]) + else: + w_s_for_dequant = w_s + w_dequant = block_quant_dequant( w_q, - w_s, + w_s_for_dequant, # TODO do not hardcode block_size=[128, 128], dtype=torch.bfloat16, diff --git a/python/sglang/test/kits/server_sanity_kit.py b/python/sglang/test/kits/server_sanity_kit.py new file mode 100644 index 000000000000..af9cf0167dd3 --- /dev/null +++ b/python/sglang/test/kits/server_sanity_kit.py @@ -0,0 +1,228 @@ +"""Black-box server sanity prompts: cheap checks that catch silent +correctness regressions (gibberish / repetition collapse / encoding), +streaming/concurrent path bugs, and endpoint health. + +Mix into any ``CustomTestCase`` subclass that exposes ``self.base_url`` +and ``self.process``. Each test is independent and fast (≤ 5 s after +warmup); the whole kit completes in < 1 min.""" + +import json +import threading + +import requests + +_REQUEST_TIMEOUT = 120 + +# Shared prefix forces all concurrent requests through the same radix +# match path; per-request suffix branches the tail so the model still +# has to predict different tokens (otherwise outputs would be identical +# and we'd be testing 1 request 8 times instead of 8 independent reqs). +_CONCURRENT_PREFIX = "You are a helpful assistant. Answer with a single word.\n" +_CONCURRENT_QA = [ + ("Q: What is the capital of France?\nA:", "paris"), + ("Q: What is the capital of Germany?\nA:", "berlin"), + ("Q: What is the capital of Italy?\nA:", "rome"), + ("Q: What is the capital of Japan?\nA:", "tokyo"), + ("Q: What is the capital of Spain?\nA:", "madrid"), + ("Q: What is the capital of Egypt?\nA:", "cairo"), + ("Q: What is the capital of Russia?\nA:", "moscow"), + ("Q: What is the capital of Australia?\nA:", "canberra"), +] + + +class ServerSanityMixin: + """12 cheap black-box probes for silent-correctness / hang / endpoint + regressions.""" + + sanity_max_new_tokens_short: int = 64 + sanity_max_new_tokens_long: int = 128 + + def _sanity_generate(self, prompt: str, max_new_tokens: int, stop=None) -> str: + sampling_params = { + "temperature": 0.0, + "max_new_tokens": max_new_tokens, + } + if stop is not None: + sampling_params["stop"] = stop + resp = requests.post( + self.base_url + "/generate", + json={"text": prompt, "sampling_params": sampling_params}, + timeout=_REQUEST_TIMEOUT, + ) + self.assertEqual(resp.status_code, 200) + return resp.json()["text"] + + def test_health(self): + # Cheapest possible alive check; FastAPI route alone. + resp = requests.get(self.base_url + "/health", timeout=10) + self.assertEqual(resp.status_code, 200) + + def test_health_generate(self): + # sglang's built-in minimal-forward sanity. 200 only if the + # scheduler can complete one prefill+decode end to end. + resp = requests.get(self.base_url + "/health_generate", timeout=60) + self.assertEqual(resp.status_code, 200) + + def test_capital_france(self): + out = self._sanity_generate( + "Q: What is the capital of France?\nA:", + self.sanity_max_new_tokens_short, + ) + self.assertIn("paris", out.lower()) + + def test_basic_math(self): + out = self._sanity_generate( + "Q: What is 17 multiplied by 23? Reply with just the number.\nA:", + self.sanity_max_new_tokens_short, + ) + self.assertIn("391", out) + + def test_color_completion(self): + out = self._sanity_generate( + "Q: The three primary colors are red, blue, and ___. " + "Fill in the blank.\nA:", + self.sanity_max_new_tokens_short, + ) + self.assertIn("yellow", out.lower()) + + def test_ascii_ratio(self): + # Language-agnostic gibberish detector. Healthy English output is + # >90% printable ASCII; multilingual token salad / Unicode noise + # from broken weight load drops well below 50%. + out = self._sanity_generate( + "Write a single sentence about a sunny day in the park.", + self.sanity_max_new_tokens_long, + ) + printable = sum(1 for c in out if 32 <= ord(c) < 127 or c in "\n\t") + ratio = printable / max(len(out), 1) + self.assertGreater( + ratio, + 0.85, + f"output looks like gibberish (printable ASCII ratio={ratio:.2f}): {out!r}", + ) + + def test_no_repetition_blowup(self): + # KV-cache / attn corruption often manifests as the model getting + # stuck looping the same n-gram. + out = self._sanity_generate( + "Briefly explain what gravity is.", + self.sanity_max_new_tokens_long, + ) + if len(out) >= 50: + windows = [out[i : i + 5] for i in range(len(out) - 5)] + most_common_count = max((windows.count(w) for w in set(windows)), default=0) + ratio = most_common_count / len(windows) + self.assertLess( + ratio, + 0.25, + f"output appears to repeat heavily (top 5-gram ratio={ratio:.2f}): {out!r}", + ) + + def test_max_token_one(self): + # Degenerate spec step. cuda-graph capture path bugs that only + # fire on minimal-output requests. + out = self._sanity_generate( + "Q: What is the capital of France? Just one word.\nA:", + max_new_tokens=1, + ) + self.assertGreater(len(out), 0) + + def test_streaming_response(self): + # SSE streaming exercises a different return path than non-stream + # /generate. Catches token-by-token streaming corruption and SSE + # framing bugs without changing the model. + with requests.post( + self.base_url + "/generate", + json={ + "text": "Q: What is the capital of France?\nA:", + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": self.sanity_max_new_tokens_short, + }, + "stream": True, + }, + stream=True, + timeout=_REQUEST_TIMEOUT, + ) as resp: + self.assertEqual(resp.status_code, 200) + chunks_seen = 0 + last_text = "" + for raw in resp.iter_lines(decode_unicode=True): + if not raw or not raw.startswith("data:"): + continue + payload = raw[len("data:") :].strip() + if payload == "[DONE]": + break + obj = json.loads(payload) + last_text = obj.get("text", last_text) + chunks_seen += 1 + self.assertGreater(chunks_seen, 0) + self.assertIn("paris", last_text.lower()) + + def test_concurrent_requests(self): + # 8 parallel reqs share a system prefix but each has a distinct + # question suffix. Shared prefix exercises radix prefix caching + # across concurrent reqs; per-request suffix forces independent + # decode tails (different canonical answers). Catches concurrent + # scheduler hangs and prefix-cache cross-contamination. + results = [None] * len(_CONCURRENT_QA) + + def worker(idx, suffix, expected): + try: + out = self._sanity_generate( + _CONCURRENT_PREFIX + suffix, + self.sanity_max_new_tokens_short, + ) + results[idx] = expected in out.lower() + except Exception: + results[idx] = False + + threads = [ + threading.Thread(target=worker, args=(i, suffix, expected)) + for i, (suffix, expected) in enumerate(_CONCURRENT_QA) + ] + for t in threads: + t.start() + for t in threads: + t.join(timeout=_REQUEST_TIMEOUT) + + passed = sum(1 for r in results if r) + # Tolerate one stochastic miss; gibberish would fail all 8. + self.assertGreaterEqual( + passed, + len(_CONCURRENT_QA) - 1, + f"concurrent answers correct: {passed}/{len(_CONCURRENT_QA)}; results={results}", + ) + + def test_long_prompt(self): + # ~8k-token filler drives the chunked-prefill path through + # multiple chunks. Catches DeepEP / large-prompt kernel crashes + # that only fire on multi-chunk prefill. + filler = "the quick brown fox jumps over the lazy dog. " * 800 + out = self._sanity_generate( + f"Read the following text and then answer.\n{filler}\n\n" + "Q: What is the capital of France?\nA:", + self.sanity_max_new_tokens_short, + ) + # Long-prompt substring match is best-effort (model may get + # distracted); primary assertion is the 200 + non-empty inside + # _sanity_generate. + self.assertGreater(len(out), 0) + + def test_determinism_temp_zero(self): + # temp=0 must be byte-identical across runs. Stop on "\n" so we + # only compare the answer word; long continuations drift on + # near-tie tokens (EP MoE / EAGLE spec) and aren't the point. + prompt = "Q: What is the capital of France? Reply in one word.\nA:" + out1 = self._sanity_generate( + prompt, self.sanity_max_new_tokens_short, stop=["\n"] + ) + # Second call exercises cache-hit path. + out2 = self._sanity_generate( + prompt, self.sanity_max_new_tokens_short, stop=["\n"] + ) + self.assertEqual( + out1.strip(), + out2.strip(), + f"temp=0 outputs diverged:\n out1={out1!r}\n out2={out2!r}", + ) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 6abe16079846..6ef153bd81a0 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1016,6 +1016,12 @@ def popen_launch_pd_server( print(f"command={' '.join(command)}") + # Merge with os.environ so caller-supplied env adds to (not replaces) + # PATH / PYTHONPATH / HF_HOME / etc. When env is None, Popen inherits + # parent's environment automatically. + if env is not None: + env = {**os.environ, **env} + process = subprocess.Popen(command, stdout=None, stderr=None, env=env) return process diff --git a/scripts/ci/cuda/ci_install_flash_mla.sh b/scripts/ci/cuda/ci_install_flash_mla.sh new file mode 100755 index 000000000000..10bd61ab2d72 --- /dev/null +++ b/scripts/ci/cuda/ci_install_flash_mla.sh @@ -0,0 +1,35 @@ +#!/bin/bash +set -euxo pipefail + +source scripts/ci/cuda/ci_install_dependency.sh + +if [ -z "${PIP_CMD:-}" ]; then + echo "FATAL:PIP_CMD is unset after sourcing ci_install_dependency.sh" + exit 1 +fi + +export CUDA_HOME=/usr/local/cuda + +if [ "${FORCE_REBUILD_FLASH_MLA:-0}" = "1" ]; then + echo "FORCE_REBUILD_FLASH_MLA=1; uninstalling any cached flash_mla before rebuild." + ${PIP_UNINSTALL_CMD:-pip uninstall -y} flash_mla ${PIP_UNINSTALL_SUFFIX:-} || true +elif python3 -c "import flash_mla" >/dev/null 2>&1; then + echo "flash_mla is already installed or importable. Skipping installation." + exit 0 +fi + +# CUDA 13.0 puts CCCL headers under /usr/local/cuda/include/cccl/cuda but +# FlashMLA's build expects them at /usr/local/cuda/include/cuda. Symlink so +# the compiler finds them. Idempotent: skip if the link/dir already exists. +if [ ! -e /usr/local/cuda/include/cuda ] && [ -d /usr/local/cuda/include/cccl/cuda ]; then + ln -s /usr/local/cuda/include/cccl/cuda /usr/local/cuda/include/cuda +fi + +# Install FlashMLA +FLASH_MLA_DIR=/root/.cache/flash-mla +rm -rf ${FLASH_MLA_DIR} +git clone https://github.com/deepseek-ai/FlashMLA.git ${FLASH_MLA_DIR} +pushd ${FLASH_MLA_DIR} +git submodule update --init --recursive +${PIP_CMD:-pip} install --no-build-isolation -v . ${PIP_INSTALL_SUFFIX:-} +popd diff --git a/test/manual/dsv4/__init__.py b/test/manual/dsv4/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test/manual/dsv4/_common.py b/test/manual/dsv4/_common.py new file mode 100644 index 000000000000..f694f49de472 --- /dev/null +++ b/test/manual/dsv4/_common.py @@ -0,0 +1,317 @@ +"""Shared fixture for DeepSeek-V4 cookbook launch-command tests. + +Each sibling ``test__.py`` declares ONE +``hardware x model_size`` cell from the cookbook (e.g. B200 x Flash) +and contains one ``CustomTestCase`` subclass per recipe +(Low-Latency / Balanced / Max-Throughput / CP, where supported). + +Each subclass launches the server with the cookbook's exact flags and +runs two sgl-eval evaluations (https://github.com/sgl-project/sgl-eval): +- ``test_smoke_gsm8k`` — short, cheap GSM8K pass to verify the server + can produce coherent math answers at all (smoke gate). +- ``test_aime25`` — full AIME25 accuracy run (heavy; 16 repeats default). + +Cookbook reference: + https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4 + +These are MANUAL tests (not CI). ``sgl-eval`` must be on PATH. + +Per-variant defaults (set on the Flash/Pro intermediate base classes): + Flash recipes -> AIME25 score threshold 0.93 + Pro recipes -> AIME25 score threshold 0.95 +GSM8K smoke threshold (0.93) is shared across Flash and Pro. + +AIME25 knobs (env vars): + DSV4_AIME25_NUM_REPEATS (default 16 -> --n-repeats) + DSV4_AIME25_TEMPERATURE (default 1.0 -> --temperature) + DSV4_AIME25_TOP_P (default 1.0 -> --top-p) + DSV4_AIME25_MAX_TOKENS (default 65536 -> --max-tokens) + DSV4_AIME25_NUM_THREADS (default 512 -> --num-threads) + DSV4_AIME25_SCORE_METRIC (default "score"; sgl-eval JSON key under "aggregate") + DSV4_AIME25_SCORE_THRESHOLD (default 0; >0 overrides per-variant default) + +GSM8K smoke knobs (env vars): + DSV4_GSM8K_NUM_EXAMPLES (default 50 -> --num-examples) + DSV4_GSM8K_N_REPEATS (default 1 -> --n-repeats) + DSV4_GSM8K_TEMPERATURE (default 0.6 -> --temperature) + DSV4_GSM8K_TOP_P (default 0.95 -> --top-p) + DSV4_GSM8K_MAX_TOKENS (default 8192 -> --max-tokens) + DSV4_GSM8K_NUM_THREADS (default 64 -> --num-threads) + DSV4_GSM8K_SCORE_METRIC (default "score"; sgl-eval JSON key under "aggregate") + DSV4_GSM8K_SCORE_THRESHOLD (default 0.93; set to 0 to skip the assertion) + +Shared knobs: + DSV4_SGL_EVAL_OUT_DIR (default /tmp/sgl-eval-out -> --out-dir) + DSV4_SGL_EVAL_BIN (default "sgl-eval"; override path to the CLI) + DSV4_SERVER_LAUNCH_TIMEOUT (default 3600s; the sglang 600s default is + too short for DSV4 model load + DeepGEMM + warmup. 1800s is also tight for the heavier + recipes (DP-attn + DeepEP); 3600s is the + safe default. Bump again for first-run + model downloads if needed.) + +Multi-node knobs (only consumed by multi-node test classes; if either +is unset, those classes ``SkipTest``): + DSV4_NODE_RANK (per-node rank for --node-rank) + DSV4_DIST_INIT_ADDR (e.g. 10.0.0.1:20000 for --dist-init-addr) + +Always-on env (set by the base class for every recipe; per-recipe EXTRA_ENV +wins on key conflict): + SGLANG_JIT_DEEPGEMM_FAST_WARMUP=1 skip the slow DeepGEMM warmup grid +""" + +import json +import os +import shutil +import subprocess +import unittest +from pathlib import Path +from typing import ClassVar, Dict, List, Optional + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +SGL_EVAL_BIN = os.environ.get("DSV4_SGL_EVAL_BIN", "sgl-eval") +SGL_EVAL_OUT_DIR = os.environ.get("DSV4_SGL_EVAL_OUT_DIR", "/tmp/sgl-eval-out") + +# DSV4 server launch needs more than the 600s sglang default: model load alone +# can take 5+ min and DeepGEMM warmup another ~5 min. First-run model download +# adds ~10-30 min on top. 1800s covers steady-state; bump via env for downloads. +SERVER_LAUNCH_TIMEOUT = int(os.environ.get("DSV4_SERVER_LAUNCH_TIMEOUT", "3600")) + +# Defaults applied to every recipe's EXTRA_ENV. Per-recipe EXTRA_ENV wins on key +# conflict. +BASE_ENV: Dict[str, str] = { + # Skip the slow exhaustive DeepGEMM warmup grid; covers the shapes DSV4 + # actually hits and shaves several minutes off server startup. + "SGLANG_JIT_DEEPGEMM_FAST_WARMUP": "1", +} + +AIME25_NUM_REPEATS = int(os.environ.get("DSV4_AIME25_NUM_REPEATS", "16")) +AIME25_TEMPERATURE = float(os.environ.get("DSV4_AIME25_TEMPERATURE", "1.0")) +AIME25_TOP_P = float(os.environ.get("DSV4_AIME25_TOP_P", "1.0")) +AIME25_MAX_TOKENS = int(os.environ.get("DSV4_AIME25_MAX_TOKENS", "65536")) +AIME25_NUM_THREADS = int(os.environ.get("DSV4_AIME25_NUM_THREADS", "512")) +AIME25_SCORE_METRIC = os.environ.get("DSV4_AIME25_SCORE_METRIC", "score") +AIME25_SCORE_THRESHOLD = float(os.environ.get("DSV4_AIME25_SCORE_THRESHOLD", "0.0")) + +GSM8K_NUM_EXAMPLES = int(os.environ.get("DSV4_GSM8K_NUM_EXAMPLES", "50")) +GSM8K_N_REPEATS = int(os.environ.get("DSV4_GSM8K_N_REPEATS", "1")) +GSM8K_TEMPERATURE = float(os.environ.get("DSV4_GSM8K_TEMPERATURE", "0.6")) +GSM8K_TOP_P = float(os.environ.get("DSV4_GSM8K_TOP_P", "0.95")) +GSM8K_MAX_TOKENS = int(os.environ.get("DSV4_GSM8K_MAX_TOKENS", "8192")) +GSM8K_NUM_THREADS = int(os.environ.get("DSV4_GSM8K_NUM_THREADS", "64")) +GSM8K_SCORE_METRIC = os.environ.get("DSV4_GSM8K_SCORE_METRIC", "score") +GSM8K_SCORE_THRESHOLD = float(os.environ.get("DSV4_GSM8K_SCORE_THRESHOLD", "0.93")) + +# DeepEP "large SMS" config — appears as `--deepep-config '{...}'` in every +# DeepEP recipe except multi-node ones (where it is gated off in the JSX). +DEEPEP_LARGE_SMS_CONFIG = ( + '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}' +) + + +def multinode_args(nnodes: int) -> List[str]: + """Return CLI args for a multi-node launch, or skip the test. + + Reads DSV4_NODE_RANK and DSV4_DIST_INIT_ADDR from the env. Raises + ``unittest.SkipTest`` when either is missing — call from inside + ``setUpClass`` so the whole class skips cleanly. + """ + rank = os.environ.get("DSV4_NODE_RANK") + addr = os.environ.get("DSV4_DIST_INIT_ADDR") + if rank is None or addr is None: + raise unittest.SkipTest( + "multi-node test requires DSV4_NODE_RANK and DSV4_DIST_INIT_ADDR" + ) + return [ + "--nnodes", + str(nnodes), + "--node-rank", + rank, + "--dist-init-addr", + addr, + ] + + +class DSV4Aime25TestBase(CustomTestCase): + """Subclass via ``DSV4FlashAime25TestBase`` or ``DSV4ProAime25TestBase``, + not directly. Per-recipe subclasses set MODEL / OTHER_ARGS / EXTRA_ENV. + + SCORE_THRESHOLD is set by the Flash/Pro intermediate base classes: + Flash 0.93, Pro 0.95. + """ + + MODEL: ClassVar[str] = "" + OTHER_ARGS: ClassVar[List[str]] = [] + EXTRA_ENV: ClassVar[Dict[str, str]] = {} + + SCORE_THRESHOLD: ClassVar[float] = 0.0 + + _BASE_CLASSES: ClassVar[set] = set() + + @classmethod + def setUpClass(cls): + if cls in cls._BASE_CLASSES: + raise unittest.SkipTest("base class; subclass to run") + if not cls.MODEL or not cls.OTHER_ARGS: + raise unittest.SkipTest(f"{cls.__name__}: MODEL and OTHER_ARGS must be set") + cls.base_url = DEFAULT_URL_FOR_TEST + env: Optional[Dict[str, str]] = {**BASE_ENV, **(cls.EXTRA_ENV or {})} + cls.process = popen_launch_server( + cls.MODEL, + cls.base_url, + timeout=SERVER_LAUNCH_TIMEOUT, + other_args=list(cls.OTHER_ARGS), + env=env, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def test_smoke_gsm8k(self): + """Quick GSM8K pass to verify the server is producing math answers.""" + self._run_sgl_eval( + eval_name="gsm8k", + n_repeats=GSM8K_N_REPEATS, + temperature=GSM8K_TEMPERATURE, + top_p=GSM8K_TOP_P, + max_tokens=GSM8K_MAX_TOKENS, + num_threads=GSM8K_NUM_THREADS, + num_examples=GSM8K_NUM_EXAMPLES, + metric=GSM8K_SCORE_METRIC, + threshold=GSM8K_SCORE_THRESHOLD, + ) + + def test_aime25(self): + """Full AIME25 accuracy run; threshold gated by Flash vs Pro base.""" + threshold = ( + AIME25_SCORE_THRESHOLD + if AIME25_SCORE_THRESHOLD > 0 + else self.SCORE_THRESHOLD + ) + self._run_sgl_eval( + eval_name="aime25", + n_repeats=AIME25_NUM_REPEATS, + temperature=AIME25_TEMPERATURE, + top_p=AIME25_TOP_P, + max_tokens=AIME25_MAX_TOKENS, + num_threads=AIME25_NUM_THREADS, + num_examples=None, + metric=AIME25_SCORE_METRIC, + threshold=threshold, + ) + + def _run_sgl_eval( + self, + eval_name, + n_repeats, + temperature, + top_p, + max_tokens, + num_threads, + num_examples, + metric, + threshold, + ): + if shutil.which(SGL_EVAL_BIN) is None: + self.skipTest(f"{SGL_EVAL_BIN!r} not found on PATH") + + out_dir = Path(SGL_EVAL_OUT_DIR) + out_dir.mkdir(parents=True, exist_ok=True) + glob_pattern = f"sgl_eval_{eval_name}_*.json" + before = set(out_dir.glob(glob_pattern)) + + cmd = [ + SGL_EVAL_BIN, + "run", + eval_name, + "--base-url", + f"{self.base_url}/v1", + "--n-repeats", + str(n_repeats), + "--temperature", + str(temperature), + "--top-p", + str(top_p), + "--max-tokens", + str(max_tokens), + "--num-threads", + str(num_threads), + "--out-dir", + str(out_dir), + ] + if num_examples is not None: + cmd += ["--num-examples", str(num_examples)] + + print(f"[{type(self).__name__}] + {' '.join(cmd)}", flush=True) + subprocess.run(cmd, check=True) + + new = sorted(set(out_dir.glob(glob_pattern)) - before) + if not new: + self.fail(f"sgl-eval produced no new {eval_name} JSON in {out_dir}") + result_path = new[-1] + with open(result_path) as f: + result = json.load(f) + print( + f"[{type(self).__name__}] sgl-eval {eval_name} result " + f"({result_path.name}): {json.dumps(result, indent=2)}", + flush=True, + ) + + score = self._extract_score(result, metric) + if threshold > 0: + self.assertGreaterEqual( + score, + threshold, + f"{eval_name} {metric}={score} below threshold {threshold}", + ) + + @staticmethod + def _extract_score(result, metric): + """Find ``metric`` (e.g. "pass@1") anywhere in the sgl-eval JSON tree.""" + + def walk(o): + if isinstance(o, dict): + if metric in o and isinstance(o[metric], (int, float)): + return float(o[metric]) + for v in o.values(): + s = walk(v) + if s is not None: + return s + elif isinstance(o, list): + for v in o: + s = walk(v) + if s is not None: + return s + return None + + score = walk(result) + if score is None: + raise AssertionError(f"metric {metric!r} not found in sgl-eval result JSON") + return score + + +class DSV4FlashAime25TestBase(DSV4Aime25TestBase): + """Base for DeepSeek-V4-Flash recipes: AIME25 threshold 0.93.""" + + SCORE_THRESHOLD = 0.93 + + +class DSV4ProAime25TestBase(DSV4Aime25TestBase): + """Base for DeepSeek-V4-Pro recipes: AIME25 threshold 0.95.""" + + SCORE_THRESHOLD = 0.95 + + +DSV4Aime25TestBase._BASE_CLASSES = { + DSV4Aime25TestBase, + DSV4FlashAime25TestBase, + DSV4ProAime25TestBase, +} diff --git a/test/manual/dsv4/test_b200_flash.py b/test/manual/dsv4/test_b200_flash.py new file mode 100644 index 000000000000..3c828116190b --- /dev/null +++ b/test/manual/dsv4/test_b200_flash.py @@ -0,0 +1,109 @@ +"""B200 (FP4) x DeepSeek-V4-Flash. + +Covers the four cookbook recipes for this hardware x model_size cell: +Low-Latency, Balanced, Max-Throughput, Context-Parallel (CP). +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DEEPEP_LARGE_SMS_CONFIG, DSV4FlashAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Flash" + + +class TestB200FlashLowLatency(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + ] + EXTRA_ENV = {} + + +class TestB200FlashBalanced(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024"} + + +class TestB200FlashMaxThroughput(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024"} + + +class TestB200FlashCP(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-a2a-backend", + "deepep", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + "--chunked-prefill-size", + "16384", + "--mem-fraction-static", + "0.78", + "--max-running-requests", + "1024", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + "SGLANG_OPT_USE_JIT_INDEXER_METADATA": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_b200_pro.py b/test/manual/dsv4/test_b200_pro.py new file mode 100644 index 000000000000..67eaadd21c74 --- /dev/null +++ b/test/manual/dsv4/test_b200_pro.py @@ -0,0 +1,125 @@ +"""B200 (FP4) x DeepSeek-V4-Pro. + +Covers the four cookbook recipes for this hardware x model_size cell: +Low-Latency, Balanced, Max-Throughput, Context-Parallel (CP). +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DEEPEP_LARGE_SMS_CONFIG, DSV4ProAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Pro" + + +class TestB200ProLowLatency(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + "--mem-fraction-static", + "0.88", + ] + EXTRA_ENV = {} + + +class TestB200ProBalanced(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--mem-fraction-static", + "0.82", + "--cuda-graph-max-bs", + "64", + "--max-running-requests", + "128", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256"} + + +class TestB200ProMaxThroughput(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--mem-fraction-static", + "0.82", + "--cuda-graph-max-bs", + "64", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256"} + + +class TestB200ProCP(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-a2a-backend", + "deepep", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + "--chunked-prefill-size", + "16384", + "--mem-fraction-static", + "0.78", + "--cuda-graph-max-bs", + "256", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + "SGLANG_OPT_USE_JIT_INDEXER_METADATA": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_b300_flash.py b/test/manual/dsv4/test_b300_flash.py new file mode 100644 index 000000000000..261bac4764db --- /dev/null +++ b/test/manual/dsv4/test_b300_flash.py @@ -0,0 +1,111 @@ +"""B300 x DeepSeek-V4-Flash. + +The cookbook generator aliases B300 to B200, so the launch flags +are identical to the B200(FP4) Flash cell. Kept as a separate file +because the hardware target (and therefore the runtime environment) +is different. Covers Low-Latency, Balanced, Max-Throughput, CP. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DEEPEP_LARGE_SMS_CONFIG, DSV4FlashAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Flash" + + +class TestB300FlashLowLatency(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + ] + EXTRA_ENV = {} + + +class TestB300FlashBalanced(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024"} + + +class TestB300FlashMaxThroughput(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024"} + + +class TestB300FlashCP(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-a2a-backend", + "deepep", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + "--chunked-prefill-size", + "16384", + "--mem-fraction-static", + "0.78", + "--max-running-requests", + "1024", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + "SGLANG_OPT_USE_JIT_INDEXER_METADATA": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_b300_pro.py b/test/manual/dsv4/test_b300_pro.py new file mode 100644 index 000000000000..1ed254e5253d --- /dev/null +++ b/test/manual/dsv4/test_b300_pro.py @@ -0,0 +1,127 @@ +"""B300 x DeepSeek-V4-Pro. + +The cookbook generator aliases B300 to B200, so the launch flags +are identical to the B200(FP4) Pro cell. Kept as a separate file +because the hardware target (and therefore the runtime environment) +is different. Covers Low-Latency, Balanced, Max-Throughput, CP. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DEEPEP_LARGE_SMS_CONFIG, DSV4ProAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Pro" + + +class TestB300ProLowLatency(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + "--mem-fraction-static", + "0.88", + ] + EXTRA_ENV = {} + + +class TestB300ProBalanced(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--mem-fraction-static", + "0.82", + "--cuda-graph-max-bs", + "64", + "--max-running-requests", + "128", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256"} + + +class TestB300ProMaxThroughput(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--dp", + "8", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--mem-fraction-static", + "0.82", + "--cuda-graph-max-bs", + "64", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256"} + + +class TestB300ProCP(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-a2a-backend", + "deepep", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + "--chunked-prefill-size", + "16384", + "--mem-fraction-static", + "0.78", + "--cuda-graph-max-bs", + "256", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + "SGLANG_OPT_USE_JIT_INDEXER_METADATA": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_dsv4_flash_mtp_dp4.py b/test/manual/dsv4/test_dsv4_flash_mtp_dp4.py new file mode 100644 index 000000000000..716a90ecebaa --- /dev/null +++ b/test/manual/dsv4/test_dsv4_flash_mtp_dp4.py @@ -0,0 +1,175 @@ +"""DSV4 Flash MTP test using EAGLE speculative algorithm. + +DSV4 Flash MTP shares the EAGLE wire path: EAGLE algo + NextN head built +into the target model weights. No separate draft model is needed (sglang +auto-falls back `--speculative-draft-model-path` to the target model). + +Test matrix mirrors test_eagle_infer_b.TestEAGLEServerBasic to maximize +cuda-graph + buffer-pool coverage on the DSV4 path: + - test_gsm8k (accuracy + spec path full forward) + - test_max_token_one (degenerate spec step, still cuda-graph captured) + - test_request_abort (cuda-graph buffer pool survives abort+restart) + +Server launch matches `run_flash_dp4.sh`: tp=4, dp=4, deepep MoE backend, +DSV4 FP8 (FP4 experts disabled). +""" + +import random +import threading +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_gsm8k_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DSV4_FLASH_MODEL_PATH = "sgl-project/DeepSeek-V4-Flash-FP8" + +DSV4_FLASH_ENV = { + "SGLANG_DSV4_FP4_EXPERTS": "0", + # MTP runs ~num_draft_tokens forward passes per step, so the deepep + # dispatch input size scales by that factor. Default 256 (used by the + # plain server) overflows once cuda-graph-max-bs * num_draft_tokens + # > 256. 1024 covers bs=128 * 4 draft tokens with headroom. + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", +} + +DEEPEP_CONFIG = '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}' + +PROMPTS = [ + "[INST] You are a helpful assistant.\\nWhere are you from? [/INST]", + "[INST] You are a helpful assistant.\\nSummarize gradient descent in 2 sentences. [/INST]", + "[INST] You are a helpful assistant.\\nWhat is 17*23? [/INST]", + "[INST] You are a helpful assistant.\\nList three primary colors. [/INST]", +] + + +class DSV4FlashMTPServerBase(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DSV4_FLASH_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_CONFIG, + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--mem-fraction-static", + "0.7", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + env=DSV4_FLASH_ENV, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def send_request(self): + time.sleep(random.uniform(0, 2)) + for prompt in PROMPTS: + resp = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": {"temperature": 0, "max_new_tokens": 256}, + }, + ) + assert resp.status_code == 200 + + def send_requests_abort(self): + for prompt in PROMPTS: + try: + time.sleep(random.uniform(0, 2)) + requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": {"temperature": 0, "max_new_tokens": 256}, + }, + timeout=0.5, + ) + except requests.exceptions.Timeout: + pass + + +class TestDSV4FlashMTPBasic(DSV4FlashMTPServerBase): + def test_gsm8k(self): + """Accuracy + spec path full forward.""" + requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_gsm8k_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.95) + + def test_max_token_one(self): + """Degenerate spec step (still cuda-graph captured).""" + requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=100, + max_new_tokens=1, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_gsm8k_eval(args) + self.assertGreater(metrics["output_throughput"], 50) + + def test_request_abort(self): + """Cuda-graph buffer pool must survive abort+restart cycles.""" + concurrency = 4 + threads = [ + threading.Thread(target=self.send_request) for _ in range(concurrency) + ] + [ + threading.Thread(target=self.send_requests_abort) + for _ in range(concurrency) + ] + for t in threads: + t.start() + for t in threads: + t.join() + self.assertIsNone(self.process.poll()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_dsv4_flash_mtp_tp8.py b/test/manual/dsv4/test_dsv4_flash_mtp_tp8.py new file mode 100644 index 000000000000..4913f0bb96bd --- /dev/null +++ b/test/manual/dsv4/test_dsv4_flash_mtp_tp8.py @@ -0,0 +1,125 @@ +"""DSV4-Flash 285B MTP performance tests on H200 TP=8. + +Manual test (8× H200, 285B FP8 weights). Not registered in CI. +""" + +import os +import tempfile +import unittest + +import requests + +from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs +from sglang.bench_one_batch_server import run_benchmark as run_one_batch_benchmark +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DSV4_FLASH_MODEL_PATH = "sgl-project/DeepSeek-V4-Flash-FP8" + +DSV4_FLASH_BASE_ENV = { + "SGLANG_ENABLE_SPEC_V2": "1", + "SGLANG_OPT_USE_TOPK_V2": "1", + "SGLANG_DSV4_FP4_EXPERTS": "0", + "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0", +} + +DSV4_FLASH_SERVER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--max-running-requests", + "8", +] + + +def _launch_dsv4_flash_server(extra_env=None): + env = dict(DSV4_FLASH_BASE_ENV) + if extra_env: + env.update(extra_env) + return popen_launch_server( + DSV4_FLASH_MODEL_PATH, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 4, + other_args=DSV4_FLASH_SERVER_ARGS, + env=env, + ) + + +class TestDSV4FlashMTPSimulatedAcc(CustomTestCase): + """bs=1 latency at isl=4096 / 900000 with `SGLANG_SIMULATE_ACC_LEN=3`. + + Reference (H200 Flash TP8): + - isl=4096 → output 258.1 tok/s, accept 2.94 + - isl=900000 → output 222.9 tok/s, accept 2.90 + """ + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = _launch_dsv4_flash_server( + extra_env={"SGLANG_SIMULATE_ACC_LEN": "3"} + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def _run_one_batch(self, input_len): + requests.get(self.base_url + "/flush_cache") + server_args = ServerArgs(model_path=DSV4_FLASH_MODEL_PATH) + bench_args = OneBatchBenchArgs( + run_name=f"dsv4_flash_simacc_isl{input_len}", + batch_size=(1,), + input_len=(input_len,), + output_len=(1024,), + base_url=self.base_url, + skip_warmup=True, + result_filename=os.path.join( + tempfile.gettempdir(), f"dsv4_flash_simacc_isl{input_len}.jsonl" + ), + append_to_github_summary=False, + ) + results, _ = run_one_batch_benchmark(server_args, bench_args) + self.assertTrue(results, "bench_one_batch_server returned no results") + return results[0] + + def test_isl_4096(self): + r = self._run_one_batch(4096) + print( + f"[flash simacc isl=4096] output_throughput={r.output_throughput:.2f} tok/s " + f"latency={r.latency:.2f}s last_ttft={r.last_ttft:.2f}s " + f"acc_length={r.acc_length:.2f}" + ) + # Reference 258.1 tok/s / acc=2.94. + self.assertGreater(r.output_throughput, 232.0) + self.assertGreater(r.acc_length, 2.85) + + def test_isl_900k(self): + r = self._run_one_batch(900_000) + print( + f"[flash simacc isl=900k] output_throughput={r.output_throughput:.2f} tok/s " + f"latency={r.latency:.2f}s last_ttft={r.last_ttft:.2f}s " + f"acc_length={r.acc_length:.2f}" + ) + # Reference 222.9 tok/s / acc=2.90. + self.assertGreater(r.output_throughput, 200.0) + self.assertGreater(r.acc_length, 2.85) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_dsv4_flash_sanity_dp4.py b/test/manual/dsv4/test_dsv4_flash_sanity_dp4.py new file mode 100644 index 000000000000..c9641cde2b2f --- /dev/null +++ b/test/manual/dsv4/test_dsv4_flash_sanity_dp4.py @@ -0,0 +1,151 @@ +"""DSV4-Flash 4-GPU server sanity matrix (TP4 variants).""" + +import unittest + +from sglang.srt.utils import kill_process_tree +from sglang.test.kits.server_sanity_kit import ServerSanityMixin +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DSV4_FLASH_MODEL_PATH = "sgl-project/DeepSeek-V4-Flash-FP8" + +DSV4_FLASH_ENV = { + "SGLANG_DSV4_FP4_EXPERTS": "0", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", +} + +DEEPEP_CONFIG = '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}' + + +def _launch(other_args, env_extra=None, timeout_mult=1): + env = dict(DSV4_FLASH_ENV) + if env_extra: + env.update(env_extra) + return popen_launch_server( + DSV4_FLASH_MODEL_PATH, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * timeout_mult, + other_args=other_args, + env=env, + ) + + +_EAGLE_SPEC_ARGS = [ + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", +] + + +class TestDSV4FlashTP4DP4(ServerSanityMixin, CustomTestCase): + """TP4 + DP4 + deepep + EAGLE MTP.""" + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = _launch( + [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_CONFIG, + "--mem-fraction-static", + "0.7", + *_EAGLE_SPEC_ARGS, + ] + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +class TestDSV4FlashTP4EP(ServerSanityMixin, CustomTestCase): + """TP attn + EP MoE (no DP attn) — exercises the DeepEP + TP-attn path.""" + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = _launch( + [ + "--trust-remote-code", + "--tp", + "4", + "--ep", + "4", + # No --enable-dp-attention by design: covers TP-attn path. + "--moe-a2a-backend", + "deepep", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "64", + "--deepep-config", + DEEPEP_CONFIG, + "--mem-fraction-static", + "0.7", + *_EAGLE_SPEC_ARGS, + ] + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +class TestDSV4FlashTP4DP4ChunkedPrefillLarge(ServerSanityMixin, CustomTestCase): + """TP4 + DP4 with --chunked-prefill-size 16384 — large chunked prefill.""" + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = _launch( + [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--chunked-prefill-size", + "16384", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_CONFIG, + "--mem-fraction-static", + "0.7", + *_EAGLE_SPEC_ARGS, + ] + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_dsv4_flash_sanity_tp8.py b/test/manual/dsv4/test_dsv4_flash_sanity_tp8.py new file mode 100644 index 000000000000..06930334a698 --- /dev/null +++ b/test/manual/dsv4/test_dsv4_flash_sanity_tp8.py @@ -0,0 +1,50 @@ +"""DSV4-Flash 8-GPU server sanity (TP8, no spec decoding).""" + +import unittest + +from sglang.srt.utils import kill_process_tree +from sglang.test.kits.server_sanity_kit import ServerSanityMixin +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DSV4_FLASH_MODEL_PATH = "sgl-project/DeepSeek-V4-Flash-FP8" + +DSV4_FLASH_ENV = { + "SGLANG_DSV4_FP4_EXPERTS": "0", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", +} + + +class TestDSV4FlashTP8NoSpec(ServerSanityMixin, CustomTestCase): + """TP8, no spec decoding.""" + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DSV4_FLASH_MODEL_PATH, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "8", + "--max-running-requests", + "8", + "--mem-fraction-static", + "0.85", + ], + env=DSV4_FLASH_ENV, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_dsv4_pd_disagg_nixl.py b/test/manual/dsv4/test_dsv4_pd_disagg_nixl.py new file mode 100644 index 000000000000..74095c09b52e --- /dev/null +++ b/test/manual/dsv4/test_dsv4_pd_disagg_nixl.py @@ -0,0 +1,148 @@ +"""DSV4 Flash PD-disagg with NIXL backend. Both sides run dp-attention ++ deepep + EAGLE MTP so attn_tp_size and the V4 state pool layout are +fully symmetric: same SWA item_len under matching attn_tp, and same +NSA c4/c128 indexer ring buffer size under matching spec status. nixl +`send_state` is page-by-index and has no V4 TP-slice / spec-asymmetric +path, so any layout mismatch would trip the item_len assert in +`nixl/conn.py`.""" + +import unittest +from types import SimpleNamespace + +from sglang.test.few_shot_gsm8k import run_eval as run_gsm8k_eval +from sglang.test.server_fixtures.disaggregation_fixture import ( + PDDisaggregationServerBase, +) +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + popen_launch_pd_server, +) + +DSV4_FLASH_MODEL_PATH = "sgl-project/DeepSeek-V4-Flash-FP8" + +DSV4_FLASH_ENV = { + "SGLANG_DSV4_FP4_EXPERTS": "0", + # MTP num_draft_tokens=4 scales dispatch by ~4x; 256 overflows at bs=128. + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", +} + +DEEPEP_CONFIG = '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}' + + +class TestDSV4FlashPDDisaggNIXL(PDDisaggregationServerBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.transfer_backend = ["--disaggregation-transfer-backend", "nixl"] + cls.rdma_devices = [] + cls.model = DSV4_FLASH_MODEL_PATH + + cls.start_prefill() + cls.start_decode() + + cls.wait_server_ready(cls.prefill_url + "/health", process=cls.process_prefill) + cls.wait_server_ready(cls.decode_url + "/health", process=cls.process_decode) + cls.launch_lb() + + @classmethod + def start_prefill(cls): + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--base-gpu-id", + "0", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-config", + DEEPEP_CONFIG, + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--mem-fraction-static", + "0.7", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + *cls.transfer_backend, + *cls.rdma_devices, + ] + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + env=DSV4_FLASH_ENV, + ) + + @classmethod + def start_decode(cls): + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--base-gpu-id", + "4", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-config", + DEEPEP_CONFIG, + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--mem-fraction-static", + "0.7", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + *cls.transfer_backend, + *cls.rdma_devices, + ] + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + env=DSV4_FLASH_ENV, + ) + + def test_gsm8k(self): + """End-to-end PD-disagg accuracy through the LB.""" + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=64, + host=f"http://{self.base_host}", + port=int(self.lb_port), + ) + metrics = run_gsm8k_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.95) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_dsv4_pro_mtp.py b/test/manual/dsv4/test_dsv4_pro_mtp.py new file mode 100644 index 000000000000..d989f22db4e0 --- /dev/null +++ b/test/manual/dsv4/test_dsv4_pro_mtp.py @@ -0,0 +1,281 @@ +"""DSV4-Pro 1.6T MTP performance tests on B200 TP=8. + +1. TestDSV4ProMTPSimulatedAcc — `SGLANG_SIMULATE_ACC_LEN=3` pins EAGLE accept + length so latency comparisons are apples-to-apples. Runs `bench_one_batch_server` + at bs=1 for isl=4096 and isl=900000 (osl=1024). + +2. TestDSV4ProMTPHongloumeng — real EAGLE accept (no SIMULATE) on Chinese + long-context input (`hongloumeng.txt`, ~627k DSV4 tokens). Builds a one-line + custom JSONL dataset on the fly and drives `bench_serving --dataset-name custom` + with one short slice (30k tokens) and the full long prompt. + +Manual test (8× B200, 1.6T weights). Not registered in CI. +""" + +import json +import os +import tempfile +import unittest +from types import SimpleNamespace + +import requests + +from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs +from sglang.bench_one_batch_server import run_benchmark as run_one_batch_benchmark +from sglang.bench_serving import run_benchmark as run_serving_benchmark +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DSV4_PRO_MODEL_PATH = "deepseek-ai/DeepSeek-V4-Pro" + +HONGLOUMENG_PATH = os.environ.get( + "SGLANG_HONGLOUMENG_PATH", + os.path.join(os.path.dirname(__file__), "hongloumeng.txt"), +) + +DSV4_PRO_BASE_ENV = { + "SGLANG_ENABLE_SPEC_V2": "1", + "SGLANG_OPT_USE_TOPK_V2": "1", + "SGLANG_OPT_USE_CUSTOM_ALL_REDUCE_V2": "1", + "SGLANG_JIT_DEEPGEMM_PRECOMPILE": "0", +} + +DSV4_PRO_SERVER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + "--mem-fraction-static", + "0.82", + "--max-running-requests", + "8", +] + + +def _launch_dsv4_pro_server(extra_env=None): + env = dict(DSV4_PRO_BASE_ENV) + if extra_env: + env.update(extra_env) + return popen_launch_server( + DSV4_PRO_MODEL_PATH, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 4, + other_args=DSV4_PRO_SERVER_ARGS, + env=env, + ) + + +class TestDSV4ProMTPSimulatedAcc(CustomTestCase): + """bs=1 latency at isl=4096 / 900000 with `SGLANG_SIMULATE_ACC_LEN=3`. + + Reference (B200 Pro TP8): + - isl=4096 → output 194.6 tok/s, accept 2.96 + - isl=900000 → output 174.6 tok/s, accept 2.93 + """ + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = _launch_dsv4_pro_server( + extra_env={"SGLANG_SIMULATE_ACC_LEN": "3"} + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def _run_one_batch(self, input_len): + requests.get(self.base_url + "/flush_cache") + server_args = ServerArgs(model_path=DSV4_PRO_MODEL_PATH) + bench_args = OneBatchBenchArgs( + run_name=f"dsv4_pro_simacc_isl{input_len}", + batch_size=(1,), + input_len=(input_len,), + output_len=(1024,), + base_url=self.base_url, + skip_warmup=True, + result_filename=os.path.join( + tempfile.gettempdir(), f"dsv4_pro_simacc_isl{input_len}.jsonl" + ), + append_to_github_summary=False, + ) + results, _ = run_one_batch_benchmark(server_args, bench_args) + self.assertTrue(results, "bench_one_batch_server returned no results") + return results[0] + + def test_isl_4096(self): + r = self._run_one_batch(4096) + print( + f"[pro simacc isl=4096] output_throughput={r.output_throughput:.2f} tok/s " + f"latency={r.latency:.2f}s last_ttft={r.last_ttft:.2f}s " + f"acc_length={r.acc_length:.2f}" + ) + # Reference 194.6 tok/s / acc=2.96 — give 10% throughput margin and a + # generous accept-length floor to absorb run-to-run jitter. + self.assertGreater(r.output_throughput, 175.0) + self.assertGreater(r.acc_length, 2.85) + + def test_isl_900k(self): + r = self._run_one_batch(900_000) + print( + f"[pro simacc isl=900k] output_throughput={r.output_throughput:.2f} tok/s " + f"latency={r.latency:.2f}s last_ttft={r.last_ttft:.2f}s " + f"acc_length={r.acc_length:.2f}" + ) + # Reference 174.6 tok/s / acc=2.93. + self.assertGreater(r.output_throughput, 155.0) + self.assertGreater(r.acc_length, 2.85) + + +def _build_hongloumeng_jsonl(num_tokens, tokenizer, out_path): + """Slice the first `num_tokens` DSV4 tokens of hongloumeng.txt into a + one-line CustomDataset JSONL. Pass num_tokens=None to keep the full text. + """ + with open(HONGLOUMENG_PATH, "r", encoding="utf-8") as f: + text = f.read() + if num_tokens is not None: + ids = tokenizer.encode(text) + text = tokenizer.decode(ids[:num_tokens]) + with open(out_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + {"conversations": [{"value": text}, {"value": "x"}]}, + ensure_ascii=False, + ) + + "\n" + ) + return out_path + + +class TestDSV4ProMTPHongloumeng(CustomTestCase): + """Real EAGLE accept on Chinese long-context (hongloumeng.txt). + + Reference (B200 Pro TP8, no SIMULATE): + - isl=30000 → output 124.4 tok/s, decode peak 184 tok/s, accept 2.47 + - isl=627059 → output 125.7 tok/s, decode peak 179 tok/s, accept 2.52 + """ + + SHORT_TOKENS = 30_000 + LONG_TOKENS = None # full file (~627k DSV4 tokens) + OUTPUT_TOKENS = 4096 + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = _launch_dsv4_pro_server() + + # Resolve tokenizer once; the server reports its own tokenizer path so + # on-the-fly token-level slicing matches what the server will see. + info = requests.get(cls.base_url + "/server_info", timeout=60).json() + tokenizer_path = info.get("tokenizer_path") or DSV4_PRO_MODEL_PATH + from sglang.srt.utils.hf_transformers_utils import get_tokenizer + + cls.tokenizer = get_tokenizer(tokenizer_path) + + cls.tmpdir = tempfile.mkdtemp(prefix="dsv4_hongloumeng_") + cls.short_jsonl = _build_hongloumeng_jsonl( + cls.SHORT_TOKENS, + cls.tokenizer, + os.path.join(cls.tmpdir, "hongloumeng_30k.jsonl"), + ) + cls.long_jsonl = _build_hongloumeng_jsonl( + cls.LONG_TOKENS, + cls.tokenizer, + os.path.join(cls.tmpdir, "hongloumeng_full.jsonl"), + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def _run_custom_bench(self, dataset_path): + requests.get(self.base_url + "/flush_cache") + args = SimpleNamespace( + backend="sglang", + base_url=self.base_url, + host=None, + port=None, + dataset_name="custom", + dataset_path=dataset_path, + model=None, + tokenizer=None, + num_prompts=1, + sharegpt_output_len=self.OUTPUT_TOKENS, + sharegpt_context_len=None, + random_input_len=4096, + random_output_len=2048, + random_range_ratio=0.0, + request_rate=float("inf"), + max_concurrency=1, + warmup_requests=0, + flush_cache=True, + multi=None, + output_file=None, + disable_tqdm=False, + disable_stream=False, + return_logprob=False, + return_routed_experts=False, + seed=0, + disable_ignore_eos=False, + extra_request_body=None, + apply_chat_template=False, + profile=None, + lora_name=None, + lora_request_distribution="uniform", + lora_zipf_alpha=1.5, + prompt_suffix="", + device="cuda", + pd_separated=False, + ready_check_timeout_sec=0, + ) + return run_serving_benchmark(args) + + def test_short_30k(self): + res = self._run_custom_bench(self.short_jsonl) + print( + f"[hongloumeng 30k] output_throughput={res['output_throughput']:.2f} tok/s " + f"accept_length={res['accept_length']:.2f} " + f"mean_ttft_ms={res['mean_ttft_ms']:.0f} " + f"mean_tpot_ms={res['mean_tpot_ms']:.2f}" + ) + # Reference 124 tok/s / accept 2.47. + self.assertGreater(res["output_throughput"], 105.0) + self.assertGreater(res["accept_length"], 2.30) + + def test_long_full(self): + res = self._run_custom_bench(self.long_jsonl) + print( + f"[hongloumeng full] output_throughput={res['output_throughput']:.2f} tok/s " + f"accept_length={res['accept_length']:.2f} " + f"mean_ttft_ms={res['mean_ttft_ms']:.0f} " + f"mean_tpot_ms={res['mean_tpot_ms']:.2f}" + ) + # Reference 125 tok/s / accept 2.52. Cold prefill takes ~85s on 627k + # tokens so the run is dominated by prefill, but decode steady-state + # accept_length is the metric we care about. + self.assertGreater(res["output_throughput"], 105.0) + self.assertGreater(res["accept_length"], 2.30) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_dsv4_swa_radix_retract.py b/test/manual/dsv4/test_dsv4_swa_radix_retract.py new file mode 100644 index 000000000000..7482eec62131 --- /dev/null +++ b/test/manual/dsv4/test_dsv4_swa_radix_retract.py @@ -0,0 +1,165 @@ +"""DSV4 stress test for SWA radix cache + tombstone + retract interaction. + +Reproduces the assert in `swa_radix_cache.cache_unfinished_req`: + assert old_prefix_len <= len(new_indices) + +Trip conditions (all required): + 1. Fork-only SWA leaf early-release on (`SGLANG_OPT_SWA_RELEASE_LEAF_LOCK_AFTER_WINDOW=1`) + 2. Multiple requests share a long prefix (so one req's tombstoned leaf + poisons match_prefix for others walking the same radix path). + 3. Memory pressure forces retract while at least one req has tombstoned + its leaf (decode_batch_idx >= sliding_window_size at retract time). + +After main #19427 changed `old_prefix_len = req.cache_protected_len` +(stable), tombstone-induced shrinks in match's `best_value_len` across +chunked-prefill rounds can make stale `cache_protected_len` exceed +current matchable length -> assert trips. + +Test passes iff the scheduler does not crash under this stress workload. +""" + +import random +import threading +import time +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +DSV4_FLASH_MODEL_PATH = "sgl-project/DeepSeek-V4-Flash-FP8" + +# Long shared prefix forces multi-chunk prefill and ensures cross-request +# prefix-cache hits so one req's tombstone affects later reqs. +SHARED_PREFIX_BLOCK = ( + "You are a careful, expert assistant. Answer concisely.\n" + "Context: " + ("the quick brown fox jumps over the lazy dog. " * 600) +) + +QUESTION_TAILS = [ + " Q: What is 17*23?\n", + " Q: List three primary colors.\n", + " Q: Where is Mount Everest?\n", + " Q: Summarize gradient descent in two sentences.\n", + " Q: Name two bodies of water in Africa.\n", + " Q: What language is spoken in Brazil?\n", +] + + +class TestDSV4FlashSWARadixRetract(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DSV4_FLASH_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--deepep-config", + '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}', + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + # Tight static memory so SWA pool fills up under load and + # retract is forced. + "--mem-fraction-static", + "0.7", + ] + env = { + "SGLANG_DSV4_FP4_EXPERTS": "0", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", + "SGLANG_OPT_SWA_RADIX_CACHE_COMPACT": "0", + "SGLANG_TEST_RETRACT": "1", + "SGLANG_TEST_RETRACT_INTERVAL": "3", + } + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + env=env, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def _send_req(self, prompt: str, max_new_tokens: int): + try: + resp = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + # Vary outputs slightly so reqs don't share decode + # paths perfectly; we want some to finish, some to + # be retracted under pressure. + "temperature": 0.7, + "max_new_tokens": max_new_tokens, + }, + }, + timeout=600, + ) + # Per-request success is not the gate; some requests are + # expected to be retracted/aborted under heavy pressure. + return resp.status_code == 200 + except Exception: + return False + + def test_swa_tombstone_retract_does_not_crash(self): + """Stress: 64 concurrent long-prompt reqs with long generation force + retract under SWA pool pressure. Reqs share a 30k+ token prefix so + tombstoned leaves from retracted reqs are on the radix path of new + reqs. Scheduler must not crash on the swa_radix_cache assert.""" + + random.seed(0) + concurrency = 64 + # Long enough generation to push past sliding_window_size -> fires + # `dec_swa_lock_only` -> tombstones leaves. Combined with SWA pool + # pressure this guarantees retract while tombstones are live. + max_new_tokens = 1024 + + threads = [] + for i in range(concurrency): + tail = QUESTION_TAILS[i % len(QUESTION_TAILS)] + # Add a small per-req suffix so reqs don't dedup at radix root + # but still share the bulk of the prefix. + prompt = SHARED_PREFIX_BLOCK + tail + f"(seed={i})" + t = threading.Thread(target=self._send_req, args=(prompt, max_new_tokens)) + threads.append(t) + t.start() + # Stagger so requests enter prefill in waves; some are still in + # decode (and have tombstoned leaves) when later waves of + # chunked-prefill reqs walk the same radix path. + time.sleep(0.05) + + for t in threads: + t.join(timeout=600) + + # The only invariant: scheduler survived. Per-request completion is + # best-effort under retract pressure. + self.assertIsNone(self.process.poll()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_gb300_flash.py b/test/manual/dsv4/test_gb300_flash.py new file mode 100644 index 000000000000..4e7a7ec5e582 --- /dev/null +++ b/test/manual/dsv4/test_gb300_flash.py @@ -0,0 +1,109 @@ +"""GB300 x DeepSeek-V4-Flash. + +Single-node TP=4 path on the deepseek-ai MXFP4 repo. Covers +Low-Latency, Balanced, Max-Throughput, Context-Parallel (CP). +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DEEPEP_LARGE_SMS_CONFIG, DSV4FlashAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Flash" + + +class TestGB300FlashLowLatency(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + ] + EXTRA_ENV = {} + + +class TestGB300FlashBalanced(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024"} + + +class TestGB300FlashMaxThroughput(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024"} + + +class TestGB300FlashCP(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-a2a-backend", + "deepep", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + "--chunked-prefill-size", + "16384", + "--mem-fraction-static", + "0.78", + "--max-running-requests", + "1024", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + "SGLANG_OPT_USE_JIT_INDEXER_METADATA": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_gb300_pro.py b/test/manual/dsv4/test_gb300_pro.py new file mode 100644 index 000000000000..2e186591b3a6 --- /dev/null +++ b/test/manual/dsv4/test_gb300_pro.py @@ -0,0 +1,127 @@ +"""GB300 x DeepSeek-V4-Pro. + +Single-node TP=4 path. Note that GB300 Pro CP bumps +mem-fraction-static to 0.88 (1.6T weights at TP=4 on 273 GB don't +fit at the default 0.78). Covers Low-Latency, Balanced, +Max-Throughput, Context-Parallel (CP). +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DEEPEP_LARGE_SMS_CONFIG, DSV4ProAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Pro" + + +class TestGB300ProLowLatency(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + "--mem-fraction-static", + "0.88", + ] + EXTRA_ENV = {} + + +class TestGB300ProBalanced(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--mem-fraction-static", + "0.9", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256"} + + +class TestGB300ProMaxThroughput(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--mem-fraction-static", + "0.9", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = {"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256"} + + +class TestGB300ProCP(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-a2a-backend", + "deepep", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + "--chunked-prefill-size", + "16384", + "--mem-fraction-static", + "0.88", + "--cuda-graph-max-bs", + "256", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + "SGLANG_OPT_USE_JIT_INDEXER_METADATA": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_h200_fp4_flash.py b/test/manual/dsv4/test_h200_fp4_flash.py new file mode 100644 index 000000000000..aa5b91ea8664 --- /dev/null +++ b/test/manual/dsv4/test_h200_fp4_flash.py @@ -0,0 +1,71 @@ +"""H200 (FP4 / Marlin) x DeepSeek-V4-Flash. + +The cookbook disables Context-Parallel for the H200 FP4 (Marlin) +hardware, so this file only covers Low-Latency, Balanced, and +Max-Throughput. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DSV4FlashAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Flash" + + +class TestH200Fp4FlashLowLatency(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "marlin", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + ] + EXTRA_ENV = {} + + +class TestH200Fp4FlashBalanced(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "marlin", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + ] + EXTRA_ENV = {} + + +class TestH200Fp4FlashMaxThroughput(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "marlin", + ] + EXTRA_ENV = {} + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_h200_fp4_pro.py b/test/manual/dsv4/test_h200_fp4_pro.py new file mode 100644 index 000000000000..ce87c984fe32 --- /dev/null +++ b/test/manual/dsv4/test_h200_fp4_pro.py @@ -0,0 +1,77 @@ +"""H200 (FP4 / Marlin) x DeepSeek-V4-Pro. + +The cookbook disables Context-Parallel for the H200 FP4 (Marlin) +hardware, so this file only covers Low-Latency, Balanced, and +Max-Throughput. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DSV4ProAime25TestBase + +MODEL = "deepseek-ai/DeepSeek-V4-Pro" + + +class TestH200Fp4ProLowLatency(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-runner-backend", + "marlin", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--mem-fraction-static", + "0.88", + ] + EXTRA_ENV = {} + + +class TestH200Fp4ProBalanced(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-runner-backend", + "marlin", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--mem-fraction-static", + "0.88", + ] + EXTRA_ENV = {} + + +class TestH200Fp4ProMaxThroughput(DSV4ProAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "8", + "--moe-runner-backend", + "marlin", + "--mem-fraction-static", + "0.88", + ] + EXTRA_ENV = {} + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_h200_fp8_flash.py b/test/manual/dsv4/test_h200_fp8_flash.py new file mode 100644 index 000000000000..2aacca9ae0b0 --- /dev/null +++ b/test/manual/dsv4/test_h200_fp8_flash.py @@ -0,0 +1,121 @@ +"""H200 (FP8) x DeepSeek-V4-Flash. + +Uses the FP8-repackaged repo (sgl-project/DeepSeek-V4-Flash-FP8) and +the SGLANG_DSV4_FP4_EXPERTS=0 env that the cookbook generator emits +for H200 FP8 cells. Covers Low-Latency, Balanced, Max-Throughput, CP. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DEEPEP_LARGE_SMS_CONFIG, DSV4FlashAime25TestBase + +MODEL = "sgl-project/DeepSeek-V4-Flash-FP8" +H200_FP8_ENV = {"SGLANG_DSV4_FP4_EXPERTS": "0"} + + +class TestH200Fp8FlashLowLatency(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + ] + EXTRA_ENV = dict(H200_FP8_ENV) + + +class TestH200Fp8FlashBalanced(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "128", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + **H200_FP8_ENV, + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", + } + + +class TestH200Fp8FlashMaxThroughput(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + **H200_FP8_ENV, + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "256", + } + + +class TestH200Fp8FlashCP(DSV4FlashAime25TestBase): + MODEL = MODEL + OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "4", + "--moe-a2a-backend", + "deepep", + "--enable-nsa-prefill-context-parallel", + "--nsa-prefill-cp-mode", + "round-robin-split", + "--chunked-prefill-size", + "16384", + "--mem-fraction-static", + "0.78", + "--max-running-requests", + "1024", + "--deepep-config", + DEEPEP_LARGE_SMS_CONFIG, + ] + EXTRA_ENV = { + **H200_FP8_ENV, + "SGLANG_OPT_USE_JIT_INDEXER_METADATA": "1", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", + } + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_h200_fp8_pro.py b/test/manual/dsv4/test_h200_fp8_pro.py new file mode 100644 index 000000000000..cdf95ff6058b --- /dev/null +++ b/test/manual/dsv4/test_h200_fp8_pro.py @@ -0,0 +1,126 @@ +"""H200 (FP8) x DeepSeek-V4-Pro. + +The cookbook ships this cell as a multi-node (2 nodes, TP=16) launch +using the FP8-repackaged repo (sgl-project/DeepSeek-V4-Pro-FP8). +Each test class skips itself unless DSV4_NODE_RANK and +DSV4_DIST_INIT_ADDR are exported. Runtime expectation: + + On every node: + DSV4_NODE_RANK=<0 or 1> \\ + DSV4_DIST_INIT_ADDR=:20000 \\ + python test/manual/models/dsv4/test_h200_fp8_pro.py + +Context-Parallel is marked TBD in the cookbook for this cell, so it +is intentionally omitted. +""" + +import os +import sys +import unittest + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from _common import DSV4ProAime25TestBase, multinode_args + +MODEL = "sgl-project/DeepSeek-V4-Pro-FP8" +H200_FP8_PRO_ENV = { + "SGLANG_DSV4_FP4_EXPERTS": "0", + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "128", +} + + +class TestH200Fp8ProLowLatency(DSV4ProAime25TestBase): + MODEL = MODEL + EXTRA_ENV = dict(H200_FP8_PRO_ENV) + + @classmethod + def setUpClass(cls): + cls.OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "16", + "--dp", + "16", + "--enable-dp-attention", + *multinode_args(2), + "--moe-a2a-backend", + "deepep", + "--cuda-graph-max-bs", + "8", + "--max-running-requests", + "32", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--mem-fraction-static", + "0.88", + ] + super().setUpClass() + + +class TestH200Fp8ProBalanced(DSV4ProAime25TestBase): + MODEL = MODEL + EXTRA_ENV = dict(H200_FP8_PRO_ENV) + + @classmethod + def setUpClass(cls): + cls.OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "16", + "--dp", + "16", + "--enable-dp-attention", + *multinode_args(2), + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--mem-fraction-static", + "0.88", + "--cuda-graph-max-bs", + "8", + "--max-running-requests", + "32", + ] + super().setUpClass() + + +class TestH200Fp8ProMaxThroughput(DSV4ProAime25TestBase): + MODEL = MODEL + EXTRA_ENV = dict(H200_FP8_PRO_ENV) + + @classmethod + def setUpClass(cls): + cls.OTHER_ARGS = [ + "--trust-remote-code", + "--tp", + "16", + "--dp", + "16", + "--enable-dp-attention", + *multinode_args(2), + "--moe-a2a-backend", + "deepep", + "--mem-fraction-static", + "0.88", + "--cuda-graph-max-bs", + "128", + "--max-running-requests", + "256", + ] + super().setUpClass() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_swa_alloc_extend_page_estimation.py b/test/manual/dsv4/test_swa_alloc_extend_page_estimation.py new file mode 100644 index 000000000000..5fc6a9829e83 --- /dev/null +++ b/test/manual/dsv4/test_swa_alloc_extend_page_estimation.py @@ -0,0 +1,136 @@ +"""Regression for SWA alloc_extend page estimation. + +Old gate in SWATokenToKVPoolAllocator.alloc_extend added one full page_size +per request unconditionally, refusing extends that fit inside the request's +last partial page. Fix replaces with get_num_new_pages-based gating. +""" + +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock + +import torch + +from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator +from sglang.test.test_utils import CustomTestCase + + +def _make_self(*, page_size: int, full_available: int, swa_available: int): + full_indices = torch.tensor([10, 11], dtype=torch.int64) + swa_indices = torch.tensor([20, 21], dtype=torch.int64) + return SimpleNamespace( + page_size=page_size, + full_attn_allocator=SimpleNamespace( + available_size=lambda: full_available, + alloc_extend=MagicMock(return_value=full_indices), + ), + swa_attn_allocator=SimpleNamespace( + available_size=lambda: swa_available, + alloc_extend=MagicMock(return_value=swa_indices), + ), + translate_loc_from_full_to_swa=lambda last_loc: last_loc, + full_to_swa_index_mapping=torch.zeros(64, dtype=torch.int64), + ) + + +def _call(stub, *, prefix_lens_cpu, seq_lens_cpu, extend_num_tokens): + return SWATokenToKVPoolAllocator.alloc_extend( + stub, + prefix_lens=prefix_lens_cpu, + prefix_lens_cpu=prefix_lens_cpu, + seq_lens=seq_lens_cpu, + seq_lens_cpu=seq_lens_cpu, + last_loc=torch.tensor( + [int(p) - 1 for p in prefix_lens_cpu.tolist()], dtype=torch.int64 + ), + extend_num_tokens=extend_num_tokens, + ) + + +class TestSWAAllocExtendPageEstimation(CustomTestCase): + def test_zero_new_pages_must_succeed(self): + # Old: 2 + 2*8 = 18 > 16 -> would refuse. + # New: prefix 5 -> 6 stays in page 0, 0 new pages. + stub = _make_self(page_size=8, full_available=16, swa_available=16) + result = _call( + stub, + prefix_lens_cpu=torch.tensor([5, 5], dtype=torch.int64), + seq_lens_cpu=torch.tensor([6, 6], dtype=torch.int64), + extend_num_tokens=2, + ) + self.assertIsNotNone(result) + stub.full_attn_allocator.alloc_extend.assert_called_once() + stub.swa_attn_allocator.alloc_extend.assert_called_once() + + def test_one_new_page_fits(self): + # Old: 6 + 2*8 = 22 > 16. New: 2 new pages == 16 // 8. + stub = _make_self(page_size=8, full_available=16, swa_available=16) + result = _call( + stub, + prefix_lens_cpu=torch.tensor([7, 7], dtype=torch.int64), + seq_lens_cpu=torch.tensor([10, 10], dtype=torch.int64), + extend_num_tokens=6, + ) + self.assertIsNotNone(result) + + def test_full_pool_genuinely_insufficient(self): + stub = _make_self(page_size=8, full_available=8, swa_available=64) + result = _call( + stub, + prefix_lens_cpu=torch.tensor([8, 8, 8, 8, 8], dtype=torch.int64), + seq_lens_cpu=torch.tensor([9, 9, 9, 9, 9], dtype=torch.int64), + extend_num_tokens=5, + ) + self.assertIsNone(result) + stub.full_attn_allocator.alloc_extend.assert_not_called() + + def test_swa_pool_genuinely_insufficient(self): + stub = _make_self(page_size=8, full_available=64, swa_available=8) + result = _call( + stub, + prefix_lens_cpu=torch.tensor([8, 8, 8, 8, 8], dtype=torch.int64), + seq_lens_cpu=torch.tensor([9, 9, 9, 9, 9], dtype=torch.int64), + extend_num_tokens=5, + ) + self.assertIsNone(result) + stub.swa_attn_allocator.alloc_extend.assert_not_called() + + def test_exactly_at_capacity_succeeds(self): + stub = _make_self(page_size=8, full_available=16, swa_available=16) + result = _call( + stub, + prefix_lens_cpu=torch.tensor([8, 8], dtype=torch.int64), + seq_lens_cpu=torch.tensor([9, 9], dtype=torch.int64), + extend_num_tokens=2, + ) + self.assertIsNotNone(result) + + def test_one_over_capacity_refuses(self): + stub = _make_self(page_size=8, full_available=16, swa_available=16) + result = _call( + stub, + prefix_lens_cpu=torch.tensor([8, 8, 8], dtype=torch.int64), + seq_lens_cpu=torch.tensor([9, 9, 9], dtype=torch.int64), + extend_num_tokens=3, + ) + self.assertIsNone(result) + + def test_zero_new_pages_across_page_sizes(self): + # Over-estimation gap grows with page_size; sweep to confirm fix + # doesn't depend on the page_size=8 numbers above. + for page_size in (16, 32, 64, 128): + stub = _make_self( + page_size=page_size, + full_available=page_size * 2, + swa_available=page_size * 2, + ) + prefix = torch.tensor([page_size - 2] * 4, dtype=torch.int64) + seq = torch.tensor([page_size - 1] * 4, dtype=torch.int64) + result = _call( + stub, prefix_lens_cpu=prefix, seq_lens_cpu=seq, extend_num_tokens=4 + ) + self.assertIsNotNone(result, f"page_size={page_size}") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/manual/dsv4/test_swa_lock_release_lifecycle.py b/test/manual/dsv4/test_swa_lock_release_lifecycle.py new file mode 100644 index 000000000000..4be0247e342c --- /dev/null +++ b/test/manual/dsv4/test_swa_lock_release_lifecycle.py @@ -0,0 +1,463 @@ +"""Regression for SWA lock release lifecycle. + +Hybrid-SWA early-release protocol: once a request's decode position passes +the sliding window, drop its prefill SWA lock without touching the full +lock, freeing SWA pages back to LRU. + +Covers: +- SWARadixCache.dec_swa_lock_only (leaf tombstone + free, internal protected->evictable) +- SWARadixCache.dec_lock_ref(skip_swa=True) +- SWARadixCache.evict swa branch for leaf with full_lock_ref > 0 +- SWARadixCache._delete_leaf skipping swa_evictable_size_ on tombstoned leaves +""" + +import unittest + +import torch + +from sglang.srt.mem_cache.base_prefix_cache import ( + DecLockRefParams, + EvictParams, + InsertParams, + MatchPrefixParams, +) +from sglang.srt.mem_cache.cache_init_params import CacheInitParams +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.radix_cache import RadixKey +from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator +from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache +from sglang.srt.utils import get_device +from sglang.test.test_utils import CustomTestCase + + +def _build_tree( + *, + sliding_window_size: int = 4, + page_size: int = 1, + kv_size: int = 128, + kv_size_swa: int = 64, +): + head_num, head_dim, num_layers, global_interval = 8, 128, 24, 4 + dtype = torch.bfloat16 + device = get_device() + full_ids = list(range(0, num_layers, global_interval)) + swa_ids = [i for i in range(num_layers) if i not in set(full_ids)] + + pool = ReqToTokenPool( + size=8, max_context_len=256, device=device, enable_memory_saver=False + ) + kv_pool = SWAKVPool( + size=kv_size, + size_swa=kv_size_swa, + page_size=page_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + swa_attention_layer_ids=swa_ids, + full_attention_layer_ids=full_ids, + enable_kvcache_transpose=False, + device=device, + ) + allocator = SWATokenToKVPoolAllocator( + size=kv_size, + size_swa=kv_size_swa, + page_size=page_size, + dtype=dtype, + device=device, + kvcache=kv_pool, + need_sort=False, + ) + tree = SWARadixCache( + params=CacheInitParams( + req_to_token_pool=pool, + token_to_kv_pool_allocator=allocator, + page_size=page_size, + disable=False, + is_eagle=False, + sliding_window_size=sliding_window_size, + ), + ) + return tree, allocator, pool + + +def _swa_alloc(allocator, need_size): + """Allocate from SWA allocator for any page_size. + + SWATokenToKVPoolAllocator.alloc() asserts page_size == 1; for page_size > 1 + we drive the underlying paged allocators directly (mirrors the helper in + test_swa_eviction_boundary.py). Required: need_size is a multiple of + page_size when page_size > 1. + """ + if allocator.page_size == 1: + return allocator.alloc(need_size) + + assert need_size % allocator.page_size == 0, ( + f"page_size > 1 requires page-aligned alloc, got {need_size=} " + f"with {allocator.page_size=}" + ) + if need_size > allocator.full_attn_allocator.available_size(): + return None + if need_size > allocator.swa_attn_allocator.available_size(): + return None + full_indices = allocator.full_attn_allocator.alloc(need_size) + swa_indices = allocator.swa_attn_allocator.alloc(need_size) + assert full_indices is not None and swa_indices is not None + allocator.full_to_swa_index_mapping[full_indices] = swa_indices + return full_indices + + +def _insert_chain(tree, allocator, token_ids): + indices = _swa_alloc(allocator, len(token_ids)) + assert indices is not None + tree.insert(InsertParams(key=RadixKey(token_ids), value=indices)) + match = tree.match_prefix(MatchPrefixParams(key=RadixKey(token_ids))) + return match.last_device_node + + +def _release_swa_lock_chain_in_place(tree, leaf, swa_uuid_for_lock): + # Mirrors dec_swa_lock_only's non-tombstone arm (protected->evictable on + # internal nodes) but skips the leaf-free + tombstone step, to construct + # the post-revival state where SWA was already early-released yet the + # leaf is back in swa_lru_list with full_lock_ref still > 0. + node = leaf + while node is not tree.root_node: + if node.swa_lock_ref > 0: + if node.swa_lock_ref == 1: + tree.swa_protected_size_ -= len(node.value) + tree.swa_evictable_size_ += len(node.value) + node.swa_lock_ref -= 1 + if swa_uuid_for_lock and node.swa_uuid == swa_uuid_for_lock: + break + node = node.parent + + +class TestSWALockReleaseLifecycle(CustomTestCase): + """Each test pins one component of the early-release fix; method names + are prefixed with the API surface they exercise so pytest output groups + them naturally.""" + + def test_dec_swa_lock_only_leaf_tombstones_and_frees(self): + tree, allocator, _ = _build_tree(sliding_window_size=4) + leaf = _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6, 7, 8]) + self.assertEqual(len(leaf.value), 8) + + inc_res = tree.inc_lock_ref(leaf) + swa_uuid = inc_res.swa_uuid_for_lock + self.assertIsNotNone(swa_uuid) + + swa_avail_before = allocator.swa_available_size() + full_avail_before = allocator.full_available_size() + self.assertEqual(leaf.swa_lock_ref, 1) + self.assertEqual(leaf.full_lock_ref, 1) + self.assertFalse(leaf.swa_tombstone) + self.assertTrue(tree.swa_lru_list.in_list(leaf)) + + tree.dec_swa_lock_only(leaf, swa_uuid_for_lock=swa_uuid) + + self.assertTrue(leaf.swa_tombstone) + self.assertFalse(tree.swa_lru_list.in_list(leaf)) + self.assertEqual(leaf.swa_lock_ref, 0) + self.assertEqual( + allocator.swa_available_size(), swa_avail_before + len(leaf.value) + ) + self.assertEqual(leaf.full_lock_ref, 1) + self.assertEqual(allocator.full_available_size(), full_avail_before) + + # sanity_check forbids live locks; release the full half before checking. + tree.dec_lock_ref( + leaf, DecLockRefParams(swa_uuid_for_lock=swa_uuid), skip_swa=True + ) + tree.sanity_check() + + def test_dec_swa_lock_only_internal_no_tombstone_no_free(self): + # Two siblings force an internal node at the shared prefix. + tree, allocator, _ = _build_tree(sliding_window_size=4) + leaf_a = _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6, 7, 8]) + _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6, 7, 9]) + + # Post-split: leaf_a now carries [8] only, parent holds the shared 7. + self.assertEqual(len(leaf_a.value), 1) + internal = leaf_a.parent + self.assertGreater(len(internal.children), 1) + self.assertEqual(len(internal.value), 7) + + inc_res = tree.inc_lock_ref(leaf_a) + swa_uuid = inc_res.swa_uuid_for_lock + # window=4, value 1 (leaf) + 7 (internal): swa lock chain ends at internal. + self.assertEqual(swa_uuid, internal.swa_uuid) + + swa_protected_before = tree.swa_protected_size_ + swa_evictable_before = tree.swa_evictable_size_ + swa_avail_before = allocator.swa_available_size() + + tree.dec_swa_lock_only(leaf_a, swa_uuid_for_lock=swa_uuid) + + self.assertFalse(internal.swa_tombstone) + self.assertTrue(tree.swa_lru_list.in_list(internal)) + self.assertEqual(internal.swa_lock_ref, 0) + self.assertEqual( + tree.swa_protected_size_, swa_protected_before - (len(leaf_a.value) + 7) + ) + self.assertEqual(tree.swa_evictable_size_, swa_evictable_before + 7) + self.assertEqual( + allocator.swa_available_size(), swa_avail_before + len(leaf_a.value) + ) + + tree.dec_lock_ref( + leaf_a, DecLockRefParams(swa_uuid_for_lock=swa_uuid), skip_swa=True + ) + tree.sanity_check() + + def test_dec_lock_ref_skip_swa_true_drops_full_only(self): + tree, allocator, _ = _build_tree(sliding_window_size=4) + leaf = _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6, 7, 8]) + + inc_res = tree.inc_lock_ref(leaf) + swa_uuid = inc_res.swa_uuid_for_lock + + tree.dec_swa_lock_only(leaf, swa_uuid_for_lock=swa_uuid) + self.assertTrue(leaf.swa_tombstone) + self.assertEqual(leaf.full_lock_ref, 1) + + swa_avail_after_release = allocator.swa_available_size() + swa_protected_after_release = tree.swa_protected_size_ + + # Without skip_swa, dec_lock_ref would assert on the swa_tombstone leaf. + tree.dec_lock_ref( + leaf, DecLockRefParams(swa_uuid_for_lock=swa_uuid), skip_swa=True + ) + + self.assertEqual(leaf.full_lock_ref, 0) + self.assertEqual(allocator.swa_available_size(), swa_avail_after_release) + self.assertEqual(tree.swa_protected_size_, swa_protected_after_release) + tree.sanity_check() + + def test_dec_lock_ref_skip_swa_false_drops_both(self): + # Default skip_swa=False must keep legacy behavior intact. + tree, allocator, _ = _build_tree(sliding_window_size=4) + leaf = _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6, 7, 8]) + + inc_res = tree.inc_lock_ref(leaf) + swa_uuid = inc_res.swa_uuid_for_lock + + full_avail_before = allocator.full_available_size() + swa_avail_before = allocator.swa_available_size() + + tree.dec_lock_ref(leaf, DecLockRefParams(swa_uuid_for_lock=swa_uuid)) + + self.assertEqual(leaf.full_lock_ref, 0) + self.assertEqual(leaf.swa_lock_ref, 0) + self.assertEqual(tree.full_protected_size_, 0) + self.assertEqual(tree.swa_protected_size_, 0) + # dec_lock_ref releases locks but doesn't free; eviction does. + self.assertEqual(allocator.full_available_size(), full_avail_before) + self.assertEqual(allocator.swa_available_size(), swa_avail_before) + tree.sanity_check() + + def test_evict_swa_leaf_with_full_lock_tombstones_in_place(self): + # Large window so inc_lock_ref locks the entire SWA chain. + tree, allocator, _ = _build_tree(sliding_window_size=64) + leaf = _insert_chain(tree, allocator, [1, 2, 3, 4]) + self.assertEqual(len(leaf.value), 4) + + inc_res = tree.inc_lock_ref(leaf) + _release_swa_lock_chain_in_place(tree, leaf, inc_res.swa_uuid_for_lock) + + self.assertEqual(leaf.full_lock_ref, 1) + self.assertEqual(leaf.swa_lock_ref, 0) + self.assertFalse(leaf.swa_tombstone) + self.assertTrue(tree.swa_lru_list.in_list(leaf)) + + swa_avail_before = allocator.swa_available_size() + swa_evictable_before = tree.swa_evictable_size_ + + # num_tokens=0 skips the full eviction loop; swa loop hits the new branch. + evict_res = tree.evict(EvictParams(num_tokens=0, swa_num_tokens=4)) + + self.assertGreaterEqual(evict_res.swa_num_tokens_evicted, 4) + self.assertTrue(leaf.swa_tombstone) + self.assertFalse(tree.swa_lru_list.in_list(leaf)) + self.assertEqual(leaf.full_lock_ref, 1) + self.assertEqual( + allocator.swa_available_size(), swa_avail_before + len(leaf.value) + ) + # Full lock prevents _delete_leaf, so the node stays attached. + self.assertIs(leaf.parent.children[leaf.key.child_key(tree.page_size)], leaf) + self.assertEqual( + tree.swa_evictable_size_, swa_evictable_before - len(leaf.value) + ) + + tree.dec_lock_ref( + leaf, + DecLockRefParams(swa_uuid_for_lock=inc_res.swa_uuid_for_lock), + skip_swa=True, + ) + tree.sanity_check() + + def test_delete_leaf_skips_swa_size_on_tombstone(self): + # Tombstone removes the count once; _delete_leaf must not subtract again. + tree, allocator, _ = _build_tree(sliding_window_size=4) + leaf = _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6, 7, 8]) + + inc_res = tree.inc_lock_ref(leaf) + swa_uuid = inc_res.swa_uuid_for_lock + + tree.dec_swa_lock_only(leaf, swa_uuid_for_lock=swa_uuid) + self.assertTrue(leaf.swa_tombstone) + + swa_evictable_before_delete = tree.swa_evictable_size_ + tree.full_lru_list.remove_node(leaf) + tree._delete_leaf(leaf) + + self.assertEqual(tree.swa_evictable_size_, swa_evictable_before_delete) + + def test_dec_swa_lock_only_leaf_page_size_variants(self): + """Single-leaf tombstone+free across all (page_size, window) regimes. + + Sweep covers: + - window multiple of page_size (page_size=2, window=4) + - page_size > window (page_size=8, window=4) + - window not multiple of page (page_size=4, window=6) + + With page_size > 1, _swa_alloc routes through the paged allocators; + free_swa(leaf.value) must release exactly len(leaf.value) tokens + (page-aligned) regardless of how page_size relates to the window. + """ + for page_size, window in [(2, 4), (8, 4), (4, 6)]: + with self.subTest(page_size=page_size, window=window): + tree, allocator, _ = _build_tree( + sliding_window_size=window, + page_size=page_size, + kv_size=max(128, 32 * page_size), + kv_size_swa=max(64, 16 * page_size), + ) + n_tokens = max(window, 2 * page_size) + n_tokens = (n_tokens + page_size - 1) // page_size * page_size + leaf = _insert_chain(tree, allocator, list(range(1, n_tokens + 1))) + self.assertEqual(len(leaf.value), n_tokens) + self.assertEqual(len(leaf.value) % page_size, 0) + + inc_res = tree.inc_lock_ref(leaf) + swa_uuid = inc_res.swa_uuid_for_lock + self.assertIsNotNone( + swa_uuid, + f"inc_lock_ref must reach the window with leaf.value=" + f"{len(leaf.value)} >= window={window}", + ) + + swa_avail_before = allocator.swa_available_size() + full_avail_before = allocator.full_available_size() + + tree.dec_swa_lock_only(leaf, swa_uuid_for_lock=swa_uuid) + + self.assertTrue(leaf.swa_tombstone) + self.assertFalse(tree.swa_lru_list.in_list(leaf)) + self.assertEqual(leaf.swa_lock_ref, 0) + self.assertEqual( + allocator.swa_available_size(), + swa_avail_before + len(leaf.value), + "free_swa must release the leaf's full page-aligned slot count", + ) + self.assertEqual(leaf.full_lock_ref, 1) + self.assertEqual(allocator.full_available_size(), full_avail_before) + + tree.dec_lock_ref( + leaf, + DecLockRefParams(swa_uuid_for_lock=swa_uuid), + skip_swa=True, + ) + tree.sanity_check() + + def test_dec_swa_lock_only_internal_page_size_gt_1(self): + """Internal-node chain release with page_size > 1. + + Two siblings sharing a page-aligned prefix force a radix split on a + page boundary. The swa lock chain therefore spans leaf -> internal, + and dec_swa_lock_only must: + - tombstone the leaf and free len(leaf.value) SWA tokens + - flip the internal node from protected -> evictable (no free, + no tombstone) + """ + page_size, window = 2, 6 + tree, allocator, _ = _build_tree( + sliding_window_size=window, page_size=page_size + ) + # Shared prefix len 4 (2 pages); divergent suffix len 2 (1 page each). + leaf_a = _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6]) + _insert_chain(tree, allocator, [1, 2, 3, 4, 7, 8]) + + self.assertEqual(len(leaf_a.value), 2) + internal = leaf_a.parent + self.assertGreater(len(internal.children), 1) + self.assertEqual(len(internal.value), 4) + + inc_res = tree.inc_lock_ref(leaf_a) + swa_uuid = inc_res.swa_uuid_for_lock + # leaf_a (2) + internal (4) = 6 >= window=6, so uuid stops at internal. + self.assertEqual(swa_uuid, internal.swa_uuid) + + swa_protected_before = tree.swa_protected_size_ + swa_evictable_before = tree.swa_evictable_size_ + swa_avail_before = allocator.swa_available_size() + + tree.dec_swa_lock_only(leaf_a, swa_uuid_for_lock=swa_uuid) + + # Leaf side: tombstoned and pages freed. + self.assertTrue(leaf_a.swa_tombstone) + self.assertFalse(tree.swa_lru_list.in_list(leaf_a)) + self.assertEqual( + allocator.swa_available_size(), + swa_avail_before + len(leaf_a.value), + ) + # Internal side: protected -> evictable, still in lru, no free. + self.assertFalse(internal.swa_tombstone) + self.assertTrue(tree.swa_lru_list.in_list(internal)) + self.assertEqual(internal.swa_lock_ref, 0) + self.assertEqual( + tree.swa_protected_size_, + swa_protected_before - (len(leaf_a.value) + len(internal.value)), + ) + self.assertEqual( + tree.swa_evictable_size_, + swa_evictable_before + len(internal.value), + ) + + tree.dec_lock_ref( + leaf_a, DecLockRefParams(swa_uuid_for_lock=swa_uuid), skip_swa=True + ) + tree.sanity_check() + + def test_full_lifecycle_inc_dec_swa_dec_lock_balances(self): + tree, allocator, _ = _build_tree(sliding_window_size=4) + leaf = _insert_chain(tree, allocator, [1, 2, 3, 4, 5, 6, 7, 8]) + + full_protected0 = tree.full_protected_size_ + swa_protected0 = tree.swa_protected_size_ + full_avail0 = allocator.full_available_size() + swa_avail0 = allocator.swa_available_size() + + inc_res = tree.inc_lock_ref(leaf) + swa_uuid = inc_res.swa_uuid_for_lock + + self.assertGreater(tree.full_protected_size_, full_protected0) + self.assertGreater(tree.swa_protected_size_, swa_protected0) + + tree.dec_swa_lock_only(leaf, swa_uuid_for_lock=swa_uuid) + + self.assertEqual(tree.swa_protected_size_, swa_protected0) + self.assertGreater(tree.full_protected_size_, full_protected0) + + tree.dec_lock_ref( + leaf, DecLockRefParams(swa_uuid_for_lock=swa_uuid), skip_swa=True + ) + + self.assertEqual(tree.full_protected_size_, full_protected0) + self.assertEqual(tree.swa_protected_size_, swa_protected0) + self.assertEqual(allocator.full_available_size(), full_avail0) + self.assertEqual(allocator.swa_available_size(), swa_avail0 + len(leaf.value)) + + tree.sanity_check() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200.py b/test/registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200.py new file mode 100644 index 000000000000..b750d04fda1b --- /dev/null +++ b/test/registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200.py @@ -0,0 +1,82 @@ +"""B200 per-commit CI: DeepSeek-V4-Flash FP4 (LowLatency recipe). + +Launches TP=4 with flashinfer_mxfp4 MoE runner + EAGLE speculative decoding. +Runs 12 ServerSanity probes (correctness, streaming, concurrency, determinism) +plus a GSM8K accuracy gate. + +Registry: stage-c-test-dsv4-4-gpu-b200 (per-commit, 4x B200) +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.server_sanity_kit import ServerSanityMixin +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, + try_cached_model, +) + +register_cuda_ci(est_time=900, suite="stage-c-test-dsv4-4-gpu-b200") + +MODEL = "deepseek-ai/DeepSeek-V4-Flash" +SERVER_LAUNCH_TIMEOUT = 3600 + + +class TestDSV4FlashFP4B200(ServerSanityMixin, CustomTestCase): + """LowLatency recipe: TP=4, FP4 (mxfp4), EAGLE spec decoding.""" + + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(MODEL) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=SERVER_LAUNCH_TIMEOUT, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--chunked-prefill-size", + "4096", + "--disable-flashinfer-autotune", + ], + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, + ) + metrics = run_eval(args) + print(f"[DSV4 Flash FP4 B200] GSM8K {metrics=}") + self.assertGreater(metrics["score"], 0.93) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200_nightly.py b/test/registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200_nightly.py new file mode 100644 index 000000000000..a7ee2bf3d858 --- /dev/null +++ b/test/registered/4-gpu-models/test_deepseek_v4_flash_fp4_b200_nightly.py @@ -0,0 +1,133 @@ +"""B200 nightly CI: DeepSeek-V4-Flash FP4 (Balanced + MaxThroughput recipes). + +Two server configurations exercise the DeepEP all-to-all + DP-attention path +that the per-commit LowLatency test does not cover. + + Balanced: TP=4, DP=4, DeepEP, EAGLE (1 step) + MaxThroughput: TP=4, DP=4, DeepEP, no speculation + +Each class inherits 12 ServerSanity probes plus a GSM8K accuracy gate. + +Registry: nightly-4-gpu-b200 +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.server_sanity_kit import ServerSanityMixin +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, + try_cached_model, +) + +register_cuda_ci(est_time=3600, suite="nightly-4-gpu-b200", nightly=True) + +MODEL = "deepseek-ai/DeepSeek-V4-Flash" +SERVER_LAUNCH_TIMEOUT = 3600 +DEEPEP_CONFIG = '{"normal_dispatch":{"num_sms":96},"normal_combine":{"num_sms":96}}' + +_DEEPEP_ENV = { + "SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK": "1024", +} + + +def _gsm8k_check(test_case): + args = SimpleNamespace( + base_url=test_case.base_url, + model=test_case.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, + ) + metrics = run_eval(args) + print(f"[{type(test_case).__name__}] GSM8K {metrics=}") + test_case.assertGreater(metrics["score"], 0.93) + + +class TestDSV4FlashFP4B200Balanced(ServerSanityMixin, CustomTestCase): + """Balanced recipe: TP=4, DP=4, DeepEP, EAGLE (1-step spec).""" + + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(MODEL) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=SERVER_LAUNCH_TIMEOUT, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "1", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "2", + "--deepep-config", + DEEPEP_CONFIG, + ], + env=_DEEPEP_ENV, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + _gsm8k_check(self) + + +class TestDSV4FlashFP4B200MaxThroughput(ServerSanityMixin, CustomTestCase): + """MaxThroughput recipe: TP=4, DP=4, DeepEP, no speculation.""" + + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(MODEL) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=SERVER_LAUNCH_TIMEOUT, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--dp", + "4", + "--enable-dp-attention", + "--moe-a2a-backend", + "deepep", + "--deepep-config", + DEEPEP_CONFIG, + ], + env=_DEEPEP_ENV, + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + _gsm8k_check(self) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/8-gpu-models/test_deepseek_v4_flash_fp4_h200.py b/test/registered/8-gpu-models/test_deepseek_v4_flash_fp4_h200.py new file mode 100644 index 000000000000..c8c3f32029b4 --- /dev/null +++ b/test/registered/8-gpu-models/test_deepseek_v4_flash_fp4_h200.py @@ -0,0 +1,79 @@ +"""H200 per-commit CI: DeepSeek-V4-Flash FP4 Marlin (LowLatency recipe). + +Launches TP=4 with Marlin FP4 MoE runner + EAGLE speculative decoding. +Runs 12 ServerSanity probes (correctness, streaming, concurrency, determinism) +plus a GSM8K accuracy gate. + +Registry: stage-c-test-dsv4-8-gpu-h200 (per-commit, 8x H200 — only 4 used by TP=4) +""" + +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.server_sanity_kit import ServerSanityMixin +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, + try_cached_model, +) + +register_cuda_ci(est_time=900, suite="stage-c-test-dsv4-8-gpu-h200") + +MODEL = "deepseek-ai/DeepSeek-V4-Flash" +SERVER_LAUNCH_TIMEOUT = 3600 + + +class TestDSV4FlashFP4H200(ServerSanityMixin, CustomTestCase): + """LowLatency recipe: TP=4, Marlin FP4, EAGLE spec decoding.""" + + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(MODEL) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=SERVER_LAUNCH_TIMEOUT, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "marlin", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + ], + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, + ) + metrics = run_eval(args) + print(f"[DSV4 Flash FP4 Marlin H200] GSM8K {metrics=}") + self.assertGreater(metrics["score"], 0.93) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/amd/test_deepseek_v4_fp4.py b/test/registered/amd/test_deepseek_v4_fp4.py index 0cf9cbaf5862..dabca68b5a92 100644 --- a/test/registered/amd/test_deepseek_v4_fp4.py +++ b/test/registered/amd/test_deepseek_v4_fp4.py @@ -56,7 +56,7 @@ "SGLANG_TOPK_TRANSFORM_512_TORCH": "1", "SGLANG_OPT_USE_TILELANG_INDEXER": "true", "SGLANG_HACK_FLASHMLA_BACKEND": "tilelang", - "SGLANG_REASONING_EFFORT": "max", + "SGLANG_DSV4_REASONING_EFFORT": "max", } # FP4 variant: FP4 mixed-precision experts. @@ -82,7 +82,7 @@ def setUpClass(cls): "8", "--disable-radix-cache", "--attention-backend", - "compressed", + "dsv4", "--max-running-requests", "256", "--page-size", diff --git a/test/registered/amd/test_deepseek_v4_fp8.py b/test/registered/amd/test_deepseek_v4_fp8.py index 12d7cd034b88..61803f87d646 100644 --- a/test/registered/amd/test_deepseek_v4_fp8.py +++ b/test/registered/amd/test_deepseek_v4_fp8.py @@ -56,7 +56,7 @@ "SGLANG_TOPK_TRANSFORM_512_TORCH": "1", "SGLANG_OPT_USE_TILELANG_INDEXER": "true", "SGLANG_HACK_FLASHMLA_BACKEND": "tilelang", - "SGLANG_REASONING_EFFORT": "max", + "SGLANG_DSV4_REASONING_EFFORT": "max", } # FP8 variant: dense-FP8 experts via the Triton MoE FP8 path. @@ -82,7 +82,7 @@ def setUpClass(cls): "8", "--disable-radix-cache", "--attention-backend", - "compressed", + "dsv4", "--max-running-requests", "256", "--page-size", diff --git a/test/registered/amd/test_deepseek_v4_pro_fp4.py b/test/registered/amd/test_deepseek_v4_pro_fp4.py index 7ee91300a1c8..9997e12ad96e 100644 --- a/test/registered/amd/test_deepseek_v4_pro_fp4.py +++ b/test/registered/amd/test_deepseek_v4_pro_fp4.py @@ -58,7 +58,7 @@ "SGLANG_TOPK_TRANSFORM_512_TORCH": "1", "SGLANG_OPT_USE_TILELANG_INDEXER": "true", "SGLANG_HACK_FLASHMLA_BACKEND": "tilelang", - "SGLANG_REASONING_EFFORT": "max", + "SGLANG_DSV4_REASONING_EFFORT": "max", } # FP4 variant: FP4 mixed-precision experts. @@ -84,7 +84,7 @@ def setUpClass(cls): "8", "--disable-radix-cache", "--attention-backend", - "compressed", + "dsv4", "--max-running-requests", "256", "--page-size", diff --git a/test/registered/amd/test_deepseek_v4_pro_fp8.py b/test/registered/amd/test_deepseek_v4_pro_fp8.py index c4595aa3dc65..e0ed05f8561f 100644 --- a/test/registered/amd/test_deepseek_v4_pro_fp8.py +++ b/test/registered/amd/test_deepseek_v4_pro_fp8.py @@ -58,7 +58,7 @@ "SGLANG_TOPK_TRANSFORM_512_TORCH": "1", "SGLANG_OPT_USE_TILELANG_INDEXER": "true", "SGLANG_HACK_FLASHMLA_BACKEND": "tilelang", - "SGLANG_REASONING_EFFORT": "max", + "SGLANG_DSV4_REASONING_EFFORT": "max", } # FP8 variant: dense-FP8 experts via the Triton MoE FP8 path. @@ -84,7 +84,7 @@ def setUpClass(cls): "8", "--disable-radix-cache", "--attention-backend", - "compressed", + "dsv4", "--max-running-requests", "256", "--page-size", diff --git a/test/registered/unit/entrypoints/openai/test_serving_chat.py b/test/registered/unit/entrypoints/openai/test_serving_chat.py index fa13c354f2d4..69f3ff4d4893 100644 --- a/test/registered/unit/entrypoints/openai/test_serving_chat.py +++ b/test/registered/unit/entrypoints/openai/test_serving_chat.py @@ -46,7 +46,7 @@ def __init__(self): reasoning_parser=None, stream_response_default_include_usage=False, ) - # Mock hf_config for _use_dpsk_v32_encoding check + # Mock hf_config for _resolve_chat_encoding_spec check mock_hf_config = Mock() mock_hf_config.architectures = ["LlamaForCausalLM"] self.model_config.hf_config = mock_hf_config @@ -685,21 +685,204 @@ def test_dpsk_v32_encoding_path(self): mock_hf_config.architectures = ["DeepseekV32ForCausalLM"] tm.model_config.hf_config = mock_hf_config - # Case 1: No chat template + DeepSeek V3.2 arch -> should use dpsk encoding + # Case 1: No chat template + DeepSeek V3.2 arch -> should use dsv32 encoding tm.tokenizer.chat_template = None serving_chat = OpenAIServingChat(tm, TemplateManager()) - self.assertTrue(serving_chat.use_dpsk_v32_encoding) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv32") - # Case 2: Chat template exists -> should NOT use dpsk encoding + # Case 2: Chat template exists -> should NOT use dsv32 encoding tm.tokenizer.chat_template = "some template" serving_chat = OpenAIServingChat(tm, TemplateManager()) - self.assertFalse(serving_chat.use_dpsk_v32_encoding) + self.assertIsNone(serving_chat.chat_encoding_spec) - # Case 3: Not DeepSeek V3.2 architecture -> should NOT use dpsk encoding + # Case 3: Not DeepSeek V3.2 architecture -> should NOT use dsv32 encoding tm.tokenizer.chat_template = None mock_hf_config.architectures = ["LlamaForCausalLM"] serving_chat = OpenAIServingChat(tm, TemplateManager()) - self.assertFalse(serving_chat.use_dpsk_v32_encoding) + self.assertIsNone(serving_chat.chat_encoding_spec) + + # Case 4: DeepseekV4 arch -> always dsv4, even with chat_template + # (release ships a stale V3 jinja we deliberately override). + mock_hf_config.architectures = ["DeepseekV4ForCausalLM"] + tm.tokenizer.chat_template = "stale v3 jinja" + serving_chat = OpenAIServingChat(tm, TemplateManager()) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv4") + + tm.tokenizer.chat_template = None + serving_chat = OpenAIServingChat(tm, TemplateManager()) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv4") + + # ------------- dsv4 task + latest_reminder ------------- + def test_dsv4_task_field_schema(self): + """Top-level `task` accepts the 6 DS task tokens and rejects others.""" + for valid in ("action", "query", "authority", "domain", "title", "read_url"): + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "hi"}], + task=valid, + ) + self.assertEqual(req.task, valid) + + # None / unset is fine + self.assertIsNone(self.basic_req.task) + + # Bogus value rejected at validation time + from pydantic import ValidationError + + with self.assertRaises(ValidationError): + ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "hi"}], + task="bogus", + ) + + def test_latest_reminder_role_accepted(self): + """`latest_reminder` is a first-class message role on generic param.""" + from sglang.srt.entrypoints.openai.protocol import ( + ChatCompletionMessageGenericParam, + ) + + msg = ChatCompletionMessageGenericParam( + role="latest_reminder", content="Be terse." + ) + self.assertEqual(msg.role, "latest_reminder") + + # Full request with reminder before user parses cleanly. + req = ChatCompletionRequest( + model="x", + messages=[ + {"role": "latest_reminder", "content": "Be terse."}, + {"role": "user", "content": "Hi"}, + ], + ) + self.assertEqual(req.messages[0].role, "latest_reminder") + self.assertEqual(req.messages[1].role, "user") + + def test_attach_task_to_last_user_message(self): + """Helper attaches task to the nearest user/developer message.""" + from sglang.srt.entrypoints.openai import encoding_dsv4 + + messages = [{"role": "user", "content": "Hi"}] + encoding_dsv4.attach_task_to_last_user_message(messages, "domain") + self.assertEqual(messages[0]["task"], "domain") + + # Prefers the LAST user message across a multi-turn conversation. + messages = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "second"}, + ] + encoding_dsv4.attach_task_to_last_user_message(messages, "query") + self.assertNotIn("task", messages[0]) + self.assertEqual(messages[2]["task"], "query") + + # `developer` role is treated like `user` (matches encoder semantics). + messages = [{"role": "developer", "content": "dev"}] + encoding_dsv4.attach_task_to_last_user_message(messages, "authority") + self.assertEqual(messages[0]["task"], "authority") + + # No user/developer present -> raises. + with self.assertRaises(ValueError): + encoding_dsv4.attach_task_to_last_user_message( + [{"role": "system", "content": "s"}], "domain" + ) + + def test_dsv4_content_parts_list_normalized(self): + """OpenAI list-of-parts content flattens to text before reaching the encoder.""" + from sglang.srt.entrypoints.openai import encoding_dsv4 + from sglang.srt.parser.jinja_template_utils import ( + process_content_for_template_format, + ) + + req = ChatCompletionRequest( + model="x", + messages=[ + { + "role": "user", + "content": [{"type": "text", "text": "say hi"}], + } + ], + ) + messages = [m.model_dump() for m in req.messages] + # Mirror the boundary normalization _process_messages does for any + # non-None chat_encoding_spec. + for i, msg in enumerate(messages): + if isinstance(msg.get("content"), list): + messages[i] = process_content_for_template_format( + msg, "string", [], [], [], [] + ) + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|User|>say hi", out) + + # Multiple text parts concat with single space; non-text parts dropped. + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "x"}}, + ], + } + ] + for i, msg in enumerate(messages): + if isinstance(msg.get("content"), list): + messages[i] = process_content_for_template_format( + msg, "string", [], [], [], [] + ) + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|User|>describe", out) + self.assertNotIn("image_url", out) + + def test_dsv4_task_and_reminder_encode_end_to_end(self): + """Task + latest_reminder plumb through to the dsv4 encoder correctly.""" + from sglang.srt.entrypoints.openai import encoding_dsv4 + + # 1) task='domain' in chat mode -> `<|domain|>` appended, no Assistant + # prefix (this is a single-shot classification, not a chat turn). + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "What is SGLang?"}], + task="domain", + ) + messages = [m.model_dump() for m in req.messages] + encoding_dsv4.attach_task_to_last_user_message(messages, req.task) + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|domain|>", out) + self.assertTrue(out.rstrip().endswith("<|domain|>")) + self.assertNotIn("<|Assistant|>", out) + + # 2) task='action' in thinking mode -> Assistant + + <|action|> + # (action is the one task that still runs a reasoning pass). + req = ChatCompletionRequest( + model="x", + messages=[{"role": "user", "content": "Hi"}], + task="action", + ) + messages = [m.model_dump() for m in req.messages] + encoding_dsv4.attach_task_to_last_user_message(messages, req.task) + out = encoding_dsv4.encode_messages(messages, thinking_mode="thinking") + self.assertIn("<|Assistant|>", out) + self.assertIn("", out) + self.assertTrue(out.rstrip().endswith("<|action|>")) + + # 3) latest_reminder preceding user -> reminder renders before user, + # Assistant prefix still comes after user. + req = ChatCompletionRequest( + model="x", + messages=[ + {"role": "latest_reminder", "content": "Be terse."}, + {"role": "user", "content": "Hello"}, + ], + ) + messages = [m.model_dump() for m in req.messages] + out = encoding_dsv4.encode_messages(messages, thinking_mode="chat") + self.assertIn("<|latest_reminder|>Be terse.", out) + self.assertIn("<|User|>Hello", out) + self.assertLess( + out.index("<|latest_reminder|>"), + out.index("<|User|>"), + ) + self.assertIn("<|Assistant|>", out) def test_streaming_abort_yields_error(self): """Test that an abort finish reason during streaming correctly yields an error and stops.""" diff --git a/test/registered/unit/managers/test_hisparse_unit.py b/test/registered/unit/managers/test_hisparse_unit.py index e3f3d54a5376..b0eba08a2788 100644 --- a/test/registered/unit/managers/test_hisparse_unit.py +++ b/test/registered/unit/managers/test_hisparse_unit.py @@ -45,6 +45,7 @@ def _make_req(rid="test-req-0", origin_input_ids=None, output_ids=None): origin_input_ids=origin_input_ids, output_ids=output_ids, fill_ids=origin_input_ids + output_ids, + seqlen=len(origin_input_ids) + len(output_ids), req_pool_idx=None, kv_allocated_len=0, kv_committed_len=0, diff --git a/test/registered/unit/managers/test_prefill_adder.py b/test/registered/unit/managers/test_prefill_adder.py index 182b78938673..22cc1f598a78 100644 --- a/test/registered/unit/managers/test_prefill_adder.py +++ b/test/registered/unit/managers/test_prefill_adder.py @@ -443,18 +443,25 @@ def test_mixed_chunk_prefill_budgets(self): self.assertEqual(result3, AddReqResult.OTHER) def _build_hybrid_swa_chunked_req( - self, *, page_size, rem_swa, rem_chunk=2048, extend_input_len=500 + self, + *, + page_size, + rem_swa, + rem_chunk=2048, + extend_input_len=500, + is_hybrid_swa=True, + full_available=100_000, ): self.mock_token_allocator.swa_available_size.return_value = rem_swa - self.mock_token_allocator.full_available_size.return_value = 100_000 - self.mock_token_allocator.available_size.return_value = 100_000 + self.mock_token_allocator.full_available_size.return_value = full_available + self.mock_token_allocator.available_size.return_value = full_available self.mock_tree_cache.sliding_window_size = 128 adder = self.create_adder( self.create_running_batch(), page_size=page_size, rem_chunk_tokens=rem_chunk, ) - adder.is_hybrid_swa = True + adder.is_hybrid_swa = is_hybrid_swa req = self.create_mock_req("chunked", priority=0, max_new_tokens=128) req.extend_input_len = extend_input_len @@ -499,6 +506,44 @@ def test_add_chunked_req_hybrid_swa_defers_when_swa_below_page(self): self.assertEqual(req.extend_input_len, original_len) self.assertEqual(len(adder.can_run_list), 0) + def test_swa_budget_for_req(self): + cases = [ + # (extend, rem_chunk, window, page, expected, label) + (64, None, 128, 16, 128 + 16, "no_cap_floor_active"), + (200, None, 256, 32, 256 + 32, "no_cap_floor_active_other_dims"), + (300, None, 128, 16, 300 + 16, "no_cap_floor_inactive"), + (200, 50, 64, 8, 64 + 8, "cap_binds_then_floor"), + (300, 500, 64, 64, 300 + 64, "cap_does_not_bind"), + (0, None, 128, 16, 128 + 16, "extend_zero_floor_only"), + ] + for extend, rem_chunk, window, page, expected, label in cases: + with self.subTest(label=label): + self.mock_tree_cache.sliding_window_size = window + adder = self.create_adder( + self.create_running_batch(), + page_size=page, + rem_chunk_tokens=rem_chunk, + ) + self.assertEqual(adder._swa_budget_for_req(extend), expected) + + def test_add_chunked_req_non_hybrid_no_swa_reservation(self): + # Non-hybrid path: the SWA-pool reservation must NOT apply, otherwise + # the fix would regress non-SWA models. + PAGE_SIZE = 16 + adder, req = self._build_hybrid_swa_chunked_req( + page_size=PAGE_SIZE, + rem_swa=10, + rem_chunk=500, + extend_input_len=200, + is_hybrid_swa=False, + full_available=300, + ) + + result = adder.add_chunked_req(req) + self.assertIsNone(result) + req.set_extend_input_len.assert_called_once_with(200) + self.assertIn(req, adder.can_run_list) + if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/mem_cache/test_swa_eviction_boundary.py b/test/registered/unit/mem_cache/test_swa_eviction_boundary.py index 751dccee2d99..2802f8b0a209 100644 --- a/test/registered/unit/mem_cache/test_swa_eviction_boundary.py +++ b/test/registered/unit/mem_cache/test_swa_eviction_boundary.py @@ -110,6 +110,7 @@ def _make_req(req_pool_idx, token_ids, cache_protected_len, tree): extra_key=None, last_node=tree.root_node, swa_uuid_for_lock=None, + swa_prefix_lock_released=False, prefix_indices=torch.tensor([], dtype=torch.int64, device=tree.device), _kv_committed_len=len(token_ids), ) diff --git a/test/registered/unit/mem_cache/test_swa_unittest.py b/test/registered/unit/mem_cache/test_swa_unittest.py index d3a6e2c421fd..154a869a5bef 100644 --- a/test/registered/unit/mem_cache/test_swa_unittest.py +++ b/test/registered/unit/mem_cache/test_swa_unittest.py @@ -2,7 +2,9 @@ import torch +from sglang.srt.environ import envs from sglang.srt.mem_cache.base_prefix_cache import ( + DecLockRefParams, EvictParams, EvictResult, InsertParams, @@ -16,80 +18,113 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci +from sglang.test.test_utils import CustomTestCase register_cuda_ci(est_time=9, suite="stage-b-test-1-gpu-large") register_amd_ci(est_time=10, suite="stage-b-test-1-gpu-small-amd") -class TestSWA(unittest.TestCase): - class _DummyReq: - def __init__(self): - self._kv_committed_len = 0 - - def pop_committed_kv_cache(self): - return self._kv_committed_len - - def _build_swa_tree( - self, - is_eagle: bool, - page_size: int = 1, - req_size: int = 8, - max_context_len: int = 64, - kv_size: int = 64, - kv_size_swa: int = 32, - sliding_window_size: int = 4, - ): - head_num = 8 - head_dim = 128 - num_layers = 24 - global_interval = 4 - dtype = torch.bfloat16 - device = get_device() - full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] - full_attention_layer_ids_set = set(full_attention_layer_ids) - swa_attention_layer_ids = [ - i for i in range(num_layers) if i not in full_attention_layer_ids_set - ] - - req_to_token_pool = ReqToTokenPool( - size=req_size, - max_context_len=max_context_len, - device=device, - enable_memory_saver=False, - ) - kv_pool = SWAKVPool( - size=kv_size, - size_swa=kv_size_swa, - page_size=page_size, - dtype=dtype, - head_num=head_num, - head_dim=head_dim, - swa_attention_layer_ids=swa_attention_layer_ids, - full_attention_layer_ids=full_attention_layer_ids, - enable_kvcache_transpose=False, - device=device, - ) - allocator = SWATokenToKVPoolAllocator( - size=kv_size, - size_swa=kv_size_swa, +class _DummyReq: + def __init__(self): + self._kv_committed_len = 0 + self.swa_prefix_lock_released = False + + def pop_committed_kv_cache(self): + return self._kv_committed_len + + +def _build_swa_tree( + is_eagle: bool, + page_size: int = 1, + req_size: int = 8, + max_context_len: int = 64, + kv_size: int = 64, + kv_size_swa: int = 32, + sliding_window_size: int = 4, +): + head_num = 8 + head_dim = 128 + num_layers = 24 + global_interval = 4 + dtype = torch.bfloat16 + device = get_device() + full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] + full_attention_layer_ids_set = set(full_attention_layer_ids) + swa_attention_layer_ids = [ + i for i in range(num_layers) if i not in full_attention_layer_ids_set + ] + + req_to_token_pool = ReqToTokenPool( + size=req_size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=False, + ) + kv_pool = SWAKVPool( + size=kv_size, + size_swa=kv_size_swa, + page_size=page_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + swa_attention_layer_ids=swa_attention_layer_ids, + full_attention_layer_ids=full_attention_layer_ids, + enable_kvcache_transpose=False, + device=device, + ) + allocator = SWATokenToKVPoolAllocator( + size=kv_size, + size_swa=kv_size_swa, + page_size=page_size, + dtype=dtype, + device=device, + kvcache=kv_pool, + need_sort=False, + ) + tree = SWARadixCache( + params=CacheInitParams( + req_to_token_pool=req_to_token_pool, + token_to_kv_pool_allocator=allocator, page_size=page_size, - dtype=dtype, - device=device, - kvcache=kv_pool, - need_sort=False, - ) - tree = SWARadixCache( - params=CacheInitParams( - req_to_token_pool=req_to_token_pool, - token_to_kv_pool_allocator=allocator, - page_size=page_size, - disable=False, - is_eagle=is_eagle, - sliding_window_size=sliding_window_size, - ), - ) - return tree, allocator, req_to_token_pool + disable=False, + is_eagle=is_eagle, + sliding_window_size=sliding_window_size, + ), + ) + return tree, allocator, req_to_token_pool + + +def _swa_alloc(allocator, need_size): + """SWA-pool alloc that also works for page_size > 1 (built-in alloc asserts page_size == 1).""" + if allocator.page_size == 1: + return allocator.alloc(need_size) + assert need_size % allocator.page_size == 0 + full_indices = allocator.full_attn_allocator.alloc(need_size) + swa_indices = allocator.swa_attn_allocator.alloc(need_size) + assert full_indices is not None and swa_indices is not None + allocator.full_to_swa_index_mapping[full_indices] = swa_indices + return full_indices + + +def _insert(tree, allocator, token_ids): + indices = _swa_alloc(allocator, len(token_ids)) + assert indices is not None + tree.insert(InsertParams(key=RadixKey(token_ids), value=indices)) + + +def _insert_chain(tree, allocator, token_ids): + _insert(tree, allocator, token_ids) + match = tree.match_prefix(MatchPrefixParams(key=RadixKey(token_ids))) + return match.last_device_node + + +def _expected_tail_size(window: int, page_size: int) -> int: + """Mirror of _maybe_split_leaf_for_swa_lock's tail_size formula.""" + return (window + page_size - 1) // page_size * page_size + + +class TestSWA(unittest.TestCase): @classmethod def setUpClass(cls): pass @@ -475,10 +510,10 @@ def test_swa_radix_cache_eagle(self): self.assertEqual(list(last_node.key), [(5, 60), (60, 70)]) def test_swa_cache_finished_req_eagle_uses_cache_protected_len_and_bigram_key(self): - tree, allocator, req_to_token_pool = self._build_swa_tree(is_eagle=True) + tree, allocator, req_to_token_pool = _build_swa_tree(is_eagle=True) # Case 1: is_insert=True should pass bigram key and use cache_protected_len. - req = self._DummyReq() + req = _DummyReq() req.req_pool_idx = 0 req.origin_input_ids = [1, 2, 3, 4, 5, 6] req.output_ids = [] @@ -513,7 +548,7 @@ def wrapped_insert(params): # Case 2: is_insert=False should free [cache_protected_len:page_aligned_len] # even when len(prefix_indices) is intentionally larger. - req2 = self._DummyReq() + req2 = _DummyReq() req2.req_pool_idx = 1 req2.origin_input_ids = [11, 12, 13, 14, 15, 16] req2.output_ids = [] @@ -546,5 +581,112 @@ def wrapped_free(indices): self.assertEqual(freed_lens, [4, 1]) +# Optimization: SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT. +# Splits a freshly-inserted leaf at the (page-aligned) sliding-window +# boundary so a future inc_lock_ref protects only ~sliding_window_size SWA +# tokens instead of the whole chunked-prefill chain. +class TestSWASplitLeafOnInsert(CustomTestCase): + def _insert_and_lock(self, *, window, page_size, leaf_len, flag_on): + tree, allocator, _ = _build_swa_tree( + is_eagle=False, + kv_size=128, + kv_size_swa=64, + sliding_window_size=window, + page_size=page_size, + ) + token_ids = list(range(leaf_len)) + with envs.SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT.override(flag_on): + leaf = _insert_chain(tree, allocator, token_ids) + result = tree.inc_lock_ref(leaf) + return tree, leaf, result + + def test_flag_off_protects_full_leaf(self): + tree, leaf, _ = self._insert_and_lock( + window=4, page_size=1, leaf_len=12, flag_on=False + ) + self.assertEqual(len(leaf.value), 12) + self.assertEqual(tree.swa_protected_size_, 12) + + def test_flag_on_caps_protection_at_window(self): + # (window, page_size, leaf_len, expected_tail_size); leaf_len picked + # > tail_size and page-aligned for page_size > 1. + cases = [ + (4, 1, 12, 4), + (4, 1, 5, 4), + (1, 1, 5, 1), + (4, 2, 12, 4), + (8, 2, 12, 8), + (4, 4, 12, 4), + # window NOT page-aligned -> tail rounds up to page boundary. + (3, 2, 12, 4), + (5, 4, 12, 8), + (3, 4, 12, 4), + ] + for window, page_size, leaf_len, expected_tail in cases: + with self.subTest(window=window, page_size=page_size, leaf_len=leaf_len): + self.assertEqual(_expected_tail_size(window, page_size), expected_tail) + tree, leaf, _ = self._insert_and_lock( + window=window, + page_size=page_size, + leaf_len=leaf_len, + flag_on=True, + ) + self.assertEqual(len(leaf.value), expected_tail) + self.assertEqual(tree.swa_protected_size_, expected_tail) + + def test_flag_on_no_split_when_leaf_within_window(self): + # leaf_len <= tail_size: split must no-op. + cases = [ + (4, 1, 4), + (4, 1, 3), + (4, 2, 4), + (3, 2, 4), + (8, 2, 4), + (4, 4, 4), + ] + for window, page_size, leaf_len in cases: + with self.subTest(window=window, page_size=page_size, leaf_len=leaf_len): + tree, leaf, _ = self._insert_and_lock( + window=window, + page_size=page_size, + leaf_len=leaf_len, + flag_on=True, + ) + self.assertEqual(len(leaf.value), leaf_len) + self.assertEqual(tree.swa_protected_size_, leaf_len) + + def test_match_prefix_returns_full_chain_after_split(self): + tree, allocator, _ = _build_swa_tree( + is_eagle=False, + kv_size=128, + kv_size_swa=64, + sliding_window_size=4, + page_size=1, + ) + token_ids = list(range(12)) + with envs.SGLANG_OPT_SWA_SPLIT_LEAF_ON_INSERT.override(True): + inserted_leaf = _insert_chain(tree, allocator, token_ids) + self.assertEqual(len(inserted_leaf.value), 4) + match = tree.match_prefix(MatchPrefixParams(key=RadixKey(token_ids))) + self.assertEqual(match.device_indices.shape[0], 12) + self.assertIs(match.last_device_node, inserted_leaf) + + def test_dec_lock_ref_after_split_balances_to_zero(self): + tree, leaf, result = self._insert_and_lock( + window=4, page_size=1, leaf_len=12, flag_on=True + ) + self.assertEqual(tree.swa_protected_size_, 4) + self.assertEqual(tree.full_protected_size_, 12) + + tree.dec_lock_ref( + leaf, + params=DecLockRefParams(swa_uuid_for_lock=result.swa_uuid_for_lock), + ) + + self.assertEqual(tree.swa_protected_size_, 0) + self.assertEqual(tree.full_protected_size_, 0) + tree.sanity_check() + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/utils/test_weight_checker.py b/test/registered/unit/utils/test_weight_checker.py index 17ec362719df..1ee0442ed61b 100644 --- a/test/registered/unit/utils/test_weight_checker.py +++ b/test/registered/unit/utils/test_weight_checker.py @@ -188,14 +188,14 @@ def test_no_quant_yields_raw_with_should_compare_true(self): b = torch.randn(4) raw = {"a.weight": a, "b.bias": b} _assert_triples_close( - _postprocess_tensors(raw), + _postprocess_tensors(raw, set()), [("a.weight", True, a), ("b.bias", True, b)], ) def test_weight_alone_without_scale_inv_does_not_trigger_dequant(self): w = torch.randn(4) raw = {"x.weight": w} - _assert_triples_close(_postprocess_tensors(raw), [("x.weight", True, w)]) + _assert_triples_close(_postprocess_tensors(raw, set()), [("x.weight", True, w)]) # --- non-persistent buffer skip --- @@ -207,7 +207,7 @@ def test_skips_cos_sin_cache_substring(self): "model.layers.0.weight": plain, } _assert_triples_close( - _postprocess_tensors(raw), + _postprocess_tensors(raw, set()), [ ("model.rotary_emb.cos_sin_cache", False, cache), ("model.layers.0.weight", True, plain), @@ -217,14 +217,14 @@ def test_skips_cos_sin_cache_substring(self): def test_skips_inv_freq_substring(self): t = torch.randn(4) _assert_triples_close( - _postprocess_tensors({"model.rotary_emb.inv_freq": t}), + _postprocess_tensors({"model.rotary_emb.inv_freq": t}, set()), [("model.rotary_emb.inv_freq", False, t)], ) def test_skips_weight_fp32_substring(self): t = torch.randn(4) _assert_triples_close( - _postprocess_tensors({"model.layers.0.mlp.gate._weight_fp32": t}), + _postprocess_tensors({"model.layers.0.mlp.gate._weight_fp32": t}, set()), [("model.layers.0.mlp.gate._weight_fp32", False, t)], ) @@ -232,7 +232,7 @@ def test_substring_match_not_endswith(self): # Pattern can appear anywhere in the name, not just at the end. t = torch.randn(4) _assert_triples_close( - _postprocess_tensors({"weird.cos_sin_cache.foo.bar": t}), + _postprocess_tensors({"weird.cos_sin_cache.foo.bar": t}, set()), [("weird.cos_sin_cache.foo.bar", False, t)], ) @@ -248,7 +248,7 @@ def test_fp8_quant_pair_with_int32_scale_dequants_via_ue8m0(self): qweight, sf_fp32, block_size=[128, 128], dtype=torch.bfloat16 ) _assert_triples_close( - _postprocess_tensors(raw), + _postprocess_tensors(raw, set()), [ ("x.weight", True, expected_dequant), ("x.weight", False, qweight), @@ -264,7 +264,7 @@ def test_fp8_quant_pair_with_fp32_scale_dequants_directly(self): qweight, sf_fp32, block_size=[128, 128], dtype=torch.bfloat16 ) _assert_triples_close( - _postprocess_tensors(raw), + _postprocess_tensors(raw, set()), [ ("x.weight", True, expected_dequant), ("x.weight", False, qweight), @@ -285,7 +285,7 @@ def test_fp8_quant_pair_yield_order_alongside_other_entries(self): ) # All dequant entries come first, then a raw pass over every key. _assert_triples_close( - _postprocess_tensors(raw), + _postprocess_tensors(raw, set()), [ ("x.weight", True, expected_dequant), ("x.weight", False, qweight), @@ -299,7 +299,7 @@ def test_only_scale_without_weight_does_not_trigger_dequant(self): # through as a normal entry with should_compare=True. s = torch.zeros(1, 1, dtype=torch.int32) _assert_triples_close( - _postprocess_tensors({"x.weight_scale_inv": s}), + _postprocess_tensors({"x.weight_scale_inv": s}, set()), [("x.weight_scale_inv", True, s)], ) diff --git a/test/run_suite.py b/test/run_suite.py index 7fe58ad9167a..c09b1fae8fc9 100644 --- a/test/run_suite.py +++ b/test/run_suite.py @@ -54,6 +54,8 @@ "stage-c-test-8-gpu-b200", "stage-c-test-deepep-4-gpu-h100", "stage-c-test-deepep-8-gpu-h200", + "stage-c-test-dsv4-4-gpu-b200", + "stage-c-test-dsv4-8-gpu-h200", ], HWBackend.NPU: [ "stage-a-test-1-gpu-small",