diff --git a/python/sglang/jit_kernel/.clang-format b/python/sglang/jit_kernel/.clang-format index 56acfb8b8f5c..690cc3fea0d7 100644 --- a/python/sglang/jit_kernel/.clang-format +++ b/python/sglang/jit_kernel/.clang-format @@ -17,7 +17,7 @@ PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name IncludeCategories: - Regex: '^$' Priority: 0 - - Regex: '^$' + - Regex: '^$' Priority: 2 - Regex: '^$' Priority: 1 diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh new file mode 100644 index 000000000000..d91ad1d0e685 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c128.cuh @@ -0,0 +1,524 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan128 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int32_t kTileElements = 2; +/// \brief Each warp will handle this many elements (split along 128) +constexpr int32_t kElementsPerWarp = 8; +constexpr uint32_t kNumWarps = 128 / kElementsPerWarp; +constexpr uint32_t kBlockSize = device::kWarpThreads * kNumWarps; + +/// \brief Need to reduce register usage to increase occupancy +#define C128_KERNEL __global__ __launch_bounds__(kBlockSize, 2) + +struct Compress128DecodeParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress128PrefillParams { + /** + * \brief Shape: `[num_indices, 128, head_dim * 2]` \n + * last dimension layout: + * | kv current | score current | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 2]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[128, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]`*/ + const int32_t* __restrict__ load_indices; + /** \brief The following part is plan info. */ + + const Plan128* __restrict__ compress_plan; + const Plan128* __restrict__ write_plan; + + uint32_t num_compress; + uint32_t num_write; +}; + +struct Compress128SharedBuffer { + using Storage = device::AlignedVector; + Storage data[kNumWarps][device::kWarpThreads + 1]; // padding to avoid bank conflict + SGL_DEVICE Storage& operator()(uint32_t warp_id, uint32_t lane_id) { + return data[warp_id][lane_id]; + } + SGL_DEVICE float& operator()(uint32_t warp_id, uint32_t lane_id, uint32_t tile_id) { + return data[warp_id][lane_id][tile_id]; + } +}; + +template +SGL_DEVICE void c128_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos, + const uint32_t lane_id) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 2; + const auto gmem = tile::Memory{lane_id, kWarpThreads}; + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv | [1] = score | + Storage kv_score[2]; +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 2; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c128_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t window_len, + const uint32_t warp_id, + const uint32_t lane_id) { + using namespace device; + + const auto element_size = head_dim * 2; + const auto score_offset = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory{lane_id, kWarpThreads}; + StorageIn kv[kElementsPerWarp]; + StorageIn score[kElementsPerWarp]; + StorageIn bias[kElementsPerWarp]; + const int32_t warp_offset = warp_id * kElementsPerWarp; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const int32_t j = i + warp_offset; + bias[i] = gmem_in.load(score_bias + j * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < kElementsPerWarp; ++i) { + const int32_t j = i + warp_offset; + const InFloat* src; + __builtin_assume(j < 128); + if (j < window_len) { + src = kv_score_buf + j * element_size; + } else { + /// NOTE: k in [-127, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = j - 127; + src = kv_score_src + k * element_size; + } + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + /// NOTE: part 2: safe online softmax + weighted sum + using TmpStorage = typename Compress128SharedBuffer::Storage; + __shared__ Compress128SharedBuffer s_local_val_max; + __shared__ Compress128SharedBuffer s_local_exp_sum; + __shared__ Compress128SharedBuffer s_local_product; + + TmpStorage tmp_val_max; + TmpStorage tmp_exp_sum; + TmpStorage tmp_product; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[kElementsPerWarp]; + +#pragma unroll + for (int32_t j = 0; j < kElementsPerWarp; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < kElementsPerWarp; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + tmp_val_max[i] = max_value; + tmp_exp_sum[i] = sum_exp_value; + tmp_product[i] = sum_product; + } + + // naturally aligned, so no bank conflict + s_local_val_max(warp_id, lane_id) = tmp_val_max; + s_local_exp_sum(warp_id, lane_id) = tmp_exp_sum; + s_local_product(warp_id, lane_id) = tmp_product; + + __syncthreads(); + + /// NOTE: part 3: online softmax + /// NOTE: We have `kTileElements * kWarpThreads * kNumWarps` values to reduce + /// each reduce will consume `kNumWarps` threads (use partial warp reduction) + constexpr uint32_t kReductionCount = kTileElements * kWarpThreads * kNumWarps; + constexpr uint32_t kIteration = kReductionCount / kBlockSize; + +#pragma unroll + for (uint32_t i = 0; i < kIteration; ++i) { + /// NOTE: Range `[0, kTileElements * kWarpThreads * kNumWarps)` + const uint32_t j = i * kBlockSize + warp_id * kWarpThreads + lane_id; + /// NOTE: Range `[0, kNumWarps)` + const uint32_t local_warp_id = j % kNumWarps; + /// NOTE: Range `[0, kTileElements * kWarpThreads)` + const uint32_t local_elem_id = j / kNumWarps; + /// NOTE: Range `[0, kTileElements)` + const uint32_t local_tile_id = local_elem_id % kTileElements; + /// NOTE: Range `[0, kWarpThreads)` + const uint32_t local_lane_id = local_elem_id / kTileElements; + /// NOTE: each warp will access the whole tile (all `kTileElements`) + /// and for different lanes, the memory access only differ in `local_warp_id` + /// so there's no bank conflict in shared memory access. + static_assert(kTileElements * kNumWarps == kWarpThreads, "TODO: support other configs"); + const auto local_val_max = s_local_val_max(local_warp_id, local_lane_id, local_tile_id); + const auto local_exp_sum = s_local_exp_sum(local_warp_id, local_lane_id, local_tile_id); + const auto local_product = s_local_product(local_warp_id, local_lane_id, local_tile_id); + const auto global_val_max = warp::reduce_max(local_val_max); + const auto rescale = expf(local_val_max - global_val_max); + const auto global_exp_sum = warp::reduce_sum(local_exp_sum * rescale); + const auto final_scale = rescale / global_exp_sum; + const auto global_product = warp::reduce_sum(local_product * final_scale); + kv_out[local_elem_id] = cast(global_product); + } +} + +template +C128_KERNEL void flash_c128_decode(const __grid_constant__ Compress128DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, batch_size // decode info + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + const uint32_t global_bid = blockIdx.x / kNumSplit; // batch id + const uint32_t global_sid = blockIdx.x % kNumSplit; // split id + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: the write must be visible to the subsequent c128_forward, + /// so only the last warp can write to HBM + /// In addition, `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + 127` + if (warp_id == kNumWarps - 1) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 127) % 128, lane_id); + } + if (seq_len % 128 == 0) { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, /*window_len=*/128, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +// compress kernel +template +C128_KERNEL void flash_c128_prefill(const __grid_constant__ Compress128PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 64 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 2; + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, load_indices, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + const uint32_t warp_id = threadIdx.x / kWarpThreads; + const uint32_t lane_id = threadIdx.x % kWarpThreads; + + uint32_t global_id; + if constexpr (kWrite) { + // for write kernel, we use global warp_id to dispatch work + global_id = (blockIdx.x * blockDim.x + threadIdx.x) / kWarpThreads; + } else { + // for compress kernel, we use block id to dispatch work + global_id = blockIdx.x; // block id + } + const uint32_t global_pid = global_id / kNumSplit; // plan id + const uint32_t global_sid = global_id % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const auto indices_ptr = kWrite ? indices : load_indices; + + const int64_t split_offset = global_sid * kTileDim; + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + const int32_t index = indices_ptr[global_bid]; + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 128) + split_offset; + + PDLWaitPrimary(); + + // only responsible for the compress part + if constexpr (kWrite) { + c128_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 128, lane_id); + } else { + c128_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, window_len, warp_id, lane_id); + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress128Kernel { + static constexpr auto decode_kernel = flash_c128_decode; + template + static constexpr auto prefill_kernel = flash_c128_prefill; + static constexpr auto prefill_c_kernel = prefill_kernel; + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr int64_t kTileDim = kTileElements * device::kWarpThreads; // 64 + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWriteBlockSize = 128; + static constexpr uint32_t kWarpsPerWriteBlock = kWriteBlockSize / device::kWarpThreads; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional /* UNUSED */) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device) + .verify(seq_lens); + + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress128DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .batch_size = batch_size, + }; + + const uint32_t num_blocks = batch_size * kNumSplit; + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(decode_kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({-1, 128, kHeadDim * 2}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 2}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({128, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + // might be needed for prefill write + const auto load_indices = extra.value_or(indices); + TensorMatcher({B}) // [read_positions] + .with_dtype() + .with_device(device_) + .verify(load_indices); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress128PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .load_indices = static_cast(load_indices.data_ptr()), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + + constexpr auto kBlockSize_C = kBlockSize; + constexpr auto kBlockSize_W = kWriteBlockSize; + if (const auto num_c_blocks = num_c * kNumSplit) { + LaunchKernel(num_c_blocks, kBlockSize_C, device) // + .enable_pdl(kUsePDL)(prefill_c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerWriteBlock)) { + LaunchKernel(num_w_blocks, kBlockSize_W, device) // + .enable_pdl(kUsePDL)(prefill_w_kernel, params); + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh new file mode 100644 index 000000000000..145ab1fb081e --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/c4.cuh @@ -0,0 +1,549 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +#include + +namespace { + +using Plan4 = device::compress::PrefillPlan; +using IndiceT = int32_t; + +/// \brief Each thread will handle this many elements (split along head_dim) +constexpr int kTileElements = 4; + +/// \brief Need to improve register usage to reduce latency +#define C4_KERNEL __global__ __launch_bounds__(128, 4) + +enum class PageMode { + RingBuffer = 8, + Page4Align = 4, +}; + +struct alignas(16) C4IndexBundle { + int32_t load_first_page; + int32_t load_second_page; + int32_t write_first_page; + int32_t last_position; +}; + +struct Compress4DecodeParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[batch_size, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[batch_size, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, ]` */ + const IndiceT* __restrict__ seq_lens; + /** \brief Shape: `[batch_size, 1]` */ + const int32_t* __restrict__ extra; + /** \NOTE: `batch_size` <= `num_indices` */ + uint32_t batch_size; +}; + +struct Compress4PrefillParams { + /** + * \brief Shape: `[num_indices, 8, head_dim * 4]` \n + * last dimension layout: + * | kv overlap | kv | score overlap | score | + */ + void* __restrict__ kv_score_buffer; + /** \brief Shape: `[num_q_tokens, head_dim * 4]` */ + const void* __restrict__ kv_score_input; + /** \brief Shape: `[num_q_tokens, head_dim]` */ + void* __restrict__ kv_compressed_output; + /** \brief Shape: `[8, head_dim]` (called `ape`) */ + const void* __restrict__ score_bias; + /** \brief Shape: `[batch_size, ]`*/ + const IndiceT* __restrict__ indices; + /** \brief Shape: `[batch_size, 4]` */ + const C4IndexBundle* __restrict__ extra; + /** \brief The following part is plan info. */ + + const Plan4* __restrict__ compress_plan; + const Plan4* __restrict__ write_plan; + uint32_t num_compress; + uint32_t num_write; +}; + +template +SGL_DEVICE void c4_write( + T* kv_score_buf, // + const T* kv_score_src, + const int64_t head_dim, + const int32_t write_pos) { + using namespace device; + + using Storage = AlignedVector; + const auto element_size = head_dim * 4; + const auto gmem = tile::Memory::warp(); + kv_score_buf += write_pos * element_size; + + /// NOTE: Layout | [0] = kv overlap | [1] = kv | [2] = score overlap | [3] = score | + Storage kv_score[4]; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv_score[i] = gmem.load(kv_score_src + head_dim * i); + } +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + gmem.store(kv_score_buf + head_dim * i, kv_score[i]); + } +} + +template +SGL_DEVICE void c4_forward( + const InFloat* kv_score_buf, + const InFloat* kv_score_src, + OutFloat* kv_out, + const InFloat* score_bias, + const int64_t head_dim, + const int32_t seq_len, + const int32_t window_len, + [[maybe_unused]] const InFloat* kv_score_overlap_buf = nullptr) { + using namespace device; + + const auto element_size = head_dim * 4; + const auto score_offset = head_dim * 2; + const auto overlap_stride = head_dim; + + /// NOTE: part 1: load kv + score + using StorageIn = AlignedVector; + const auto gmem_in = tile::Memory::warp(); + StorageIn kv[8]; + StorageIn score[8]; + StorageIn bias[8]; + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + bias[i] = gmem_in.load(score_bias + i * head_dim); + } + +#pragma unroll + for (int32_t i = 0; i < 8; ++i) { + const bool is_overlap = i < 4; + const InFloat* src; + if (i < window_len) { + /// NOTE: `seq_len` must be a multiple of 4 here + if constexpr (kPaged) { + const auto kv_score_ptr = is_overlap ? kv_score_overlap_buf : kv_score_buf; + const int32_t k = i % 4; + src = kv_score_ptr + k * element_size; + } else { + const int32_t k = (seq_len + i) % 8; + src = kv_score_buf + k * element_size; + } + } else { + /// NOTE: k in [-7, 0]. We'll load from the ragged `kv_score_src` + const int32_t k = i - 7; + src = kv_score_src + k * element_size; + } + src += (is_overlap ? 0 : overlap_stride); + kv[i] = gmem_in.load(src); + score[i] = gmem_in.load(src + score_offset); + } + + if (seq_len == 4) { + [[unlikely]]; + constexpr float kFloatNegInf = -1e9f; +#pragma unroll + for (int32_t i = 0; i < 4; ++i) { + kv[i].fill(cast(0.0f)); + score[i].fill(cast(kFloatNegInf)); + } + } + + /// NOTE: part 2: safe online softmax + weighted sum + using StorageOut = AlignedVector; + const auto gmem_out = tile::Memory::warp(); + StorageOut result; + +#pragma unroll + for (int32_t i = 0; i < kTileElements; ++i) { + float score_fp32[8]; + +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + score_fp32[j] = cast(score[j][i]) + cast(bias[j][i]); + } + + float max_value = score_fp32[0]; + float sum_exp_value = 0.0f; + +#pragma unroll + for (int32_t j = 1; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + max_value = fmaxf(max_value, fp32_score); + } + + float sum_product = 0.0f; +#pragma unroll + for (int32_t j = 0; j < 8; ++j) { + const auto fp32_score = score_fp32[j]; + const auto exp_score = expf(fp32_score - max_value); + sum_product += cast(kv[j][i]) * exp_score; + sum_exp_value += exp_score; + } + + result[i] = cast(sum_product / sum_exp_value); + } + + gmem_out.store(kv_out, result); +} + +template +C4_KERNEL void flash_c4_decode(const __grid_constant__ Compress4DecodeParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, seq_lens, extra, batch_size // decode info + ] = params; + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_bid = global_wid / kNumSplit; // batch id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + if (global_bid >= batch_size) return; + + const int32_t index = indices[global_bid]; + const int32_t seq_len = seq_lens[global_bid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + global_bid * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + global_bid * kHeadDim + split_offset; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + + PDLWaitPrimary(); + + /// NOTE: `position` = `seq_len - 1`. To avoid underflow, we use `seq_len + page_size - 1` + if constexpr (kMode == PageMode::Page4Align) { + const auto index_prev = extra[global_bid]; + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 3) % 4); + if (seq_len % 4 == 0) { + const auto kv_overlap = kv_buf + (index_prev - index) * (kElementSize * 4); + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, 8, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/(seq_len + 7) % 8); + if (seq_len % 4 == 0) { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, /*window_size=*/8); + } + } + + PDLTriggerSecondary(); +} + +template +C4_KERNEL void flash_c4_prefill(const __grid_constant__ Compress4PrefillParams params) { + using namespace device; + + constexpr int64_t kTileDim = kTileElements * kWarpThreads; // 128 + constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + constexpr int64_t kElementSize = kHeadDim * 4; // `* 4` due to overlap transform + score + static_assert(kHeadDim % kTileDim == 0, "Head dim must be multiple of tile dim"); + + const auto& [ + _kv_score_buffer, _kv_score_input, _kv_compressed_output, _score_bias, // kv score + indices, extra, compress_plan, write_plan, num_compress, num_write // prefill plan + ] = params; + + const uint32_t global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t global_wid = global_tid / kWarpThreads; // warp id + const uint32_t global_pid = global_wid / kNumSplit; // plan id + const uint32_t global_sid = global_wid % kNumSplit; // split id + + /// NOTE: compiler can optimize this if-else at compile time + const auto num_plans = kWrite ? num_write : num_compress; + const auto plan_ptr = kWrite ? write_plan : compress_plan; + if (global_pid >= num_plans) return; + + const auto& [ragged_id, global_bid, position, window_len] = plan_ptr[global_pid]; + const int64_t split_offset = global_sid * kTileDim; + + // kv score + const auto kv_score_buffer = static_cast(_kv_score_buffer); + + // kv input + const auto kv_score_input = static_cast(_kv_score_input); + const auto kv_src = kv_score_input + ragged_id * kElementSize + split_offset; + + // kv output + const auto kv_compressed_output = static_cast(_kv_compressed_output); + const auto kv_out = kv_compressed_output + ragged_id * kHeadDim + split_offset; + + if (ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + + // score bias (ape) + const auto score_bias = static_cast(_score_bias) + split_offset; + const auto seq_len = position + 1; + const int32_t index = indices[global_bid]; + + PDLWaitPrimary(); + + if constexpr (kMode == PageMode::Page4Align) { + const auto write_second_page = index; + const auto [load_first_page, load_second_page, write_first_page, last_pos] = extra[global_bid]; + if constexpr (kWrite) { + int32_t index; + if (position < static_cast(last_pos)) { + index = write_first_page; + } else { + index = write_second_page; + } + const auto kv_buf = kv_score_buffer + index * (kElementSize * 4) + split_offset; + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 4); + } else { + int32_t index_overlap, index_normal; + if (window_len <= 4) { + index_overlap = load_second_page; + index_normal = load_second_page; // not used + } else { + index_overlap = load_first_page; + index_normal = load_second_page; + } + const auto kv_buf = kv_score_buffer + index_normal * (kElementSize * 4) + split_offset; + const auto kv_overlap = kv_score_buffer + index_overlap * (kElementSize * 4) + split_offset; + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len, kv_overlap); + } + } else { + static_assert(kMode == PageMode::RingBuffer, "Unsupported PageMode"); + const auto kv_buf = kv_score_buffer + index * (kElementSize * 8) + split_offset; + if constexpr (kWrite) { + c4_write(kv_buf, kv_src, kHeadDim, /*write_pos=*/position % 8); + } else { + c4_forward(kv_buf, kv_src, kv_out, score_bias, kHeadDim, seq_len, window_len); + } + } + + PDLTriggerSecondary(); +} + +template +struct FlashCompress4Kernel { + template + static constexpr auto decode_kernel = flash_c4_decode; + template + static constexpr auto prefill_kernel = flash_c4_prefill; + template + static constexpr auto prefill_c_kernel = prefill_kernel; + template + static constexpr auto prefill_w_kernel = prefill_kernel; + static constexpr uint32_t kBlockSize = 128; + static constexpr uint32_t kTileDim = kTileElements * device::kWarpThreads; + static constexpr uint32_t kNumSplit = kHeadDim / kTileDim; + static constexpr uint32_t kWarpsPerBlock = kBlockSize / device::kWarpThreads; + + using Self = FlashCompress4Kernel; + + static void run_decode( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::Optional extra) { + using namespace host; + + // this should not happen in practice + auto B = SymbolicSize{"batch_size"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({B, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({B, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({B}) // seq lens + .with_dtype() + .with_device(device_) + .verify(seq_lens); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto params = Compress4DecodeParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extra = static_cast(extra_ptr), + .batch_size = batch_size, + }; + const auto kernel = extra_ptr != nullptr ? decode_kernel // + : decode_kernel; + const uint32_t num_blocks = div_ceil(batch_size * kNumSplit, kWarpsPerBlock); + LaunchKernel(num_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(kernel, params); + } + + static void run_prefill( + const tvm::ffi::TensorView kv_score_buffer, + const tvm::ffi::TensorView kv_score_input, + const tvm::ffi::TensorView kv_compressed_output, + const tvm::ffi::TensorView ape, + const tvm::ffi::TensorView indices, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const tvm::ffi::Optional extra) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto N = SymbolicSize{"num_q_tokens"}; + auto X = SymbolicSize{"compress_tokens"}; + auto Y = SymbolicSize{"write_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + const auto extra_ptr = _get_extra_pointer(B, device_, extra, /*is_prefill=*/true); + const auto page_size = extra_ptr != nullptr ? 4 : 8; + + TensorMatcher({-1, page_size, kHeadDim * 4}) // kv score + .with_dtype() + .with_device(device_) + .verify(kv_score_buffer); + TensorMatcher({N, kHeadDim * 4}) // kv score input + .with_dtype() + .with_device(device_) + .verify(kv_score_input); + TensorMatcher({N, kHeadDim}) // kv compressed output + .with_dtype() + .with_device(device_) + .verify(kv_compressed_output); + TensorMatcher({8, kHeadDim}) // ape + .with_dtype() + .with_device(device_) + .verify(ape); + TensorMatcher({B}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + TensorMatcher({X, compress::kPrefillPlanDim}) // compress plan + .with_dtype() + .with_device(device_) + .verify(compress_plan); + TensorMatcher({Y, compress::kPrefillPlanDim}) // write plan + .with_dtype() + .with_device(device_) + .verify(write_plan); + + const auto device = device_.unwrap(); + const auto batch_size = static_cast(B.unwrap()); + const auto num_q_tokens = static_cast(N.unwrap()); + const auto num_c = static_cast(X.unwrap()); + const auto num_w = static_cast(Y.unwrap()); + const auto params = Compress4PrefillParams{ + .kv_score_buffer = kv_score_buffer.data_ptr(), + .kv_score_input = kv_score_input.data_ptr(), + .kv_compressed_output = kv_compressed_output.data_ptr(), + .score_bias = ape.data_ptr(), + .indices = static_cast(indices.data_ptr()), + .extra = static_cast(extra_ptr), + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .num_compress = num_c, + .num_write = num_w, + }; + RuntimeCheck(num_q_tokens >= batch_size, "num_q_tokens must be >= batch_size"); + RuntimeCheck(num_q_tokens >= std::max(num_c, num_w), "invalid prefill plan"); + if (const auto num_c_blocks = div_ceil(num_c * kNumSplit, kWarpsPerBlock)) { + const auto c_kernel = extra_ptr != nullptr ? prefill_c_kernel // + : prefill_c_kernel; + LaunchKernel(num_c_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(c_kernel, params); + } + if (const auto num_w_blocks = div_ceil(num_w * kNumSplit, kWarpsPerBlock)) { + const auto w_kernel = extra_ptr != nullptr ? prefill_w_kernel // + : prefill_w_kernel; + LaunchKernel(num_w_blocks, kBlockSize, device) // + .enable_pdl(kUsePDL)(w_kernel, params); + } + } + + // some auxiliary functions + private: + static const void* _get_extra_pointer( + host::SymbolicSize& B, // batch_size + host::SymbolicDevice& device, + const tvm::ffi::Optional& extra, + bool is_prefill = false) { + // only have value when using page-aligned mode + if (!extra.has_value()) return nullptr; + const auto& extra_tensor = extra.value(); + /// NOTE: the metadata layout is different for prefill and decode: + /// for prefill, last 4 are: + /// load overlap | load normal | write overlap | last written page + /// for decode, last 1 is the write (also load) overlap + host::TensorMatcher({B, is_prefill ? 4 : 1}) // extra tensor + .with_dtype() + .with_device(device) + .verify(extra_tensor); + const auto data_ptr = extra_tensor.data_ptr(); + host::RuntimeCheck(data_ptr != nullptr, "extra tensor data ptr is null"); + if (is_prefill) { + static_assert(alignof(C4IndexBundle) == 16); + host::RuntimeCheck(std::bit_cast(data_ptr) % 16 == 0, "extra tensor is not properly aligned"); + } + return data_ptr; + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh new file mode 100644 index 000000000000..46acaa9c46b3 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/common.cuh @@ -0,0 +1,208 @@ +#include +#include + +#include + +#include + +namespace host::compress { + +using PlanResult = tvm::ffi::Tuple; + +struct CompressParams { + PrefillPlan* __restrict__ compress_plan; + PrefillPlan* __restrict__ write_plan; + const int64_t* __restrict__ seq_lens; + const int64_t* __restrict__ extend_lens; + uint32_t batch_size; + uint32_t num_tokens; + uint32_t compress_ratio; + bool is_overlap; +}; + +inline constexpr uint32_t kBlockSize = 1024; + +#define PLAN_KERNEL __global__ __launch_bounds__(kBlockSize, 1) inline + +PLAN_KERNEL void plan_prefill_cuda(const __grid_constant__ CompressParams params) { + const auto &[ + compress_plan, write_plan, seq_lens, extend_lens, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + __shared__ uint32_t compress_counter; + __shared__ uint32_t write_counter; + + uint32_t batch_id = 0; + uint32_t counter = 0; + uint32_t extend_len = extend_lens[0]; + + const auto tid = threadIdx.x; + if (tid == 0) { + compress_counter = 0; + write_counter = 0; + } + __syncthreads(); + + for (uint32_t i = tid; i < num_tokens; i += blockDim.x) { + const uint32_t ragged_id = i; + uint32_t j = ragged_id - counter; + while (j >= extend_len) { + j -= extend_len; + batch_id += 1; + if (batch_id >= batch_size) [[unlikely]] + break; + counter += extend_len; + extend_len = extend_lens[batch_id]; + } + if (batch_id >= batch_size) [[unlikely]] + break; + const uint32_t seq_len = seq_lens[batch_id]; + const uint32_t extend_len = extend_lens[batch_id]; + const uint32_t prefix_len = seq_len - extend_len; + const uint32_t ratio = compress_ratio * (1 + is_overlap); + const uint32_t window_len = j + 1 < ratio ? ratio - (j + 1) : 0; + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = ragged_id, + .batch_id = batch_id, + .position = position, + .window_len = window_len, + }; + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + if ((position + 1) % compress_ratio == 0) { + const auto write_pos = atomicAdd(&compress_counter, 1); + compress_plan[write_pos] = plan; + } + if (position >= start_write_pos) { + const auto write_pos = atomicAdd(&write_counter, 1); + write_plan[write_pos] = plan; + } + } + __syncthreads(); + constexpr auto kInvalid = static_cast(-1); + const auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + const auto compress_count = compress_counter; + const auto write_count = write_counter; + for (uint32_t i = compress_count + tid; i < num_tokens; i += blockDim.x) { + compress_plan[i] = kInvalidPlan; + } + for (uint32_t i = write_count + tid; i < num_tokens; i += blockDim.x) { + write_plan[i] = kInvalidPlan; + } +} + +inline PlanResult plan_prefill_host(const CompressParams& params, const bool use_cuda_graph) { + const auto &[ + compress_ptr, write_ptr, seq_lens_ptr, extend_lens_ptr, // pointers + batch_size, num_tokens, compress_ratio, is_overlap // values + ] = params; + + uint32_t counter = 0; + uint32_t compress_counter = 0; + uint32_t write_counter = 0; + const auto ratio = compress_ratio * (1 + is_overlap); + for (const auto i : irange(batch_size)) { + const uint32_t seq_len = seq_lens_ptr[i]; + const uint32_t extend_len = extend_lens_ptr[i]; + const uint32_t prefix_len = seq_len - extend_len; + RuntimeCheck(0 < extend_len && extend_len <= seq_len); + /// NOTE: `start_write_pos` must be a multiple of `compress_ratio` + const uint32_t start_write_pos = [seq_len, compress_ratio, is_overlap] { + const uint32_t pos = seq_len / compress_ratio * compress_ratio; + if (!is_overlap) return pos; + /// NOTE: to avoid unsigned integer underflow, don't use `pos - compress_ratio` + return pos >= compress_ratio ? pos - compress_ratio : 0; + }(); + /// NOTE: `position` is within [prefix_len, seq_len) + for (const auto j : irange(extend_len)) { + const uint32_t position = prefix_len + j; + const auto plan = PrefillPlan{ + .ragged_id = counter + j, + .batch_id = i, + .position = position, + .window_len = ratio - std::min(j + 1, ratio), + }; + RuntimeCheck(plan.is_valid(compress_ratio, is_overlap), "Internal error!"); + if ((position + 1) % compress_ratio == 0) { + compress_ptr[compress_counter++] = plan; + } + if (position >= start_write_pos) { + write_ptr[write_counter++] = plan; + } + } + counter += extend_len; + } + RuntimeCheck(counter == num_tokens, "input size ", counter, " != num_q_tokens ", num_tokens); + if (!use_cuda_graph) return PlanResult{compress_counter, write_counter}; + constexpr auto kInvalid = static_cast(-1); + constexpr auto kInvalidPlan = PrefillPlan{kInvalid, kInvalid, kInvalid, kInvalid}; + for (const auto i : irange(compress_counter, num_tokens)) { + compress_ptr[i] = kInvalidPlan; + } + for (const auto i : irange(write_counter, num_tokens)) { + write_ptr[i] = kInvalidPlan; + } + return PlanResult{num_tokens, num_tokens}; +} + +inline PlanResult plan_prefill( + const tvm::ffi::TensorView extend_lens, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView compress_plan, + const tvm::ffi::TensorView write_plan, + const uint32_t compress_ratio, + const bool is_overlap, // for overlap transform, we have to keep 1 more extra window + const bool use_cuda_graph) { + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_tokens"}; + auto device = SymbolicDevice{}; + const bool is_cuda = [&] { + if (extend_lens.device().device_type == kDLCUDA) { + device.set_options(); + return true; + } else { + device.set_options(); + return false; + } + }(); + TensorMatcher({N}) // extend_lens and seq_lens + .with_dtype() + .with_device(device) + .verify(extend_lens) + .verify(seq_lens); + TensorMatcher({M, kPrefillPlanDim}) // compress_plan and write_plan + .with_dtype() + .with_device(device) + .verify(compress_plan) + .verify(write_plan); + + const auto params = CompressParams{ + .compress_plan = static_cast(compress_plan.data_ptr()), + .write_plan = static_cast(write_plan.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .extend_lens = static_cast(extend_lens.data_ptr()), + .batch_size = static_cast(N.unwrap()), + .num_tokens = static_cast(M.unwrap()), + .compress_ratio = compress_ratio, + .is_overlap = is_overlap, + }; + + if (!is_cuda) return plan_prefill_host(params, use_cuda_graph); + /// NOTE: cuda kernel plan is naturally compatible with cuda graph + LaunchKernel(1, kBlockSize, device.unwrap())(plan_prefill_cuda, params); + return PlanResult{params.num_tokens, params.num_tokens}; +} + +} // namespace host::compress + +namespace { + +[[maybe_unused]] +constexpr auto& plan_compress_prefill = host::compress::plan_prefill; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh new file mode 100644 index 000000000000..d3953578b925 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/fused_norm_rope.cuh @@ -0,0 +1,254 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include + +namespace { + +using Plan = device::compress::PrefillPlan; + +/// \brief common block size for memory-bound kernel +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedNormRopeParams { + void* __restrict__ input; + const void* __restrict__ weight; + float eps; + uint32_t num_works; + const void* __restrict__ handle; + const float* __restrict__ freqs_cis; + uint32_t compress_ratio; +}; + +enum class ForwardMode { + CompressExtend = 0, + CompressDecode = 1, + DefaultForward = 2, +}; + +template +__global__ void fused_norm_rope(const __grid_constant__ FusedNormRopeParams params) { + using namespace device; + using enum ForwardMode; + + constexpr int64_t kMaxVecSize = 16 / sizeof(DType); + constexpr int64_t kVecSize = std::min(kMaxVecSize, kHeadDim / kWarpThreads); + constexpr int64_t kLocalSize = kHeadDim / (kWarpThreads * kVecSize); + constexpr int64_t kRopeVecSize = kRopeDim / (kWarpThreads * 2); + constexpr uint32_t kRopeSize = kRopeDim / kVecSize; + static_assert(kHeadDim % (kWarpThreads * kVecSize) == 0); + static_assert(kLocalSize * kVecSize * kWarpThreads == kHeadDim); + static_assert(kRopeDim % (kWarpThreads * 2) == 0); + static_assert(kRopeDim % (kVecSize * kLocalSize) == 0); + static_assert(kRopeSize <= kWarpThreads); + static_assert(kRopeVecSize == 1, "only support rope dim = 64"); + + const auto& [ + _input, _weight, eps, num_works, // norm + handle, freqs_cis, compress_ratio // rope + ] = params; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto work_id = blockIdx.x * kNumWarps + warp_id; + + if (work_id >= num_works) return; + + DType* input; + int32_t position; + if constexpr (kMode == CompressExtend) { + const auto plan = static_cast(handle)[work_id]; + input = static_cast(_input) + plan.ragged_id * kHeadDim; + position = plan.position + 1 - compress_ratio; + if (plan.ragged_id == 0xFFFFFFFF) [[unlikely]] + return; + } else if constexpr (kMode == CompressDecode) { + input = static_cast(_input) + work_id * kHeadDim; + const auto seq_len = static_cast(handle)[work_id]; + if (seq_len % compress_ratio != 0) return; + position = seq_len - compress_ratio; + } else if constexpr (kMode == DefaultForward) { + input = static_cast(_input) + work_id * kHeadDim; + position = static_cast(handle)[work_id]; + } else { + static_assert(host::dependent_false_v, "Unsupported Mode"); + } + + using Storage = AlignedVector; + __shared__ Storage s_rope_input[kNumWarps][kRopeSize]; + + // prefetch freq + const auto mem_freq = tile::Memory::warp(); + const auto freq = mem_freq.load(freqs_cis + position * kRopeDim); + + PDLWaitPrimary(); + + // part 1: norm + { + const auto gmem = tile::Memory::warp(); + Storage input_vec[kLocalSize]; + Storage weight_vec[kLocalSize]; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + input_vec[i] = gmem.load(input, i); + } + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + weight_vec[i] = gmem.load(_weight, i); + } + + float sum_of_squares = 0.0f; +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + sum_of_squares += fp32_input * fp32_input; + } + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto norm_factor = math::rsqrt(sum_of_squares / kHeadDim + eps); + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { +#pragma unroll + for (int j = 0; j < kVecSize; ++j) { + const auto fp32_input = cast(input_vec[i][j]); + const auto fp32_weight = cast(weight_vec[i][j]); + input_vec[i][j] = cast(fp32_input * norm_factor * fp32_weight); + } + } + + const bool is_rope_lane = lane_id >= kWarpThreads - kRopeSize; + +#pragma unroll + for (int i = 0; i < kLocalSize; ++i) { + if (i == kLocalSize - 1 && is_rope_lane) { + const auto rope_id = lane_id - (kWarpThreads - kRopeSize); + s_rope_input[warp_id][rope_id] = input_vec[i]; + } else { + gmem.store(input, input_vec[i], i); + } + } + + __syncwarp(); + } + + // part 2: rope + { + // mem elem = DType x 2 + using DTypex2_t = packed_t; + const auto mem_elem = tile::Memory::warp(); + const auto elem = mem_elem.load(s_rope_input[warp_id]); + const auto [x_real, x_imag] = cast(elem); + const auto [freq_real, freq_imag] = freq; + const fp32x2_t output = { + x_real * freq_real - x_imag * freq_imag, + x_real * freq_imag + x_imag * freq_real, + }; + mem_elem.store(input + (kHeadDim - kRopeDim), cast(output)); + } + + PDLTriggerSecondary(); +} + +template +struct FusedNormRopeKernel { + template + static constexpr auto fused_kernel = fused_norm_rope; + + static void forward( + const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView handle, + const tvm::ffi::TensorView freqs_cis, + int32_t _mode, + float eps, + uint32_t compress_ratio) { + using namespace host; + using enum ForwardMode; + + const auto mode = static_cast(_mode); + + auto B = SymbolicSize{"num_q_tokens"}; + auto N = SymbolicSize{"num_compress_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, kHeadDim}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({kHeadDim}) // weight + .with_dtype() + .with_device(device_) + .verify(weight); + TensorMatcher({-1, kRopeDim}) // freqs_cis + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + switch (mode) { + case CompressExtend: + TensorMatcher({N, compress::kPrefillPlanDim}) // plan + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case CompressDecode: + TensorMatcher({N}) // seq_len + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio > 0); + break; + case DefaultForward: + TensorMatcher({N}) // position + .with_dtype() + .with_device(device_) + .verify(handle); + RuntimeCheck(compress_ratio == 0); + break; + default: + Panic("unsupported forward mode: ", static_cast(mode)); + } + + // launch kernel + const auto num_compress_tokens = static_cast(N.unwrap()); + if (num_compress_tokens == 0) return; + const auto params = FusedNormRopeParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .eps = eps, + .num_works = num_compress_tokens, + .handle = handle.data_ptr(), + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .compress_ratio = compress_ratio, + }; + const auto num_blocks = div_ceil(num_compress_tokens, kNumWarps); + using KernelType = std::decay_t)>; + static constexpr KernelType kernel_table[3] = { + [static_cast(CompressExtend)] = fused_kernel, + [static_cast(CompressDecode)] = fused_kernel, + [static_cast(DefaultForward)] = fused_kernel, + }; + const auto kernel = kernel_table[static_cast(mode)]; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh new file mode 100644 index 000000000000..cf422c58bc1e --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/hash_topk.cuh @@ -0,0 +1,137 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +[[maybe_unused]] +SGL_DEVICE float act_sqrt_softplus(float x) { + const float softplus = fmaxf(x, 0.0f) + log1pf(expf(-fabsf(x))); + return sqrtf(softplus); +} + +struct MoEHashTopKParams { + const float* __restrict__ router_logits; + const int64_t* __restrict__ input_id; + const int32_t* __restrict__ tid2eid; + int32_t* __restrict__ topk_ids; + float* __restrict__ topk_weights; + uint32_t num_tokens; + uint32_t topk; + uint32_t num_routed_experts; + uint32_t num_shared_experts; + float routed_scaling_factor; +}; + +template +__global__ void moe_hash_topk_fused(const MoEHashTopKParams __grid_constant__ params) { + using namespace device; + const auto& [ + router_logits, input_id, tid2eid, topk_ids, topk_weights, // pointers + num_tokens, topk, num_routed_experts, num_shared_experts, routed_scaling_factor] = + params; + + const uint32_t topk_fused = topk + num_shared_experts; + const uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t warp_id = tid / kWarpThreads; + const uint32_t lane_id = tid % kWarpThreads; + if (warp_id >= num_tokens) return; + // we can safely prefetch the token id + const auto token_id = input_id[warp_id]; + + PDLWaitPrimary(); + + float routed_weight = 0.0f; + int32_t expert_id = 0; + if (lane_id < topk) { + expert_id = tid2eid[token_id * topk + lane_id]; + routed_weight = Fn(router_logits[warp_id * num_routed_experts + expert_id]); + } + + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk_fused) { + const bool is_shared = lane_id >= topk; + const auto output_offset = warp_id * topk_fused + lane_id; + topk_ids[output_offset] = is_shared ? num_routed_experts + lane_id - topk : expert_id; + topk_weights[output_offset] = is_shared ? 1.0f / routed_scaling_factor : routed_weight / routed_sum; + } + + PDLTriggerSecondary(); +} + +template +struct HashTopKKernel { + static constexpr auto kernel = moe_hash_topk_fused; + + static void + run(const tvm::ffi::TensorView router_logits, + const tvm::ffi::TensorView input_id, + const tvm::ffi::TensorView tid2eid, + const tvm::ffi::TensorView topk_weights, + const tvm::ffi::TensorView topk_ids, + float routed_scaling_factor) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto E = SymbolicSize{"num_routed_experts"}; + auto K = SymbolicSize{"topk_fused"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, E}) // + .with_dtype() + .with_device(device) + .verify(router_logits); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(input_id); + TensorMatcher({-1, -1}) // + .with_dtype() + .with_device(device) + .verify(tid2eid); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_weights); + TensorMatcher({N, K}) // + .with_dtype() + .with_device(device) + .verify(topk_ids); + + const auto num_tokens = static_cast(N.unwrap()); + const auto topk_fused = static_cast(K.unwrap()); + const auto topk = static_cast(tid2eid.size(1)); + const auto shared_experts = topk_fused - topk; + RuntimeCheck(topk <= topk_fused, "HashTopKKernel requires topk <= topk_fused"); + RuntimeCheck(topk_fused <= device::kWarpThreads, "HashTopKKernel requires topk_fused <= warp size"); + + const auto params = MoEHashTopKParams{ + .router_logits = static_cast(router_logits.data_ptr()), + .input_id = static_cast(input_id.data_ptr()), + .tid2eid = static_cast(tid2eid.data_ptr()), + .topk_ids = static_cast(topk_ids.data_ptr()), + .topk_weights = static_cast(topk_weights.data_ptr()), + .num_tokens = num_tokens, + .topk = topk, + .num_routed_experts = static_cast(E.unwrap()), + .num_shared_experts = shared_experts, + .routed_scaling_factor = routed_scaling_factor, + }; + const auto kBlockSize = 128u; + const auto kNumWarps = kBlockSize / device::kWarpThreads; + const auto num_blocks = div_ceil(num_tokens, kNumWarps); + LaunchKernel(num_blocks, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh new file mode 100644 index 000000000000..aefec24372a8 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/hisparse_transfer.cuh @@ -0,0 +1,82 @@ +#include +#include + +#include + +#include + +#include +#include + +#include + +namespace { + +/// NOTE: for offload to cpu kernel, we use persistent kernel +inline constexpr uint32_t kBlockSize = 1024; +inline constexpr uint32_t kBlockQuota = 4; + +#define OFFLOAD_KERNEL __global__ __launch_bounds__(kBlockSize, 1) + +struct OffloadParams { + void** gpu_caches; + void** cpu_caches; + const int64_t* gpu_indices; + const int64_t* cpu_indices; + uint32_t num_items; + uint32_t num_layers; +}; + +OFFLOAD_KERNEL void offload_to_cpu(const __grid_constant__ OffloadParams params) { + using namespace device::hisparse; + const auto [gpu_caches, cpu_caches, gpu_indices, cpu_indices, num_items, num_layers] = params; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + constexpr auto kNumWarps = (kBlockSize / 32) * kBlockQuota; + for (auto i = global_tid / 32; i < num_items; i += kNumWarps) { + const int32_t gpu_index = gpu_indices[i]; + const int32_t cpu_index = cpu_indices[i]; + for (auto j = 0u; j < num_layers; ++j) { + const auto gpu_cache = gpu_caches[j]; + const auto cpu_cache = cpu_caches[j]; + transfer_item( + /*dst_cache=*/cpu_cache, + /*src_cache=*/gpu_cache, + /*dst_index=*/cpu_index, + /*src_index=*/gpu_index); + } + } +} + +[[maybe_unused]] +void hisparse_transfer( + tvm::ffi::TensorView gpu_ptrs, + tvm::ffi::TensorView cpu_ptrs, + tvm::ffi::TensorView gpu_indices, + tvm::ffi::TensorView cpu_indices) { + using namespace host; + auto N = SymbolicSize{"num_items"}; + auto L = SymbolicSize{"num_layers"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({L}) // 1D cache pointers + .with_dtype() + .with_device(device_) + .verify(gpu_ptrs) + .verify(cpu_ptrs); + TensorMatcher({N}) // 1D indices + .with_dtype() + .with_device(device_) + .verify(gpu_indices) + .verify(cpu_indices); + const auto params = OffloadParams{ + .gpu_caches = static_cast(gpu_ptrs.data_ptr()), + .cpu_caches = static_cast(cpu_ptrs.data_ptr()), + .gpu_indices = static_cast(gpu_indices.data_ptr()), + .cpu_indices = static_cast(cpu_indices.data_ptr()), + .num_items = static_cast(N.unwrap()), + .num_layers = static_cast(L.unwrap()), + }; + LaunchKernel(kBlockQuota, kBlockSize, device_.unwrap())(offload_to_cpu, params); +} + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh new file mode 100644 index 000000000000..38be97555853 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/paged_mqa_metadata.cuh @@ -0,0 +1,119 @@ +#include +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kBlockSize = 1024; +constexpr uint32_t kSplitKV = 256; // const for both SM90 and SM100 + +struct MetadataParams { + /// NOTE: batch_size > 0 + uint32_t batch_size; + uint32_t num_sm; + const uint32_t* __restrict__ context_lens; + uint32_t* __restrict__ schedule_metadata; + bool use_smem = true; +}; + +__global__ __launch_bounds__(kBlockSize, 1) // + void smxx_paged_mqa_logits_metadata(const MetadataParams params) { + using namespace device; + extern __shared__ uint32_t s_length[]; + static constexpr auto kNumWarps = kBlockSize / kWarpThreads; + static_assert(kNumWarps == kWarpThreads); + + const auto tx = threadIdx.x; + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + __shared__ uint32_t s_warp_sum[kNumWarps]; + + uint32_t local_sum = 0; + for (uint32_t i = tx; i < params.batch_size; i += kBlockSize) { + const auto length = params.context_lens[i]; + local_sum += (length + kSplitKV - 1) / kSplitKV; + if (params.use_smem) s_length[i] = length; + } + + s_warp_sum[warp_id] = warp::reduce_sum(local_sum); + __syncthreads(); + + const auto global_sum = warp::reduce_sum(s_warp_sum[lane_id]); + if (lane_id != 0) return; + + const auto length_ptr = params.use_smem ? s_length : params.context_lens; + + const auto avg = global_sum / params.num_sm; + const auto ret = global_sum % params.num_sm; + uint32_t q = 0; + uint32_t num_work = (length_ptr[0] + kSplitKV - 1) / kSplitKV; + uint32_t sum_work = num_work; + for (auto i = warp_id; i <= params.num_sm; i += kNumWarps) { + const auto target = i * avg + min(i, ret); + while (sum_work <= target) { + if (++q >= params.batch_size) break; + num_work = (length_ptr[q] + kSplitKV - 1) / kSplitKV; + sum_work += num_work; + } + if (q >= params.batch_size) { + params.schedule_metadata[2 * i + 0] = params.batch_size; + params.schedule_metadata[2 * i + 1] = 0; + } else { + // sum > target && (sum - length) <= target + params.schedule_metadata[2 * i + 0] = q; + params.schedule_metadata[2 * i + 1] = target - (sum_work - num_work); + } + } +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +struct IndexerMetadataKernel { + static constexpr auto kMaxBatchSizeInSmem = 16384 * 2; // 128 KB smeme + static void run(tvm::ffi::TensorView seq_lens, tvm::ffi::TensorView metadata) { + using namespace host; + auto N = SymbolicSize{"batch_size"}; + auto M = SymbolicSize{"num_sm"}; + auto device = SymbolicDevice{}; + device.set_options(); + TensorMatcher({N}) // + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({M, 2}) // + .with_dtype() + .with_device(device) + .verify(metadata); + const auto batch_size = static_cast(N.unwrap()); + const auto num_sm = static_cast(M.unwrap()) - 1; + RuntimeCheck(num_sm <= 1024); + const auto use_smem = batch_size <= kMaxBatchSizeInSmem; + const auto params = MetadataParams{ + .batch_size = batch_size, + .num_sm = num_sm, + .context_lens = static_cast(seq_lens.data_ptr()), + .schedule_metadata = static_cast(metadata.data_ptr()), + .use_smem = use_smem, + }; + constexpr auto kernel = smxx_paged_mqa_logits_metadata; + setup_kernel_smem_once(); + const auto smem = use_smem ? (batch_size + 1) * sizeof(uint32_t) : 0; + LaunchKernel(1, kBlockSize, device.unwrap(), smem)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh new file mode 100644 index 000000000000..2239d3972d64 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/rope.cuh @@ -0,0 +1,169 @@ +#include +#include + +#include +#include +#include + +#include + +#include + +namespace { + +using DType = bf16_t; +constexpr int64_t kRopeDim = 64; +constexpr uint32_t kBlockSize = 128; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; + +struct FusedQKRopeParams { + void* __restrict__ q; + void* __restrict__ k; + const float* __restrict__ freqs_cis; + const void* __restrict__ positions; + int64_t q_stride_batch; + int64_t k_stride_batch; + int64_t q_stride_head; + int64_t k_stride_head; + uint32_t num_q_heads; + uint32_t num_k_heads; + uint32_t batch_size; +}; + +template +__global__ __launch_bounds__(kBlockSize, 16) // + void deepseek_rope_kernel(const __grid_constant__ FusedQKRopeParams param) { + using namespace device; + using DType2 = packed_t; + + const auto warp_id = threadIdx.x / kWarpThreads; + const auto lane_id = threadIdx.x % kWarpThreads; + const auto global_warp_id = blockIdx.x * kNumWarps + warp_id; + + const auto& [ + q, k, freqs_cis, positions, // + q_stride_batch, k_stride_batch, q_stride_head, k_stride_head, // + num_q_heads, num_k_heads, batch_size + ] = param; + + const auto num_total_heads = num_q_heads + num_k_heads; + const auto head_id = global_warp_id % num_total_heads; + const auto batch_id = global_warp_id / num_total_heads; + if (batch_id >= batch_size) return; + + const auto position = static_cast(positions)[batch_id]; + const auto is_q = head_id < num_q_heads; + const auto local_head = is_q ? head_id : (head_id - num_q_heads); + const auto stride_batch = is_q ? q_stride_batch : k_stride_batch; + const auto stride_head = is_q ? q_stride_head : k_stride_head; + const auto base_ptr = is_q ? q : k; + const auto input = static_cast(pointer::offset(base_ptr, batch_id * stride_batch, local_head * stride_head)); + + const auto freq_ptr = reinterpret_cast(freqs_cis + position * kRopeDim); + const auto [f_real, f_imag] = freq_ptr[lane_id]; + PDLWaitPrimary(); + + const auto data = input[lane_id]; + const auto [x_real, x_imag] = cast(data); + fp32x2_t output; + if constexpr (kInverse) { + // (a + bi) * (c - di) = (ac + bd) + (bc - ad)i + output = { + x_real * f_real + x_imag * f_imag, + x_imag * f_real - x_real * f_imag, + }; + } else { + // (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + output = { + x_real * f_real - x_imag * f_imag, + x_real * f_imag + x_imag * f_real, + }; + } + input[lane_id] = cast(output); + + PDLTriggerSecondary(); +} + +template +struct FusedQKRopeKernel { + // 4 kernel variants: {forward, inverse} x {int32, int64} + static constexpr auto kernel_fwd_i32 = deepseek_rope_kernel; + static constexpr auto kernel_fwd_i64 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i32 = deepseek_rope_kernel; + static constexpr auto kernel_inv_i64 = deepseek_rope_kernel; + + static void forward( + const tvm::ffi::TensorView q, + const tvm::ffi::Optional k, + const tvm::ffi::TensorView freqs_cis, + const tvm::ffi::TensorView positions, + bool inverse) { + using namespace host; + + auto B = SymbolicSize{"batch_size"}; + auto Q = SymbolicSize{"num_q_heads"}; + auto K = SymbolicSize{"num_k_heads"}; + constexpr auto D = kRopeDim; + auto device_ = SymbolicDevice{}; + device_.set_options(); + + TensorMatcher({B, Q, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(q); + if (k.has_value()) { + TensorMatcher({B, K, D}) // + .with_strides({-1, -1, 1}) + .with_dtype() + .with_device(device_) + .verify(k.value()); + } else { + K.set_value(0); + } + TensorMatcher({-1, D}) // + .with_dtype() + .with_device(device_) + .verify(freqs_cis); + + auto pos_dtype = SymbolicDType{}; + TensorMatcher({B}) // + .with_dtype(pos_dtype) + .with_device(device_) + .verify(positions); + const bool pos_i32 = pos_dtype.is_type(); + + const auto batch_size = static_cast(B.unwrap()); + if (batch_size == 0) return; + + const auto num_q_heads = static_cast(Q.unwrap()); + const auto num_k_heads = static_cast(K.unwrap()); + const auto num_total_heads = num_q_heads + num_k_heads; + const auto total_warps = batch_size * num_total_heads; + const auto num_blocks = div_ceil(total_warps, kNumWarps); + + const auto elem_size = static_cast(sizeof(DType)); + const auto params = FusedQKRopeParams{ + .q = q.data_ptr(), + .k = k ? k.value().data_ptr() : nullptr, + .freqs_cis = static_cast(freqs_cis.data_ptr()), + .positions = positions.data_ptr(), + .q_stride_batch = q.stride(0) * elem_size, + .k_stride_batch = k ? k.value().stride(0) * elem_size : 0, + .q_stride_head = q.stride(1) * elem_size, + .k_stride_head = k ? k.value().stride(1) * elem_size : 0, + .num_q_heads = num_q_heads, + .num_k_heads = num_k_heads, + .batch_size = batch_size, + }; + + // dispatch: {inverse} x {pos_i32} + using KernelType = decltype(kernel_fwd_i32); + const KernelType kernel = + inverse ? (pos_i32 ? kernel_inv_i32 : kernel_inv_i64) : (pos_i32 ? kernel_fwd_i32 : kernel_fwd_i64); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh new file mode 100644 index 000000000000..e6fd97b6a248 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/silu_and_mul_masked_post_quant.cuh @@ -0,0 +1,300 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +namespace { + +[[maybe_unused]] +SGL_DEVICE int32_t cast_to_ue8m0(float x) { + uint32_t u = __float_as_uint(x); + int32_t exp = int32_t((u >> 23) & 0xFF); + uint32_t mant = u & 0x7FFFFF; + return exp + (mant != 0); +} + +[[maybe_unused]] +SGL_DEVICE float fp8_e4m3_clip(float val) { + return fmaxf(fminf(val, device::math::FP8_E4M3_MAX), -device::math::FP8_E4M3_MAX); +} + +[[maybe_unused]] +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}}; +} + +struct SiluMulQuantParams { + const bf16_t* __restrict__ input; + fp8_e4m3_t* __restrict__ output; + float* __restrict__ output_scale; + const int32_t* __restrict__ masked_m; + float swiglu_limit; // only read when kApplySwigluLimit=true + int64_t hidden_dim; + uint32_t num_tokens; + uint32_t num_experts; +}; + +constexpr uint32_t kMaxExperts = 256; + +struct alignas(16) CTAWork { + uint32_t expert_id; + uint32_t expert_token_id; + bool valid; +}; + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +[[maybe_unused]] +SGL_DEVICE CTAWork get_work(const SiluMulQuantParams& params) { + // Preconditions: + // 1. blockDim.x >= params.num_experts + // 2. params.num_experts <= kMaxExperts + using namespace device; + static_assert(kWarpThreads == 32); + + static __shared__ uint32_t s_warp_sum[32]; + static __shared__ CTAWork result; + + result.valid = false; + + const uint32_t tx = threadIdx.x; + const uint32_t lane_id = tx % kWarpThreads; + const uint32_t warp_id = tx / kWarpThreads; + + const uint32_t val = tx < params.num_experts ? params.masked_m[tx] : 0u; + + // Per-warp inclusive scan of masked_m. + const uint32_t warp_inclusive = warp_inclusive_sum(lane_id, val); + const uint32_t warp_exclusive = warp_inclusive - val; + + // Write each warp total. + if (lane_id == kWarpThreads - 1) s_warp_sum[warp_id] = warp_inclusive; + __syncthreads(); + const auto tmp_val = lane_id < warp_id ? s_warp_sum[lane_id] : 0u; + const auto prefix_exclusive = warp::reduce_sum(tmp_val) + warp_exclusive; + const auto bx = blockIdx.x; + if (prefix_exclusive <= bx && bx < prefix_exclusive + val) { + result = {tx, bx - prefix_exclusive, true}; + } + __syncthreads(); + return result; +} + +template +__global__ __launch_bounds__(1024, 2) void // maximize occupancy + silu_mul_quant_kernel(const SiluMulQuantParams __grid_constant__ params) { + using namespace device; + + constexpr uint32_t kGroupSize = 128u; + constexpr uint32_t kWorkThreads = 16u; + // each thread will handle 8 elements + using InputVec = AlignedVector; + using OutputVec = AlignedVector; + static_assert(8 * kWorkThreads == 128, "Invalid tiling"); + static_assert(!(kTransposed && !kScaleUE8M0), "transposed layout only supports ue8m0"); + + const auto [expert_id, token_id, valid] = get_work(params); + + if (!valid) return; + + const auto work_id = threadIdx.x / kWorkThreads; + + const auto offset = expert_id * params.num_tokens + token_id; + const auto input = params.input + offset * params.hidden_dim * 2; + const auto output = params.output + offset * params.hidden_dim; + [[maybe_unused]] + const auto output_scale = [&] { + const auto num_groups = params.hidden_dim / kGroupSize; + if constexpr (kTransposed) { + const auto base = reinterpret_cast(params.output_scale); + // Physical layout is [E, G//4, N] int32. Each int32 packs 4 consecutive + // group scales for the same token, so the byte address is: + // expert_offset + (group/4)*N*4 + token*4 + group%4 + return base + expert_id * num_groups * params.num_tokens + (work_id / 4u) * (params.num_tokens * 4u) + + token_id * 4u + (work_id % 4u); + } else { + return params.output_scale + offset * num_groups + work_id; + } + }(); + + PDLWaitPrimary(); + + InputVec gate_vec, up_vec; + gate_vec.load(input, threadIdx.x); + up_vec.load(input, threadIdx.x + blockDim.x); + + float local_max = 0.0f; + float results[8]; + +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + if constexpr (kApplySwigluLimit) { + // Fused fp32 path: bf16 load → fp32 clamp → fp32 silu → fp32 mul → fp32 result. + // Avoids the silu→bf16→mul→fp32 round-trip of the non-fused path since we already + // have gate/up in fp32 registers after clamp. + const float limit = params.swiglu_limit; + + const auto [g0_raw, g1_raw] = cast(gate_vec[i]); + const float g0 = fminf(g0_raw, limit); + const float g1 = fminf(g1_raw, limit); + + const float silu0 = g0 / (1.0f + expf(-g0)); + const float silu1 = g1 / (1.0f + expf(-g1)); + + const auto [u0_raw, u1_raw] = cast(up_vec[i]); + const float u0 = fmaxf(fminf(u0_raw, limit), -limit); + const float u1 = fmaxf(fminf(u1_raw, limit), -limit); + + const float val0 = u0 * silu0; + const float val1 = u1 * silu1; + results[2 * i + 0] = val0; + results[2 * i + 1] = val1; + local_max = fmaxf(local_max, fmaxf(fabsf(val0), fabsf(val1))); + } else { + // original code path — must stay byte-equal to pre-fusion kernel. + const auto [g0, g1] = cast(gate_vec[i]); + + float silu0 = g0 / (1.0f + expf(-g0)); + float silu1 = g1 / (1.0f + expf(-g1)); + + bf16x2_t silu_d = cast(fp32x2_t{silu0, silu1}); + auto [val0, val1] = cast(up_vec[i] * silu_d); + results[2 * i + 0] = val0; + results[2 * i + 1] = val1; + local_max = fmaxf(local_max, fmaxf(fabsf(val0), fabsf(val1))); + } + } + + local_max = warp::reduce_max(local_max); + + const float absmax = fmaxf(local_max, 1e-10f); + float scale; + uint32_t ue8m0_exp; + + if constexpr (kScaleUE8M0) { + const float raw_scale = absmax / math::FP8_E4M3_MAX; + ue8m0_exp = cast_to_ue8m0(raw_scale); + scale = __uint_as_float(ue8m0_exp << 23); + } else { + scale = absmax / math::FP8_E4M3_MAX; + } + const auto inv_scale = 1.0f / scale; + + OutputVec out_vec; +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + const float scaled_val0 = results[2 * i + 0] * inv_scale; + const float scaled_val1 = results[2 * i + 1] * inv_scale; + out_vec[i] = pack_fp8(scaled_val0, scaled_val1); + } + + PDLTriggerSecondary(); + + out_vec.store(output, threadIdx.x); + if constexpr (kTransposed) { + *output_scale = ue8m0_exp; + } else { + *output_scale = scale; + } +} + +// ---- Host wrapper +// ------------------------------------------------------------------------------------------------------------------------ + +template +struct SiluAndMulMaskedPostQuantKernel { + static_assert(kGroupSize == 128); + static constexpr auto kernel_normal = silu_mul_quant_kernel; + static constexpr auto kernel_transposed = silu_mul_quant_kernel; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView output_scale, + const tvm::ffi::TensorView masked_m, + const uint32_t topk, + const bool transposed, + const double swiglu_limit) { + using namespace host; + + auto device = SymbolicDevice{}; + auto E = SymbolicSize{"num_experts"}; + auto T = SymbolicSize{"num_tokens_padded"}; + auto D = SymbolicSize{"hidden_dim x 2"}; + auto N = SymbolicSize{"hidden_dim"}; + auto G = SymbolicSize{"num_groups"}; + device.set_options(); + + TensorMatcher({E, T, D}) // input + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({E, T, N}) // output + .with_dtype() + .with_device(device) + .verify(output); + if (!transposed) { + TensorMatcher({E, T, G}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + } else { + RuntimeCheck(kScaleUE8M0, "transposed layout only supports scale_ue8m0=true"); + auto G_ = SymbolicSize{"G // 4"}; + TensorMatcher({E, G_, T}) // + .with_dtype() + .with_device(device) + .verify(output_scale); + G.set_value(G_.unwrap() * 4); + } + TensorMatcher({E}) // + .with_dtype() + .with_device(device) + .verify(masked_m); + + const auto num_experts = static_cast(E.unwrap()); + const auto num_tokens = static_cast(T.unwrap()); + const auto num_groups = static_cast(G.unwrap()); + const auto hidden_dim = N.unwrap(); + + RuntimeCheck(D.unwrap() == 2 * hidden_dim, "invalid dimension"); + RuntimeCheck(hidden_dim % kGroupSize == 0); + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum (256)"); + RuntimeCheck(num_groups * kGroupSize == hidden_dim, "invalid num_groups"); + + const auto params = SiluMulQuantParams{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .output_scale = static_cast(output_scale.data_ptr()), + .masked_m = static_cast(masked_m.data_ptr()), + .swiglu_limit = static_cast(swiglu_limit), + .hidden_dim = hidden_dim, + .num_tokens = num_tokens, + .num_experts = num_experts, + }; + + const auto num_threads = hidden_dim / 8; + RuntimeCheck(num_threads % device::kWarpThreads == 0); + RuntimeCheck(num_threads >= num_experts); + const auto kernel = transposed ? kernel_transposed : kernel_normal; + LaunchKernel(num_tokens * topk, num_threads, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh new file mode 100644 index 000000000000..4076fe6aa253 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/store.cuh @@ -0,0 +1,223 @@ +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace { + +struct FusedStoreCacheParam { + const void* __restrict__ input; + void* __restrict__ cache; + const void* __restrict__ indices; + uint32_t num_tokens; +}; + +[[maybe_unused]] +SGL_DEVICE int32_t cast_to_ue8m0(float x) { + uint32_t u = __float_as_uint(x); + int32_t exp = int32_t((u >> 23) & 0xFF); + uint32_t mant = u & 0x7FFFFF; + return exp + (mant != 0); +} + +[[maybe_unused]] +SGL_DEVICE float inv_scale_ue8m0(int32_t exp) { + return __uint_as_float((127 + 127 - exp) << 23); +} + +[[maybe_unused]] +SGL_DEVICE float fp8_e4m3_clip(float val) { + namespace math = device::math; + return math::max(math::min(val, math::FP8_E4M3_MAX), -math::FP8_E4M3_MAX); +} + +[[maybe_unused]] +SGL_DEVICE fp8x2_e4m3_t pack_fp8(float x, float y) { + return fp8x2_e4m3_t{fp32x2_t{fp8_e4m3_clip(x), fp8_e4m3_clip(y)}}; +} + +template +__global__ void fused_store_flashmla_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 584 = 576 + 8 + constexpr int64_t kPageBytes = host::div_ceil(584 << kPageBits, 576) * 576; + + // each warp handles 64 elements, 8 warps, each block handles 1 row + const auto& [input, cache, indices, num_tokens] = param; + const uint32_t bid = blockIdx.x; + const uint32_t tid = threadIdx.x; + const uint32_t wid = tid / 32; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[bid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + const auto elems = static_cast(input)[tid + bid * 256]; + if (wid != 7) { + const auto [x, y] = cast(elems); + const auto abs_max = warp::reduce_max(fmaxf(fabs(x), fabs(y))); + const auto scale_raw = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto scale_ue8m0 = cast_to_ue8m0(scale_raw); + const auto inv_scale = inv_scale_ue8m0(scale_ue8m0); + const auto result = pack_fp8(x * inv_scale, y * inv_scale); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576); + const auto scale_ptr = pointer::offset(page_ptr, 576 << kPageBits, offset * 8); + static_cast(value_ptr)[tid] = result; + static_cast(scale_ptr)[wid] = scale_ue8m0; + } else { + const auto result = cast(elems); + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 576, 448); + static_cast(value_ptr)[tid - 7 * 32] = result; + } + + PDLTriggerSecondary(); +} + +template +__global__ void fused_store_indexer_cache(const __grid_constant__ FusedStoreCacheParam param) { + using namespace device; + + /// NOTE: 132 = 128 + 4 + constexpr int64_t kPageBytes = 132 << kPageBits; + + // each warp handles 128 elements, 1 warp, each block handles multiple rows + const auto& [input, cache, indices, num_tokens] = param; + const auto global_tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto global_wid = global_tid / 32; + const auto lane_id = threadIdx.x % 32; + + if (global_wid >= num_tokens) return; + + PDLWaitPrimary(); + + // prefetch the index + const auto index = static_cast(indices)[global_wid]; + // always load the value from input (don't store if invalid) + using Float2 = packed_t; + using InStorage = AlignedVector; + using OutStorage = AlignedVector; + const auto elems = static_cast(input)[global_tid]; + const auto [x0, x1] = cast(elems[0]); + const auto [y0, y1] = cast(elems[1]); + const auto local_max = fmaxf(fmaxf(fabs(x0), fabs(x1)), fmaxf(fabs(y0), fabs(y1))); + const auto abs_max = warp::reduce_max(local_max); + // use normal fp32 scale + const auto scale = fmaxf(1e-4f, abs_max) / math::FP8_E4M3_MAX; + const auto inv_scale = 1.0f / scale; + const int32_t page = index >> kPageBits; + const int32_t offset = index & ((1 << kPageBits) - 1); + const auto page_ptr = pointer::offset(cache, page * kPageBytes); + const auto value_ptr = pointer::offset(page_ptr, offset * 128); + const auto scale_ptr = pointer::offset(page_ptr, 128 << kPageBits, offset * 4); + OutStorage result; + result[0] = pack_fp8(x0 * inv_scale, x1 * inv_scale); + result[1] = pack_fp8(y0 * inv_scale, y1 * inv_scale); + static_cast(value_ptr)[lane_id] = result; + static_cast(scale_ptr)[0] = scale; + + PDLTriggerSecondary(); +} + +template +struct FusedStoreCacheFlashMLAKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = host::div_ceil(584 * kPageSize, 576) * 576; + static constexpr auto kernel = fused_store_flashmla_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 512}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 256; + const auto num_blocks = num_tokens; + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +template +struct FusedStoreCacheIndexerKernel { + static constexpr int32_t kLogSize = std::countr_zero(kPageSize); + static constexpr int64_t kPageBytes = 132 * kPageSize; + static constexpr auto kernel = fused_store_indexer_cache; + + static_assert(std::has_single_bit(kPageSize), "kPageSize must be a power of 2"); + static_assert(1 << kLogSize == kPageSize); + + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView cache, tvm::ffi::TensorView indices) { + using namespace host; + + auto N = SymbolicSize{"num_tokens"}; + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({N, 128}) // input + .with_dtype() + .with_device(device_) + .verify(input); + TensorMatcher({-1, -1}) // cache + .with_strides({kPageBytes, 1}) + .with_dtype() + .with_device(device_) + .verify(cache); + TensorMatcher({N}) // indices + .with_dtype() + .with_device(device_) + .verify(indices); + const auto num_tokens = static_cast(N.unwrap()); + const auto params = FusedStoreCacheParam{ + .input = input.data_ptr(), + .cache = cache.data_ptr(), + .indices = indices.data_ptr(), + .num_tokens = num_tokens, + }; + const auto kBlockSize = 128; + const auto num_blocks = div_ceil(num_tokens * 32, kBlockSize); + LaunchKernel(num_blocks, kBlockSize, device_.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh new file mode 100644 index 000000000000..ef2be43c07e2 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk.cuh @@ -0,0 +1,336 @@ +#include +#include + +#include + +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t kTopK = 512; +constexpr uint32_t kTopKBlockSize = 512; +constexpr uint32_t kSMEM = 16 * 1024 * sizeof(uint32_t); // 64KB (bytes) + +struct TopK512Params { + const float* __restrict__ scores; + const int32_t* __restrict__ seq_lens; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + int32_t* __restrict__ raw_indices; // optional: output raw abs position indices before page transform + const int64_t score_stride; + const int64_t page_table_stride; + uint32_t page_bits; +}; + +SGL_DEVICE uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +SGL_DEVICE uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + int32_t* __restrict__ raw_indices, // optional: output raw abs position indices + const uint32_t length, + const uint32_t page_bits) { + static_assert(kTopK <= kTopKBlockSize); + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + if (raw_indices != nullptr) { + raw_indices[tx] = tx; + } + } else if (kTopK == kTopKBlockSize || tx < kTopK) { + indices[tx] = -1; // fill invalid indices to -1 + if (raw_indices != nullptr) { + raw_indices[tx] = -1; + } + } +} + +[[maybe_unused]] +SGL_DEVICE void radix_topk(const float* __restrict__ input, int32_t* __restrict__ output, const uint32_t length) { + constexpr uint32_t RADIX = 256; + constexpr uint32_t BLOCK_SIZE = kTopKBlockSize; + constexpr uint32_t SMEM_INPUT_SIZE = kSMEM / (2 * sizeof(int32_t)); + + alignas(128) __shared__ uint32_t _s_histogram_buf[2][RADIX + 32]; + alignas(128) __shared__ uint32_t s_counter; + alignas(128) __shared__ uint32_t s_threshold_bin_id; + alignas(128) __shared__ uint32_t s_num_input[2]; + alignas(128) __shared__ int32_t s_last_remain; + + extern __shared__ uint32_t s_input_idx[][kSMEM / (2 * sizeof(int32_t))]; + + const uint32_t tx = threadIdx.x; + uint32_t remain_topk = kTopK; + auto& s_histogram = _s_histogram_buf[0]; + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int32_t i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (tx < RADIX) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = _s_histogram_buf[k][tx]; + if (tx + j < RADIX) { + value += _s_histogram_buf[k][tx + j]; + } + _s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + if (remain_topk == 0) { + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const uint32_t bin = convert_to_uint8(input[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (uint32_t idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw_input = input[idx]; + const uint32_t bin = convert_to_uint8(raw_input); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (pos < SMEM_INPUT_SIZE) { + [[likely]] s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto raw_num_input = s_num_input[r_idx]; + const auto num_input = raw_num_input < SMEM_INPUT_SIZE ? raw_num_input : SMEM_INPUT_SIZE; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > remain_topk && s_histogram[tx + 1] <= remain_topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = remain_topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + remain_topk -= s_histogram[threshold_bin + 1]; + + if (remain_topk == 0) { + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (uint32_t i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + output[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + output[kTopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (pos < SMEM_INPUT_SIZE) { + /// NOTE: (dark) fuse the histogram computation here + [[likely]] s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +__global__ void topk_512_transform(const __grid_constant__ TopK512Params params) { + const auto &[ + scores, seq_lens, page_table, page_indices, raw_indices, // pointers + score_stride, page_table_stride, page_bits // sizes + ] = params; + const uint32_t work_id = blockIdx.x; + + /// NOTE: dangerous prefetch seq_len before PDL wait + const uint32_t seq_len = seq_lens[work_id]; + const auto score_ptr = scores + work_id * score_stride; + const auto page_ptr = page_table + work_id * page_table_stride; + const auto indices_ptr = page_indices + work_id * kTopK; + const auto raw_indices_ptr = raw_indices != nullptr ? raw_indices + work_id * kTopK : nullptr; + + device::PDLWaitPrimary(); + + if (seq_len <= kTopK) { + naive_transform(score_ptr, page_ptr, indices_ptr, raw_indices_ptr, seq_len, page_bits); + } else { + __shared__ int32_t s_topk_indices[kTopK]; + radix_topk(score_ptr, s_topk_indices, seq_len); + static_assert(kTopK <= kTopKBlockSize); + const auto tx = threadIdx.x; + if (kTopK == kTopKBlockSize || tx < kTopK) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], page_bits); + if (raw_indices_ptr != nullptr) { + raw_indices_ptr[tx] = s_topk_indices[tx]; + } + } + } + + device::PDLTriggerSecondary(); +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +template +struct TopK512Kernel { + static constexpr auto kernel = topk_512_transform; + + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional raw_indices) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, 512}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + int32_t* raw_indices_ptr = nullptr; + if (raw_indices.has_value()) { + TensorMatcher({B, 512}) // optional raw indices output, must be contiguous + .with_dtype() + .with_device(device) + .verify(raw_indices.value()); + raw_indices_ptr = static_cast(raw_indices.value().data_ptr()); + } + + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopK512Params{ + .scores = static_cast(scores.data_ptr()), + .seq_lens = static_cast(seq_lens.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .raw_indices = raw_indices_ptr, + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + constexpr auto kSMEM_ = kSMEM + sizeof(int32_t); // align up a little + setup_kernel_smem_once(); + LaunchKernel(batch_size, kTopKBlockSize, device.unwrap(), kSMEM_).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh new file mode 100644 index 000000000000..8be610d54aa6 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/deepseek_v4/topk_v2.cuh @@ -0,0 +1,396 @@ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace { + +constexpr uint32_t K = 512; +constexpr uint32_t kBlockSize = 1024; +constexpr uint32_t kNumWarps = kBlockSize / device::kWarpThreads; +static_assert(K <= kBlockSize); + +// always use float4 to load from global memory +using Vec4 = device::AlignedVector; + +// --------------------------------------------------------------------------- +// Order-preserving FP16 key -> histogram bin +// --------------------------------------------------------------------------- + +template +SGL_DEVICE uint32_t extract_bin(float x) { + static_assert(0 < kBits && kBits < 15); + const auto hx = device::cast(x); + const uint16_t bits = *reinterpret_cast(&hx); + const uint16_t key = (bits & 0x8000) ? ~bits : bits | 0x8000; + return key >> (16 - kBits); +} + +SGL_DEVICE uint32_t warp_inclusive_sum(uint32_t lane_id, uint32_t val) { + static_assert(device::kWarpThreads == 32); +#pragma unroll + for (uint32_t offset = 1; offset < 32; offset *= 2) { + uint32_t n = __shfl_up_sync(0xFFFFFFFF, val, offset); + if (lane_id >= offset) val += n; + } + return val; +} + +struct TopKProblem { + const float* __restrict__ scores; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ indices; + const uint32_t length; + const uint32_t page_bits; +}; + +struct SmallTopKImpl { + static constexpr uint32_t kHistBits = 12; + static constexpr uint32_t kHistBins = 1 << kHistBits; + static constexpr uint32_t kVecsPerThread = 4; + static constexpr uint32_t kMaxTolerance = 2; + [[maybe_unused]] + static constexpr uint32_t kMaxSeqLen = kVecsPerThread * 4 * kBlockSize; + + struct alignas(16) MatchBin { + uint32_t bin; + uint32_t above_count; + uint32_t equal_count; + }; + + struct alignas(8) Tie { + uint32_t idx; + float score; + }; + + struct Smem { + using HistVec = device::AlignedVector; + alignas(128) uint32_t counter_gt; + alignas(128) uint32_t counter_eq; + alignas(128) MatchBin match; + alignas(128) uint32_t warp_sum[kNumWarps]; + alignas(16) union { + uint32_t histogram[kHistBins]; + HistVec histogram_vec[kBlockSize]; + Tie tie_buffer[kBlockSize]; + }; + }; + + SGL_DEVICE static void run(const TopKProblem problem, void* _smem) { + using namespace device; + + const auto smem = static_cast(_smem); + const auto tx = threadIdx.x; + Smem::HistVec hist_vec; + hist_vec.fill(0); + smem->histogram_vec[tx] = hist_vec; + __syncthreads(); + + PDLWaitPrimary(); + + // Load scores into registers + Vec4 local[kVecsPerThread]; +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { + const uint32_t base = (tx + v * kBlockSize) * 4; + if (base >= problem.length) break; + local[v].load(problem.scores, tx + v * kBlockSize); + } + + // Accumulate histogram via shared-memory atomics +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if (idx >= problem.length) goto LABEL_ACC_FINISH; + atomicAdd(&smem->histogram[extract_bin(local[v][e])], 1); + } + } + LABEL_ACC_FINISH: + __syncthreads(); + + // Phase 2: Exclusive prefix scan -> find threshold bin + constexpr uint32_t kItems = kHistBins / kBlockSize; + + const auto lane_id = tx % kWarpThreads; + const auto warp_id = tx / kWarpThreads; + + { + smem->counter_gt = smem->counter_eq = 0; + + uint32_t orig[kItems]; + const auto hist_vec = smem->histogram_vec[tx]; + uint32_t tmp_local_sum = 0; + +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + orig[i] = hist_vec[i]; + tmp_local_sum += orig[i]; + } + + const auto warp_inclusive = warp_inclusive_sum(lane_id, tmp_local_sum); + const auto warp_exclusive = warp_inclusive - tmp_local_sum; + if (lane_id == device::kWarpThreads - 1) { + smem->warp_sum[warp_id] = warp_inclusive; + } + + __syncthreads(); + + const auto tmp = smem->warp_sum[lane_id]; + // Exactly one bin satisfies: above < K && above + count >= K + uint32_t prefix_sum = warp::reduce_sum(lane_id < warp_id ? tmp : 0); + prefix_sum += warp_exclusive; +#pragma unroll + for (uint32_t i = 0; i < kItems; ++i) { + prefix_sum += orig[i]; + const auto above = problem.length - prefix_sum; + if (above < K && above + orig[i] >= K) { + smem->match = { + .bin = tx * kItems + i, + .above_count = above, + .equal_count = orig[i], + }; + } + } + __syncthreads(); + } + + const auto [thr_bin, num_above, num_equal] = smem->match; + const bool need_tiebreak = (num_equal + num_above > K + kMaxTolerance); + + // Phase 3: Scatter + // Elements strictly above threshold go directly to output. + // Tied elements: simple path admits first-come; tiebreak path collects into tie_buffer. +#pragma unroll + for (uint32_t v = 0; v < kVecsPerThread; ++v) { +#pragma unroll + for (uint32_t e = 0; e < 4; ++e) { + const uint32_t idx = (tx + v * kBlockSize) * 4 + e; + if (idx >= problem.length) goto LABEL_SCATTER_DONE; + const uint32_t bin = extract_bin(local[v][e]); + if (bin > thr_bin) { + problem.indices[atomicAdd(&smem->counter_gt, 1)] = idx; + } else if (bin == thr_bin) { + const auto pos = atomicAdd(&smem->counter_eq, 1); + if (need_tiebreak) { + if (pos < kBlockSize) { + smem->tie_buffer[pos] = {.idx = idx, .score = local[v][e]}; + } + } else { + if (const auto which = pos + num_above; which < K) { + problem.indices[which] = idx; + } + } + } + } + } + LABEL_SCATTER_DONE: + if (!need_tiebreak) return; + + // Phase 4: Tie-breaking within the threshold bin. + // Assume num_ties <= kBlockSize (at most 1 block of ties). + // Each thread takes one tied element, computes its rank (number of + // elements with strictly higher score, breaking exact float ties by + // original index), and writes to output if rank < topk_remain. + __syncthreads(); + + const uint32_t num_ties = num_equal < kBlockSize ? num_equal : kBlockSize; + const uint32_t topk_remain = K - num_above; + + const auto is_greater = [](const Tie& a, const Tie& b) { + return (a.score > b.score) || (a.score == b.score && a.idx < b.idx); + }; + + if (num_ties <= kWarpThreads) { + static_assert(kWarpThreads <= kNumWarps); + if (lane_id >= num_ties || warp_id >= num_ties) return; // some threads are idle + /// NOTE: use long long to avoid mask overflow when num_ties == 32 + const uint32_t mask = (1ull << num_ties) - 1u; + const auto tie = smem->tie_buffer[lane_id]; + const auto target_tie = smem->tie_buffer[warp_id]; + const bool pred = is_greater(tie, target_tie); + const auto rank = static_cast(__popc(__ballot_sync(mask, pred))); + if (lane_id == 0 && rank < topk_remain) { + problem.indices[num_above + rank] = target_tie.idx; + } + } else if (num_ties <= kWarpThreads * 2) { + [[unlikely]]; + // 64 x 64 topk implementation: each thread takes 2 elements + const auto lane_id_1 = lane_id + kWarpThreads; + const auto warp_id_1 = warp_id + kWarpThreads; + const auto invalid = Tie{.idx = 0xFFFFFFFF, .score = -FLT_MAX}; + const auto tie_0 = smem->tie_buffer[lane_id]; + const auto tie_1 = lane_id_1 < num_ties ? smem->tie_buffer[lane_id_1] : invalid; + if (true) { + const auto target = smem->tie_buffer[warp_id]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + problem.indices[num_above + rank] = target.idx; + } + } + if (warp_id_1 < num_ties) { + const auto target = smem->tie_buffer[warp_id_1]; + const bool pred_0 = is_greater(tie_0, target); + const bool pred_1 = is_greater(tie_1, target); + const auto rank_0 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_0))); + const auto rank_1 = static_cast(__popc(__ballot_sync(0xFFFFFFFF, pred_1))); + const auto rank = rank_0 + rank_1; + if (lane_id == 0 && rank < topk_remain) { + problem.indices[num_above + rank] = target.idx; + } + } + } else { + [[unlikely]]; + // Block-level: each thread reads from tie_buffer in shared memory + if (tx >= num_ties) return; + const auto target_tie = smem->tie_buffer[tx]; + uint32_t rank = 0; + for (uint32_t i = 0; i < num_ties; i++) { + const auto tie = smem->tie_buffer[i]; + if (is_greater(tie, target_tie)) rank++; + } + if (rank < topk_remain) { + problem.indices[num_above + rank] = target_tie.idx; + } + } + } +}; + +struct TopKParams { + const uint32_t* __restrict__ seq_lens; + const float* __restrict__ scores; + const int32_t* __restrict__ page_table; + int32_t* __restrict__ page_indices; + const int64_t score_stride; + const int64_t page_table_stride; + /// NOTE: indices stride must = K + uint32_t page_bits; +}; + +SGL_DEVICE int32_t page_to_indices(const int32_t* __restrict__ page_table, uint32_t i, uint32_t page_bits) { + const uint32_t mask = (1u << page_bits) - 1u; + return (page_table[i >> page_bits] << page_bits) | (i & mask); +} + +[[maybe_unused]] +SGL_DEVICE void naive_transform( + const float* __restrict__, // unused + const int32_t* __restrict__ page_table, + int32_t* __restrict__ indices, + const uint32_t length, + const uint32_t page_bits) { + if (const auto tx = threadIdx.x; tx < length) { + indices[tx] = page_to_indices(page_table, tx, page_bits); + } else if (tx < K) { + indices[tx] = -1; // fill invalid indices to -1 + } +} + +__global__ __launch_bounds__(kBlockSize, 2) // optimize prefill + void topk_transform_v2(const __grid_constant__ TopKParams params) { + alignas(128) extern __shared__ uint8_t smem[]; + const auto batch_id = blockIdx.x; + const auto seq_len = params.seq_lens[batch_id]; + const auto score_ptr = params.scores + batch_id * params.score_stride; + const auto page_ptr = params.page_table + batch_id * params.page_table_stride; + const auto indices_ptr = params.page_indices + batch_id * K; + if (seq_len <= K) return naive_transform(score_ptr, page_ptr, indices_ptr, seq_len, params.page_bits); + __shared__ int32_t s_topk_indices[K]; + const auto problem = TopKProblem{ + .scores = score_ptr, + .page_table = page_ptr, + .indices = s_topk_indices, + .length = seq_len, + .page_bits = params.page_bits, + }; + SmallTopKImpl::run(problem, smem); + device::PDLTriggerSecondary(); + __syncthreads(); + if (const auto tx = threadIdx.x; tx < K) { + indices_ptr[tx] = page_to_indices(page_ptr, s_topk_indices[tx], params.page_bits); + } +} + +template +void setup_kernel_smem_once(host::DebugInfo where = {}) { + [[maybe_unused]] + static const auto result = [] { + const auto fptr = std::bit_cast(f); + return ::cudaFuncSetAttribute(fptr, ::cudaFuncAttributeMaxDynamicSharedMemorySize, kMaxDynamicSMEM); + }(); + host::RuntimeDeviceCheck(result, where); +} + +struct TopK512Kernel { + static constexpr auto kSMEM = sizeof(typename SmallTopKImpl::Smem) + 128; + static void transform( + const tvm::ffi::TensorView scores, + const tvm::ffi::TensorView seq_lens, + const tvm::ffi::TensorView page_table, + const tvm::ffi::TensorView page_indices, + const uint32_t page_size, + const tvm::ffi::Optional unused) { + using namespace host; + auto B = SymbolicSize{"batch_size"}; + auto S = SymbolicSize{"score_stride"}; + auto P = SymbolicSize{"page_table_stride"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({B, -1}) // strided scores + .with_strides({S, 1}) + .with_dtype() + .with_device(device) + .verify(scores); + TensorMatcher({B}) // seq_lens, must be contiguous + .with_dtype() + .with_device(device) + .verify(seq_lens); + TensorMatcher({B, -1}) // strided page table + .with_strides({P, 1}) + .with_dtype() + .with_device(device) + .verify(page_table); + TensorMatcher({B, 512}) // output, must be contiguous + .with_dtype() + .with_device(device) + .verify(page_indices); + + RuntimeCheck(!unused.has_value(), "topk_transform_v2 only accepts 5 arguments"); + RuntimeCheck(std::has_single_bit(page_size), "page_size must be power of 2"); + const auto page_bits = static_cast(std::countr_zero(page_size)); + const auto batch_size = static_cast(B.unwrap()); + const auto params = TopKParams{ + .seq_lens = static_cast(seq_lens.data_ptr()), + .scores = static_cast(scores.data_ptr()), + .page_table = static_cast(page_table.data_ptr()), + .page_indices = static_cast(page_indices.data_ptr()), + .score_stride = S.unwrap(), + .page_table_stride = P.unwrap(), + .page_bits = page_bits, + }; + RuntimeCheck(std::bit_cast(params.scores) % 16 == 0, "scores must be 16-byte aligned"); + RuntimeCheck(params.score_stride % 4 == 0, "score_stride must be a multiple of 4"); + constexpr auto kernel = topk_transform_v2; + setup_kernel_smem_once(); + LaunchKernel(batch_size, kBlockSize, device.unwrap(), kSMEM) // + .enable_pdl(true)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/hisparse.cuh b/python/sglang/jit_kernel/csrc/hisparse.cuh new file mode 100644 index 000000000000..bebee01c17b2 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/hisparse.cuh @@ -0,0 +1,655 @@ +#include +#include + +#include + +#include +#include + +#include +#include +#include +#include + +namespace { + +constexpr int WARP_SIZE = 32; +constexpr int32_t TOKEN_HIT = 0xFFFFFFFF; + +__device__ __forceinline__ void +transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { + const uint64_t* __restrict__ src = static_cast(src_addr); + uint64_t* __restrict__ dst = static_cast(dst_addr); + const int total_chunks = item_size_bytes / sizeof(uint64_t); + +#pragma unroll + for (int j = lane_id; j < total_chunks; j += WARP_SIZE) { + uint64_t tmp; + asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory"); + asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory"); + } +} + +__device__ __forceinline__ int warp_inclusive_scan(int* s_data, int lane_id, int offset, int count, int accumulator) { + int idx = lane_id + offset; + int val = (idx < count) ? s_data[idx] : 0; + +#pragma unroll + for (int i = 1; i < 32; i *= 2) { + int n = __shfl_up_sync(0xffffffff, val, i); + if (lane_id >= i) val += n; + } + val += accumulator; + if (idx < count) { + s_data[idx] = val; + } + accumulator = __shfl_sync(0xffffffff, val, 31); + return accumulator; +} + +template +__global__ void parallel_transfer_kernel( + const int64_t* __restrict__ transfer_tasks_src, + const int64_t* __restrict__ transfer_tasks_dst, + const void* __restrict__ host_cache_k, + const void* __restrict__ host_cache_v, + void* __restrict__ device_buffer_k, + void* __restrict__ device_buffer_v, + int64_t total_tasks, + int64_t page_size, + int64_t src_padding, + int64_t dst_padding, + int64_t item_size_bytes, + bool is_mla) { + const int64_t tasks_per_block = NUM_TOP_K; + const int64_t stride_per_block = tasks_per_block + 1; + + const int global_task_id = blockIdx.x * (BLOCK_SIZE / WARP_SIZE) + threadIdx.x / WARP_SIZE; + if (global_task_id >= total_tasks) return; + + const int block_id = global_task_id / tasks_per_block; + const int task_in_block = global_task_id % tasks_per_block; + + // Layout: [task_0...task_N | count] per block + const int64_t block_base = block_id * stride_per_block; + const int64_t count_idx = block_base + tasks_per_block; + const int64_t valid_count = transfer_tasks_src[count_idx]; + + if (task_in_block >= valid_count) return; + + const int64_t task_idx = block_base + task_in_block; + const int64_t src_loc = transfer_tasks_src[task_idx]; + const int64_t dst_loc = transfer_tasks_dst[task_idx]; + + const int64_t src_offset = src_loc * item_size_bytes; + const int64_t dst_offset = dst_loc * item_size_bytes; + + const int32_t lane_id = threadIdx.x % WARP_SIZE; + + const auto src_k = static_cast(host_cache_k) + src_offset; + auto dst_k = static_cast(device_buffer_k) + dst_offset; + transfer_item_warp(lane_id, src_k, dst_k, item_size_bytes); + + if (!is_mla) { + const auto src_v = static_cast(host_cache_v) + src_offset; + auto dst_v = static_cast(device_buffer_v) + dst_offset; + transfer_item_warp(lane_id, src_v, dst_v, item_size_bytes); + } +} + +// Each block processes one request +// IndexT: type for req_pool_indices and seq_lens (int32_t or int64_t), The cuda graph mode requires int32_t +// Layout: [HOT_BUFFER_SIZE slots for LRU] + [page_size slots for newest token] +// newest_slot is at HOT_BUFFER_SIZE (first position of extra page) +template +__global__ void load_cache_to_device_buffer_kernel( + const int32_t* __restrict__ top_k_tokens, + int32_t* __restrict__ device_buffer_tokens, + const int64_t* __restrict__ host_cache_locs, + const int32_t* __restrict__ device_buffer_locs, + const void* __restrict__ host_cache_k, + const void* __restrict__ host_cache_v, + void* __restrict__ device_buffer_k, + void* __restrict__ device_buffer_v, + int32_t* __restrict__ top_k_device_locs, + int16_t* __restrict__ diff_map, + const IndexT* __restrict__ req_pool_indices, + const IndexT* __restrict__ seq_lens, + int16_t* __restrict__ lru_slots, + int64_t* __restrict__ transfer_tasks_src, + int64_t* __restrict__ transfer_tasks_dst, + int64_t buffer_stride_0, + int64_t buffer_stride_1, + int64_t host_stride, + int64_t diff_map_stride, + int64_t lru_slot_stride_0, + int64_t lru_slot_stride_1, + int64_t top_k_tokens_stride, + int64_t top_k_device_locs_stride, + int64_t page_size, + int64_t layer_id, + int64_t item_size_bytes) { + constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE; + constexpr int NUM_TOKEN_CHUNKS = (NUM_TOP_K + WARP_SIZE - 1) / WARP_SIZE; + // LRU uses all HOT_BUFFER_SIZE slots (newest_slot is now outside at HOT_BUFFER_SIZE) + constexpr int LRU_SIZE = HOT_BUFFER_SIZE; + constexpr int NUM_BUFFER_CHUNKS = (LRU_SIZE + WARP_SIZE - 1) / WARP_SIZE; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int64_t rid = req_pool_indices[bid]; + const int64_t seq_len = seq_lens[bid]; + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid % WARP_SIZE; + const unsigned int lanes_before = ((unsigned int)1 << lane_id) - 1; + + // Calculate offsets for this request + const int top_k_tokens_offset = bid * top_k_tokens_stride; + const int top_k_device_locs_offset = bid * top_k_device_locs_stride; + const int diff_map_offset = bid * diff_map_stride; + const int32_t* my_top_k_tokens = top_k_tokens + top_k_tokens_offset; + int32_t* my_top_k_device_locs = top_k_device_locs + top_k_device_locs_offset; + int16_t* my_diff_map = diff_map + diff_map_offset; + + const int buffer_offset = rid * buffer_stride_0 + layer_id * buffer_stride_1; + const int host_offset = rid * host_stride; + const int lru_slot_offset = rid * lru_slot_stride_0 + layer_id * lru_slot_stride_1; + int32_t* my_device_buffer_tokens = device_buffer_tokens + buffer_offset; + const int32_t* my_device_buffer_locs = device_buffer_locs + buffer_offset; + const int64_t* my_host_cache_locs = host_cache_locs + host_offset; + int16_t* my_lru_slots = lru_slots + lru_slot_offset; + + // Fast path: if seq_len <= HOT_BUFFER_SIZE, use sequential buffer layout + // Set device locations for all tokens up to seq_len (indexed by token/page id) + if (seq_len <= HOT_BUFFER_SIZE) { + // Calculate max token/page index based on seq_len + const int32_t newest_idx = seq_len - 1; + // (page_size > 1) ? static_cast(seq_len / page_size) : static_cast(seq_len); + + // Tokens 1 to seq_len-1: read from page table + for (int i = tid; i < newest_idx; i += BLOCK_SIZE) { + int32_t page_start = my_device_buffer_locs[i * page_size]; + // top_k indices might not come in order of token locations + my_top_k_device_locs[i] = page_start / page_size; + } + + // seq_len position (newest): fixed at HOT_BUFFER_SIZE + if (tid == 0) { + my_top_k_device_locs[newest_idx] = my_device_buffer_locs[HOT_BUFFER_SIZE]; + } + + if (tid == 0) { + const int64_t tasks_per_block = NUM_TOP_K * page_size; + const int64_t stride_per_block = tasks_per_block + 1; + const int64_t block_base = bid * stride_per_block; + const int64_t count_idx = block_base + tasks_per_block; + transfer_tasks_src[count_idx] = 0; + } + return; + } + + __shared__ int32_t s_top_k_tokens[NUM_TOP_K]; + __shared__ int32_t s_chunk_offset[NUM_BUFFER_CHUNKS + 1]; + __shared__ int32_t s_missed_tokens[NUM_TOP_K]; + __shared__ int32_t s_evictable_slots[NUM_TOP_K]; + __shared__ int32_t s_total_misses; + __shared__ int32_t s_total_hits; + __shared__ int32_t s_total_evictable; + __shared__ int32_t s_newest_hit; + __shared__ bool s_lru_bitmap[HOT_BUFFER_SIZE]; + __shared__ int16_t s_lru_slots_out[LRU_SIZE]; + + // Initialize shared memory counters used across phases. + if (tid == 0) { + s_total_misses = 0; + s_total_hits = 0; + s_total_evictable = 0; + s_newest_hit = 0; + } + + const int newest_slot = HOT_BUFFER_SIZE; + // For page-wise topk, use page id. + const int32_t newest_token = seq_len - 1; + // (seq_len >= 0) ? static_cast((page_size > 1) ? (seq_len / page_size) : seq_len) : -1; + + // Build diff_map for top-k tokens and reset shared buffers. + for (int i = tid; i < NUM_TOP_K; i += BLOCK_SIZE) { + int32_t top_k_val = my_top_k_tokens[i]; + my_diff_map[top_k_val] = i; + s_top_k_tokens[i] = top_k_val; + s_evictable_slots[i] = -1; + } + for (int i = tid; i < HOT_BUFFER_SIZE; i += BLOCK_SIZE) { + s_lru_bitmap[i] = false; + } + + __syncthreads(); + + // If topk includes the latest token, bind it to newest_slot (at HOT_BUFFER_SIZE) and mark as hit. + // newest_slot is at the first position of the extra page, excluded from LRU tracking. + if (tid == 0 && newest_token >= 0 && newest_token < diff_map_stride) { + const int newest_topk_idx = my_diff_map[newest_token]; + if (newest_topk_idx >= 0) { + s_top_k_tokens[newest_topk_idx] = TOKEN_HIT; + my_top_k_device_locs[newest_topk_idx] = my_device_buffer_locs[newest_slot]; + my_device_buffer_tokens[newest_slot] = newest_token; + s_newest_hit = 1; + } + } + __syncthreads(); + + for (int i = tid; i < NUM_BUFFER_CHUNKS + 1; i += BLOCK_SIZE) { + s_chunk_offset[i] = 0; + } + __syncthreads(); + + constexpr int ITERATIONS_PER_WARP_BUFFER = (NUM_BUFFER_CHUNKS + NUM_WARPS - 1) / NUM_WARPS; + int total_hit_count = 0; + for (int iter = 0; iter < ITERATIONS_PER_WARP_BUFFER; iter++) { + int chunk_idx = warp_id + iter * NUM_WARPS; + bool has_valid_chunk = chunk_idx < NUM_BUFFER_CHUNKS; + + const int slot_idx = chunk_idx * WARP_SIZE + lane_id; + const bool has_valid_slot = has_valid_chunk && (slot_idx < HOT_BUFFER_SIZE); + const int32_t buf_slot = has_valid_slot ? static_cast(my_lru_slots[slot_idx]) : -1; + const bool has_valid_buf_slot = has_valid_slot && (buf_slot >= 0) && (buf_slot < HOT_BUFFER_SIZE); + int32_t my_buffer_token = has_valid_buf_slot ? my_device_buffer_tokens[buf_slot] : -1; + int my_found_top_k_idx = my_buffer_token >= 0 ? my_diff_map[my_buffer_token] : -1; + + // Record hits + if (my_found_top_k_idx >= 0 && has_valid_buf_slot) { + s_top_k_tokens[my_found_top_k_idx] = TOKEN_HIT; + my_top_k_device_locs[my_found_top_k_idx] = my_device_buffer_locs[buf_slot]; + } + __syncthreads(); + + bool is_hit = my_found_top_k_idx != -1; + int local_hit_offset = 0; + + // seems unnecessary + if (warp_id == 0) { + const int base_chunk = iter * NUM_WARPS; + const int idx = base_chunk + lane_id + 1; + if (idx < NUM_BUFFER_CHUNKS + 1) { + s_chunk_offset[idx] = 0; + } + } + __syncthreads(); + if (has_valid_chunk) { + const unsigned int hit_mask = __ballot_sync(0xFFFFFFFF, is_hit); + local_hit_offset = __popc(hit_mask & lanes_before); + int warp_hit_count = __popc(hit_mask); + if (lane_id == 0) { + s_chunk_offset[chunk_idx + 1] = warp_hit_count; + } + } + __syncthreads(); + + if (warp_id == 0) { + total_hit_count = + warp_inclusive_scan(s_chunk_offset, lane_id, chunk_idx + 1, NUM_BUFFER_CHUNKS + 1, total_hit_count); + if (tid == 0) { + s_total_hits = total_hit_count; + } + } + __syncthreads(); + + if (is_hit && has_valid_buf_slot) { + int hit_offset = s_chunk_offset[chunk_idx] + local_hit_offset; + s_lru_slots_out[hit_offset] = buf_slot; + s_lru_bitmap[buf_slot] = true; + } + } + __syncthreads(); + + // Move staged hits to the tail so hits are most recent in LRU order. + for (int i = s_total_hits - 1 - tid; i >= 0; i -= BLOCK_SIZE) { + const int dst = LRU_SIZE - s_total_hits + i; + s_lru_slots_out[dst] = s_lru_slots_out[i]; + } + __syncthreads(); + + // Second pass to collect evictable slots + for (int i = tid; i < NUM_BUFFER_CHUNKS + 1; i += BLOCK_SIZE) { + s_chunk_offset[i] = 0; + } + __syncthreads(); + + int total_evictable = 0; + for (int iter = 0; iter < ITERATIONS_PER_WARP_BUFFER; iter++) { + const int chunk_idx = warp_id + iter * NUM_WARPS; + const bool has_valid_chunk = chunk_idx < NUM_BUFFER_CHUNKS; + + const int slot_idx = chunk_idx * WARP_SIZE + lane_id; + const bool has_valid_slot = has_valid_chunk && (slot_idx < LRU_SIZE); + const int32_t buf_slot = has_valid_slot ? static_cast(my_lru_slots[slot_idx]) : -1; + const bool has_valid_buf_slot = has_valid_slot && (buf_slot >= 0) && (buf_slot < HOT_BUFFER_SIZE); + bool is_evictable = has_valid_buf_slot && !s_lru_bitmap[buf_slot]; + int local_evictable_offset = 0; + if (warp_id == 0) { + const int base_chunk = iter * NUM_WARPS; + const int idx = base_chunk + lane_id + 1; + if (idx < NUM_BUFFER_CHUNKS + 1) { + s_chunk_offset[idx] = 0; + } + } + __syncthreads(); + + if (has_valid_chunk) { + const unsigned int evictable_mask = __ballot_sync(0xFFFFFFFF, is_evictable); + local_evictable_offset = __popc(evictable_mask & lanes_before); + const int warp_evictable_count = __popc(evictable_mask); + if (lane_id == 0) { + s_chunk_offset[chunk_idx + 1] = warp_evictable_count; + } + } + __syncthreads(); + + if (warp_id == 0) { + total_evictable = + warp_inclusive_scan(s_chunk_offset, lane_id, chunk_idx + 1, NUM_BUFFER_CHUNKS + 1, total_evictable); + } + __syncthreads(); + + if (is_evictable && has_valid_buf_slot) { + const int evictable_offset = s_chunk_offset[chunk_idx] + local_evictable_offset; + int num_misses = NUM_TOP_K - s_total_hits - s_newest_hit; + if (num_misses < 0) { + num_misses = 0; + } + if (evictable_offset < num_misses) { + s_evictable_slots[evictable_offset] = buf_slot; + s_lru_slots_out[LRU_SIZE - s_total_hits - 1 - evictable_offset] = buf_slot; + } else { + s_lru_slots_out[evictable_offset - num_misses] = buf_slot; + } + } + } + __syncthreads(); + if (tid == 0) { + s_total_evictable = total_evictable; + } + + for (int i = tid; i < HOT_BUFFER_SIZE; i += BLOCK_SIZE) { + if (i < NUM_TOP_K) { + int32_t top_k_val = my_top_k_tokens[i]; + my_diff_map[top_k_val] = -1; + } + if (i < LRU_SIZE) { + my_lru_slots[i] = s_lru_slots_out[i]; + } + } + // Reset offsets for the miss counting phase. + for (int i = tid; i < NUM_BUFFER_CHUNKS + 1; i += BLOCK_SIZE) { + s_chunk_offset[i] = 0; + } + __syncthreads(); + + constexpr int ITERATIONS_PER_WARP_TOKEN = (NUM_TOKEN_CHUNKS + NUM_WARPS - 1) / NUM_WARPS; + for (int iter = 0; iter < ITERATIONS_PER_WARP_TOKEN; iter++) { + int chunk_idx = warp_id + iter * NUM_WARPS; + bool has_valid_chunk = chunk_idx < NUM_TOKEN_CHUNKS; + + const int chunk_token_start = chunk_idx * WARP_SIZE; + const int my_token_idx = chunk_token_start + lane_id; + const bool has_valid_token = has_valid_chunk && (my_token_idx < NUM_TOP_K); + + int32_t my_token = 0; + bool is_miss = false; + int local_miss_offset = 0; + + if (has_valid_token) { + is_miss = s_top_k_tokens[my_token_idx] != TOKEN_HIT; + if (is_miss) { + my_token = s_top_k_tokens[my_token_idx]; + } + } + + // Intra-warp communication for miss counting. + const unsigned int miss_mask = __ballot_sync(0xFFFFFFFF, is_miss); + if (warp_id == 0) { + const int base_chunk = iter * NUM_WARPS; + const int idx = base_chunk + lane_id + 1; + if (idx < NUM_TOKEN_CHUNKS + 1) { + s_chunk_offset[idx] = 0; + } + } + __syncthreads(); + if (has_valid_chunk) { + local_miss_offset = __popc(miss_mask & lanes_before); + const int warp_miss_count = __popc(miss_mask); + if (lane_id == 0) { + s_chunk_offset[chunk_idx + 1] = warp_miss_count; + } + } + __syncthreads(); + + if (warp_id == 0) { + s_total_misses = + warp_inclusive_scan(s_chunk_offset, lane_id, chunk_idx + 1, NUM_TOKEN_CHUNKS + 1, s_total_misses); + } + __syncthreads(); + + // Clamp misses to the number of available evictable slots. + if (tid == 0 && s_total_misses > s_total_evictable) { + s_total_misses = s_total_evictable; + } + __syncthreads(); + + if (is_miss) { + int miss_offset = s_chunk_offset[chunk_idx] + local_miss_offset; + if (miss_offset >= s_total_evictable) { + continue; + } + int evict_slot = s_evictable_slots[miss_offset]; + s_missed_tokens[miss_offset] = my_token; + if (evict_slot >= 0 && evict_slot < HOT_BUFFER_SIZE) { + my_top_k_device_locs[my_token_idx] = my_device_buffer_locs[evict_slot]; + my_device_buffer_tokens[evict_slot] = my_token; + } else { + my_top_k_device_locs[my_token_idx] = -1; + } + } + __syncthreads(); + } + + const int64_t tasks_per_block = NUM_TOP_K * page_size; + const int64_t stride_per_block = tasks_per_block + 1; + const int64_t block_base = bid * stride_per_block; + + // Emit transfer tasks for each miss (page_size items per miss). + for (int miss_idx = tid; miss_idx < s_total_misses; miss_idx += BLOCK_SIZE) { + const int32_t miss_token = s_missed_tokens[miss_idx]; + const int evict_slot = s_evictable_slots[miss_idx]; + + if (evict_slot >= 0 && evict_slot < HOT_BUFFER_SIZE && miss_token >= 0) { + for (int page_offset = 0; page_offset < page_size; page_offset++) { + const int64_t src_loc = my_host_cache_locs[miss_token * page_size + page_offset]; + const int64_t dst_loc = my_device_buffer_locs[evict_slot] * page_size + page_offset; + + const int task_idx = block_base + miss_idx * page_size + page_offset; + transfer_tasks_src[task_idx] = src_loc; + transfer_tasks_dst[task_idx] = dst_loc; + } + } + } + + if (tid == 0) { + const int64_t count_idx = block_base + tasks_per_block; + transfer_tasks_src[count_idx] = s_total_misses * page_size; + } +} + +template +struct SparseCacheKernel { + template + static void + run(tvm::ffi::TensorView top_k_tokens, + tvm::ffi::TensorView device_buffer_tokens, + tvm::ffi::TensorView host_cache_locs, + tvm::ffi::TensorView device_buffer_locs, + tvm::ffi::TensorView host_cache_k, + tvm::ffi::TensorView host_cache_v, + tvm::ffi::TensorView device_buffer_k, + tvm::ffi::TensorView device_buffer_v, + tvm::ffi::TensorView top_k_device_locs, + tvm::ffi::TensorView diff_map, + tvm::ffi::TensorView req_pool_indices, + tvm::ffi::TensorView seq_lens, + tvm::ffi::TensorView lru_slots, + tvm::ffi::TensorView transfer_tasks_src, + tvm::ffi::TensorView transfer_tasks_dst, + int64_t page_size, + int64_t layer_id, + int64_t item_size_bytes) { + using namespace host; + + const int64_t bs = top_k_tokens.shape()[0]; + const int64_t host_stride = host_cache_locs.shape()[1]; + const int64_t buffer_stride_0 = device_buffer_tokens.strides()[0]; + const int64_t buffer_stride_1 = device_buffer_tokens.strides()[1]; + const int64_t diff_map_stride = diff_map.shape()[1]; + const int64_t lru_slot_stride_0 = lru_slots.strides()[0]; + const int64_t lru_slot_stride_1 = lru_slots.strides()[1]; + const int64_t top_k_tokens_stride = top_k_tokens.strides()[0]; + const int64_t top_k_device_locs_stride = top_k_device_locs.strides()[0]; + + const int32_t* top_k_tokens_ptr = static_cast(top_k_tokens.data_ptr()); + int32_t* device_buffer_tokens_ptr = static_cast(device_buffer_tokens.data_ptr()); + const int64_t* host_cache_locs_ptr = static_cast(host_cache_locs.data_ptr()); + const int32_t* device_buffer_locs_ptr = static_cast(device_buffer_locs.data_ptr()); + const void* host_cache_k_ptr = host_cache_k.data_ptr(); + const void* host_cache_v_ptr = (IsMLA || host_cache_v.ndim() == 0) ? nullptr : host_cache_v.data_ptr(); + void* device_buffer_k_ptr = device_buffer_k.data_ptr(); + void* device_buffer_v_ptr = (IsMLA || device_buffer_v.ndim() == 0) ? nullptr : device_buffer_v.data_ptr(); + int32_t* top_k_device_locs_ptr = static_cast(top_k_device_locs.data_ptr()); + int16_t* diff_map_ptr = static_cast(diff_map.data_ptr()); + const IndexT* req_pool_indices_ptr = static_cast(req_pool_indices.data_ptr()); + const IndexT* seq_lens_ptr = static_cast(seq_lens.data_ptr()); + int16_t* lru_slots_ptr = static_cast(lru_slots.data_ptr()); + + const auto device = LaunchKernel::resolve_device(top_k_tokens.device()); + + int64_t* transfer_tasks_src_ptr = static_cast(transfer_tasks_src.data_ptr()); + int64_t* transfer_tasks_dst_ptr = static_cast(transfer_tasks_dst.data_ptr()); + const int64_t max_transfer_tasks = bs * NUM_TOP_K * page_size; + + // Kernel 1: Determine hits/misses and collect transfer tasks + LaunchKernel(bs, BLOCK_SIZE, device)( + load_cache_to_device_buffer_kernel, + top_k_tokens_ptr, + device_buffer_tokens_ptr, + host_cache_locs_ptr, + device_buffer_locs_ptr, + host_cache_k_ptr, + host_cache_v_ptr, + device_buffer_k_ptr, + device_buffer_v_ptr, + top_k_device_locs_ptr, + diff_map_ptr, + req_pool_indices_ptr, + seq_lens_ptr, + lru_slots_ptr, + transfer_tasks_src_ptr, + transfer_tasks_dst_ptr, + buffer_stride_0, + buffer_stride_1, + host_stride, + diff_map_stride, + lru_slot_stride_0, + lru_slot_stride_1, + top_k_tokens_stride, + top_k_device_locs_stride, + page_size, + layer_id, + item_size_bytes); + + // Kernel 2: parallel transfer + constexpr int TRANSFER_BLOCK_SIZE = 256; + constexpr int WARPS_PER_BLOCK = TRANSFER_BLOCK_SIZE / WARP_SIZE; + const int64_t num_transfer_blocks = (max_transfer_tasks + WARPS_PER_BLOCK - 1) / WARPS_PER_BLOCK; + + LaunchKernel(num_transfer_blocks, TRANSFER_BLOCK_SIZE, device)( + parallel_transfer_kernel, + transfer_tasks_src_ptr, + transfer_tasks_dst_ptr, + host_cache_k_ptr, + host_cache_v_ptr, + device_buffer_k_ptr, + device_buffer_v_ptr, + max_transfer_tasks, + page_size, + 0, // src padding + 0, // dst padding + item_size_bytes, + IsMLA); + } +}; + +template +void load_cache_to_device_buffer( + tvm::ffi::TensorView top_k_tokens, + tvm::ffi::TensorView device_buffer_tokens, + tvm::ffi::TensorView host_cache_locs, + tvm::ffi::TensorView device_buffer_locs, + tvm::ffi::TensorView host_cache_k, + tvm::ffi::TensorView host_cache_v, + tvm::ffi::TensorView device_buffer_k, + tvm::ffi::TensorView device_buffer_v, + tvm::ffi::TensorView top_k_device_locs, + tvm::ffi::TensorView diff_map, + tvm::ffi::TensorView req_pool_indices, + tvm::ffi::TensorView seq_lens, + tvm::ffi::TensorView lru_slots, + tvm::ffi::TensorView transfer_tasks_src, + tvm::ffi::TensorView transfer_tasks_dst, + int64_t page_size, + int64_t layer_id, + int64_t item_size_bytes) { + const auto& dtype = req_pool_indices.dtype(); + const bool is_int64 = (dtype.bits == 64); + + if (is_int64) { + SparseCacheKernel::template run( + top_k_tokens, + device_buffer_tokens, + host_cache_locs, + device_buffer_locs, + host_cache_k, + host_cache_v, + device_buffer_k, + device_buffer_v, + top_k_device_locs, + diff_map, + req_pool_indices, + seq_lens, + lru_slots, + transfer_tasks_src, + transfer_tasks_dst, + page_size, + layer_id, + item_size_bytes); + } else { + SparseCacheKernel::template run( + top_k_tokens, + device_buffer_tokens, + host_cache_locs, + device_buffer_locs, + host_cache_k, + host_cache_v, + device_buffer_k, + device_buffer_v, + top_k_device_locs, + diff_map, + req_pool_indices, + seq_lens, + lru_slots, + transfer_tasks_src, + transfer_tasks_dst, + page_size, + layer_id, + item_size_bytes); + } +} + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh b/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh new file mode 100644 index 000000000000..6476a3be232f --- /dev/null +++ b/python/sglang/jit_kernel/csrc/moe/moe_fused_gate.cuh @@ -0,0 +1,363 @@ +#include +#include + +#include +#include +#include + +#include + +#include +#include + +namespace { + +constexpr uint32_t kWarpSize = 32; +constexpr uint32_t kWarpsPerCTA = 6; +constexpr uint32_t kSmallTokenThreshold = 512; +constexpr uint32_t kMaxExperts = 512; +constexpr uint32_t kMaxTopK = 16; + +enum class ScoringFunc : uint32_t { + kSigmoid = 0, + kSqrtSoftplus = 1, +}; + +struct MoEFusedGateParams { + const float* __restrict__ input; + const float* __restrict__ bias; + float* __restrict__ output; + int32_t* __restrict__ indices; + uint32_t num_rows; + uint32_t num_experts; + uint32_t topk; + uint32_t num_fused_shared_experts; + bool renormalize; + float routed_scaling_factor; + bool apply_routed_scaling_factor_on_output; +}; + +template +__device__ __forceinline__ float compute_score(float x) { + if constexpr (kScoringFunc == ScoringFunc::kSigmoid) { + // sigmoid(x) = 1 / (1 + exp(-x)) + return 1.0f / (1.0f + expf(-x)); + } else { + // sqrt(softplus(x)) = sqrt(log(1 + exp(x))) + float softplus = log1pf(expf(x)); + return sqrtf(softplus); + } +} + +template +__global__ void moe_fused_gate_kernel_small_token(const MoEFusedGateParams __grid_constant__ params) { + const auto& [input, bias, output, indices, num_rows, num_experts, topk, num_fused_shared_experts, renormalize, routed_scaling_factor, apply_routed_scaling_factor_on_output] = + params; + + uint32_t row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + // number of routed experts to select (excluding fused shared experts) + const uint32_t topk_routed = topk - num_fused_shared_experts; + + uint32_t tid = threadIdx.x; + uint32_t warp_id = tid / kWarpSize; + uint32_t lane_id = tid % kWarpSize; + + extern __shared__ float shared_mem[]; + float* shared_scores = shared_mem; + float* shared_original_scores = shared_mem + num_experts; + + // For warp-level reduction + __shared__ float warp_maxs[kWarpsPerToken]; + __shared__ int warp_experts[kWarpsPerToken]; + __shared__ int selected_experts[kMaxTopK]; + + for (uint32_t e = tid; e < num_experts; e += blockDim.x) { + float input_val = input[row_idx * num_experts + e]; + float bias_val = bias[e]; + float score_val = compute_score(input_val); + float biased_val = score_val + bias_val; + shared_scores[e] = biased_val; + shared_original_scores[e] = score_val; + } + + __syncthreads(); + + // only select topk_routed experts (excluding shared experts) + for (uint32_t k = 0; k < topk_routed; k++) { + float my_val = -FLT_MAX; + int my_expert = -1; + for (uint32_t e = tid; e < num_experts; e += blockDim.x) { + if (shared_scores[e] > my_val) { + my_val = shared_scores[e]; + my_expert = e; + } + } + + float warp_max_val = my_val; + int warp_max_expert = my_expert; + +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, warp_max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, warp_max_expert, offset); + if (other_val > warp_max_val) { + warp_max_val = other_val; + warp_max_expert = other_expert; + } + } + + if (lane_id == 0 && warp_id < kWarpsPerToken) { + warp_maxs[warp_id] = warp_max_val; + warp_experts[warp_id] = warp_max_expert; + } + + __syncthreads(); + + if (warp_id == 0) { + float final_max = (lane_id < kWarpsPerToken) ? warp_maxs[lane_id] : -FLT_MAX; + int final_expert = (lane_id < kWarpsPerToken) ? warp_experts[lane_id] : -1; + +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, final_max, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, final_expert, offset); + if (other_val > final_max) { + final_max = other_val; + final_expert = other_expert; + } + } + + if (lane_id == 0) { + selected_experts[k] = final_expert; + } + } + + __syncthreads(); + + int selected = selected_experts[k]; + if (selected >= 0 && tid == 0) { + shared_scores[selected] = -FLT_MAX; + } + + __syncthreads(); + } + + static_assert(kMaxTopK <= device::kWarpThreads); + if (tid >= device::kWarpThreads) return; + + // only use the first warp to perform write to global operation + float routed_weight = 0.0f; + int32_t selected_expert = 0; + if (tid < topk_routed) { + int expert_id = selected_experts[tid]; + float score = shared_original_scores[expert_id]; + if (expert_id >= 0 && expert_id < static_cast(num_experts)) { + routed_weight = score; + selected_expert = expert_id; + } + } + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (tid < topk) { + const bool is_shared = tid >= topk_routed; + const auto output_offset = row_idx * topk + tid; + const auto weight = is_shared ? (routed_sum / routed_scaling_factor) : routed_weight; + const auto expert_id = is_shared ? (num_experts + tid - topk_routed) : selected_expert; + const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f; + const auto norm = renormalize && routed_sum > 0.0f ? routed_sum : 1.0f; + output[output_offset] = weight / norm * scale; + indices[output_offset] = expert_id; + } +} + +template +__global__ void moe_fused_gate_kernel(const MoEFusedGateParams __grid_constant__ params) { + const auto& [input, bias, output, indices, num_rows, num_experts, topk, num_fused_shared_experts, renormalize, routed_scaling_factor, apply_routed_scaling_factor_on_output] = + params; + + uint32_t row_idx = blockIdx.x * kWarpsPerCTA + threadIdx.y; + if (row_idx >= num_rows) return; + + // number of routed experts to select (excluding fused shared experts) + const uint32_t topk_routed = topk - num_fused_shared_experts; + + uint32_t lane_id = threadIdx.x; + uint32_t warp_id = threadIdx.y; + + extern __shared__ float shared_mem[]; + float* shared_scores = shared_mem + warp_id * num_experts * 2; + float* shared_original_scores = shared_scores + num_experts; + __shared__ int selected_experts[kWarpsPerCTA][kMaxTopK]; + int* warp_selected_experts = selected_experts[warp_id]; + + for (uint32_t e = lane_id; e < num_experts; e += kWarpSize) { + float input_val = input[row_idx * num_experts + e]; + float bias_val = bias[e]; + float score_val = compute_score(input_val); + float biased_val = score_val + bias_val; + shared_scores[e] = biased_val; + shared_original_scores[e] = score_val; + } + + __syncwarp(); + + // only select topk_routed experts + for (uint32_t k = 0; k < topk_routed; k++) { + float max_val = -FLT_MAX; + int max_expert = -1; + + for (uint32_t expert = lane_id; expert < num_experts; expert += kWarpSize) { + if (shared_scores[expert] > max_val) { + max_val = shared_scores[expert]; + max_expert = expert; + } + } + + for (int offset = kWarpSize / 2; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, max_expert, offset); + + if (other_val > max_val || (other_val == max_val && other_expert < max_expert)) { + max_val = other_val; + max_expert = other_expert; + } + } + + if (lane_id == 0) { + warp_selected_experts[k] = max_expert; + if (max_expert != -1) { + shared_scores[max_expert] = -FLT_MAX; + } + } + + __syncwarp(); + } + + static_assert(kMaxTopK <= device::kWarpThreads); + + float routed_weight = 0.0f; + int32_t selected_expert = 0; + if (lane_id < topk_routed) { + int expert_id = warp_selected_experts[lane_id]; + if (expert_id >= 0 && expert_id < static_cast(num_experts)) { + routed_weight = shared_original_scores[expert_id]; + selected_expert = expert_id; + } + } + const auto routed_sum = device::warp::reduce_sum(routed_weight); + if (lane_id < topk) { + const bool is_shared = lane_id >= topk_routed; + const auto output_idx = row_idx * topk + lane_id; + const auto weight = is_shared ? (routed_sum / routed_scaling_factor) : routed_weight; + const auto expert_id = is_shared ? (num_experts + lane_id - topk_routed) : selected_expert; + const auto scale = apply_routed_scaling_factor_on_output ? routed_scaling_factor : 1.0f; + const auto norm = renormalize && routed_sum > 0.0f ? routed_sum : 1.0f; + output[output_idx] = weight / norm * scale; + indices[output_idx] = expert_id; + } +} + +template +void dispatch_small_token_kernel( + uint32_t num_rows, + uint32_t threads_per_block, + uint32_t warps_per_token, + DLDevice device, + size_t smem_per_row, + const MoEFusedGateParams& params) { + using namespace host; + if (warps_per_token <= 8) { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<8, kScoringFunc>, params); + } else if (warps_per_token <= 12) { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<12, kScoringFunc>, params); + } else { + LaunchKernel(num_rows, threads_per_block, device, smem_per_row)( + moe_fused_gate_kernel_small_token<16, kScoringFunc>, params); + } +} + +struct MoEFusedGateKernel { + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView bias, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView indices, + uint32_t topk, + uint32_t scoring_func, // 0 = sigmoid, 1 = sqrtsoftplus + uint32_t num_fused_shared_experts, + bool renormalize, + float routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { + using namespace host; + + auto N = SymbolicSize{"num_rows"}; + auto E = SymbolicSize{"num_experts"}; + auto K = SymbolicSize{"topk"}; + auto device = SymbolicDevice{}; + K.set_value(topk); + device.set_options(); + + TensorMatcher({N, E}).with_dtype().with_device(device).verify(input); + TensorMatcher({E}).with_dtype().with_device(device).verify(bias); + TensorMatcher({N, K}).with_dtype().with_device(device).verify(output); + TensorMatcher({N, K}).with_dtype().with_device(device).verify(indices); + + const auto num_rows = static_cast(N.unwrap()); + const auto num_experts = static_cast(E.unwrap()); + + RuntimeCheck(num_experts <= kMaxExperts, "num_experts exceeds maximum supported value"); + RuntimeCheck(scoring_func <= 1, "scoring_func must be 0 (sigmoid) or 1 (sqrtsoftplus)"); + RuntimeCheck(topk > num_fused_shared_experts, "topk must be greater than num_fused_shared_experts"); + + const auto params = MoEFusedGateParams{ + .input = static_cast(input.data_ptr()), + .bias = static_cast(bias.data_ptr()), + .output = static_cast(output.data_ptr()), + .indices = static_cast(indices.data_ptr()), + .num_rows = num_rows, + .num_experts = num_experts, + .topk = topk, + .num_fused_shared_experts = num_fused_shared_experts, + .renormalize = renormalize, + .routed_scaling_factor = routed_scaling_factor, + .apply_routed_scaling_factor_on_output = apply_routed_scaling_factor_on_output, + }; + + const size_t smem_per_row = 2 * num_experts * sizeof(float); + + bool use_small_token_kernel = num_rows <= kSmallTokenThreshold; + + if (use_small_token_kernel) { + // 1 token per block + uint32_t warps_per_token = div_ceil(num_experts, kWarpSize); + warps_per_token = std::min(warps_per_token, 16u); + uint32_t threads_per_block = warps_per_token * kWarpSize; + + if (scoring_func == 0) { + dispatch_small_token_kernel( + num_rows, threads_per_block, warps_per_token, device.unwrap(), smem_per_row, params); + } else { + dispatch_small_token_kernel( + num_rows, threads_per_block, warps_per_token, device.unwrap(), smem_per_row, params); + } + } else { + // multiple tokens per block + uint32_t num_blocks = div_ceil(num_rows, kWarpsPerCTA); + dim3 block_dim(kWarpSize, kWarpsPerCTA); + size_t large_smem = smem_per_row * kWarpsPerCTA; + + if (scoring_func == 0) { + LaunchKernel(num_blocks, block_dim, device.unwrap(), large_smem)( + moe_fused_gate_kernel, params); + } else { + LaunchKernel(num_blocks, block_dim, device.unwrap(), large_smem)( + moe_fused_gate_kernel, params); + } + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/deepseek_v4.py b/python/sglang/jit_kernel/deepseek_v4.py new file mode 100644 index 000000000000..d092c424dc24 --- /dev/null +++ b/python/sglang/jit_kernel/deepseek_v4.py @@ -0,0 +1,845 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Literal, NamedTuple, Optional, Tuple, Union + +import torch +import triton +import triton.language as tl + +from sglang.jit_kernel.utils import ( + cache_once, + is_arch_support_pdl, + load_jit, + make_cpp_args, +) +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +def make_name(name: str) -> str: + return f"dpsk_v4_{name}" + + +@cache_once +def _jit_common_module() -> Module: + return load_jit( + make_name(f"common"), + cuda_files=[f"deepseek_v4/common.cuh"], + cuda_wrappers=[("plan_compress_prefill", "plan_compress_prefill")], + ) + + +@cache_once +def _jit_topk_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("topk"), + *args, + cuda_files=["deepseek_v4/topk.cuh"], + cuda_wrappers=[("topk_transform", f"TopK512Kernel<{args}>::transform")], + ) + + +@cache_once +def _jit_topk_v2_module() -> Module: + return load_jit( + make_name("topk_v2"), + cuda_files=["deepseek_v4/topk_v2.cuh"], + cuda_wrappers=[("topk_transform", "TopK512Kernel::transform")], + ) + + +@cache_once +def _jit_hash_topk_module() -> Module: + args = make_cpp_args("act_sqrt_softplus", is_arch_support_pdl()) + return load_jit( + make_name("hash_topk"), + *args, + cuda_files=["deepseek_v4/hash_topk.cuh"], + cuda_wrappers=[("hash_topk", f"HashTopKKernel<{args}>::run")], + ) + + +@cache_once +def _jit_compress_module( + head_dim: int, + dtype_in: torch.dtype, + dtype_out: torch.dtype, + ratio: Literal[4, 128], +) -> Module: + args = make_cpp_args(head_dim, dtype_in, dtype_out, is_arch_support_pdl()) + kernel_class = f"FlashCompress{ratio}Kernel<{args}>" + return load_jit( + make_name(f"compress_{ratio}"), + *args, + cuda_files=[f"deepseek_v4/c{ratio}.cuh"], + cuda_wrappers=[ + ("decode", f"{kernel_class}::run_decode"), + ("prefill", f"{kernel_class}::run_prefill"), + ], + ) + + +@cache_once +def _jit_fused_rope_module() -> Module: + args = make_cpp_args(is_arch_support_pdl()) + return load_jit( + make_name("fused_rope"), + *args, + cuda_files=["deepseek_v4/rope.cuh"], + cuda_wrappers=[("forward", f"FusedQKRopeKernel<{args}>::forward")], + ) + + +@cache_once +def _jit_norm_rope_module( + dtype: torch.dtype, + head_dim: int, + rope_dim: int, +) -> Module: + args = make_cpp_args(dtype, head_dim, rope_dim, is_arch_support_pdl()) + return load_jit( + make_name(f"fused_norm_rope"), + *args, + cuda_files=[f"deepseek_v4/fused_norm_rope.cuh"], + cuda_wrappers=[ + ("forward", f"FusedNormRopeKernel<{args}>::forward"), + ], + ) + + +@cache_once +def _jit_fused_store_module( + name: Literal["flashmla", "indexer"], + input_dtype: torch.dtype, + index_dtype: torch.dtype, + page_size: int, +) -> Module: + args = make_cpp_args(input_dtype, index_dtype, page_size, is_arch_support_pdl()) + cname = "FlashMLA" if name == "flashmla" else "Indexer" + kernel_class = f"FusedStoreCache{cname}Kernel<{args}>" + return load_jit( + make_name("store_" + name), + *args, + cuda_files=["deepseek_v4/store.cuh"], + cuda_wrappers=[("run", f"{kernel_class}::run")], + ) + + +@cache_once +def _jit_metadata_module(): + return load_jit( + make_name("metadata"), + cuda_files=["deepseek_v4/paged_mqa_metadata.cuh"], + cuda_wrappers=[("run", "IndexerMetadataKernel::run")], + ) + + +def topk_transform_512( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, + ver: Literal[1, 2] = 1, +) -> None: + """Output to page_indices tensor, optionally also output raw abs position indices""" + module = _jit_topk_v2_module() if ver == 2 else _jit_topk_module() + module.topk_transform( + scores, seq_lens, page_tables, out_page_indices, page_size, out_raw_indices + ) + + +def hash_topk( + router_logits: torch.Tensor, + input_ids: torch.Tensor, + tid2eid: torch.Tensor, + num_fused_shared_experts: int = 0, + routed_scaling_factor: float = 1.0, + scoring_func: str = "sqrtsoftplus", +) -> Tuple[torch.Tensor, torch.Tensor]: + assert scoring_func == "sqrtsoftplus" + num_tokens = router_logits.size(0) + topk_routed = tid2eid.size(1) + topk_fused = topk_routed + num_fused_shared_experts + topk_ids = torch.empty( + (num_tokens, topk_fused), dtype=torch.int32, device=router_logits.device + ) + topk_weights = torch.empty( + (num_tokens, topk_fused), dtype=torch.float32, device=router_logits.device + ) + module = _jit_hash_topk_module() + module.hash_topk( + router_logits, + input_ids, + tid2eid, + topk_weights, + topk_ids, + routed_scaling_factor, + ) + return topk_weights, topk_ids + + +class CompressorPrefillPlan(NamedTuple): + compress_ratio: int + compress_plan: torch.Tensor + write_plan: torch.Tensor + + def copy_(self, other: CompressorPrefillPlan) -> None: + assert self.compress_ratio == other.compress_ratio + self.compress_plan.copy_(other.compress_plan) + self.write_plan.copy_(other.write_plan) + + @staticmethod + def generate( + compress_ratio: Literal[4, 128], + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: torch.Tensor, + device: torch.device, + use_cuda_graph: bool = False, + ) -> CompressorPrefillPlan: + assert seq_lens.device == extend_lens.device + seq_lens = seq_lens.to(torch.int64) + extend_lens = extend_lens.to(torch.int64) + plan_tensor = torch.empty( + (2, num_q_tokens, 16), + dtype=torch.uint8, + device=seq_lens.device, + pin_memory=seq_lens.is_cpu, + ) + module = _jit_common_module() + is_overlap = compress_ratio == 4 + # NOTE: when seq_lens on CUDA device or use_cuda_graph = True, + # the C++/CUDA implementation will pad up to num_q_tokens + plan_lens = module.plan_compress_prefill( + extend_lens, + seq_lens, + plan_tensor[0], + plan_tensor[1], + compress_ratio, + is_overlap, + use_cuda_graph, + ) + return CompressorPrefillPlan( + compress_ratio, + plan_tensor[0, : plan_lens[0]].to(device, non_blocking=True), + plan_tensor[1, : plan_lens[1]].to(device, non_blocking=True), + ) + + +# NOTE: only decode plan is compatible with cuda graph +class CompressorDecodePlan(NamedTuple): + compress_ratio: int + seq_lens: torch.Tensor + + def copy_(self, other: CompressorDecodePlan) -> None: + assert self.compress_ratio == other.compress_ratio + self.seq_lens.copy_(other.seq_lens) + + +def compress_plan( + compress_ratio: Literal[4, 128], + num_q_tokens: int, + seq_lens: torch.Tensor, + extend_lens: Optional[torch.Tensor], + device: torch.device, +) -> Union[CompressorDecodePlan, CompressorPrefillPlan]: + if extend_lens is not None: + return CompressorPrefillPlan.generate( + compress_ratio, + num_q_tokens, + seq_lens, + extend_lens, + device, + ) + else: + assert num_q_tokens == len(seq_lens) + seq_lens = seq_lens.to(device, non_blocking=True) + return CompressorDecodePlan(compress_ratio, seq_lens) + + +def compress_forward( + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + ape: torch.Tensor, + indices: torch.Tensor, + plan: Union[CompressorDecodePlan, CompressorPrefillPlan, None] = None, + extra_data: Optional[torch.Tensor] = None, + *, + head_dim: int, + compress_ratio: Literal[4, 128], + out: Optional[torch.Tensor] = None, + seq_lens: Optional[torch.Tensor] = None, + extend_lens: Optional[torch.Tensor] = None, +) -> torch.Tensor: + # TODO(dark): support dynamic plan and dispatch for decode kernel + # Currently, there's no load-balancing for compression kernel + # In worst cases, few SM will be overloaded with most compression work. + # For C4, this may not be a big issue, since the compression is fast enough, + # and the compression is quite common (with an probability of 1/4 in average). + # For C128, the compression involves CTA reduction, which is relatively slow, + # and the compression is rare (with an probability of 1/128 in average). + # We may need to implement dynamic dispatch to better balance the load among SMs. + # We may need some interface like `module.plan(...)` to prepare before forward pass. + assert head_dim % 128 == 0 + num_q_tokens = kv_score_input.shape[0] + if out is None: + out = kv_score_input.new_empty((num_q_tokens, head_dim)) + if plan is None: + assert seq_lens is not None + plan = compress_plan( + compress_ratio, + num_q_tokens, + seq_lens, + extend_lens, + kv_score_input.device, + ) + assert plan.compress_ratio == compress_ratio, "Mismatched compress ratio in plan!" + module = _jit_compress_module( + head_dim, + kv_score_input.dtype, + out.dtype, + compress_ratio, + ) + F = module.decode if isinstance(plan, CompressorDecodePlan) else module.prefill + F(kv_score_buffer, kv_score_input, out, ape, indices, *plan[1:], extra_data) + return out + + +def compress_fused_norm_rope_inplace( + kv: torch.Tensor, + weight: torch.Tensor, + eps: float, + freq_cis: torch.Tensor, + plan: Union[CompressorDecodePlan, CompressorPrefillPlan], +) -> None: + freq_cis = torch.view_as_real(freq_cis).flatten(-2) + module = _jit_norm_rope_module(kv.dtype, kv.shape[-1], freq_cis.shape[-1]) + module.forward( + kv, + weight, + plan[1], # decode: seq_lens, prefill: compress_plan + freq_cis, + 1 if isinstance(plan, CompressorDecodePlan) else 0, # mode + eps, + plan.compress_ratio, + ) + + +def fused_norm_rope_inplace( + kv: torch.Tensor, + weight: torch.Tensor, + eps: float, + freq_cis: torch.Tensor, + positions: torch.Tensor, +) -> None: + freq_cis = torch.view_as_real(freq_cis).flatten(-2) + module = _jit_norm_rope_module(kv.dtype, kv.shape[-1], freq_cis.shape[-1]) + module.forward( + kv, + weight, + positions, + freq_cis, + 2, # mode + eps, + 0, # compress_ratio (no use in this mode) + ) + + +def fused_rope( + q: torch.Tensor, + k: Optional[torch.Tensor], + freqs_cis: torch.Tensor, + positions: torch.Tensor, + inverse: bool = False, +) -> None: + """Apply rotary embeddings to both Q and K in a single fused CUDA kernel. + + Args: + q: [batch_size, num_q_heads, rope_dim] bfloat16 + k: [batch_size, num_k_heads, rope_dim] bfloat16 or None + freqs_cis: [max_seq_len, rope_dim // 2] complex64 (full table) + positions: [batch_size] int32 or int64, indices into freqs_cis + inverse: if True, apply inverse rotation (conjugate freqs) + """ + from sglang.srt.utils import is_hip + + if is_hip(): + from sglang.srt.layers.deepseek_v4_rope import apply_rotary_emb_triton + + apply_rotary_emb_triton(q, freqs_cis, positions=positions, inverse=inverse) + if k is not None: + apply_rotary_emb_triton(k, freqs_cis, positions=positions, inverse=inverse) + return + + freqs_real = torch.view_as_real(freqs_cis).flatten(-2).contiguous() + module = _jit_fused_rope_module() + module.forward(q, k, freqs_real, positions, inverse) + + +@cache_once +def _tilelang_make_swa_indices_kernel(swa_window_size: int, threads: int = 128) -> Any: + import tilelang + import tilelang.language as T + + batch_size = T.dynamic("batch_size") + batch_size_plus_1 = T.dynamic("batch_size_plus_1") + num_q_tokens = T.dynamic("num_q_tokens") + num_warps = threads // 32 + assert swa_window_size % 32 == 0 + + @tilelang.jit + def make_swa_prefill_indices( + seq_lens_k: T.Tensor[(batch_size,), T.int32], + seq_lens_q: T.Tensor[(batch_size,), T.int32], + cu_seqlens_q: T.Tensor[(batch_size_plus_1,), T.int32], + swa_indices: T.Tensor[(num_q_tokens, swa_window_size), T.int32], + ): + _ = batch_size_plus_1 # unused, but don't remove it + with T.Kernel(T.ceildiv(num_q_tokens, num_warps), threads=threads) as bx: + # each warp handles 1 q token + tx = T.get_thread_binding() + warp_id = tx // 32 + lane_id = tx % 32 + s_batch_id = T.alloc_shared((num_warps,), dtype=T.int32) + + token_id = warp_id + bx * num_warps + if token_id >= num_q_tokens: + return + for i in T.serial(0, batch_size, step=32): + j = i + lane_id + if cu_seqlens_q[j] <= token_id < cu_seqlens_q[j + 1]: + s_batch_id[warp_id] = j + T.sync_warp() + + seq_idx = s_batch_id[warp_id] + kv_len = seq_lens_k[seq_idx] + qo_len = seq_lens_q[seq_idx] + cum_qo_len = cu_seqlens_q[seq_idx] + prefix_len = kv_len - qo_len + curr_seq_qo_idx = token_id - cum_qo_len + end_abs_pos = prefix_len + curr_seq_qo_idx + 1 + start_abs_pos = T.max(end_abs_pos - swa_window_size, 0) + old_kv_start = seq_idx * swa_window_size + new_kv_start = batch_size * swa_window_size + cum_qo_len + + for i in T.unroll(0, swa_window_size, step=32): + j = i + lane_id + abs_pos = start_abs_pos + j + swa_indices[token_id, j] = T.if_then_else( + abs_pos < end_abs_pos, + T.if_then_else( + abs_pos < prefix_len, + old_kv_start + abs_pos % swa_window_size, + new_kv_start + (abs_pos - prefix_len), + ), + -1, + ) + + return make_swa_prefill_indices + + +def tilelang_make_swa_prefill_indices( + seq_lens_k: torch.Tensor, + seq_lens_q: torch.Tensor, + swa_indices: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if cu_seqlens_q is None: + cu_seqlens_q = torch.cumsum(seq_lens_q, dim=0, dtype=torch.int32) + cu_seqlens_q = torch.nn.functional.pad(cu_seqlens_q, (1, 0), value=0) + swa_window_size = swa_indices.shape[1] + kernel = _tilelang_make_swa_indices_kernel(swa_window_size) + kernel(seq_lens_k, seq_lens_q, cu_seqlens_q, swa_indices) + return swa_indices + + +@triton.jit +def create_paged_compress_data_kernel( + req_pool_indices_ptr, # int32 [batch] + seq_lens_ptr, # int32 [batch] + extend_seq_lens_ptr, # int32 [batch] + req_to_token_ptr, # int32 [A, B] + full_to_swa_index_mapping_ptr, # int32 [C] + out_0_ptr, # int32 [batch] + out_1_ptr, # int32 [batch, out_dim] + batch_size, + stride_req_to_token_0, + stride_req_to_token_1: tl.constexpr, # 1 + stride_out_1_0, + stride_out_1_1: tl.constexpr, # 1 + compress_ratio: tl.constexpr, + is_overlap: tl.constexpr, # 0/1 + swa_page_size: tl.constexpr, + ring_size: tl.constexpr, + BLOCK: tl.constexpr, +) -> None: + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < batch_size + + # load per-batch + rid = tl.load(req_pool_indices_ptr + offs, mask=mask, other=0).to(tl.int32) + seq_len = tl.load(seq_lens_ptr + offs, mask=mask, other=0).to(tl.int32) + extend_len = tl.load(extend_seq_lens_ptr + offs, mask=mask, other=0).to(tl.int32) + prefix_len = seq_len - extend_len + + cr = compress_ratio + write_pos = ((seq_len - 1) // cr) * cr + load_pos = ((prefix_len - 1) // cr) * cr + write_overlap_pos = write_pos - cr + load_overlap_pos = load_pos - cr + v0 = tl.zeros([BLOCK], tl.int32) + v1 = tl.zeros([BLOCK], tl.int32) + v2 = tl.zeros([BLOCK], tl.int32) + v3 = tl.zeros([BLOCK], tl.int32) + + for i in tl.static_range(4): + if i == 0: + pos = load_pos + elif i == 1: + pos = write_pos + elif i == 2: + pos = load_overlap_pos + else: + pos = write_overlap_pos + pos = tl.maximum(pos, 0) + # req_to_token[rid, pos] + loc = tl.load( + req_to_token_ptr + + rid * stride_req_to_token_0 + + pos * stride_req_to_token_1, + mask=mask, + other=0, + ).to(tl.int32) + swa_loc = tl.load(full_to_swa_index_mapping_ptr + loc, mask=mask, other=0).to( + tl.int32 + ) + swa_page = swa_loc // swa_page_size + state_loc = swa_page * ring_size + (swa_loc % ring_size) + state_loc = state_loc // cr + if i == 0: + v0 = state_loc + elif i == 1: + v1 = state_loc + elif i == 2: + v2 = state_loc + else: + v3 = state_loc + + tl.store(out_0_ptr + offs, v1, mask=mask) + + if is_overlap: + base = out_1_ptr + offs * stride_out_1_0 + tl.store(base + 0 * stride_out_1_1, v2, mask=mask) + tl.store(base + 1 * stride_out_1_1, v0, mask=mask) + tl.store(base + 2 * stride_out_1_1, v3, mask=mask) + tl.store(base + 3 * stride_out_1_1, write_pos.to(tl.int32), mask=mask) + else: + base = out_1_ptr + offs * stride_out_1_0 + tl.store(base + 0 * stride_out_1_1, v0, mask=mask) + + +def triton_create_paged_compress_data( + *, + compress_ratio: int, + is_overlap: bool, + swa_page_size: int, + ring_size: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + req_to_token: torch.Tensor, + full_to_swa_index_mapping: torch.Tensor, + block: int = 128, +) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = req_pool_indices.shape[0] + out_dim = 4 if is_overlap else 1 + device_args: dict = dict(device=req_pool_indices.device, dtype=torch.int32) + out_0 = torch.empty((batch_size,), **device_args) + out_1 = torch.empty((batch_size, out_dim), **device_args) + grid = (triton.cdiv(batch_size, block),) + create_paged_compress_data_kernel[grid]( + req_pool_indices, + seq_lens, + extend_seq_lens, + req_to_token, + full_to_swa_index_mapping, + out_0, + out_1, + batch_size=batch_size, # type: ignore + stride_req_to_token_0=req_to_token.stride(0), # type: ignore + stride_req_to_token_1=req_to_token.stride(1), # type: ignore + stride_out_1_0=out_1.stride(0), # type: ignore + stride_out_1_1=out_1.stride(1), # type: ignore + compress_ratio=compress_ratio, # type: ignore + is_overlap=1 if is_overlap else 0, # type: ignore + swa_page_size=swa_page_size, # type: ignore + ring_size=ring_size, # type: ignore + BLOCK=block, # type: ignore + ) + if not is_overlap: + out_1.squeeze_(1) + return out_0, out_1 + + +def fused_store_cache( + input: torch.Tensor, + cache: torch.Tensor, + indices: torch.Tensor, + *, + page_size: int, + type: Literal["flashmla", "indexer"], +) -> None: + module = _jit_fused_store_module( + name=type, + input_dtype=input.dtype, + index_dtype=indices.dtype, + page_size=page_size, + ) + module.run(input, cache, indices) + + +@cache_once +def _jit_silu_mul_quant_module( + quant_group_size: int, scale_ue8m0: bool, apply_swiglu_limit: bool +) -> Module: + args = make_cpp_args( + quant_group_size, scale_ue8m0, is_arch_support_pdl(), apply_swiglu_limit + ) + return load_jit( + make_name("silu_mul_quant"), + *args, + cuda_files=["deepseek_v4/silu_and_mul_masked_post_quant.cuh"], + cuda_wrappers=[("run", f"SiluAndMulMaskedPostQuantKernel<{args}>::run")], + ) + + +def silu_and_mul_masked_post_quant( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + scale_ue8m0: bool = False, + topk: int = 8, + transposed: bool = False, + swiglu_limit: Optional[float] = None, +) -> None: + """ + Fused SiLU-and-mul with per-group FP8 quantization for expert-parallel MoE. + + input shape: [expert_num, token_num_padded, hidden_dim] + output shape: [expert_num, token_num_padded, hidden_dim // 2], dtype fp8_e4m3 + output_scale shape: [expert_num, token_num_padded, hidden_dim // 2 // quant_group_size], dtype float32 + masked_m shape: [expert_num], dtype int32. i.e. actual token count per expert + topk: max routed experts per token (grid = token_num_padded * topk blocks) + swiglu_limit: Optional. When None (default), use the original fast path (no clamp). + When set, JIT-compiles a separate kernel variant that clamps gate to + [-inf, L] and up to [-L, L] before silu (fused). + """ + apply_swiglu_limit = swiglu_limit is not None + if apply_swiglu_limit: + deepseek_v4_moe_code_path_checker.observed += 1 + module = _jit_silu_mul_quant_module( + quant_group_size, scale_ue8m0, apply_swiglu_limit + ) + module.run( + input, + output, + output_scale, + masked_m, + topk, + transposed, + float(swiglu_limit) if apply_swiglu_limit else 0.0, + ) + + +def get_paged_mqa_logits_metadata(seq_lens: torch.Tensor, page_size: int, num_sm: int): + assert page_size == 64 + seq_lens = seq_lens.to(torch.int32) + metadata = seq_lens.new_empty(num_sm + 1, 2) + module = _jit_metadata_module() + module.run(seq_lens, metadata) + return metadata + + +@cache_once +def _jit_torch_cublas_bf16_fp32() -> Any: + import torch.utils.cpp_extension + + source = """ +#include +#include +#include + +torch::Tensor linear_bf16_fp32( + torch::Tensor X, + torch::Tensor W) +{ + int batch = X.size(0); + int in_features = X.size(1); + int out_features = W.size(0); + + auto Y = torch::empty( + {batch, out_features}, + torch::dtype(torch::kFloat32).device(X.device())); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + + float alpha = 1.0f; + float beta = 0.0f; + + cublasGemmEx( + handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + out_features, + batch, + in_features, + &alpha, + W.data_ptr(), CUDA_R_16BF, in_features, + X.data_ptr(), CUDA_R_16BF, in_features, + &beta, + Y.data_ptr(), CUDA_R_32F, out_features, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP + ); + + return Y; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("linear_bf16_fp32", &linear_bf16_fp32, "BF16xBF16 -> FP32 linear (no bias)"); +} +""" + module = torch.utils.cpp_extension.load_inline( + name="linear_bf16_fp32", + cpp_sources="", + cuda_sources=source, + extra_cflags=["-O3"], + extra_cuda_cflags=["-O3"], + verbose=False, + ) + return module + + +def linear_bf16_fp32(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + from sglang.srt.environ import envs + + algo = envs.SGLANG_OPT_BF16_FP32_GEMM_ALGO.get() + + if algo == "cublas": + module = _jit_torch_cublas_bf16_fp32() + return module.linear_bf16_fp32(x, y) + elif algo == "deep_gemm": + import deep_gemm + + z = x.new_empty(x.size(0), y.size(0), dtype=torch.float32) + deep_gemm.bf16_gemm_nt(x, y, z) + return z + else: # fall back to torch fp32 GEMM + return torch.nn.functional.linear(x.float(), y.float()) + + +def _compile_one(*input_tuple) -> None: + name, job_fn, *args = input_tuple + print(f"Compiling {name}...", flush=True) + job_fn(*args) + print(f"Finished compiling {name}.", flush=True) + + +def compile_aot(): + c_dtype = torch.float32 # compress uses float32 + jobs = [ + ("cublas", _jit_torch_cublas_bf16_fp32), + ("common", _jit_common_module), + ("topk", _jit_topk_module), + ("hash_topk", _jit_hash_topk_module), + ("rope", _jit_fused_rope_module), + ("metadata", _jit_metadata_module), + ( + "compress_128_4", + _jit_compress_module, + 128, + c_dtype, + c_dtype, + 4, + ), + ( + "compress_512_4", + _jit_compress_module, + 512, + c_dtype, + c_dtype, + 4, + ), + ( + "compress_512_128", + _jit_compress_module, + 512, + c_dtype, + c_dtype, + 128, + ), + ( + "norm_rope_128_64", + _jit_norm_rope_module, + c_dtype, + 128, + 64, + ), + ( + "norm_rope_512_64", + _jit_norm_rope_module, + c_dtype, + 512, + 64, + ), + ( + "store_flashmla_bf16_swa_256", + _jit_fused_store_module, + "flashmla", + torch.bfloat16, + torch.int32, + 256, + ), + ( + "store_flashmla_fp32_c4_64", + _jit_fused_store_module, + "flashmla", + torch.float32, + torch.int32, + 64, + ), + ( + "store_flashmla_fp32_c128_2", + _jit_fused_store_module, + "flashmla", + torch.float32, + torch.int32, + 2, + ), + ( + "store_indexer_fp32_c4_64", + _jit_fused_store_module, + "indexer", + torch.float32, + torch.int32, + 64, + ), + ] + # use multiprocess to speed up compilation + import multiprocessing + + max_parallel_jobs = min(len(jobs), multiprocessing.cpu_count()) + with multiprocessing.Pool(processes=max_parallel_jobs) as pool: + pool.starmap(_compile_one, jobs) + + +if __name__ == "__main__": + compile_aot() diff --git a/python/sglang/jit_kernel/hisparse.py b/python/sglang/jit_kernel/hisparse.py new file mode 100644 index 000000000000..cc9f6b5f15df --- /dev/null +++ b/python/sglang/jit_kernel/hisparse.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING + +import torch + +from sglang.jit_kernel.utils import load_jit, make_cpp_args + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +@functools.cache +def _jit_sparse_transfer_module() -> Module: + return load_jit( + "sparse_transfer", + cuda_files=["deepseek_v4/hisparse_transfer.cuh"], + cuda_wrappers=[("offload", "hisparse_transfer")], + ) + + +@functools.cache +def _jit_sparse_module( + item_size_bytes: int, + block_size: int, + num_top_k: int, + hot_buffer_size: int, + is_mla: bool = False, +) -> Module: + template_args = make_cpp_args(block_size, num_top_k, hot_buffer_size, is_mla) + cache_args = make_cpp_args( + item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla + ) + return load_jit( + "sparse_cache", + *cache_args, + cuda_files=["hisparse.cuh"], + cuda_wrappers=[ + ( + "load_cache_to_device_buffer", + f"load_cache_to_device_buffer<{template_args}>", + ) + ], + ) + + +def load_cache_to_device_buffer_mla( + top_k_tokens: torch.Tensor, + device_buffer_tokens: torch.Tensor, + host_cache_locs: torch.Tensor, + device_buffer_locs: torch.Tensor, + host_cache: torch.Tensor, + device_buffer: torch.Tensor, + top_k_device_locs: torch.Tensor, + diff_map: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + lru_slots: torch.Tensor, + transfer_tasks_src: torch.Tensor, + transfer_tasks_dst: torch.Tensor, + page_size: int, + layer_id: int, + item_size_bytes: int, + *, + block_size: int = 256, + num_top_k: int = 512, + hot_buffer_size: int = 1024, +) -> None: + # Infer parameters if not provided + if num_top_k <= 0: + num_top_k = top_k_tokens.size(-1) + if hot_buffer_size <= 0: + hot_buffer_size = device_buffer_tokens.size(-1) + + # Validate that HOT_BUFFER_SIZE >= NUM_TOP_K + assert ( + hot_buffer_size >= num_top_k + ), f"hot_buffer_size ({hot_buffer_size}) must be >= num_top_k ({num_top_k})" + + module = _jit_sparse_module( + item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla=True + ) + + # Create empty tensors for V cache (not used in MLA) + empty = torch.empty(0) + + module.load_cache_to_device_buffer( + top_k_tokens, + device_buffer_tokens, + host_cache_locs, + device_buffer_locs, + host_cache, + empty, + device_buffer, + empty, + top_k_device_locs, + diff_map, + req_pool_indices, + seq_lens, + lru_slots, + transfer_tasks_src, + transfer_tasks_dst, + page_size, + layer_id, + item_size_bytes, + ) + + +def load_cache_to_device_buffer( + top_k_tokens: torch.Tensor, + device_buffer_tokens: torch.Tensor, + host_cache_locs: torch.Tensor, + device_buffer_locs: torch.Tensor, + host_cache_k: torch.Tensor, + host_cache_v: torch.Tensor, + device_buffer_k: torch.Tensor, + device_buffer_v: torch.Tensor, + top_k_device_locs: torch.Tensor, + diff_map: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + lru_slots: torch.Tensor, + transfer_tasks_src: torch.Tensor, + transfer_tasks_dst: torch.Tensor, + page_size: int, + layer_id: int, + item_size_bytes: int, + *, + block_size: int = 256, + num_top_k: int = 512, + hot_buffer_size: int = 1024, +) -> None: + # Infer parameters if not provided + if num_top_k <= 0: + num_top_k = top_k_tokens.size(-1) + if hot_buffer_size <= 0: + hot_buffer_size = device_buffer_tokens.size(-1) + + # Validate that HOT_BUFFER_SIZE >= NUM_TOP_K + assert ( + hot_buffer_size >= num_top_k + ), f"hot_buffer_size ({hot_buffer_size}) must be >= num_top_k ({num_top_k})" + + module = _jit_sparse_module( + item_size_bytes, block_size, num_top_k, hot_buffer_size, is_mla=False + ) + + module.load_cache_to_device_buffer( + top_k_tokens, + device_buffer_tokens, + host_cache_locs, + device_buffer_locs, + host_cache_k, + host_cache_v, + device_buffer_k, + device_buffer_v, + top_k_device_locs, + diff_map, + req_pool_indices, + seq_lens, + lru_slots, + transfer_tasks_src, + transfer_tasks_dst, + page_size, + layer_id, + item_size_bytes, + ) + + +def offload_to_host( + gpu_ptrs: torch.Tensor, + cpu_ptrs: torch.Tensor, + gpu_indices: torch.Tensor, + cpu_indices: torch.Tensor, +) -> None: + module = _jit_sparse_transfer_module() + module.offload(gpu_ptrs, cpu_ptrs, gpu_indices, cpu_indices) diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh new file mode 100644 index 000000000000..02b166d01c73 --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/compress.cuh @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include + +#include +#include + +#include + +namespace device::compress { + +struct alignas(16) PrefillPlan { + uint32_t ragged_id; + uint32_t batch_id; + uint32_t position; + uint32_t window_len; // must be in `[0, compress_ratio * (1 + is_overlap))` + + bool is_valid(const uint32_t ratio, const bool is_overlap) const { + const uint32_t max_window_len = ratio * (1 + is_overlap); + return window_len < max_window_len; + } +}; + +} // namespace device::compress + +namespace host::compress { + +using device::compress::PrefillPlan; +using PrefillPlanTensorDtype = uint8_t; +inline constexpr int64_t kPrefillPlanDim = 16; + +static_assert(alignof(PrefillPlan) == sizeof(PrefillPlan)); +static_assert(sizeof(PrefillPlan) == kPrefillPlanDim * sizeof(PrefillPlanTensorDtype)); + +} // namespace host::compress diff --git a/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh new file mode 100644 index 000000000000..0a3acc47734a --- /dev/null +++ b/python/sglang/jit_kernel/include/sgl_kernel/deepseek_v4/kvcacheio.cuh @@ -0,0 +1,96 @@ +#include +#include + +#include + +#include + +namespace device::hisparse { + +/// NOTE: We call nope+rope as a "value" here. +/// GPU Cache layout: +/// VALUE 0, VALUE 1, ..., VALUE 63, +/// SCALE 0, SCALE 1, ..., SCALE 63, +/// [Padding to align to 576 bytes] +/// CPU Cache follow a trivial linear layout without any padding. +inline constexpr int64_t kGPUPageSize = 64; +inline constexpr int64_t kGPUPageBits = 6; // log2(kGPUPageSize) +inline constexpr int64_t kValueBytes = 576; +inline constexpr int64_t kScaleBytes = 8; +/// NOTE: FlashMLA requires each page to be aligned to 576 bytes +inline constexpr int64_t kCPUItemBytes = kValueBytes + kScaleBytes; +inline constexpr int64_t kGPUPageBytes = host::div_ceil(kCPUItemBytes * kGPUPageSize, 576) * 576; +inline constexpr int64_t kGPUScaleOffset = kValueBytes * kGPUPageSize; + +struct PointerInfo { + int64_t* value_ptr; + int64_t* scale_ptr; +}; + +SGL_DEVICE PointerInfo get_pointer_gpu(void* cache, int32_t index) { + using namespace device; + static_assert(1 << kGPUPageBits == kGPUPageSize); + const int32_t page_num = index >> kGPUPageBits; + const int32_t page_offset = index & (kGPUPageSize - 1); + const auto page_ptr = pointer::offset(cache, page_num * kGPUPageBytes); + const auto value_ptr = pointer::offset(page_ptr, page_offset * kValueBytes); + const auto scale_ptr = pointer::offset(page_ptr, kGPUScaleOffset + page_offset * kScaleBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +SGL_DEVICE PointerInfo get_pointer_cpu(void* cache, int32_t index) { + using namespace device; + const auto value_ptr = pointer::offset(cache, index * kCPUItemBytes); + const auto scale_ptr = pointer::offset(value_ptr, kValueBytes); + return {static_cast(value_ptr), static_cast(scale_ptr)}; +} + +enum class TransferDirection { + DeviceToDevice = 0, + DeviceToHost = 1, + HostToDevice = 2, +}; + +template +SGL_DEVICE void transfer_item(void* dst_cache, void* src_cache, const int32_t dst_index, const int32_t src_index) { + constexpr bool is_dst_device = (direction != TransferDirection::DeviceToHost); + constexpr bool is_src_device = (direction != TransferDirection::HostToDevice); + constexpr auto dst_fn = is_dst_device ? get_pointer_gpu : get_pointer_cpu; + constexpr auto src_fn = is_src_device ? get_pointer_gpu : get_pointer_cpu; + + const auto [dst_value_ptr, dst_scale_ptr] = dst_fn(dst_cache, dst_index); + const auto [src_value_ptr, src_scale_ptr] = src_fn(src_cache, src_index); + + int64_t local_items[2]; + const int64_t* tail_src_ptr; + int64_t* tail_dst_ptr; + + const int32_t lane_id = threadIdx.x % 32; + + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + local_items[i] = src_value_ptr[j]; + } + + if (lane_id < 8) { // handle the tail element safely + const auto last_id = 64 + lane_id; + tail_src_ptr = src_value_ptr + last_id; + tail_dst_ptr = dst_value_ptr + last_id; + } else { // broadcast load/store is safe + tail_src_ptr = src_scale_ptr; + tail_dst_ptr = dst_scale_ptr; + } + + const auto tail_item = *tail_src_ptr; + + // store first 512 bytes of value + for (int i = 0; i < 2; ++i) { + const auto j = lane_id + i * 32; + dst_value_ptr[j] = local_items[i]; + } + + // store the tail element + *tail_dst_ptr = tail_item; +} + +} // namespace device::hisparse diff --git a/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh b/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh index 5510b44746c9..f48d34181d4f 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/vec.cuh @@ -50,14 +50,10 @@ struct AlignedVector { using storage_t = AlignedStorage; public: - template - SGL_DEVICE void load(const U* ptr, std::size_t offset = 0) { - static_assert(std::is_same_v || std::is_same_v); + SGL_DEVICE void load(const void* ptr, std::size_t offset = 0) { m_storage = reinterpret_cast(ptr)[offset]; } - template - SGL_DEVICE void store(U* ptr, std::size_t offset = 0) const { - static_assert(std::is_same_v || std::is_same_v); + SGL_DEVICE void store(void* ptr, std::size_t offset = 0) const { reinterpret_cast(ptr)[offset] = m_storage; } SGL_DEVICE void fill(T value) { diff --git a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh index d69526e97f29..079ac0155872 100644 --- a/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh +++ b/python/sglang/jit_kernel/include/sgl_kernel/warp.cuh @@ -1,23 +1,24 @@ #pragma once #include +#include // Some warp primitives namespace device::warp { static constexpr uint32_t kFullMask = 0xffffffffu; -template +template SGL_DEVICE T reduce_sum(T value, uint32_t active_mask = kFullMask) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) + for (auto mask = kThreads >> 1; mask > 0; mask >>= 1) value = value + __shfl_xor_sync(active_mask, value, mask, 32); return value; } -template +template SGL_DEVICE T reduce_max(T value, uint32_t active_mask = kFullMask) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) + for (auto mask = kThreads >> 1; mask > 0; mask >>= 1) value = math::max(value, __shfl_xor_sync(active_mask, value, mask, 32)); return value; } diff --git a/python/sglang/jit_kernel/moe_fused_gate.py b/python/sglang/jit_kernel/moe_fused_gate.py new file mode 100644 index 000000000000..a2dd8dd5283c --- /dev/null +++ b/python/sglang/jit_kernel/moe_fused_gate.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Tuple + +import torch + +from sglang.jit_kernel.utils import cache_once, load_jit + +if TYPE_CHECKING: + from tvm_ffi.module import Module + + +_SCORING_FUNC_MAP = { + "sigmoid": 0, + "sqrtsoftplus": 1, +} + + +@cache_once +def _jit_moe_fused_gate_module() -> Module: + return load_jit( + "moe_fused_gate", + cuda_files=["moe/moe_fused_gate.cuh"], + cuda_wrappers=[("moe_fused_gate", "MoEFusedGateKernel::run")], + ) + + +@cache_once +def can_use_moe_fused_gate() -> bool: + logger = logging.getLogger(__name__) + try: + _jit_moe_fused_gate_module() + return True + except Exception as e: + logger.warning(f"Failed to load JIT MoE fused gate kernel: {e}") + return False + + +def moe_fused_gate( + input: torch.Tensor, + bias: torch.Tensor, + topk: int, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + renormalize: bool = True, + routed_scaling_factor: float = 1.0, + apply_routed_scaling_factor_on_output: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused MoE gating kernel that computes top-k experts with biased scoring, no grouping. + + Args: + input: Input tensor of shape (num_rows, num_experts), dtype float32 + bias: Bias tensor of shape (num_experts,), dtype float32 + topk: Number of top experts to select (including fused shared experts) + scoring_func: Scoring function type: "sigmoid" or "sqrtsoftplus" + num_fused_shared_experts: Number of fused shared experts. If > 0, the last + `num_fused_shared_experts` slots in topk are reserved for shared experts. + renormalize: Whether to renormalize the output weights + routed_scaling_factor: Scaling factor for routed weights + apply_routed_scaling_factor_on_output: Whether to apply scaling factor on output + + Returns: + Tuple of (topk_weights, topk_indices): + - topk_weights: (num_rows, topk), float32 + - topk_indices: (num_rows, topk), int32 + """ + scoring_func_int = _SCORING_FUNC_MAP.get(scoring_func.lower()) + assert ( + scoring_func_int is not None + ), f"Unknown scoring_func '{scoring_func}', must be one of {list(_SCORING_FUNC_MAP.keys())}" + + assert input.dtype == torch.float32, "input must be float32" + assert bias.dtype == torch.float32, "bias must be float32" + assert input.ndim == 2, "input must be 2D" + assert bias.ndim == 1, "bias must be 1D" + assert input.size(1) == bias.size(0), "input and bias must have same num_experts" + assert topk > num_fused_shared_experts, "topk must be > num_fused_shared_experts" + + num_rows, _ = input.shape + device = input.device + + output = torch.empty(num_rows, topk, dtype=torch.float32, device=device) + indices = torch.empty(num_rows, topk, dtype=torch.int32, device=device) + + module = _jit_moe_fused_gate_module() + module.moe_fused_gate( + input, + bias, + output, + indices, + topk, + scoring_func_int, + num_fused_shared_experts, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, + ) + + return output, indices diff --git a/python/sglang/jit_kernel/utils.py b/python/sglang/jit_kernel/utils.py index e8358d35d68e..563e8b3322b5 100644 --- a/python/sglang/jit_kernel/utils.py +++ b/python/sglang/jit_kernel/utils.py @@ -42,7 +42,7 @@ def _package_install(): DEFAULT_CFLAGS = ["-std=c++20", "-O3"] DEFAULT_CUDA_CFLAGS = ["-std=c++20", "-O3", "--expt-relaxed-constexpr"] DEFAULT_LDFLAGS = [] -CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, bool, torch.dtype] +CPP_TEMPLATE_TYPE: TypeAlias = Union[int, float, str, bool, torch.dtype] class CPPArgList(list[str]): @@ -54,6 +54,8 @@ def __str__(self) -> str: torch.float: "fp32_t", torch.float16: "fp16_t", torch.bfloat16: "bf16_t", + torch.int32: "int32_t", + torch.int64: "int64_t", } @@ -61,7 +63,8 @@ def make_cpp_args(*args: CPP_TEMPLATE_TYPE) -> CPPArgList: def _convert(arg: CPP_TEMPLATE_TYPE) -> str: if isinstance(arg, bool): return "true" if arg else "false" - if isinstance(arg, (int, float)): + # NOTE: str are treated as global symbols rather than string literals + if isinstance(arg, (int, str, float)): return str(arg) if isinstance(arg, torch.dtype): return CPP_DTYPE_MAP[arg] @@ -169,5 +172,9 @@ def wrapper(*args, **kwargs): def is_arch_support_pdl() -> bool: import torch + from sglang.srt.utils import is_hip + + if is_hip(): + return False device = torch.cuda.current_device() return torch.cuda.get_device_capability(device)[0] >= 9 diff --git a/python/sglang/multimodal_gen/runtime/models/registry.py b/python/sglang/multimodal_gen/runtime/models/registry.py index 5e6367a40c18..07d956420769 100644 --- a/python/sglang/multimodal_gen/runtime/models/registry.py +++ b/python/sglang/multimodal_gen/runtime/models/registry.py @@ -218,7 +218,9 @@ def _try_load_model_cls( try: return model.load_model_cls() except Exception: - logger.exception("Ignore import error when loading '%s'", model_arch) + logger.exception( + "In _try_load_model_cls: Ignore import error when loading '%s'", model_arch + ) return None diff --git a/python/sglang/srt/configs/deepseek_v4.py b/python/sglang/srt/configs/deepseek_v4.py new file mode 100644 index 000000000000..dd15c69f3931 --- /dev/null +++ b/python/sglang/srt/configs/deepseek_v4.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass, field +from typing import Dict, List + +from transformers import PretrainedConfig + +from sglang.srt.layers.quantization.base_config import QuantizationConfig + + +@dataclass +class DeepSeekV4Config(PretrainedConfig): + architectures: List[str] + attention_bias: bool = False + attention_dropout: float = 0.0 + bos_token_id: int = 0 + eos_token_id: int = 1 + ep_size: int = 1 + first_k_dense_replace: int = 0 + hidden_act: str = "silu" + hidden_size: int = 4096 + index_head_dim: int = 128 + index_n_heads: int = 64 + index_topk: int = 512 + initializer_range: float = 0.02 + intermediate_size: int = 2048 + kv_lora_rank: int = 512 + max_position_embeddings: int = 65536 + model_type: str = "deepseek_v4" + moe_intermediate_size: int = 2048 + moe_layer_freq: int = 1 + n_group: int = 8 + n_routed_experts: int = 256 + n_shared_experts: int = 1 + norm_topk_prob: bool = True + + num_attention_heads: int = 64 + num_experts_per_tok: int = 6 + num_hidden_layers: int = 43 + num_key_value_heads: int = 1 + + q_lora_rank: int = 1024 + qk_nope_head_dim: int = 448 + qk_rope_head_dim: int = 64 + + quantization_config: QuantizationConfig = field(default_factory=QuantizationConfig) + + rms_norm_eps: float = 1e-6 + + rope_scaling: Dict[str, float] = field(default_factory=dict) + rope_theta: int = 10000 + + routed_scaling_factor: float = 1.5 + scoring_func: str = "sqrtsoftplus" + + tie_word_embeddings: bool = False + + topk_group: int = 8 + topk_method: str = "noaux_tc" + + use_cache: bool = True + v_head_dim: int = 512 + vocab_size: int = 129280 + o_lora_rank: int = 1024 + o_groups: int = 8 + window_size: int = 128 + + compress_rope_theta: int = 40000 + compress_ratios: List[int] = field(default_factory=list) + + n_hash_layers: int = 3 + hc_mult: int = 4 + hc_sinkhorn_iters: int = 20 + hc_eps: float = 1e-6 diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index b3a49f8a410e..8ddb7d558c7a 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -66,8 +66,15 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool: ) +def is_deepseek_compressed(config: PretrainedConfig) -> bool: + return config.architectures is not None and ( + config.architectures[0] == "DeepseekV4ForCausalLM" + or config.architectures[0] == "DeepseekV4ForCausalLMNextN" + ) + + def get_nsa_index_head_dim(config: PretrainedConfig) -> int: - assert is_deepseek_nsa(config) + assert is_deepseek_nsa(config) or is_deepseek_compressed(config) return config.index_head_dim @@ -272,6 +279,14 @@ def _config_draft_model(self): ): self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN" + if ( + is_draft_model + and self.hf_config.architectures[0] == "DeepseekV4ForCausalLM" + ): + self.hf_config.architectures[0] = "DeepseekV4ForCausalLMNextN" + # TODO: Tmp hardcode to 1 since num_nextn_predict_layers is not in config.json + self.hf_config.num_nextn_predict_layers = 1 + if is_draft_model and self.hf_config.architectures[0] in [ "Glm4MoeForCausalLM", "Glm4MoeLiteForCausalLM", @@ -318,7 +333,19 @@ def _derive_hybrid_model(self): and not self.disable_hybrid_swa_memory ) - if self.is_hybrid_swa: + if not self.is_hybrid_swa: + return + + logger.info(f"Hybrid swa model: {self.hf_config.architectures=}") + + # FIXME: distinguish Compressed Attention SWA from Mimo's SWA compress + self.is_swa_with_compressed_attention = any( + arch in ["DeepseekV4ForCausalLM", "DeepseekV4ForCausalLMNextN"] + for arch in self.hf_config.architectures + ) + + if self.is_hybrid_swa and not self.is_swa_with_compressed_attention: + # NOTE: hybrid swa with compressed attention does not need to get layer ids self.swa_attention_layer_ids, self.full_attention_layer_ids = ( get_hybrid_layer_ids( self.hf_config.architectures, @@ -419,6 +446,31 @@ def _derive_model_shapes(self): scaling_factor = self.hf_config.rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + elif ( + "DeepseekV4ForCausalLM" in self.hf_config.architectures + or "DeepseekV4ForCausalLMNextN" in self.hf_config.architectures + ): + self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim + if envs.SGLANG_DSV4_MODE.get() == "2604": + self.qk_nope_head_dim = self.hf_config.head_dim - self.qk_rope_head_dim + self.window_size = self.hf_config.sliding_window + else: + self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim + self.window_size = self.hf_config.window_size + self.head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + if envs.SGLANG_DSV4_MODE.get() == "2604": + self.v_head_dim = self.head_dim + self.index_head_dim = self.hf_config.index_head_dim + self.compress_ratios = self.hf_config.compress_ratios + self.attention_arch = AttentionArch.MHA + self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim) + if self.hf_config.rope_scaling: + mscale_all_dim = self.hf_config.rope_scaling.get( + "mscale_all_dim", False + ) + scaling_factor = self.hf_config.rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale elif "MiniCPM3ForCausalLM" in self.hf_config.architectures: self.head_dim = 128 @@ -483,6 +535,14 @@ def _derive_model_shapes(self): if self.num_key_value_heads is None: self.num_key_value_heads = self.num_attention_heads self.hidden_size = self.hf_text_config.hidden_size + hc_mult = getattr(self.hf_text_config, "hc_mult", 1) + self.spec_hidden_size = ( + self.hidden_size * hc_mult + if hc_mult > 1 + and envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + else self.hidden_size + ) self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.num_attention_layers = self.num_hidden_layers if "LongcatFlashForCausalLM" in self.hf_config.architectures: @@ -723,10 +783,11 @@ def _get_modelopt_quant_type(self) -> str: return "fp8" # Default fallback def _get_sliding_window_size(self) -> Optional[int]: - sliding_window_size = getattr(self.hf_text_config, "sliding_window_size", None) - if sliding_window_size is None: - sliding_window_size = getattr(self.hf_text_config, "sliding_window", None) - return sliding_window_size + key_list = ["sliding_window_size", "sliding_window", "window_size"] + for key in key_list: + if hasattr(self.hf_text_config, key): + return getattr(self.hf_text_config, key) + return None def _validate_quantize_and_serve_config(self): """Validate quantize_and_serve configuration.""" @@ -1230,6 +1291,8 @@ def is_hybrid_swa_model(model_architectures: List[str]): hybrid_swa_archs = { "Llama4ForConditionalGeneration", + "DeepseekV4ForCausalLM", + "DeepseekV4ForCausalLMNextN", "GptOssForCausalLM", "MiMoV2FlashForCausalLM", "MiMoV2MTP", diff --git a/python/sglang/srt/debug_utils/deepseek_v4_debug_utils.py b/python/sglang/srt/debug_utils/deepseek_v4_debug_utils.py new file mode 100644 index 000000000000..1f734b945c09 --- /dev/null +++ b/python/sglang/srt/debug_utils/deepseek_v4_debug_utils.py @@ -0,0 +1,6 @@ +class _MoECodePathChecker: + def __init__(self): + self.observed = 0 + + +deepseek_v4_moe_code_path_checker = _MoECodePathChecker() diff --git a/python/sglang/srt/debug_utils/dump_comparator.py b/python/sglang/srt/debug_utils/dump_comparator.py index ed84e13c5b3d..0537c20f1a7c 100644 --- a/python/sglang/srt/debug_utils/dump_comparator.py +++ b/python/sglang/srt/debug_utils/dump_comparator.py @@ -1,71 +1,63 @@ +"""Simplified dump comparator — a self-contained single-file script for comparing +two dump directories tensor-by-tensor. + +For advanced features (unshard, token alignment, per-dimension annotations), see the +full ``comparator/`` package: ``python -m sglang.srt.debug_utils.comparator``. +""" + import argparse import functools +import os import re from dataclasses import dataclass from pathlib import Path -from typing import Callable, Dict, List, Optional +from typing import Callable, List, Optional -import einops -import polars as pl import torch -from sglang.srt.debug_utils.dump_loader import find_row, read_meta from sglang.srt.debug_utils.dumper import get_truncated_value def main(args): + import polars as pl + + from sglang.srt.debug_utils.dump_loader import find_row, read_meta + df_target = read_meta(args.target_path) df_target = df_target.filter( - (pl.col("forward_pass_id") >= args.start_id) - & (pl.col("forward_pass_id") <= args.end_id) + (pl.col("step") >= args.start_step) & (pl.col("step") <= args.end_step) ) if args.filter: df_target = df_target.filter(pl.col("filename").str.contains(args.filter)) - assert all( - c in df_target.columns - for c in ["rank", "forward_pass_id", "dump_index", "name"] - ) + assert all(c in df_target.columns for c in ["rank", "step", "dump_index", "name"]) df_baseline = read_meta(args.baseline_path) print("df_target", df_target) print("df_baseline", df_baseline) - location_info_of_target_pass_id = _get_location_info_of_target_pass_id() - tensor_dim_descs = _get_tensor_dim_descs() + tensor_dim_descs: List[TensorDimDesc] = _get_tensor_dim_descs() for row in df_target.iter_rows(named=True): path_target = Path(args.target_path) / row["filename"] - if location_info_of_target_pass_id is not None: - location_info = location_info_of_target_pass_id.get(row["forward_pass_id"]) - if location_info is None: - continue - baseline_forward_pass_id = location_info.baseline_forward_pass_id - baseline_token_slice = location_info.baseline_token_slice - else: - baseline_forward_pass_id = ( - row["forward_pass_id"] - args.start_id + args.baseline_start_id - ) - baseline_token_slice = None - - tensor_dim_desc = None - if tensor_dim_descs is not None: - tensor_dim_descs_filtered = [ + tensor_dim_desc: Optional[TensorDimDesc] = None + if tensor_dim_descs: + matched: list[TensorDimDesc] = [ desc for desc in tensor_dim_descs - if re.search(desc["pattern"], row["filename"]) is not None + if re.search(desc.pattern, row["filename"]) is not None ] - if tensor_dim_descs_filtered: - tensor_dim_desc = tensor_dim_descs_filtered[0] + if matched: + tensor_dim_desc = matched[0] row_baseline = find_row( df_baseline, conditions=dict( - forward_pass_id=baseline_forward_pass_id, + step=row["step"], **{ k: v for k, v in row.items() - if k not in ["forward_pass_id", "dump_index", "filename"] + if k not in ["step", "dump_index", "filename"] }, ), ) @@ -88,27 +80,16 @@ def main(args): path_target=path_target, diff_threshold=args.diff_threshold, name=row["name"], - baseline_token_slice=baseline_token_slice, tensor_dim_desc=tensor_dim_desc, ) print() -def _split_einops_pattern(pattern): - return re.findall(r"\([^()]*\)|\S+", pattern) - - -def _get_einops_dim_index(pattern: str, dim_name: str): - pattern_list = _split_einops_pattern(pattern) - return pattern_list.index(dim_name) - - def check_tensor_pair( path_baseline, path_target, diff_threshold: float = 1e-3, name="", - baseline_token_slice=None, tensor_dim_desc: Optional["TensorDimDesc"] = None, ): x_baseline = _load_object(path_baseline) @@ -127,18 +108,15 @@ def check_tensor_pair( ) if tensor_dim_desc is not None: - if (s := baseline_token_slice) is not None: - dim = _get_einops_dim_index(tensor_dim_desc.baseline_desc, "num_tokens") - x_baseline = x_baseline.narrow( - dim=dim, start=s.start, length=s.stop - s.start - ) + import einops + x_baseline = einops.rearrange( x_baseline, tensor_dim_desc.baseline_desc + " -> " + tensor_dim_desc.target_desc, ) - if (f := tensor_dim_desc.baseline_cropper) is not None: + if tensor_dim_desc.baseline_cropper is not None: print("Apply baseline_cropper") - x_baseline = f(x_baseline) + x_baseline = tensor_dim_desc.baseline_cropper(x_baseline) x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name) x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape) @@ -213,20 +191,20 @@ def _compute_and_print_diff( ): raw_abs_diff = (x_target - x_baseline).abs() + if raw_abs_diff.numel() == 0: + print(prefix_text + "⚠️ Empty tensor, skipping diff computation") + return dict(max_abs_diff=0.0) + max_abs_diff = raw_abs_diff.max().item() mean_abs_diff = raw_abs_diff.mean().item() rel_diff = _calc_rel_diff(x_target, x_baseline) + rel_diff_marker: str = "❌" if rel_diff > diff_threshold else "✅" print( prefix_text - + "\t".join( - f"{'❌' if value > diff_threshold else '✅'} {name}={value}" - for name, value in [ - ("rel_diff", rel_diff), - ("max_abs_diff", max_abs_diff), - ("mean_abs_diff", mean_abs_diff), - ] - ) + + f"{rel_diff_marker} rel_diff={rel_diff}\t" + + f"max_abs_diff={max_abs_diff}\t" + + f"mean_abs_diff={mean_abs_diff}" ) max_diff_coord = _argmax_coord(raw_abs_diff) @@ -280,38 +258,66 @@ def _load_object(path): print(f"Skip load {path} since error {e}") return None + if isinstance(x, dict) and "value" in x: + x = x["value"] + if not isinstance(x, torch.Tensor): print(f"Skip load {path} since {type(x)=} is not a Tensor ({x=})") return None return x.cuda() -# TODO may make customization endpoints configurable via args pointing to code file def _comparison_preprocessor(x_baseline, x_target, name): """Customization endpoint. Can insert arbitrary adhoc postprocessing logic here.""" - return x_baseline, x_target - - -@dataclass -class LocationInfo: - baseline_forward_pass_id: int - baseline_token_slice: slice - + if bool(int(os.environ.get("SGLANG_HACK_TRUNCATE_DIM0", "0"))): + if ( + x_baseline.dim() >= 1 + and x_target.dim() >= 1 + and x_baseline.shape[0] != x_target.shape[0] + and x_baseline.shape[1:] == x_target.shape[1:] + ): + n = min(x_baseline.shape[0], x_target.shape[0]) + print( + f"[HACK] Truncating dim0: {x_baseline.shape[0]} vs {x_target.shape[0]} -> {n}" + ) + x_baseline = x_baseline[:n] + x_target = x_target[:n] + + # CP round-robin stride-select: baseline has FULL tokens [N], + # target has LOCAL tokens [N/cp_size] from round-robin stride. + # Stride-select baseline[cp_rank::cp_size] to align with target. + # Usage: SGLANG_HACK_CP_ROUND_ROBIN_STRIDE="0:4" (rank:size) + cp_stride_spec = os.environ.get("SGLANG_HACK_CP_ROUND_ROBIN_STRIDE", "") + if cp_stride_spec: + cp_rank, cp_size = (int(x) for x in cp_stride_spec.split(":")) + if ( + x_baseline.dim() >= 1 + and x_target.dim() >= 1 + and x_baseline.shape[1:] == x_target.shape[1:] + and x_baseline.shape[0] != x_target.shape[0] + and x_baseline.shape[0] == x_target.shape[0] * cp_size + ): + print( + f"[HACK] CP round-robin stride-select: " + f"baseline[{cp_rank}::{cp_size}] " + f"({x_baseline.shape[0]} -> {x_target.shape[0]})" + ) + x_baseline = x_baseline[cp_rank::cp_size] -def _get_location_info_of_target_pass_id() -> Optional[Dict[int, LocationInfo]]: - """Customization endpoint.""" - return None + return x_baseline, x_target @dataclass class TensorDimDesc: + pattern: str baseline_desc: str target_desc: str - baseline_cropper: Optional[Callable[[torch.Tensor], torch.Tensor]] + baseline_cropper: Optional[Callable[[torch.Tensor], torch.Tensor]] = None def _get_tensor_dim_descs() -> List[TensorDimDesc]: - """Customization endpoint.""" + """Customization endpoint. Return a list of TensorDimDesc to rearrange baseline + dimensions to match target layout via einops before comparison.""" return [] @@ -320,9 +326,8 @@ def _get_tensor_dim_descs() -> List[TensorDimDesc]: parser = argparse.ArgumentParser() parser.add_argument("--baseline-path", type=str) parser.add_argument("--target-path", type=str) - parser.add_argument("--start-id", type=int, default=0) - parser.add_argument("--end-id", type=int, default=1000000) - parser.add_argument("--baseline-start-id", type=int, default=0) + parser.add_argument("--start-step", type=int, default=0) + parser.add_argument("--end-step", type=int, default=1000000) parser.add_argument("--diff-threshold", type=float, default=1e-3) parser.add_argument( "--filter", type=str, default=None, help="Regex to filter filenames" diff --git a/python/sglang/srt/debug_utils/dump_loader.py b/python/sglang/srt/debug_utils/dump_loader.py index e798e815d6bf..f35a455c2c99 100644 --- a/python/sglang/srt/debug_utils/dump_loader.py +++ b/python/sglang/srt/debug_utils/dump_loader.py @@ -1,11 +1,60 @@ import functools import os +from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict +from typing import Any, Callable, Dict, Optional, Tuple import polars as pl import torch +LOAD_FAILED: object = object() + + +def parse_meta_from_filename(path: Path) -> Dict[str, Any]: + stem = Path(path).stem + result: Dict[str, Any] = {} + for kv in stem.split("___"): + if "=" in kv: + k, v = kv.split("=", 1) + result[k] = v + for field_name, converter in _TYPED_FIELDS: + if field_name in result: + result[field_name] = converter(result[field_name]) + return result + + +@dataclass +class ValueWithMeta: + value: Any + meta: Dict[str, Any] + + @staticmethod + def load(path: Path) -> "ValueWithMeta": + path = Path(path) + meta_from_filename = parse_meta_from_filename(path) + + try: + raw = torch.load(path, weights_only=False, map_location="cpu") + except Exception as e: + print(f"Skip load {path} since error {e}") + return ValueWithMeta( + value=LOAD_FAILED, meta={**meta_from_filename, "filename": path.name} + ) + + value, meta_from_embedded = _unwrap_dict_format(raw) + return ValueWithMeta( + value=value, + meta={**meta_from_filename, **meta_from_embedded, "filename": path.name}, + ) + + +def _unwrap_dict_format(obj: Any) -> Tuple[Any, Dict[str, Any]]: + if isinstance(obj, dict) and "value" in obj: + meta = obj.get("meta", {}) + assert isinstance(meta, dict), f"Expected meta to be dict, got {type(meta)}" + return obj["value"], meta + return obj, {} + class DumpLoader: def __init__(self): @@ -25,8 +74,8 @@ def load(self, name, **kwargs): from sglang.srt.debug_utils.dumper import dumper - forward_pass_id = dumper._forward_pass_id - conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs) + step = dumper._state.step + conditions = dict(name=name, step=step, **kwargs) row = find_row(self._df, conditions=conditions) assert ( row is not None @@ -34,6 +83,8 @@ def load(self, name, **kwargs): path = self._directory / row["filename"] output = torch.load(path, weights_only=False) + if isinstance(output, dict) and "value" in output: + output = output["value"] print( f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})" @@ -48,10 +99,7 @@ def read_meta(directory): rows = [] for p in directory.glob("*.pt"): try: - full_kwargs = {} - for kv in p.stem.split("___"): - k, v = kv.split("=") - full_kwargs[k] = v + full_kwargs = parse_meta_from_filename(p) rows.append( { "filename": str(p.name), @@ -63,7 +111,7 @@ def read_meta(directory): df = pl.DataFrame(rows) df = df.with_columns( - pl.col("forward_pass_id").cast(int), + pl.col("step").cast(int), pl.col("rank").cast(int), pl.col("dump_index").cast(int), ) @@ -81,26 +129,27 @@ def _add_duplicate_index(df: pl.DataFrame) -> pl.DataFrame: return df -def find_row(df, conditions: Dict[str, Any]): - df_sub = df.filter( - functools.reduce( - lambda a, b: a & b, - [ - ( - pl.col(col) - == _cast_to_polars_dtype(conditions[col], df.schema[col]) - if conditions[col] is not None - else pl.col(col).is_null() - ) - for col in conditions.keys() - if col in df.columns - ], +def filter_rows(df: pl.DataFrame, conditions: Dict[str, Any]) -> list[dict]: + filter_exprs = [ + ( + pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col]) + if conditions[col] is not None + else pl.col(col).is_null() ) - ) - if len(df_sub) > 1: - print(f"find_row find ambiguous results: {df_sub=}") + for col in conditions + if col in df.columns + ] + if not filter_exprs: + return [] + return df.filter(functools.reduce(lambda a, b: a & b, filter_exprs)).to_dicts() + + +def find_row(df: pl.DataFrame, conditions: Dict[str, Any]): + rows = filter_rows(df, conditions) + if len(rows) > 1: + print(f"find_row find ambiguous results: {rows=}") return None - return df_sub.to_dicts()[0] if len(df_sub) > 0 else None + return rows[0] if rows else None def _cast_to_polars_dtype(value, target_dtype): @@ -116,4 +165,19 @@ def _cast_to_polars_dtype(value, target_dtype): return value +def read_tokenizer_path(directory: Path) -> Optional[str]: + """Read tokenizer_path from any .pt file's embedded metadata in a dump directory.""" + for p in directory.glob("*.pt"): + item: ValueWithMeta = ValueWithMeta.load(p) + tokenizer_path: Optional[str] = item.meta.get("tokenizer_path") + if tokenizer_path is not None: + return str(tokenizer_path) + return None + + +_TYPED_FIELDS: list[tuple[str, Callable[[str], Any]]] = [ + ("rank", int), +] + + dump_loader = DumpLoader() diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 26df1bcb7d48..6dc012368f55 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -52,6 +52,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.common import release_kv_cache +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import ( HybridLinearKVPool, HybridReqToTokenPool, @@ -280,6 +281,17 @@ def _init_kv_manager(self) -> BaseKVManager: self.metadata_buffers.get_buf_infos() ) + if isinstance(self.token_to_kv_pool, DeepSeekV4TokenToKVPool): + from sglang.srt.environ import envs + + assert ( + envs.SGLANG_OPT_DPSK_V4_RADIX.get() + ), "V4 PD disaggregation requires radix mode (SGLANG_OPT_DPSK_V4_RADIX=1)" + assert self.prefill_pp_size == 1, ( + "V4 PD disaggregation requires PP=1 " + "(get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized flat list)" + ) + if hasattr(self.token_to_kv_pool, "get_state_buf_infos"): state_data_ptrs, state_data_lens, state_item_lens = ( self.token_to_kv_pool.get_state_buf_infos() @@ -288,7 +300,7 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.state_data_lens = state_data_lens kv_args.state_item_lens = state_item_lens - if isinstance(self.token_to_kv_pool, SWAKVPool): + if isinstance(self.token_to_kv_pool, (SWAKVPool, DeepSeekV4TokenToKVPool)): kv_args.state_type = "swa" elif isinstance(self.token_to_kv_pool, HybridLinearKVPool): kv_args.state_type = "mamba" @@ -541,8 +553,10 @@ def pop_preallocated( .cpu() .numpy() ] - elif isinstance(self.token_to_kv_pool, SWAKVPool): - # SWA hybrid model: send decode-side SWA window indices + elif isinstance( + self.token_to_kv_pool, (SWAKVPool, DeepSeekV4TokenToKVPool) + ): + # SWA / V4 hybrid model: send decode-side SWA window indices seq_len = len(decode_req.req.origin_input_ids) window_size = self.scheduler.sliding_window_size diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 35762a7446dd..417bfb1f1744 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -49,6 +49,7 @@ ScheduleBatch, ) from sglang.srt.mem_cache.common import release_kv_cache +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, NSATokenToKVPool from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end @@ -149,6 +150,20 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.gpu_id = self.scheduler.gpu_id + if isinstance(self.token_to_kv_pool, DeepSeekV4TokenToKVPool): + from sglang.srt.environ import envs + + assert ( + envs.SGLANG_OPT_DPSK_V4_RADIX.get() + ), "V4 PD disaggregation requires radix mode (SGLANG_OPT_DPSK_V4_RADIX=1)" + assert self.pp_size == 1, ( + "V4 PD disaggregation requires PP=1 " + "(get_mla_kv_ptrs_with_pp cannot slice V4's buffer-type-organized flat list)" + ) + assert ( + self.decode_tp_size == self.scheduler.tp_size + ), "V4 PD disaggregation requires same TP size on prefill and decode" + if hasattr(self.token_to_kv_pool, "get_state_buf_infos"): state_data_ptrs, state_data_lens, state_item_lens = ( self.token_to_kv_pool.get_state_buf_infos() @@ -157,7 +172,7 @@ def _init_kv_manager(self) -> BaseKVManager: kv_args.state_data_lens = state_data_lens kv_args.state_item_lens = state_item_lens - if isinstance(self.token_to_kv_pool, SWAKVPool): + if isinstance(self.token_to_kv_pool, (SWAKVPool, DeepSeekV4TokenToKVPool)): kv_args.state_type = "swa" elif isinstance(self.token_to_kv_pool, HybridLinearKVPool): kv_args.state_type = "mamba" @@ -697,8 +712,11 @@ def send_kv_chunk( .cpu() .numpy() ] - elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool): - # SWA hybrid model: send last window KV indices + elif isinstance( + self.token_to_kv_pool_allocator.get_kvcache(), + (SWAKVPool, DeepSeekV4TokenToKVPool), + ): + # SWA / V4 hybrid model: send last window KV indices seq_len = len(req.fill_ids) window_size = self.sliding_window_size window_start = max(0, seq_len - window_size) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index d660172de587..8f3e4416d35f 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -340,9 +340,10 @@ def kv_to_page_num(num_kv_indices: int, page_size: int): def is_mla_backend(target_kv_pool) -> bool: + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool - return isinstance(target_kv_pool, MLATokenToKVPool) + return isinstance(target_kv_pool, (MLATokenToKVPool, DeepSeekV4TokenToKVPool)) def prepare_abort(req: Req, error_message: str, status_code=None): diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index e71f93ebc3b8..6fd822f140eb 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -243,6 +243,9 @@ def capture(self): try: self._IS_CAPTURING = True yield + except Exception as e: + print(e, flush=True) + raise e finally: self._IS_CAPTURING = False if not self.disabled: diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index afac1d03dace..b2129103006b 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -1596,6 +1596,7 @@ def _execute_server_warmup(server_args: ServerArgs): ) assert res.status_code == 200, f"{res.text}" _global_state.tokenizer_manager.server_status = ServerStatus.Up + print(f"Warmup request successful: {res.json()}", flush=True) # TODO: remove after debugging else: logger.info(f"Start of pd disaggregation warmup ...") diff --git a/python/sglang/srt/entrypoints/openai/encoding_dsv4.py b/python/sglang/srt/entrypoints/openai/encoding_dsv4.py new file mode 100644 index 000000000000..e9f4b02d5c9a --- /dev/null +++ b/python/sglang/srt/entrypoints/openai/encoding_dsv4.py @@ -0,0 +1,840 @@ +# Adapted from the DeepSeek-V4 release reference implementation. +""" +DeepSeek-V4 Encoding + +A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages +with tool calling, thinking mode, and quick instruction task support. +""" + +import copy +import json +import re +from typing import Any, Dict, List, Optional, Tuple, Union + +# ============================================================ +# Special Tokens +# ============================================================ + +bos_token: str = "<|begin▁of▁sentence|>" +eos_token: str = "<|end▁of▁sentence|>" +thinking_start_token: str = "" +thinking_end_token: str = "" +dsml_token: str = "|DSML|" + +USER_SP_TOKEN = "<|User|>" +ASSISTANT_SP_TOKEN = "<|Assistant|>" +LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>" + +# Task special tokens for internal classification tasks +DS_TASK_SP_TOKENS = { + "action": "<|action|>", + "query": "<|query|>", + "authority": "<|authority|>", + "domain": "<|domain|>", + "title": "<|title|>", + "read_url": "<|read_url|>", +} +VALID_TASKS = set(DS_TASK_SP_TOKENS.keys()) + +# ============================================================ +# Templates +# ============================================================ + +system_msg_template: str = "{content}" +user_msg_template: str = "{content}" +latest_reminder_msg_template: str = "{content}" +assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token +assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}" +thinking_template: str = "{reasoning_content}" + +response_format_template: str = ( + "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}" +) +tool_call_template: str = ( + '<{dsml_token}invoke name="{name}">\n{arguments}\n' +) +tool_calls_template = ( + "<{dsml_token}{tc_block_name}>\n{tool_calls}\n" +) +tool_calls_block_name: str = "tool_calls" + +tool_output_template: str = "{content}" + +REASONING_EFFORT_MAX = ( + "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n" + "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n" + "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n" +) + +TOOLS_TEMPLATE = """## Tools + +You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following: + +<{dsml_token}tool_calls> +<{dsml_token}invoke name="$TOOL_NAME"> +<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE +... + +<{dsml_token}invoke name="$TOOL_NAME2"> +... + + + +String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`. + +If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response. + +Otherwise, output directly after {thinking_end_token} with tool calls or final response. + +### Available Tool Schemas + +{tool_schemas} + +You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls. +""" + +# ============================================================ +# Utility Functions +# ============================================================ + + +def to_json(value: Any) -> str: + """Serialize a value to JSON string.""" + try: + return json.dumps(value, ensure_ascii=False) + except: + return json.dumps(value, ensure_ascii=True) + + +def tools_from_openai_format(tools): + """Extract function definitions from OpenAI-format tool list.""" + return [tool["function"] for tool in tools] + + +def tool_calls_from_openai_format(tool_calls): + """Convert OpenAI-format tool calls to internal format.""" + return [ + { + "name": tool_call["function"]["name"], + "arguments": tool_call["function"]["arguments"], + } + for tool_call in tool_calls + ] + + +def tool_calls_to_openai_format(tool_calls): + """Convert internal tool calls to OpenAI format.""" + return [ + { + "type": "function", + "function": { + "name": tool_call["name"], + "arguments": tool_call["arguments"], + }, + } + for tool_call in tool_calls + ] + + +def encode_arguments_to_dsml(tool_call: Dict[str, str]) -> str: + """ + Encode tool call arguments into DSML parameter format. + + Args: + tool_call: Dict with "name" and "arguments" (JSON string) keys. + + Returns: + DSML-formatted parameter string. + """ + p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' + P_dsml_strs = [] + + try: + arguments = json.loads(tool_call["arguments"]) + except Exception as err: + arguments = {"arguments": tool_call["arguments"]} + + for k, v in arguments.items(): + p_dsml_str = p_dsml_template.format( + dsml_token=dsml_token, + key=k, + is_str="true" if isinstance(v, str) else "false", + value=v if isinstance(v, str) else to_json(v), + ) + P_dsml_strs.append(p_dsml_str) + + return "\n".join(P_dsml_strs) + + +def decode_dsml_to_arguments( + tool_name: str, tool_args: Dict[str, Tuple[str, str]] +) -> Dict[str, str]: + """ + Decode DSML parameters back to a tool call dict. + + Args: + tool_name: Name of the tool. + tool_args: Dict mapping param_name -> (value, is_string_flag). + + Returns: + Dict with "name" and "arguments" (JSON string) keys. + """ + + def _decode_value(key: str, value: str, string: str): + if string == "true": + value = to_json(value) + return f"{to_json(key)}: {value}" + + tool_args_json = ( + "{" + + ", ".join( + [_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()] + ) + + "}" + ) + return dict(name=tool_name, arguments=tool_args_json) + + +def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str: + """ + Render tool schemas into the system prompt format. + + Args: + tools: List of tool schema dicts (each with name, description, parameters). + + Returns: + Formatted tools section string. + """ + tools_json = [to_json(t) for t in tools] + + return TOOLS_TEMPLATE.format( + tool_schemas="\n".join(tools_json), + dsml_token=dsml_token, + thinking_start_token=thinking_start_token, + thinking_end_token=thinking_end_token, + ) + + +def find_last_user_index(messages: List[Dict[str, Any]]) -> int: + """Find the index of the last user/developer message.""" + last_user_index = -1 + for idx in range(len(messages) - 1, -1, -1): + if messages[idx].get("role") in ["user", "developer"]: + last_user_index = idx + break + return last_user_index + + +# ============================================================ +# Message Rendering +# ============================================================ + + +def render_message( + index: int, + messages: List[Dict[str, Any]], + thinking_mode: str, + drop_thinking: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + """ + Render a single message at the given index into its encoded string form. + + This is the core function that converts each message in the conversation + into the DeepSeek-V4 format. + + Args: + index: Index of the message to render. + messages: Full list of messages in the conversation. + thinking_mode: Either "chat" or "thinking". + drop_thinking: Whether to drop reasoning content from earlier turns. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + Encoded string for this message. + """ + assert 0 <= index < len(messages) + assert thinking_mode in [ + "chat", + "thinking", + ], f"Invalid thinking_mode `{thinking_mode}`" + + prompt = "" + msg = messages[index] + last_user_idx = find_last_user_index(messages) + + role = msg.get("role") + content = msg.get("content") + tools = msg.get("tools") + response_format = msg.get("response_format") + tool_calls = msg.get("tool_calls") + reasoning_content = msg.get("reasoning_content") + wo_eos = msg.get("wo_eos", False) + + if tools: + tools = tools_from_openai_format(tools) + if tool_calls: + tool_calls = tool_calls_from_openai_format(tool_calls) + + # Reasoning effort prefix (only at index 0 in thinking mode with max effort) + assert reasoning_effort in [ + "max", + None, + "high", + ], f"Invalid reasoning effort: {reasoning_effort}" + if index == 0 and thinking_mode == "thinking" and reasoning_effort == "max": + prompt += REASONING_EFFORT_MAX + + if role == "system": + prompt += system_msg_template.format(content=content or "") + if tools: + prompt += "\n\n" + render_tools(tools) + if response_format: + prompt += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + elif role == "developer": + assert content, f"Invalid message for role `{role}`: {msg}" + + content_developer = USER_SP_TOKEN + content_developer += content + + if tools: + content_developer += "\n\n" + render_tools(tools) + if response_format: + content_developer += "\n\n" + response_format_template.format( + schema=to_json(response_format) + ) + + prompt += user_msg_template.format(content=content_developer) + + elif role == "user": + prompt += USER_SP_TOKEN + + # Handle content blocks (tool results mixed with text) + content_blocks = msg.get("content_blocks") + if content_blocks: + parts = [] + for block in content_blocks: + block_type = block.get("type") + if block_type == "text": + parts.append(block.get("text", "")) + elif block_type == "tool_result": + tool_content = block.get("content", "") + if isinstance(tool_content, list): + text_parts = [] + for b in tool_content: + if b.get("type") == "text": + text_parts.append(b.get("text", "")) + else: + text_parts.append(f"[Unsupported {b.get('type')}]") + tool_content = "\n\n".join(text_parts) + parts.append(tool_output_template.format(content=tool_content)) + else: + parts.append(f"[Unsupported {block_type}]") + prompt += "\n\n".join(parts) + else: + prompt += content or "" + + elif role == "latest_reminder": + prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format( + content=content + ) + + elif role == "tool": + raise NotImplementedError( + "deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()" + ) + + elif role == "assistant": + thinking_part = "" + tc_content = "" + + if tool_calls: + tc_list = [ + tool_call_template.format( + dsml_token=dsml_token, + name=tc.get("name"), + arguments=encode_arguments_to_dsml(tc), + ) + for tc in tool_calls + ] + tc_content += "\n\n" + tool_calls_template.format( + dsml_token=dsml_token, + tool_calls="\n".join(tc_list), + tc_block_name=tool_calls_block_name, + ) + + summary_content = content or "" + rc = reasoning_content or "" + + # Check if previous message has a task - if so, this is a task output (no thinking) + prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None + + if thinking_mode == "thinking" and not prev_has_task: + if not drop_thinking or index > last_user_idx: + thinking_part = ( + thinking_template.format(reasoning_content=rc) + thinking_end_token + ) + else: + thinking_part = "" + + if wo_eos: + prompt += assistant_msg_wo_eos_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + prompt += assistant_msg_template.format( + reasoning=thinking_part, + content=summary_content, + tool_calls=tc_content, + ) + else: + raise NotImplementedError(f"Unknown role: {role}") + + # Append transition tokens based on what follows + if index + 1 < len(messages) and messages[index + 1].get("role") not in [ + "assistant", + "latest_reminder", + ]: + return prompt + + task = messages[index].get("task") + if task is not None: + # Task special token for internal classification tasks + assert ( + task in VALID_TASKS + ), f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}" + task_sp_token = DS_TASK_SP_TOKENS[task] + + if task != "action": + # Non-action tasks: append task sp token directly after the message + prompt += task_sp_token + else: + # Action task: append Assistant + thinking token + action sp token + prompt += ASSISTANT_SP_TOKEN + prompt += ( + thinking_end_token + if thinking_mode != "thinking" + else thinking_start_token + ) + prompt += task_sp_token + + elif messages[index].get("role") in ["user", "developer"]: + # Normal generation: append Assistant + thinking token + prompt += ASSISTANT_SP_TOKEN + if not drop_thinking and thinking_mode == "thinking": + prompt += thinking_start_token + elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx: + prompt += thinking_start_token + else: + prompt += thinking_end_token + + return prompt + + +# ============================================================ +# Preprocessing +# ============================================================ + + +def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merge tool messages into the preceding user message using content_blocks format. + + DeepSeek-V4 does not have a standalone "tool" role; instead, tool results + are encoded as blocks within user messages. + + This function converts a standard OpenAI-format conversation (with separate + "tool" role messages) into V4 format where tool results are merged into + user messages. + + Args: + messages: List of message dicts in OpenAI format. + + Returns: + Processed message list with tool messages merged into user messages. + """ + merged: List[Dict[str, Any]] = [] + + for msg in messages: + msg = copy.deepcopy(msg) + role = msg.get("role") + + if role == "tool": + # Convert tool message to a user message with tool_result block + tool_block = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + "content": msg.get("content", ""), + } + # Merge into previous message if it's already a user (merged tool) + if ( + merged + and merged[-1].get("role") == "user" + and "content_blocks" in merged[-1] + ): + merged[-1]["content_blocks"].append(tool_block) + else: + merged.append( + { + "role": "user", + "content_blocks": [tool_block], + } + ) + elif role == "user": + text_block = {"type": "text", "text": msg.get("content", "")} + if ( + merged + and merged[-1].get("role") == "user" + and "content_blocks" in merged[-1] + and merged[-1].get("task") is None + ): + merged[-1]["content_blocks"].append(text_block) + else: + new_msg = { + "role": "user", + "content": msg.get("content", ""), + "content_blocks": [text_block], + } + # Preserve extra fields (task, wo_eos, mask, etc.) + for key in ("task", "wo_eos", "mask"): + if key in msg: + new_msg[key] = msg[key] + merged.append(new_msg) + else: + merged.append(msg) + + return merged + + +def sort_tool_results_by_call_order( + messages: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: + """ + Sort tool_result blocks within user messages by the order of tool_calls + in the preceding assistant message. + + Args: + messages: Preprocessed message list (after merge_tool_messages). + + Returns: + Message list with sorted tool result blocks. + """ + last_tool_call_order: Dict[str, int] = {} + + for msg in messages: + role = msg.get("role") + if role == "assistant" and msg.get("tool_calls"): + last_tool_call_order = {} + for idx, tc in enumerate(msg["tool_calls"]): + tc_id = tc.get("id") or tc.get("function", {}).get("id", "") + if tc_id: + last_tool_call_order[tc_id] = idx + + elif role == "user" and msg.get("content_blocks"): + tool_blocks = [ + b for b in msg["content_blocks"] if b.get("type") == "tool_result" + ] + if len(tool_blocks) > 1 and last_tool_call_order: + sorted_blocks = sorted( + tool_blocks, + key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0), + ) + sorted_idx = 0 + new_blocks = [] + for block in msg["content_blocks"]: + if block.get("type") == "tool_result": + new_blocks.append(sorted_blocks[sorted_idx]) + sorted_idx += 1 + else: + new_blocks.append(block) + msg["content_blocks"] = new_blocks + + return messages + + +# ============================================================ +# Main Encoding Function +# ============================================================ + + +def encode_messages( + messages: List[Dict[str, Any]], + thinking_mode: str, + context: Optional[List[Dict[str, Any]]] = None, + drop_thinking: bool = True, + add_default_bos_token: bool = True, + reasoning_effort: Optional[str] = None, +) -> str: + """ + Encode a list of messages into the DeepSeek-V4 prompt format. + + This is the main entry point for encoding conversations. It handles: + - BOS token insertion + - Thinking mode with optional reasoning content dropping + - Tool message merging into user messages + - Multi-turn conversation context + + Args: + messages: List of message dicts to encode. + thinking_mode: Either "chat" or "thinking". + context: Optional preceding context messages (already encoded prefix). + drop_thinking: If True, drop reasoning_content from earlier assistant turns + (only keep reasoning for messages after the last user message). + add_default_bos_token: Whether to prepend BOS token at conversation start. + reasoning_effort: Optional reasoning effort level ("max", "high", or None). + + Returns: + The encoded prompt string. + """ + context = context if context else [] + + # Preprocess: merge tool messages and sort tool results + messages = merge_tool_messages(messages) + messages = sort_tool_results_by_call_order(context + messages)[len(context) :] + if context: + context = merge_tool_messages(context) + context = sort_tool_results_by_call_order(context) + + full_messages = context + messages + + prompt = bos_token if add_default_bos_token and len(context) == 0 else "" + + # Resolve drop_thinking: if any message has tools defined, don't drop thinking + effective_drop_thinking = drop_thinking + if any(m.get("tools") for m in full_messages): + effective_drop_thinking = False + + if thinking_mode == "thinking" and effective_drop_thinking: + full_messages = _drop_thinking_messages(full_messages) + # After dropping, recalculate how many messages to render + # (context may have shrunk too) + num_to_render = len(full_messages) - len(_drop_thinking_messages(context)) + context_len = len(full_messages) - num_to_render + else: + num_to_render = len(messages) + context_len = len(context) + + for idx in range(num_to_render): + prompt += render_message( + idx + context_len, + full_messages, + thinking_mode=thinking_mode, + drop_thinking=effective_drop_thinking, + reasoning_effort=reasoning_effort, + ) + + return prompt + + +def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Drop reasoning_content and non-essential messages before the last user message. + + Behavior: + - Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept. + - Messages at or after the last user index are always kept. + - Assistant messages before the last user get reasoning_content removed. + - Developer messages before the last user are dropped entirely. + """ + last_user_idx = find_last_user_index(messages) + result = [] + keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"} + + for idx, msg in enumerate(messages): + role = msg.get("role") + if role in keep_roles or idx >= last_user_idx: + result.append(msg) + elif role == "assistant": + msg = copy.copy(msg) + msg.pop("reasoning_content", None) + result.append(msg) + # developer and other roles before last_user_idx are dropped + + return result + + +# ============================================================ +# Parsing (Decoding model output) +# ============================================================ + + +def _read_until_stop( + index: int, text: str, stop: List[str] +) -> Tuple[int, str, Optional[str]]: + """ + Read text from index until one of the stop strings is found. + + Returns: + Tuple of (new_index, content_before_stop, matched_stop_string_or_None). + """ + min_pos = len(text) + matched_stop = None + + for s in stop: + pos = text.find(s, index) + if pos != -1 and pos < min_pos: + min_pos = pos + matched_stop = s + + if matched_stop: + content = text[index:min_pos] + return min_pos + len(matched_stop), content, matched_stop + else: + content = text[index:] + return len(text), content, None + + +def parse_tool_calls( + index: int, text: str +) -> Tuple[int, Optional[str], List[Dict[str, str]]]: + """ + Parse DSML tool calls from text starting at the given index. + + Args: + index: Starting position in text. + text: The full text to parse. + + Returns: + Tuple of (new_index, last_stop_token, list_of_tool_call_dicts). + Each tool call dict has "name" and "arguments" keys. + """ + tool_calls: List[Dict[str, Any]] = [] + stop_token = None + tool_calls_end_token = f"" + + while index < len(text): + index, _, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}invoke", tool_calls_end_token] + ) + if _ != ">\n": + raise ValueError(f"Tool call format error: expected '>\\n' but got '{_}'") + + if stop_token == tool_calls_end_token: + break + + if stop_token is None: + raise ValueError("Missing special token in tool calls") + + index, tool_name_content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n$', tool_name_content, flags=re.DOTALL + ) + if len(p_tool_name) != 1: + raise ValueError(f"Tool name format error: '{tool_name_content}'") + tool_name = p_tool_name[0] + + tool_args: Dict[str, Tuple[str, str]] = {} + while stop_token == f"<{dsml_token}parameter": + index, param_content, stop_token = _read_until_stop( + index, text, [f"/{dsml_token}parameter"] + ) + + param_kv = re.findall( + r'^ name="(.*?)" string="(true|false)">(.*?)<$', + param_content, + flags=re.DOTALL, + ) + if len(param_kv) != 1: + raise ValueError(f"Parameter format error: '{param_content}'") + param_name, string, param_value = param_kv[0] + + if param_name in tool_args: + raise ValueError(f"Duplicate parameter name: '{param_name}'") + tool_args[param_name] = (param_value, string) + + index, content, stop_token = _read_until_stop( + index, text, [f"<{dsml_token}parameter", f"\n": + raise ValueError( + f"Parameter format error: expected '>\\n' but got '{content}'" + ) + + tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args) + tool_calls.append(tool_call) + + return index, stop_token, tool_calls + + +def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]: + """ + Parse a model completion text into a structured assistant message. + + This function takes the raw text output from the model (a single assistant turn) + and extracts: + - reasoning_content (thinking block) + - content (summary/response) + - tool_calls (if any) + + NOTE: This function is designed to parse only correctly formatted strings and + will raise ValueError for malformed output. + + Args: + text: The raw completion text (including EOS token). + thinking_mode: Either "chat" or "thinking". + + Returns: + Dict with keys: "role", "content", "reasoning_content", "tool_calls". + tool_calls are in OpenAI format. + """ + summary_content, reasoning_content, tool_calls = "", "", [] + index, stop_token = 0, None + tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}" + + is_thinking = thinking_mode == "thinking" + is_tool_calling = False + + if is_thinking: + index, content_delta, stop_token = _read_until_stop( + index, text, [thinking_end_token, tool_calls_start_token] + ) + reasoning_content = content_delta + assert ( + stop_token == thinking_end_token + ), "Invalid thinking format: missing " + + index, content_delta, stop_token = _read_until_stop( + index, text, [eos_token, tool_calls_start_token] + ) + summary_content = content_delta + if stop_token == tool_calls_start_token: + is_tool_calling = True + else: + assert stop_token == eos_token, "Invalid format: missing EOS token" + + if is_tool_calling: + index, stop_token, tool_calls = parse_tool_calls(index, text) + + index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token]) + assert not tool_ends_text, "Unexpected content after tool calls" + + assert len(text) == index and stop_token in [ + eos_token, + None, + ], "Unexpected content at end" + + for sp_token in [ + bos_token, + eos_token, + thinking_start_token, + thinking_end_token, + dsml_token, + ]: + assert ( + sp_token not in summary_content and sp_token not in reasoning_content + ), f"Unexpected special token '{sp_token}' in content" + + return { + "role": "assistant", + "content": summary_content, + "reasoning_content": reasoning_content, + "tool_calls": tool_calls_to_openai_format(tool_calls), + } diff --git a/python/sglang/srt/entrypoints/openai/serving_chat.py b/python/sglang/srt/entrypoints/openai/serving_chat.py index 3b8773a551cb..5dc4e4c8533a 100644 --- a/python/sglang/srt/entrypoints/openai/serving_chat.py +++ b/python/sglang/srt/entrypoints/openai/serving_chat.py @@ -13,7 +13,7 @@ from fastapi.responses import ORJSONResponse, StreamingResponse from jsonschema import Draft202012Validator, SchemaError -from sglang.srt.entrypoints.openai.encoding_dsv32 import encode_messages +from sglang.srt.entrypoints.openai import encoding_dsv4, encoding_dsv32 from sglang.srt.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, @@ -41,6 +41,7 @@ process_routed_experts_from_ret, to_openai_style_logprobs, ) +from sglang.srt.environ import envs from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.function_call_parser import FunctionCallParser from sglang.srt.function_call.json_array_parser import JsonArrayParser @@ -112,7 +113,9 @@ def __init__( and self.tokenizer_manager.model_config.hf_config.model_type == "gpt_oss" ) - self.use_dpsk_v32_encoding = self._use_dpsk_v32_encoding() + # Which Python-based chat encoder (if any) bypasses apply_chat_template. + # Values: "dsv32", "dsv4", or None. + self.chat_encoding_spec = self._resolve_chat_encoding_spec() def _handle_last_assistant_message( self, @@ -170,14 +173,25 @@ def _append_assistant_prefix_to_prompt_ids( encoded = encoded[1:] return prompt_ids + encoded - def _use_dpsk_v32_encoding(self) -> bool: + def _resolve_chat_encoding_spec(self) -> Optional[str]: + if self.tool_call_parser == "deepseekv4": + return "dsv4" + if self.tool_call_parser == "deepseekv32": + return "dsv32" + + architectures = self.tokenizer_manager.model_config.hf_config.architectures + arch = architectures[0] if architectures else "" + + if "DeepseekV4" in arch: + return "dsv4" + has_chat_template = ( self.tokenizer_manager.tokenizer is not None and self.tokenizer_manager.tokenizer.chat_template is not None ) - architectures = self.tokenizer_manager.model_config.hf_config.architectures - is_dpsk_v32 = "DeepseekV3" in architectures[0] if architectures else False - return not has_chat_template and is_dpsk_v32 + if "DeepseekV3" in arch and not has_chat_template: + return "dsv32" + return None def _request_id_prefix(self) -> str: return "chatcmpl-" @@ -377,14 +391,14 @@ def _apply_jinja_template( template_content_format = self.template_manager.jinja_template_content_format - if self.use_dpsk_v32_encoding: - thinking_mode = ( - "thinking" - if (request.chat_template_kwargs or {}).get("thinking") - else "chat" + if self.chat_encoding_spec is not None: + # Per-request wins; env is fallback so existing + # `export SGLANG_ENABLE_THINKING=1` workflow keeps working here. + thinking_requested = (request.chat_template_kwargs or {}).get( + "thinking", envs.SGLANG_ENABLE_THINKING.get() ) - messages = request.messages - messages = [msg.model_dump() for msg in messages] + thinking_mode = "thinking" if thinking_requested else "chat" + messages = [msg.model_dump() for msg in request.messages] # Handle continue_final_message: separate final assistant message messages, assistant_prefix = self._handle_last_assistant_message( @@ -396,7 +410,28 @@ def _apply_jinja_template( messages.insert(0, {"role": "system", "content": ""}) if request.tools: messages[0]["tools"] = [tool.model_dump() for tool in request.tools] - real_input = encode_messages(messages, thinking_mode=thinking_mode) + + if self.chat_encoding_spec == "dsv4": + # V4 encoder only accepts "max" / "high" / None. + # OpenAI protocol defaults to "medium" which V4 rejects; drop it. + # Fallback: if request didn't set it, try env SGLANG_REASONING_EFFORT. + effort_source = request.reasoning_effort + if effort_source is None: + env_val = envs.SGLANG_REASONING_EFFORT.get() + if env_val: + effort_source = env_val + v4_reasoning_effort = ( + effort_source if effort_source in ("max", "high") else None + ) + real_input = encoding_dsv4.encode_messages( + messages, + thinking_mode=thinking_mode, + reasoning_effort=v4_reasoning_effort, + ) + else: + real_input = encoding_dsv32.encode_messages( + messages, thinking_mode=thinking_mode + ) prompt_ids = self.tokenizer_manager.tokenizer.encode(real_input) # Append assistant prefix if continue_final_message is enabled @@ -446,17 +481,16 @@ def _apply_jinja_template( ) try: + chat_template_kwargs = request.chat_template_kwargs or {} + if envs.SGLANG_ENABLE_THINKING.get(): + chat_template_kwargs["thinking"] = True prompt_ids = self.tokenizer_manager.tokenizer.apply_chat_template( openai_compatible_messages, tokenize=True, add_generation_prompt=True, tools=tools, reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), + **(chat_template_kwargs), return_dict=False, ) except Exception as e: @@ -475,11 +509,7 @@ def _apply_jinja_template( add_generation_prompt=True, tools=tools, reasoning_effort=request.reasoning_effort, - **( - request.chat_template_kwargs - if request.chat_template_kwargs - else {} - ), + **(chat_template_kwargs), return_dict=False, ) except jinja2.TemplateError as template_error: @@ -1194,7 +1224,7 @@ def _get_reasoning_from_request(self, request: ChatCompletionRequest) -> bool: """Judge whether the request needs reasoning""" if not self.reasoning_parser: return False - if self.reasoning_parser in ["deepseek-v3"]: + if self.reasoning_parser in ["deepseek-v3", "deepseek-v4"]: return ( request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 928ec998ee99..a091564e0656 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -337,6 +337,7 @@ class Envs: SGLANG_DG_CACHE_DIR = EnvStr(os.path.expanduser("~/.cache/deep_gemm")) SGLANG_DG_USE_NVRTC = EnvBool(False) SGLANG_USE_DEEPGEMM_BMM = EnvBool(False) + SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT = EnvBool(True) SGLANG_CHUNKED_PREFIX_CACHE_THRESHOLD = EnvInt(8192) # DeepEP @@ -454,11 +455,71 @@ class Envs: # TokenizerManager SGLANG_REQUEST_STATE_WAIT_TIMEOUT = EnvInt(4) + # Chat Template + SGLANG_ENABLE_THINKING = EnvBool(False) + # Default reasoning_effort for dsv4 chat encoder when request doesn't set it. + # Accepts "", "max", "high" (empty string means unset). Other values filtered to None. + SGLANG_REASONING_EFFORT = EnvStr("") + + # DeepSeek V4 + SGLANG_DSV4_MODE = EnvStr("2604") + SGLANG_DSV4_2604_SUBMODE = EnvStr("2604B") + SGLANG_DSV4_FP4_EXPERTS = EnvBool(True) # Set False when using FP4-to-FP8 converted checkpoint with 2604 config + SGLANG_OPT_DEEPGEMM_HC_PRENORM = EnvBool(True) + SGLANG_OPT_USE_TILELANG_MHC_PRE = EnvBool(True) + SGLANG_OPT_USE_TILELANG_MHC_POST = EnvBool(True) + SGLANG_OPT_USE_FUSED_COMPRESS = EnvBool(True) + SGLANG_HACK_FLASHMLA_BACKEND = EnvStr("kernel") + SGLANG_HACK_SKIP_FP4_FP8_GEMM = EnvBool(False) + SGLANG_OPT_FP8_WO_A_GEMM = EnvBool(False) + + + SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK = EnvBool(True) + SGLANG_OPT_USE_TILELANG_SWA_PREPARE = EnvBool(True) + SGLANG_OPT_USE_MULTI_STREAM_OVERLAP = EnvBool(True) + + SGLANG_OPT_DEBUG_PAGED_COMPRESS = EnvBool(False) + SGLANG_OPT_USE_FUSED_PAGED_COMPRESS = EnvBool(True) + SGLANG_FIX_MTP_HC_HIDDEN = EnvBool(True) + SGLANG_OPT_V4_DRAFT_EXTEND_CUDA_GRAPH = EnvBool(False) + SGLANG_OPT_TRITON_PREPARE_COMPRESS = EnvBool(True) + SGLANG_OPT_USE_FUSED_STORE_CACHE = EnvBool(True) + SGLANG_OPT_USE_OVERLAP_STORE_CACHE = EnvBool(True) + SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvStr("cublas") + SGLANG_OPT_USE_FUSED_HASH_TOPK = EnvBool(True) + SGLANG_OPT_USE_JIT_EP_ACTIVATION = EnvBool(True) + SGLANG_OPT_CACHE_SWA_TRANSLATION = EnvBool(True) + SGLANG_OPT_SWA_RADIX_CACHE_COMPACT = EnvBool(True) + SGLANG_OPT_USE_JIT_INDEXER_METADATA = EnvBool(False) + SGLANG_OPT_SWIGLU_CLAMP_FUSION = EnvBool(False) + SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE = EnvInt(-1) + SGLANG_DSV4_FIX_ATTN_PADDING = EnvBool(False) + + # Advanced CUDA Graph Capture to reduce the prepare overhead + SGLANG_ADVANCED_CUDA_GRAPH_CAPTURE = EnvBool(False) + + SGLANG_OPT_USE_TRITON_CA_METADATA = EnvBool(True) + + # dsv4 radix + SGLANG_OPT_DPSK_V4_RADIX = EnvBool(True) + # ds temp, for backward compatibility + SGLANG_OPT_USE_OLD_COMPRESSOR = EnvBool(False) + + # for AMD support + SGLANG_OPT_USE_TILELANG_INDEXER = EnvBool(False) + SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) + SGLANG_FP8_PAGED_MQA_LOGITS_TORCH = EnvBool(False) + # Symmetric Memory SGLANG_SYMM_MEM_PREALLOC_GB_SIZE = EnvInt(-1) # Aiter SGLANG_USE_AITER_FP8_PER_TOKEN = EnvBool(False) + # on HIP, force FP8 MoE off the aiter fused_moe path and onto the + # Triton runner. Skips aiter's weight shuffle during load AND bypasses the aiter + # dispatch at apply time, while leaving SGLANG_USE_AITER behavior intact for all + # other code paths (topk, non-MoE FP8, etc.). + SGLANG_FORCE_TRITON_MOE_FP8 = EnvBool(False) # fmt: on @@ -491,10 +552,6 @@ def _convert_SGL_to_SGLANG(): _print_deprecated_env( "SGLANG_MOE_NVFP4_DISPATCH", "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH" ) - _print_deprecated_env( - "SGLANG_ENABLE_TP_MEMORY_INBALANCE_CHECK", - "SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK", - ) for key, value in os.environ.items(): if key.startswith("SGL_"): diff --git a/python/sglang/srt/flashmla_tests/__init__.py b/python/sglang/srt/flashmla_tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/srt/flashmla_tests/kernelkit/.gitignore b/python/sglang/srt/flashmla_tests/kernelkit/.gitignore new file mode 100644 index 000000000000..42e7a8a6a5e9 --- /dev/null +++ b/python/sglang/srt/flashmla_tests/kernelkit/.gitignore @@ -0,0 +1,9 @@ +build +*.so +*.egg-info/ +__pycache__/ +dist/ +/.vscode +.cache +/temp +/profiles diff --git a/python/sglang/srt/flashmla_tests/kernelkit/__init__.py b/python/sglang/srt/flashmla_tests/kernelkit/__init__.py new file mode 100644 index 000000000000..43750b3bec72 --- /dev/null +++ b/python/sglang/srt/flashmla_tests/kernelkit/__init__.py @@ -0,0 +1,20 @@ +from . import bench, compare, generate, precision, utils +from .bench import bench_by_cuda_events, bench_kineto +from .compare import ( + check_is_allclose, + check_is_allclose_comparator, + check_is_bitwise_equal, + check_is_bitwise_equal_comparator, + get_cos_diff, +) +from .generate import ( + gen_non_contiguous_randn_tensor, + gen_non_contiguous_tensor, + non_contiguousify, +) +from .precision import ( + LowPrecisionMode, + is_low_precision_mode, + optional_cast_to_bf16_and_cast_back, +) +from .utils import Counter, cdiv, colors, is_using_profiling_tools, set_random_seed diff --git a/python/sglang/srt/flashmla_tests/kernelkit/bench.py b/python/sglang/srt/flashmla_tests/kernelkit/bench.py new file mode 100644 index 000000000000..b0d6ab1a2d32 --- /dev/null +++ b/python/sglang/srt/flashmla_tests/kernelkit/bench.py @@ -0,0 +1,290 @@ +import dataclasses +from typing import Callable, Dict, List, Tuple, Union, overload + +import torch +import triton + +from .utils import is_using_profiling_tools + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +@triton.jit +def profiler_range_start_marker_kernel(): + pass + + +def _run_profiler_range_start_marker_kernel(): + profiler_range_start_marker_kernel[(1,)]() + + +@dataclasses.dataclass +class BenchKinetoRawResult: + """ + A struct holding the result of `bench_kineto` + """ + + is_using_nsys: bool + num_tests: int + time_ranges: Dict[str, List[Tuple[float, float]]] + + def _get_matched_kernel_name( + self, + name_substr: str, + allow_no_match: bool = False, + allow_multiple_match: bool = False, + ) -> List[str]: + matched_names = [ + name for name in self.time_ranges.keys() if name_substr in name + ] + if not allow_no_match and len(matched_names) == 0: + all_kernel_names_str = "\n - " + "\n - ".join(self.time_ranges.keys()) + raise ValueError( + f"Error: No kernel name matched for substring {name_substr}.\nAvailable kernels are: {all_kernel_names_str}" + ) + if not allow_multiple_match and len(matched_names) > 1: + raise ValueError( + f"Error: Multiple kernel matched for substring {name_substr}: {', '.join(matched_names)}" + ) + return matched_names + + def get_kernel_names(self) -> List[str]: + return list(self.time_ranges.keys()) + + def get_kernel_times( + self, + kernel_names_substr: List[str], + allow_indivisible_run_count: bool = False, + allow_missing: bool = False, + allow_multiple_match: bool = False, + return_avg_individual_run: bool = False, + ) -> List[float]: + """ + Get the average each-run time usage of each kernel provided in `kernel_names` + + If return_avg_individual_run is False, return sum(time) / num_tests, else return sum(time) / len(time) + If is_using_profiling_tools (which is conflict with bench_kineto), return a series of 1 seconds + """ + if is_using_profiling_tools(): + return [1 for _ in range(len(kernel_names_substr))] + + result = [] + for substr in kernel_names_substr: + matched_names = self._get_matched_kernel_name( + substr, + allow_no_match=allow_missing, + allow_multiple_match=allow_multiple_match, + ) + if len(matched_names) == 0: + assert allow_missing + result.append(0) + else: + time_usage_sum = 0 + run_cnt_sum = 0 + for matched_name in matched_names: + run_cnt = len(self.time_ranges[matched_name]) + if ( + not allow_indivisible_run_count + and run_cnt % self.num_tests != 0 + ): + raise RuntimeError( + f"Error: the number of runs for kernel {matched_name} ({run_cnt}) is indivisible by `num_tests` ({self.num_tests})" + ) + time_usage_sum += sum( + [end - start for (start, end) in self.time_ranges[matched_name]] + ) + run_cnt_sum += run_cnt + denominator = ( + run_cnt_sum if return_avg_individual_run else self.num_tests + ) + result.append(time_usage_sum / denominator) + return result + + def get_kernel_time(self, kernel_name_substr: str) -> float: + return self.get_kernel_times([kernel_name_substr])[0] + + def get_e2e_time( + self, start_kernel_name_substr: str, end_kenrel_name_substr: str + ) -> float: + """ + Get the end-to-end time usage for a sequence of kernels + defined as "last kernel end time" - "first kernel start time" + If is_using_profiling_tools (which is conflict with bench_kineto), return 1 second + """ + if is_using_profiling_tools(): + return 1 + + start_kernel_name = self._get_matched_kernel_name(start_kernel_name_substr)[0] + end_kernel_name = self._get_matched_kernel_name(end_kenrel_name_substr)[0] + num_start_kernels = len(self.time_ranges[start_kernel_name]) + num_end_kernels = len(self.time_ranges[end_kernel_name]) + if num_start_kernels % self.num_tests != 0: + raise RuntimeError( + f"Error: the number of runs for kernel {start_kernel_name} ({num_start_kernels}) is indivisible by `num_tests` ({self.num_tests})" + ) + if num_end_kernels % self.num_tests != 0: + raise RuntimeError( + f"Error: the number of runs for kernel {end_kernel_name} ({num_end_kernels}) is indivisible by `num_tests` ({self.num_tests})" + ) + time_spans = [] + for i in range(self.num_tests): + end_time = self.time_ranges[end_kernel_name][ + (i + 1) * (num_end_kernels // self.num_tests) - 1 + ][1] + start_time = self.time_ranges[start_kernel_name][ + i * (num_start_kernels // self.num_tests) + ][0] + time_spans.append((start_time, end_time)) + result = sum([end - start for (start, end) in time_spans]) / self.num_tests + return result + + +def bench_kineto( + fn: Callable, num_tests: int = 30, flush_l2: bool = True +) -> BenchKinetoRawResult: + """ + Run `fn` for `num_tests` times under `bench_kineto` (CUPTI), and returns a BenchKinetoRawResult + """ + using_nsys = is_using_profiling_tools() + + # By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + schedule = ( + torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + if not using_nsys + else None + ) + profiler = ( + torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule + ) + if not using_nsys + else empty_suppress() + ) + with profiler: + for i in range(2): + if i == 1 and not using_nsys: + _run_profiler_range_start_marker_kernel() # This marks the start of the profiling range + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device="cuda").zero_() + enable_nvtx_range = i == 1 and _ == num_tests - 1 + if enable_nvtx_range: + torch.cuda.nvtx.range_push("profile_target") + fn() + if enable_nvtx_range: + torch.cuda.nvtx.range_pop() + if not using_nsys: + if i == 0: + torch.cuda.synchronize() + profiler.step() + + if using_nsys: + return BenchKinetoRawResult(True, num_tests, {}) + + from torch.autograd.profiler_util import ( # pylint: disable=import-outside-toplevel + EventList, + FunctionEvent, + ) + + events: EventList = profiler.events() # type: ignore + + # Filter out all events that are not function events + events: List[FunctionEvent] = [ + event for event in events if isinstance(event, FunctionEvent) + ] + + # Filter out all events before the range marker + for idx, event in enumerate(events): + if event.name == "profiler_range_start_marker_kernel": + events = events[idx + 1 :] + break + else: + raise RuntimeError("Could not find profiler range start marker kernel event") + + # Get time ranges of each kernel + kernel_times = {} + for event in events: + kernel_name = event.name + if kernel_name not in kernel_times: + kernel_times[kernel_name] = [] + kernel_times[kernel_name].append( + (event.time_range.start / 1e6, event.time_range.end / 1e6) + ) + + return BenchKinetoRawResult(False, num_tests, kernel_times) + + +@overload +def bench_by_cuda_events( + kernels: List[Callable], num_warmups_each: int, num_runs_each: int +) -> List[float]: ... + + +@overload +def bench_by_cuda_events( + kernels: Callable, num_warmups_each: int, num_runs_each: int +) -> float: ... + + +def bench_by_cuda_events( + kernels: Union[List[Callable], Callable], num_warmups_each: int, num_runs_each: int +) -> Union[List[float], float]: + buf_for_l2_clear = torch.empty(int(256e6 // 4), dtype=torch.int32, device="cuda") + + is_kernel_single_callable = isinstance(kernels, Callable) + if is_kernel_single_callable: + kernels = [kernels] + + torch.cuda.synchronize() + for i in range(num_warmups_each): + for kernel in kernels: + kernel() + if i == 0: + # Ensure the first run is successful + try: + torch.cuda.synchronize() + except Exception as e: + print(f"Kernel {kernel.__name__} failed on warmup run {i}: {e}") + return [] + + start_events = [ + [torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] + for _ in kernels + ] + end_events = [ + [torch.cuda.Event(enable_timing=True) for _ in range(num_runs_each)] + for _ in kernels + ] + for i in range(num_runs_each): + for j, kernel in enumerate(kernels): + buf_for_l2_clear.random_() + if i == num_runs_each - 1: + torch.cuda.nvtx.range_push("profile_target") + start_events[j][i].record() + kernel() + end_events[j][i].record() + if i == num_runs_each - 1: + torch.cuda.nvtx.range_pop() + + torch.cuda.synchronize() + time_usages = [ + sum( + [ + start_events[j][i].elapsed_time(end_events[j][i]) * 1e-3 + for i in range(num_runs_each) + ] + ) + / num_runs_each + for j in range(len(kernels)) + ] + if is_kernel_single_callable: + time_usages = time_usages[0] + return time_usages diff --git a/python/sglang/srt/flashmla_tests/kernelkit/compare.py b/python/sglang/srt/flashmla_tests/kernelkit/compare.py new file mode 100644 index 000000000000..3ecc6e0e2d1e --- /dev/null +++ b/python/sglang/srt/flashmla_tests/kernelkit/compare.py @@ -0,0 +1,138 @@ +from typing import List + +import torch + + +def check_is_bitwise_equal_comparator( + ans: torch.Tensor, ref: torch.Tensor, result: torch.Tensor +): + """ + Return if two tensors are bitwise equal + Return a bool if avoid_sync is False, else return a tensor + """ + assert ans.shape == ref.shape, "Shape mismatch" + torch.all(torch.eq(ans, ref), out=result) + + +def check_is_bitwise_equal( + name: str, ans: torch.Tensor, ref: torch.Tensor, quiet: bool = False +) -> bool: + is_bitwise_equal = torch.equal(ans, ref) + if not quiet and not is_bitwise_equal: + print( + f"`{name}` mismatch: not bitwise equal. Mismatch count: {(ans != ref).sum().item()} out of {ans.numel()}" + ) + return is_bitwise_equal + + +def get_cos_diff(ans: torch.Tensor, ref: torch.Tensor) -> float: + """ + Calculate the cosine diff between two tensors + Return a float if avoid_sync is False, else return a tensor + """ + ans, ref = ans.double(), ref.double() + if (ref * ref).sum().item() < 1e-12: + return 0 + denominator = (ans * ans + ref * ref).sum().item() + sim = 2 * (ans * ref).sum().item() / denominator + return 1 - sim + + +def check_is_allclose( + name: str, + ans: torch.Tensor, + ref: torch.Tensor, + abs_tol: float = 1e-5, + rel_tol: float = 1e-2, + cos_diff_tol: float = 1e-7, + quiet: bool = False, +) -> bool: + """ + Check if two tensors are close enough + Return a bool if avoid_sync is False, else return a tensor + """ + assert ( + ans.shape == ref.shape + ), f"`{name}` Shape mismatch: {ans.shape} vs {ref.shape}" + assert ( + ans.dtype == ref.dtype + ), f"`{name}` Dtype mismatch: {ans.dtype} vs {ref.dtype}" + + ans = ans.clone().to(torch.float) + ref = ref.clone().to(torch.float) + + def report_err(*args, **kwargs): + if not quiet: + print(*args, **kwargs) + + # Deal with anomalies + def deal_with_anomalies(val: float): + ref_mask = (ref == val) if (val == val) else (ref != ref) + ans_mask = (ans == val) if (val == val) else (ans != ans) + ref[ref_mask] = 0.0 + ans[ans_mask] = 0.0 + if not torch.equal(ref_mask, ans_mask): + report_err( + f"`{name}` Anomaly number `{val}` mismatch: {ans_mask.sum().item()} in ans but {ref_mask.sum().item()} in ref" + ) + return False + return True + + anomalies_check_passed = True + anomalies_check_passed &= deal_with_anomalies(float("inf")) + anomalies_check_passed &= deal_with_anomalies(float("-inf")) + anomalies_check_passed &= deal_with_anomalies(float("nan")) + + cos_diff = get_cos_diff(ans, ref) + raw_abs_err = torch.abs(ans - ref) + raw_rel_err = raw_abs_err / (torch.abs(ref) + (1e-6)) + rel_err = raw_rel_err.masked_fill(raw_abs_err < abs_tol, 0) + abs_err = raw_abs_err.masked_fill(raw_rel_err < rel_tol, 0) + pass_mask = (abs_err < abs_tol) | (rel_err < rel_tol) + + if not anomalies_check_passed: + return False + + if not pass_mask.all(): + report_err(f"`{name}` mismatch") + max_abs_err_pos: int = torch.argmax(abs_err, keepdim=True).item() + max_rel_err_pos: int = torch.argmax(rel_err, keepdim=True).item() + + def get_pos_in_tensor(t: torch.Tensor, pos: int) -> List[int]: + result = [] + for size in t.shape[::-1]: + result.append(pos % size) + pos = pos // size + assert pos == 0 + return result[::-1] + + report_err( + f"max abs err: {torch.max(abs_err).item()}: pos {get_pos_in_tensor(ans, max_abs_err_pos)}, {ans.reshape(-1)[max_abs_err_pos].item()} vs {ref.reshape(-1)[max_abs_err_pos].item()}" + ) + report_err( + f"max rel err: {torch.max(rel_err).item()}: pos {get_pos_in_tensor(ans, max_rel_err_pos)}, {ans.reshape(-1)[max_rel_err_pos].item()} vs {ref.reshape(-1)[max_rel_err_pos].item()}" + ) + report_err( + f"{pass_mask.sum()} out of {pass_mask.numel()} passed ({pass_mask.sum()/pass_mask.numel()*100.0:.2f}%)" + ) + report_err(f"Cosine diff: {cos_diff} (threshold: {cos_diff_tol})") + return False + else: + if abs(cos_diff) > cos_diff_tol: + report_err( + f"`{name}` mismatch: Cosine diff too large: {cos_diff} vs {cos_diff_tol})" + ) + return False + return True + + +def check_is_allclose_comparator( + name: str, + ans: torch.Tensor, + ref: torch.Tensor, + out: torch.Tensor, + abs_tol: float = 1e-5, + rel_tol: float = 1e-2, + cos_diff_tol: float = 1e-7, +): + out.fill_(check_is_allclose(name, ans, ref, abs_tol, rel_tol, cos_diff_tol)) diff --git a/python/sglang/srt/flashmla_tests/kernelkit/generate.py b/python/sglang/srt/flashmla_tests/kernelkit/generate.py new file mode 100644 index 000000000000..7c37a9457caf --- /dev/null +++ b/python/sglang/srt/flashmla_tests/kernelkit/generate.py @@ -0,0 +1,34 @@ +import torch + + +def _get_new_non_contiguous_tensor_shape(shape): + """ + Get the expanded shape for a non-contiguous tensor. + The last dimension is increased by 128 (for alignment), and all other dimensions are increased by 1 + """ + return [ + dim + 128 if dim_idx == len(shape) - 1 else dim + 1 + for dim_idx, dim in enumerate(shape) + ] + + +def gen_non_contiguous_randn_tensor(shape, *args, **kwargs): + new_shape = _get_new_non_contiguous_tensor_shape(shape) + base_tensor = torch.randn(new_shape, *args, **kwargs) + slices = [slice(0, dim) for dim in shape] + return base_tensor[slices] + + +def gen_non_contiguous_tensor(shape, *args, **kwargs): + new_shape = _get_new_non_contiguous_tensor_shape(shape) + base_tensor = torch.empty(new_shape, *args, **kwargs) + slices = [slice(0, dim) for dim in shape] + return base_tensor[slices] + + +def non_contiguousify(tensor: torch.Tensor) -> torch.Tensor: + new_tensor = gen_non_contiguous_tensor( + tensor.shape, dtype=tensor.dtype, device=tensor.device + ) + new_tensor[:] = tensor + return new_tensor diff --git a/python/sglang/srt/flashmla_tests/kernelkit/precision.py b/python/sglang/srt/flashmla_tests/kernelkit/precision.py new file mode 100644 index 000000000000..5a85562c15d1 --- /dev/null +++ b/python/sglang/srt/flashmla_tests/kernelkit/precision.py @@ -0,0 +1,35 @@ +import torch + +_is_low_precision_mode_stack = [] + + +class LowPrecisionMode: + def __init__(self, enabled: bool = True): + self.enabled = enabled + + def __enter__(self): + global _is_low_precision_mode_stack + _is_low_precision_mode_stack.append(self.enabled) + + def __exit__(self, exc_type, exc_value, traceback): + global _is_low_precision_mode_stack + _is_low_precision_mode_stack.pop() + + +def is_low_precision_mode() -> bool: + global _is_low_precision_mode_stack + if len(_is_low_precision_mode_stack) == 0: + return False + return _is_low_precision_mode_stack[-1] + + +def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor: + assert ( + tensor.dtype == torch.float32 + ), "Input tensor must be of dtype torch.float32 for optional casting." + if is_low_precision_mode(): + tensor_bf16 = tensor.to(torch.bfloat16) + tensor_fp32 = tensor_bf16.to(torch.float32) + return tensor_fp32 + else: + return tensor diff --git a/python/sglang/srt/flashmla_tests/kernelkit/utils.py b/python/sglang/srt/flashmla_tests/kernelkit/utils.py new file mode 100644 index 000000000000..29101c2a734f --- /dev/null +++ b/python/sglang/srt/flashmla_tests/kernelkit/utils.py @@ -0,0 +1,57 @@ +import functools +import os + +colors = { + "RED_FG": "\033[31m", + "GREEN_FG": "\033[32m", + "CYAN_FG": "\033[36m", + "GRAY_FG": "\033[90m", + "YELLOW_FG": "\033[33m", + "RED_BG": "\033[41m", + "GREEN_BG": "\033[42m", + "CYAN_BG": "\033[46m", + "YELLOW_BG": "\033[43m", + "GRAY_BG": "\033[100m", + "CLEAR": "\033[0m", +} + + +def cdiv(a: int, b: int) -> int: + return (a + b - 1) // b + + +@functools.lru_cache() +def is_using_profiling_tools() -> bool: + """ + Return whether we are running under profiling tools like nsys or ncu + + NOTE cuda-gdb will also cause conflict with CUPTI (bench_kineto) but currently we lack ways to detect it + """ + is_using_nsys = os.environ.get("NSYS_PROFILING_SESSION_ID") is not None + is_using_ncu = os.environ.get("NV_COMPUTE_PROFILER_PERFWORKS_DIR") is not None + is_using_compute_sanitizer = ( + os.environ.get("NV_SANITIZER_INJECTION_PORT_RANGE_BEGIN") is not None + ) + return is_using_nsys or is_using_ncu or is_using_compute_sanitizer + + +def set_random_seed(seed: int): + import random + + import numpy as np + import torch + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +class Counter: + def __init__(self): + self.count = 0 + + def next(self) -> int: + self.count += 1 + return self.count - 1 diff --git a/python/sglang/srt/flashmla_tests/lib.py b/python/sglang/srt/flashmla_tests/lib.py new file mode 100644 index 000000000000..43763cd4a4bb --- /dev/null +++ b/python/sglang/srt/flashmla_tests/lib.py @@ -0,0 +1,574 @@ +import dataclasses +import enum +import os +import random +from typing import List, Optional + +from sglang.srt.utils import is_hip + +if is_hip(): + pass +else: + import flash_mla + +import torch + +from . import kernelkit as kk +from . import quant + + +class TestTarget(enum.Enum): + FWD = 0 + DECODE = 1 + + +@dataclasses.dataclass +class ExtraTestParamForDecode: + b: int + is_varlen: bool + have_zero_seqlen_k: bool + extra_s_k: Optional[int] = None + extra_topk: Optional[int] = None + block_size: int = 64 + extra_block_size: Optional[int] = None + have_extra_topk_length: bool = False + + +@dataclasses.dataclass +class TestParam: + s_q: int + s_kv: int + topk: int + h_q: int = 128 + h_kv: int = 1 + d_qk: int = 512 + d_v: int = 512 + seed: int = -1 # -1: to be filled automatically + check_correctness: bool = True + is_all_indices_invalid: bool = ( + False # All indices are invalid, i.e., all indices are set to a large number (e.g., 2147483647) + ) + num_runs: int = 10 + have_attn_sink: bool = False + have_topk_length: bool = False + decode: Optional[ExtraTestParamForDecode] = None + + +@dataclasses.dataclass +class RawTestParamForDecode: + """ + "Flattened" test parameters for decoding test + + In our test script, to maintain compatibility with TestParam, we embed decode-only parameters into TestParam.decode, which is not very convinient when construct testcases. So here we have a "flattened" version of test parameters for decoding test. + """ + + b: int + h_q: int + s_q: int + h_kv: int + s_kv: int + is_varlen: bool + topk: int + is_all_indices_invalid: bool = False + have_zero_seqlen_k: bool = False + have_topk_length: bool = False + enable_attn_sink: bool = True + extra_s_k: Optional[int] = None + extra_topk: Optional[int] = None + block_size: int = 64 + extra_block_size: Optional[int] = None + have_extra_topk_length: bool = False + d_qk: int = 576 # Q/K head dim (= dv + RoPE dim) + d_v: int = 512 # V head dim + check_correctness: bool = True + num_runs: int = 10 + seed: int = -1 + + def to_test_param(self) -> TestParam: + return TestParam( + self.s_q, + self.s_kv, + self.topk, + self.h_q, + self.h_kv, + self.d_qk, + self.d_v, + self.seed, + self.check_correctness, + self.is_all_indices_invalid, + self.num_runs, + self.enable_attn_sink, + self.have_topk_length, + decode=ExtraTestParamForDecode( + self.b, + self.is_varlen, + self.have_zero_seqlen_k, + self.extra_s_k, + self.extra_topk, + self.block_size, + self.extra_block_size, + self.have_extra_topk_length, + ), + ) + + +@dataclasses.dataclass +class Testcase: + p: TestParam + dOut: torch.Tensor # [s_q, h_q, d_v] + q: torch.Tensor # [s_q, h_q, d_qk] + kv: torch.Tensor # [s_kv, h_kv, d_qk] + indices: torch.Tensor # [s_q, h_kv, topk] + sm_scale: float + attn_sink: Optional[torch.Tensor] # [h_q] + topk_length: Optional[torch.Tensor] # [s_q] + + +def _randperm_batch( + batch_size: int, perm_range: torch.Tensor, perm_size: int, paddings: List[int] +) -> torch.Tensor: + """ + Generate random permutations in batch + The return tensor, denoted as `res`, has a shape of [batch_size, perm_size]. `0 <= res[i, :] < perm_range[i]` holds. + Values within each row are unique. + If, for some `i`, `perm_range[i] < perm_size` holds, then `res[i, :]` contains values in `[0, perm_range[i])` as many as possible, and the rest are filled with `padding`. + """ + assert not torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(True) + perm_range_max = max(int(torch.max(perm_range).item()), perm_size) + rand = torch.rand(batch_size, perm_range_max, dtype=torch.float32) + rand[ + torch.arange(0, perm_range_max).broadcast_to(batch_size, perm_range_max) + >= perm_range.view(batch_size, 1) + ] = float( + "-inf" + ) # Fill invalid positions, so that the following `topk` operators will select positions within `perm_range` first + res = rand.topk(perm_size, dim=-1, sorted=True).indices.to(torch.int32) + if len(paddings) == 1: + res[res >= perm_range.view(batch_size, 1)] = paddings[0] + else: + fillers = torch.tensor(paddings, dtype=torch.int32).index_select( + 0, torch.randint(0, len(paddings), (res.numel(),), dtype=torch.int32) + ) + res.masked_scatter_(res >= perm_range.view(batch_size, 1), fillers) + torch.use_deterministic_algorithms(False) + return res + + +def generate_testcase(t: TestParam) -> Testcase: + kk.set_random_seed(t.seed) + q = ( + torch.randn((t.s_q, t.h_q, t.d_qk), dtype=torch.bfloat16) / 10 + + (random.random() - 0.5) / 10 + ) + kv = ( + torch.randn((t.s_kv, t.h_kv, t.d_qk), dtype=torch.bfloat16) / 10 + + (random.random() - 0.5) / 10 + ) + do = ( + torch.randn((t.s_q, t.h_q, t.d_v), dtype=torch.bfloat16) / 10 + + (random.random() - 0.5) / 10 + ) + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + do.clamp_(-10, 10) + + invalid_indices_candidate = [ + -2147483648, + -123456, + -1, + t.s_kv, + 114514, + 1919810, + 2147480000, + 2147483647, + ] + indices = _randperm_batch( + t.s_q, + torch.full((t.s_q,), t.s_kv, dtype=torch.int32), + t.topk, + invalid_indices_candidate, + ).view(t.s_q, t.h_kv, t.topk) + + if t.is_all_indices_invalid: + all_indices_invalid_mask = torch.randn(t.s_q, device="cpu") < -2 + indices[all_indices_invalid_mask[:, None, None].broadcast_to(indices.shape)] = ( + random.choice(invalid_indices_candidate) + ) + indices = indices.to(q.device) + + attn_sink = None + if t.have_attn_sink: + attn_sink = torch.randn((t.h_q,), dtype=torch.float32) + mask = torch.randn((t.h_q,), dtype=torch.float32) + attn_sink[mask < -0.5] = float("-inf") + attn_sink[mask > +0.5] = float("+inf") + + topk_length = None + if t.have_topk_length: + topk_length = torch.randint( + 0, max(t.topk + 1, 64), (t.s_q,), dtype=torch.int32, device=q.device + ).clamp_max(t.topk) + + q = kk.non_contiguousify(q) + kv = kk.non_contiguousify(kv) + do = kk.non_contiguousify(do) + indices = kk.non_contiguousify(indices) + + return Testcase( + p=t, + dOut=do, + q=q, + kv=kv, + indices=indices, + sm_scale=0.5, # Otherwise dK is too small compared to dV + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +@dataclasses.dataclass +class KVScope: + t: TestParam + cache_seqlens: torch.Tensor + block_table: torch.Tensor + blocked_k: torch.Tensor + abs_indices: torch.Tensor + indices_in_kvcache: torch.Tensor + topk_length: Optional[torch.Tensor] + blocked_k_quantized: Optional[torch.Tensor] = None + + def quant_and_dequant_(self): + """ + For FP8 cases, we need to quantize the KV cache for Flash MLA. + Besides, the quantization error may be too large to be distinguished from wrong kernels, so we de-quantize kvcache here to mitigate quantization error + """ + fp8_kvcache_layout = None + if self.t.d_qk == 576: + fp8_kvcache_layout = quant.FP8KVCacheLayout.V32_FP8Sparse + elif self.t.d_qk == 512: + assert self.abs_indices is not None + fp8_kvcache_layout = quant.FP8KVCacheLayout.MODEL1_FP8Sparse + else: + assert False + self.blocked_k_quantized = quant.quantize_k_cache( + self.blocked_k, fp8_kvcache_layout + ) + blocked_k_dequantized = quant.dequantize_k_cache( + self.blocked_k_quantized, fp8_kvcache_layout + ) + self.blocked_k = blocked_k_dequantized + + def get_kvcache_for_flash_mla(self) -> torch.Tensor: + """ + Return the quantized blocked_k for Flash MLA + """ + assert ( + self.blocked_k_quantized is not None + ), "Please call `quant_and_dequant_` first before calling `get_kvcache_for_flash_mla`" + return self.blocked_k_quantized + + def apply_perm(self, perm: torch.Tensor) -> "KVScope": + """ + Apply a batch permutation to this KVScope. Used for batch-invariance test + """ + new_kvscope = KVScope( + self.t, + self.cache_seqlens[perm], + self.block_table[perm], + self.blocked_k, + self.abs_indices[perm], + self.indices_in_kvcache[perm], + self.topk_length[perm] if self.topk_length is not None else None, + self.blocked_k_quantized, + ) + return new_kvscope + + +@dataclasses.dataclass +class TestcaseForDecode: + p: TestParam + q: torch.Tensor # [b, s_q, h_q, d_qk] + attn_sink: Optional[torch.Tensor] # [h_q] + sm_scale: float + kv_scope: KVScope + extra_kv_scope: Optional[KVScope] + + +def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode: + kk.set_random_seed(t.seed) + assert t.h_q % t.h_kv == 0 + assert t.decode is not None + + q = torch.randn((t.decode.b, t.s_q, t.h_q, t.d_qk)) + q.clamp_(min=-1.0, max=1.0) + + attn_sink = None + if t.have_attn_sink: + attn_sink = torch.randn((t.h_q,), dtype=torch.float32) + inf_mask = torch.randn((t.h_q,), dtype=torch.float32) + attn_sink[inf_mask > 0.5] = float("inf") + attn_sink[inf_mask < -0.5] = float("-inf") + + def generate_one_k_scope( + s_k: int, + block_size: int, + topk: int, + is_varlen: bool, + have_zero_seqlen: bool, + is_all_indices_invalid: bool, + have_topk_length: bool, + ) -> KVScope: + b = t.decode.b # type: ignore + cache_seqlens_cpu = torch.full((b,), s_k, dtype=torch.int32, device="cpu") + if is_varlen: + for i in range(b): + cache_seqlens_cpu[i] = max(random.normalvariate(s_k, s_k / 2), t.s_q) + + if have_zero_seqlen: + zeros_mask = torch.randn(b, dtype=torch.float32, device="cpu") > 0 + cache_seqlens_cpu[zeros_mask] = 0 + + max_seqlen_alignment = 4 * block_size + max_seqlen_pad = ( + max(kk.cdiv(int(cache_seqlens_cpu.max().item()), max_seqlen_alignment), 1) + * max_seqlen_alignment + ) + cache_seqlens = cache_seqlens_cpu.cuda() + + assert max_seqlen_pad % block_size == 0 + block_table = torch.arange( + b * max_seqlen_pad // block_size, dtype=torch.int32 + ).view(b, max_seqlen_pad // block_size) + block_table = block_table.view(-1)[torch.randperm(block_table.numel())].view( + b, -1 + ) + + blocked_k = ( + kk.gen_non_contiguous_randn_tensor( + (block_table.numel(), block_size, t.h_kv, t.d_qk) + ) + / 10 + ) + blocked_k.clamp_(min=-1.0, max=1.0) + + abs_indices = torch.empty((b, t.s_q, topk), dtype=torch.int32) + if is_all_indices_invalid: + abs_indices.fill_(-1) + else: + abs_indices[:] = _randperm_batch( + b * t.s_q, cache_seqlens.repeat_interleave(t.s_q), topk, [-1] + ).view(b, t.s_q, topk) + indices_in_kvcache = quant.abs_indices2indices_in_kvcache( + abs_indices, block_table, block_size + ) + + topk_length = ( + torch.randint(0, topk + 1, (b,), dtype=torch.int32, device=q.device) + if have_topk_length + else None + ) + + # Mask nonused KV as NaN + if have_topk_length: + indices_in_kvcache_masked = indices_in_kvcache.clone() + indices_in_kvcache_masked[ + torch.arange(0, topk).view(1, 1, topk).broadcast_to(b, t.s_q, topk) + >= (topk_length.view(b, 1, 1) if have_topk_length else topk) + ] = -1 + else: + indices_in_kvcache_masked = indices_in_kvcache + + blocked_k = blocked_k.view(-1, t.h_kv, t.d_qk) + nonused_indices_mask = torch.ones( + blocked_k.size(0) * blocked_k.size(1), dtype=torch.bool, device="cpu" + ) + nonused_indices_mask[indices_in_kvcache_masked] = False + blocked_k[nonused_indices_mask, :, :] = float("nan") + blocked_k = blocked_k.view(-1, block_size, t.h_kv, t.d_qk) + + block_table = kk.non_contiguousify(block_table) + abs_indices = kk.non_contiguousify(abs_indices) + indices_in_kvcache = kk.non_contiguousify(indices_in_kvcache) + return KVScope( + t, + cache_seqlens, + block_table, + blocked_k, + abs_indices, + indices_in_kvcache, + topk_length, + ) + + kv_scope0 = generate_one_k_scope( + t.s_kv, + t.decode.block_size, + t.topk, + t.decode.is_varlen, + t.decode.have_zero_seqlen_k, + t.is_all_indices_invalid, + t.have_topk_length, + ) + kv_scope0.quant_and_dequant_() + if t.decode.extra_topk is not None: + if t.decode.extra_s_k is None: + t.decode.extra_s_k = t.decode.extra_topk * 2 + if t.decode.extra_block_size is None: + t.decode.extra_block_size = t.decode.block_size + kv_scope1 = generate_one_k_scope( + t.decode.extra_s_k, + t.decode.extra_block_size, + t.decode.extra_topk, + t.decode.is_varlen, + t.decode.have_zero_seqlen_k, + t.is_all_indices_invalid, + t.decode.have_extra_topk_length, + ) + kv_scope1.quant_and_dequant_() + else: + assert ( + t.decode.extra_block_size is None + and t.decode.extra_s_k is None + and not t.decode.have_extra_topk_length + ) + kv_scope1 = None + + sm_scale = t.d_qk**-0.55 + + q = kk.non_contiguousify(q) + return TestcaseForDecode(t, q, attn_sink, sm_scale, kv_scope0, kv_scope1) + + +def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool): + assert not return_p_sum + return flash_mla.flash_mla_sparse_fwd( + t.q, + t.kv, + t.indices, + sm_scale=t.sm_scale, + attn_sink=t.attn_sink, + topk_length=t.topk_length, + ) + + +def run_flash_mla_decode( + p: TestParam, t: TestcaseForDecode, tile_scheduler_metadata, num_splits +): + assert p.decode is not None + return flash_mla.flash_mla_with_kvcache( + t.q, + t.kv_scope.get_kvcache_for_flash_mla(), + None, + None, + p.d_v, + tile_scheduler_metadata, + num_splits, + t.sm_scale, + False, + True, + t.kv_scope.indices_in_kvcache, + t.attn_sink, + ( + t.extra_kv_scope.get_kvcache_for_flash_mla() + if t.extra_kv_scope is not None + else None + ), + t.extra_kv_scope.indices_in_kvcache if t.extra_kv_scope is not None else None, + t.kv_scope.topk_length, + ( + t.extra_kv_scope.topk_length + if t.extra_kv_scope is not None and t.extra_kv_scope.topk_length is not None + else None + ), + ) + + +@dataclasses.dataclass +class FlopsAndMemVolStatistics: + """ + FLOPs and memory volume statistics for prefilling + """ + + fwd_flop: float + fwd_mem_vol: float + + +def count_flop_and_mem_vol(p: TestParam, t: Testcase) -> FlopsAndMemVolStatistics: + total_topk = ( + (p.s_q * p.topk) if t.topk_length is None else t.topk_length.sum().item() + ) + indices_valid_mask = (t.indices >= 0) & (t.indices < p.s_kv) + if t.topk_length is not None: + indices_valid_mask &= ( + torch.arange(p.topk)[None, None, :].broadcast_to(p.s_q, p.h_kv, p.topk) + ) < t.topk_length[:, None, None] + num_valid_indices = indices_valid_mask.sum().item() + + fwd_flop = 2 * total_topk * p.h_q * (p.d_qk + p.d_v) + fwd_mem_vol = num_valid_indices * p.d_qk * 2 + p.s_q * p.h_q * (p.d_qk + p.d_v) * 2 + return FlopsAndMemVolStatistics( + fwd_flop, + fwd_mem_vol, + ) + + +@dataclasses.dataclass +class FlopsAndMemVolStatisticsForDecode: + """ + FLOPs and memory volume statistics for decoding + """ + + flop: float + mem_vol: float + + +def count_flop_and_mem_vol_for_decode( + p: TestParam, t: TestcaseForDecode +) -> FlopsAndMemVolStatisticsForDecode: + assert p.decode + b = p.decode.b + + def get_num_attended_tokens(kv_scope: KVScope) -> int: + topk = kv_scope.indices_in_kvcache.shape[-1] + if kv_scope.topk_length is None: + return b * p.s_q * topk + else: + return int(kv_scope.topk_length.sum().item()) * p.s_q + + def get_num_retrieved_tokens(kv_scope: KVScope) -> int: + if kv_scope.topk_length is None: + indices = kv_scope.indices_in_kvcache + else: + indices = kv_scope.indices_in_kvcache.clone() + batch, s_q, topk = indices.shape + mask = torch.arange(0, topk, device=indices.device).view( + 1, 1, topk + ).broadcast_to(batch, s_q, topk) >= kv_scope.topk_length.view(batch, 1, 1) + indices[mask] = -1 + num_unique_tokens = indices.unique().numel() # type: ignore + return num_unique_tokens + + num_attended_tokens = get_num_attended_tokens(t.kv_scope) + ( + get_num_attended_tokens(t.extra_kv_scope) if t.extra_kv_scope is not None else 0 + ) + num_retrieved_tokens = get_num_retrieved_tokens(t.kv_scope) + ( + get_num_retrieved_tokens(t.extra_kv_scope) + if t.extra_kv_scope is not None + else 0 + ) + + compute_flop = 2 * p.h_q * num_attended_tokens * (p.d_qk + p.d_v) + kv_token_size = 656 if p.d_qk == 576 else 576 # Assume FP8 KV Cache + mem_vol = sum( + [ + 2 * b * p.s_q * p.h_q * p.d_qk, # Q + num_retrieved_tokens * kv_token_size, # K + 2 * b * p.s_q * p.h_q * p.d_v, # O + ] + ) + return FlopsAndMemVolStatisticsForDecode(compute_flop, mem_vol) + + +def is_no_cooldown() -> bool: + return os.environ.get("NO_COOLDOWN", "").lower() in ["1", "yes", "y"] diff --git a/python/sglang/srt/flashmla_tests/quant.py b/python/sglang/srt/flashmla_tests/quant.py new file mode 100644 index 000000000000..28b9b024a429 --- /dev/null +++ b/python/sglang/srt/flashmla_tests/quant.py @@ -0,0 +1,256 @@ +import enum +from typing import Tuple + +import torch + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +# from sglang.srt.utils import is_hip +# FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn +FP8_DTYPE = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +class FP8KVCacheLayout(enum.Enum): + V32_FP8Sparse = 1 + MODEL1_FP8Sparse = 2 + + def get_meta(self) -> Tuple[int, int, int, int, int]: + # Return: (d, d_nope, d_rope, tile_size, num_tiles) + return { + FP8KVCacheLayout.V32_FP8Sparse: (576, 512, 64, 128, 4), + FP8KVCacheLayout.MODEL1_FP8Sparse: (512, 448, 64, 64, 7), + }[self] + + +def _cast_scale_inv_to_ue8m0( + scales_inv: torch.Tensor, out_dtype=torch.float32 +) -> torch.Tensor: + return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype) + + +def quantize_k_cache( + input_k_cache: torch.Tensor, # (num_blocks, block_size, h_k, d) + kvcache_layout: FP8KVCacheLayout, +) -> torch.Tensor: + """ + Quantize the k-cache + For more detail about the layout of K/V, please refer to comments in flash_mla_interface.py + """ + d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta() + assert input_k_cache.shape[-1] == d + num_blocks, block_size, h_k, _ = input_k_cache.shape + assert h_k == 1 + input_k_cache = input_k_cache.squeeze(2) # [num_blocks, block_size, d] + input_elem_size = input_k_cache.element_size() + + if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse: + bytes_per_token = d_nope + num_tiles * 4 + input_elem_size * d_rope + result = torch.empty( + (num_blocks, block_size + 1, bytes_per_token), + dtype=FP8_DTYPE, + device=input_k_cache.device, + )[:, :block_size, :] + result_k_nope_part = result[..., :d_nope] + result_k_scale_factor = result[..., d_nope : d_nope + num_tiles * 4].view( + torch.float32 + ) + result_k_rope_part = result[..., d_nope + num_tiles * 4 :].view( + input_k_cache.dtype + ) + result_k_rope_part[:] = input_k_cache[..., d_nope:] + + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = ( + torch.abs( + input_k_cache[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ] + ) + .max(dim=-1) + .values.float() + / 448.0 + ) # [num_blocks, block_size] + cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv) + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv + + cur_scale_factors_inv.unsqueeze_(-1) # [num_blocks, block_size, 1] + cur_quantized_nope = ( + input_k_cache[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].float() + / cur_scale_factors_inv.float() + ).to(FP8_DTYPE) + result_k_nope_part[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ] = cur_quantized_nope + + result = result.view(num_blocks, block_size, 1, -1) + return result + + elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse: + bytes_per_token = d_nope + 2 * d_rope + num_tiles + 1 + size_per_block_padded = (block_size * bytes_per_token + 576 - 1) // 576 * 576 + result = torch.empty( + (num_blocks, size_per_block_padded), + dtype=FP8_DTYPE, + device=input_k_cache.device, + )[:, : block_size * bytes_per_token] + result_k_nope_rope_part = result[:, : block_size * (d_nope + 2 * d_rope)].view( + num_blocks, block_size, d_nope + 2 * d_rope + ) + result_k_nope = result_k_nope_rope_part[ + :, :, :d_nope + ] # [num_blocks, block_size, d_nope] + result_k_rope = result_k_nope_rope_part[:, :, d_nope:].view( + input_k_cache.dtype + ) # [num_blocks, block_size, d_rope] + result_k_scale_factor = ( + result[:, block_size * (d_nope + 2 * d_rope) :] + .view(num_blocks, block_size, 8)[:, :, :7] + .view(torch.float8_e8m0fnu) + ) # [num_blocks, block_size, num_tiles] + + result_k_rope[:] = input_k_cache[..., d_nope:] + for tile_idx in range(0, num_tiles): + cur_scale_factors_inv = ( + torch.abs( + input_k_cache[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ] + ) + .max(dim=-1) + .values.float() + / 448.0 + ) # [num_blocks, block_size] + cur_scale_factors_inv = _cast_scale_inv_to_ue8m0(cur_scale_factors_inv) + result_k_scale_factor[:, :, tile_idx] = cur_scale_factors_inv.to( + torch.float8_e8m0fnu + ) + + cur_scale_factors_inv = cur_scale_factors_inv.view( + num_blocks, block_size, 1 + ) + cur_quantized_nope = ( + input_k_cache[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].float() + / cur_scale_factors_inv.float() + ).to(FP8_DTYPE) + result_k_nope[:, :, tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_quantized_nope + ) + + result = result.view(num_blocks, block_size, 1, -1) + return result + + else: + raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}") + + +def dequantize_k_cache( + quant_k_cache: torch.Tensor, # (num_blocks, block_size, 1, bytes_per_token) + kvcache_layout: FP8KVCacheLayout, +) -> torch.Tensor: + """ + De-quantize the k-cache + """ + # NOTE ADD + assert quant_k_cache.dtype == FP8_DTYPE + + d, d_nope, d_rope, tile_size, num_tiles = kvcache_layout.get_meta() + num_blocks, block_size, h_k, _ = quant_k_cache.shape + assert h_k == 1 + result = torch.empty( + (num_blocks, block_size, d), dtype=torch.bfloat16, device=quant_k_cache.device + ) + + if kvcache_layout == FP8KVCacheLayout.V32_FP8Sparse: + quant_k_cache = quant_k_cache.view(num_blocks, block_size, -1) + + input_nope = quant_k_cache[..., :d_nope] + input_scale = quant_k_cache[..., d_nope : d_nope + num_tiles * 4].view( + torch.float32 + ) + input_rope = quant_k_cache[..., d_nope + num_tiles * 4 :].view(torch.bfloat16) + result[..., d_nope:] = input_rope + + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].to(torch.float32) + cur_scales = input_scale[..., tile_idx].unsqueeze(-1) + result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_nope * cur_scales + ) + + elif kvcache_layout == FP8KVCacheLayout.MODEL1_FP8Sparse: + quant_k_cache = quant_k_cache.view(num_blocks, -1) # [num_blocks, ...] + input_nope_rope = quant_k_cache[:, : block_size * (d_nope + 2 * d_rope)].view( + num_blocks, block_size, d_nope + 2 * d_rope + ) + input_nope = input_nope_rope[:, :, :d_nope] + input_rope = input_nope_rope[:, :, d_nope:].view(torch.bfloat16) + input_scale = ( + quant_k_cache[:, block_size * (d_nope + 2 * d_rope) :] + .view(num_blocks, block_size, 8)[:, :, :7] + .view(torch.float8_e8m0fnu) + ) # [num_blocks, block_size, num_tiles] + + result[..., d_nope:] = input_rope + for tile_idx in range(0, num_tiles): + cur_nope = input_nope[ + ..., tile_idx * tile_size : (tile_idx + 1) * tile_size + ].to(torch.bfloat16) + cur_scales = input_scale[:, :, tile_idx].to(torch.bfloat16).unsqueeze(-1) + result[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( + cur_nope * cur_scales + ) + + else: + raise NotImplementedError(f"Unsupported kvcache_layout: {kvcache_layout}") + + result = result.view(num_blocks, block_size, 1, d) + return result + + +def abs_indices2indices_in_kvcache( + abs_indices: torch.Tensor, # [b, s_q, topk] + block_table: torch.Tensor, # [b, /] + block_size: int, +) -> torch.Tensor: + """ + Convert abs_indices (logical index, ranging from 0 to s_k-1) to index expected by the sparse attn kernel + Equivalent to: + + b, s_q, topk = abs_indices.shape + indices_in_kvcache = torch.empty_like(abs_indices) + for i in range(b): + cur_abs_indices = abs_indices[i, :, :].clone() # [s_q, topk] + invalid_mask = cur_abs_indices == -1 + cur_abs_indices[invalid_mask] = 0 + cur_indices_in_kvcache = block_table[i].index_select(0, cur_abs_indices.flatten()//block_size).view(s_q, topk)*block_size + cur_abs_indices%block_size + cur_indices_in_kvcache[invalid_mask] = -1 + indices_in_kvcache[i] = cur_indices_in_kvcache + return indices_in_kvcache + + """ + b, s_q, topk = abs_indices.shape + _, max_blocks_per_seq = block_table.shape + + abs_indices = abs_indices.clone() + invalid_mask = abs_indices == -1 + abs_indices[invalid_mask] = 0 + + real_block_idxs = block_table.view(-1).index_select( + 0, + ( + abs_indices // block_size + + torch.arange(0, b).view(b, 1, 1) * max_blocks_per_seq + ).view(-1), + ) + indices_in_kvcache = ( + real_block_idxs.view(b, s_q, topk) * block_size + abs_indices % block_size + ) + indices_in_kvcache[invalid_mask] = -1 + + return indices_in_kvcache diff --git a/python/sglang/srt/flashmla_tests/ref.py b/python/sglang/srt/flashmla_tests/ref.py new file mode 100644 index 000000000000..23acb2b1461e --- /dev/null +++ b/python/sglang/srt/flashmla_tests/ref.py @@ -0,0 +1,137 @@ +from typing import Optional, Tuple + +import torch + +from .lib import KVScope, Testcase, TestcaseForDecode, TestParam + + +def _merge_two_lse( + lse0: torch.Tensor, lse1: Optional[torch.Tensor], s_q: int, h_q: int +) -> torch.Tensor: + if lse1 is None: + return lse0 + else: + return torch.logsumexp( + torch.stack([lse0.view(s_q, h_q), lse1.broadcast_to(s_q, h_q)], dim=0), + dim=0, + ) + + +def ref_sparse_attn_fwd( + p: TestParam, t: Testcase +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Returns: + - o: [s_q, h_q, dv] + - o_fp32: [s_q, h_q, dv] + - max_logits: [s_q, h_q] + - lse: [s_q, h_q] + """ + indices = t.indices.clone().squeeze(1) + if t.topk_length is not None: + mask = torch.arange(p.topk, device=t.topk_length.device).unsqueeze( + 0 + ).broadcast_to(p.s_q, p.topk) >= t.topk_length.unsqueeze( + 1 + ) # [s_q, topk] + indices[mask] = -1 + invalid_mask = (indices < 0) | (indices >= p.s_kv) # [s_q, topk] + indices[invalid_mask] = 0 + + q = t.q.float() + gathered_kv = ( + t.kv.index_select(dim=0, index=indices.flatten()) + .reshape(p.s_q, p.topk, p.d_qk) + .float() + ) # [s_q, topk, d_qk] + P = q @ gathered_kv.transpose(1, 2) # [s_q, h_q, topk] + P *= t.sm_scale + P[invalid_mask.unsqueeze(1).broadcast_to(P.shape)] = float("-inf") + + orig_lse = torch.logsumexp(P, dim=-1) # [s_q, h_q] + max_logits = P.max(dim=-1).values # [s_q, h_q] + + lse_for_o = _merge_two_lse(orig_lse, t.attn_sink, p.s_q, p.h_q) + if not torch.is_inference_mode_enabled(): + lse_for_o = lse_for_o.clone() + lse_for_o[lse_for_o == float("-inf")] = float( + "+inf" + ) # So that corresponding O will be 0 + s_for_o = torch.exp(P - lse_for_o.unsqueeze(-1)) + out = s_for_o @ gathered_kv[..., : p.d_v] # [s_q, h_q, dv] + + lonely_q_mask = orig_lse == float("-inf") # [s_q, h_q] + orig_lse[lonely_q_mask] = float("+inf") + return (out.to(torch.bfloat16), out, max_logits, orig_lse) + + +def ref_sparse_attn_decode( + p: TestParam, t: TestcaseForDecode +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference implementation of sparse decoding attention in PyTorch + """ + assert p.h_kv == 1 + assert p.decode is not None + b = p.decode.b + + def process_kv_scope(kv_scope: KVScope) -> Tuple[torch.Tensor, torch.Tensor]: + assert kv_scope.indices_in_kvcache is not None + topk = kv_scope.indices_in_kvcache.size(-1) + indices_in_kv_cache_fixed = torch.clamp_min( + kv_scope.indices_in_kvcache, 0 + ) # Otherwise torch.index_select will complain + gathered_kv = ( + kv_scope.blocked_k.view(-1, p.d_qk) + .index_select(0, indices_in_kv_cache_fixed.view(-1)) + .view(b, p.s_q, topk, p.d_qk) + ) # [b, s_q, topk, d] + invalid_mask = kv_scope.indices_in_kvcache == -1 + if kv_scope.topk_length is not None: + invalid_mask |= torch.arange(0, topk, device=invalid_mask.device).view( + 1, 1, topk + ).broadcast_to(b, p.s_q, topk) >= kv_scope.topk_length.view(b, 1, 1) + return gathered_kv, invalid_mask + + gathered_kv, invalid_mask = process_kv_scope(t.kv_scope) + if t.extra_kv_scope is not None: + gathered_kv1, invalid_mask1 = process_kv_scope(t.extra_kv_scope) + gathered_kv = torch.cat( + [gathered_kv, gathered_kv1], dim=2 + ) # [b, s_q, topk+extra_topk, d] + invalid_mask = torch.cat( + [invalid_mask, invalid_mask1], dim=2 + ) # [b, s_q, topk+extra_topk] + + # may use more advanced approach + + gathered_kv = gathered_kv.view(b * p.s_q, -1, p.d_qk).float() + gathered_kv[gathered_kv != gathered_kv] = 0.0 + q = t.q.float().view(b * p.s_q, p.h_q, p.d_qk) + attn_weight = q @ gathered_kv.transpose( + -1, -2 + ) # [t.b*t.s_q, t.h_q, topk+extra_topk] + attn_weight *= t.sm_scale + attn_weight[ + invalid_mask.view(b * p.s_q, 1, -1).broadcast_to( + b * p.s_q, p.h_q, invalid_mask.size(-1) + ) + ] = float("-inf") + lse = attn_weight.logsumexp(dim=-1) # [t.b*t.s_q, t.h_q] + attn_weight = torch.exp(attn_weight - lse.unsqueeze(-1)) + output = attn_weight @ gathered_kv[..., : p.d_v] # [t.b*t.s_q, t.h_q, t.dv] + output = output.view(b, p.s_q, p.h_q, p.d_v) + lse = lse.view(b, p.s_q, p.h_q) + + # Attention sink + if t.attn_sink is not None: + output *= ( + 1.0 / (1.0 + torch.exp(t.attn_sink.view(1, 1, p.h_q) - lse)) + ).unsqueeze(-1) + + # Correct for q tokens which has no attendable k + lonely_q_mask = lse == float("-inf") + output[lonely_q_mask.unsqueeze(-1).broadcast_to(b, p.s_q, p.h_q, p.d_v)] = 0.0 + lse[lonely_q_mask] = float("+inf") + + return output.to(torch.bfloat16), lse.transpose(1, 2) diff --git a/python/sglang/srt/function_call/deepseekv4_detector.py b/python/sglang/srt/function_call/deepseekv4_detector.py new file mode 100644 index 000000000000..2bb74ceebf3a --- /dev/null +++ b/python/sglang/srt/function_call/deepseekv4_detector.py @@ -0,0 +1,27 @@ +from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector + + +class DeepSeekV4Detector(DeepSeekV32Detector): + """ + Detector for DeepSeek V4 DSML tool-call format. + + Identical to V3.2 except the outer block wrapper is + ``<|DSML|tool_calls>...`` instead of + ``<|DSML|function_calls>...``. The inner + ``<|DSML|invoke>`` / ``<|DSML|parameter>`` shape is unchanged. + + Example (XML parameters): + ``` + <|DSML|tool_calls> + <|DSML|invoke name="get_weather"> + <|DSML|parameter name="city" string="true">San Francisco + + + ``` + """ + + def __init__(self): + super().__init__() + self.bot_token = "<|DSML|tool_calls>" + self.eot_token = "" + self.function_calls_regex = r"<|DSML|tool_calls>(.*?)" diff --git a/python/sglang/srt/function_call/function_call_parser.py b/python/sglang/srt/function_call/function_call_parser.py index 10d14cc432eb..6f2ac60f091c 100644 --- a/python/sglang/srt/function_call/function_call_parser.py +++ b/python/sglang/srt/function_call/function_call_parser.py @@ -12,6 +12,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import ToolCallItem from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.deepseekv4_detector import DeepSeekV4Detector from sglang.srt.function_call.deepseekv31_detector import DeepSeekV31Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector @@ -48,6 +49,7 @@ class FunctionCallParser: "deepseekv3": DeepSeekV3Detector, "deepseekv31": DeepSeekV31Detector, "deepseekv32": DeepSeekV32Detector, + "deepseekv4": DeepSeekV4Detector, "glm": Glm4MoeDetector, "glm45": Glm4MoeDetector, "glm47": Glm47MoeDetector, diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 246e2554f4b6..f641a3e380b3 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -1,6 +1,8 @@ import logging from typing import TYPE_CHECKING +from sglang.srt.environ import envs + logger = logging.getLogger(__name__) @@ -84,6 +86,22 @@ def create_nsa_backend(runner): return NativeSparseAttnBackend(runner) +@register_attention_backend("compressed") +def create_compressed_backend(runner): + if envs.SGLANG_OPT_DPSK_V4_RADIX.get(): + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DeepseekV4BackendRadix, + ) + + logger.info("Using DeepseekV4BackendRadix for compressed attention backend.") + return DeepseekV4BackendRadix(runner) + else: + from sglang.srt.layers.attention.deepseek_v4_backend import DeepseekV4Backend + + logger.info("Using DeepseekV4Backend for compressed attention backend.") + return DeepseekV4Backend(runner) + + @register_attention_backend("triton") def create_triton_backend(runner): assert not runner.model_config.is_encoder_decoder, ( diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 8d14e32a916b..f59645c6870e 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -57,6 +57,10 @@ def get_cuda_graph_seq_len_fill_value(self): """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" raise NotImplementedError() + # TODO improve naming + def on_after_cuda_graph_warmup_pass(self): + pass + def get_verify_buffers_to_fill_after_draft(self): """ Return buffers of verify attention kernels that needs to be filled after draft. @@ -128,6 +132,7 @@ def forward_decode( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ): """Run a forward for decode.""" raise NotImplementedError() @@ -140,6 +145,7 @@ def forward_extend( layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache: bool = True, + **kwargs, ): """Run a forward for extend.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/compressed/__init__.py b/python/sglang/srt/layers/attention/compressed/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/sglang/srt/layers/attention/compressed/compressor.py b/python/sglang/srt/layers/attention/compressed/compressor.py new file mode 100644 index 000000000000..2b6b5a4cc090 --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/compressor.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Literal, NamedTuple, Optional, Union + +import torch + +from sglang.jit_kernel.deepseek_v4 import ( + CompressorDecodePlan, + CompressorPrefillPlan, + compress_forward, + compress_fused_norm_rope_inplace, + triton_create_paged_compress_data, +) +from sglang.srt.environ import envs +from sglang.srt.layers.attention.nsa.quant_k_cache_v4 import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.layers.attention.nsa.triton_kernel import act_quant + +if TYPE_CHECKING: + from sglang.srt.layers.attention.compressed.metadata import DeepseekV4Metadata + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.models.deepseek_v4 import Compressor, DeepseekRefRMSNorm + + +class FusedCompressMetadata(NamedTuple): + write_loc: torch.Tensor # shape [num_q_tokens] + extra_data: Optional[torch.Tensor] # shape [num_q_tokens] or None + plan: Union[CompressorDecodePlan, CompressorPrefillPlan] + + def copy_(self, other: FusedCompressMetadata) -> None: + from .metadata import maybe_copy_inplace + + self.write_loc.copy_(other.write_loc) + maybe_copy_inplace(self.extra_data, src=other.extra_data) + self.plan.copy_(other.plan) # type: ignore + + +class CompressorBackend: + def __init__(self): + super().__init__() + self.forward_metadata: DeepseekV4Metadata + + def get_paged_compress_metadata(self, compress_ratio: int) -> FusedCompressMetadata: + attr_name = f"c{compress_ratio}_compress_metadata" + metadata = getattr(self.forward_metadata, attr_name) + assert isinstance(metadata, FusedCompressMetadata) + return metadata + + def forward_compress( + self, + *, + kv_score_buffer: torch.Tensor, + kv_score_input: torch.Tensor, + ape: torch.Tensor, + head_dim: int, + norm: DeepseekRefRMSNorm, + freqs_cis_cache: torch.Tensor, + rotate: bool, + forward_batch: ForwardBatch, + compress_ratio: int, + is_paged: bool = False, + ) -> torch.Tensor: + from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation + + assert compress_ratio == 4 or compress_ratio == 128 + if is_paged: + metadata = self.get_paged_compress_metadata(compress_ratio) + coff = 2 if is_overlap_compress(compress_ratio) else 1 + last_dim = 2 * head_dim * coff + assert kv_score_buffer.shape[-1] == last_dim + kv_score_buffer = kv_score_buffer.view(-1, compress_ratio, last_dim) + else: + plan = make_compressor_plan(compress_ratio, forward_batch) + metadata = (forward_batch.req_pool_indices.to(torch.int32), None, plan) + indices, extra_data, plan = metadata + + # NOTE: shape [num_q_tokens, head_dim] + kv_compressed = compress_forward( + kv_score_buffer=kv_score_buffer, + kv_score_input=kv_score_input, + ape=ape, + indices=indices, + plan=plan, + compress_ratio=compress_ratio, + head_dim=head_dim, + extra_data=extra_data, + ) + compress_fused_norm_rope_inplace( + kv_compressed, + norm.weight, + norm.eps, + freqs_cis_cache, + plan, + ) + return rotate_activation(kv_compressed) if rotate else kv_compressed + + def forward_core_compressor( + self, + x: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + compressor: Compressor, + ) -> None: + # NOTE: this function will 1. compress kv 2. store to kv pool + if forward_batch.forward_mode.is_idle(): + return + token_to_kv_pool = forward_batch.token_to_kv_pool + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + new_compressed_kv = compressor(x, forward_batch) + core_metadata = self.forward_metadata.core_metadata + out_loc = ( + core_metadata.c4_out_loc + if compressor.ratio == 4 + else core_metadata.c128_out_loc + ) + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + token_to_kv_pool.set_extra_key_buffer_fused( + layer_id=layer_id, + loc=out_loc, + cache_k=new_compressed_kv, + ) + else: + pack = quant_to_nope_fp8_rope_bf16_pack_triton(new_compressed_kv.bfloat16()) + token_to_kv_pool.set_extra_key_buffer(layer_id, out_loc, pack) + + def forward_indexer_compressor( + self, + x: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + compressor: Compressor, + ) -> None: + assert is_overlap_compress(compressor.ratio) + token_to_kv_pool = forward_batch.token_to_kv_pool + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + new_compressed_kv = compressor(x, forward_batch) + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + token_to_kv_pool.set_index_k_fused( + layer_id=layer_id, + loc=self.forward_metadata.core_metadata.c4_out_loc, + cache_k=new_compressed_kv, + ) + else: + new_compressed_kv_fp8, new_compressed_kv_scale = act_quant( + new_compressed_kv + ) + token_to_kv_pool.set_index_k_scale_buffer( + layer_id=layer_id, + loc=self.forward_metadata.core_metadata.c4_out_loc, + index_k=new_compressed_kv_fp8, + index_k_scale=new_compressed_kv_scale, + ) + + +def is_overlap_compress(compress_ratio: int) -> bool: + return compress_ratio == 4 + + +def make_compressor_plan( + compress_ratio: Literal[4, 128], + forward_batch: ForwardBatch, +) -> Union[CompressorDecodePlan, CompressorPrefillPlan]: + if forward_batch.forward_mode.is_decode(): + seq_lens_32 = forward_batch.seq_lens.to(torch.int32) + return CompressorDecodePlan(compress_ratio, seq_lens_32) + if forward_batch.forward_mode.is_prefill(): + assert not forward_batch.forward_mode.is_target_verify() + extend_lens_list = forward_batch.extend_seq_lens_cpu + seq_lens_cpu = forward_batch.seq_lens_cpu + assert extend_lens_list is not None and seq_lens_cpu is not None + return CompressorPrefillPlan.generate( + compress_ratio=compress_ratio, + num_q_tokens=sum(extend_lens_list), + seq_lens=seq_lens_cpu, + extend_lens=torch.tensor(extend_lens_list), + device=forward_batch.seq_lens.device, + ) + elif forward_batch.forward_mode.is_target_verify(): + raise NotImplementedError("target verify mode to be implemented") + else: + raise NotImplementedError(f"unsupported mode {forward_batch.forward_mode=}") + + +def create_paged_compressor_data( + compress_ratio: Literal[4, 128], + *, + is_prefill: bool, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_lens: Optional[torch.Tensor] = None, + seq_lens_cpu: Optional[List[int]] = None, + extend_lens_cpu: Optional[List[int]] = None, + use_prefill_cuda_graph: bool = False, + num_q_tokens: Optional[int] = None, +) -> FusedCompressMetadata: + # TODO(dark): remove this hardcode + swa_page_size = token_to_kv_pool.swa_page_size + ring_size = token_to_kv_pool.get_ring_size(compress_ratio=compress_ratio) + assert ring_size % compress_ratio == 0 + + def clip_down(positions: torch.Tensor) -> torch.Tensor: + return positions // compress_ratio * compress_ratio + + def get_raw_loc(positions: torch.Tensor) -> torch.Tensor: + # NOTE: special case for overlap, we will handle it properly in kernel + positions = positions.masked_fill(positions < 0, 0) + loc = req_to_token[req_pool_indices, positions] + swa_loc = token_to_kv_pool.translate_loc_from_full_to_swa(loc) + swa_pages = swa_loc // swa_page_size + state_loc = swa_pages * ring_size + swa_loc % ring_size + if envs.SGLANG_OPT_DEBUG_PAGED_COMPRESS.get(): + assert torch.all(state_loc % compress_ratio == 0) + return (state_loc // compress_ratio).to(torch.int32) + + is_overlap = is_overlap_compress(compress_ratio) + + # NOTE(dark): + # the spec of the following extra data is highly coupled with kernel implementation + # DO NOT modify it unless you know what you are doing + if is_prefill: + assert extend_lens is not None + if envs.SGLANG_OPT_TRITON_PREPARE_COMPRESS.get(): + write_loc, extra_data = triton_create_paged_compress_data( + compress_ratio=compress_ratio, + is_overlap=is_overlap, + swa_page_size=swa_page_size, + ring_size=ring_size, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_seq_lens=extend_lens, + req_to_token=req_to_token, + full_to_swa_index_mapping=token_to_kv_pool.full_to_swa_index_mapping, + ) + else: + prefix_lens = seq_lens - extend_lens + write_positions = clip_down(seq_lens - 1) + load_positions = clip_down(prefix_lens - 1) + write_loc = get_raw_loc(write_positions) + load_loc = get_raw_loc(load_positions) + if envs.SGLANG_OPT_DEBUG_PAGED_COMPRESS.get(): + assert torch.all(prefix_lens >= 0) + if is_overlap: + load_overlap_loc = get_raw_loc(load_positions - compress_ratio) + write_overlap_loc = get_raw_loc(write_positions - compress_ratio) + extra_data = torch.stack( + [ + load_overlap_loc, + load_loc, + write_overlap_loc, + write_positions.to(torch.int32), + ], + dim=1, + ) + else: + extra_data = load_loc + + plan_kwargs: dict + if seq_lens_cpu is None: + assert num_q_tokens is not None + plan_kwargs = dict( + num_q_tokens=num_q_tokens, + seq_lens=seq_lens, + extend_lens=extend_lens, + ) + else: + assert extend_lens_cpu is not None + plan_kwargs = dict( + num_q_tokens=sum(extend_lens_cpu), + seq_lens=torch.tensor(seq_lens_cpu), + extend_lens=torch.tensor(extend_lens_cpu), + ) + plan = CompressorPrefillPlan.generate( + compress_ratio=compress_ratio, + device=seq_lens.device, + use_cuda_graph=use_prefill_cuda_graph, + **plan_kwargs, + ) + else: + write_positions = clip_down(seq_lens - 1) + write_loc = get_raw_loc(write_positions) + if is_overlap: + write_overlap_loc = get_raw_loc(write_positions - compress_ratio) + extra_data = write_overlap_loc.view(-1, 1) + else: + extra_data = None + plan = CompressorDecodePlan(compress_ratio, seq_lens.to(torch.int32)) + + return FusedCompressMetadata(write_loc=write_loc, extra_data=extra_data, plan=plan) diff --git a/python/sglang/srt/layers/attention/compressed/indexer.py b/python/sglang/srt/layers/attention/compressed/indexer.py new file mode 100644 index 000000000000..93abea7d478b --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/indexer.py @@ -0,0 +1,605 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from sglang.jit_kernel.deepseek_v4 import topk_transform_512 +from sglang.srt.environ import envs +from sglang.srt.layers.attention.compressed.metadata import ( + PagedCoreMetadata, + PagedIndexerMetadata, +) +from sglang.srt.layers.attention.indexer_topk_capturer import ( + get_global_indexer_capturer, +) +from sglang.srt.layers.attention.nsa.triton_kernel import act_quant +from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp +from sglang.srt.utils import is_hip + +if TYPE_CHECKING: + from sglang.srt.layers.attention.compressed.compressor import CompressorBackend + from sglang.srt.layers.attention.compressed.metadata import DeepseekV4Metadata + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + from sglang.srt.models.deepseek_v4 import C4Indexer + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +# if is_hip(): +if is_fp8_fnuz(): + FP8_DTYPE = torch.float8_e4m3fnuz + # FP8_MAX = torch.finfo(FP8_DTYPE).max + FP8_MAX = 224.0 +else: + FP8_DTYPE = torch.float8_e4m3fn + FP8_MAX = torch.finfo(FP8_DTYPE).max + + +def fp8_paged_mqa_logits_torch( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + + assert head_dim == 128, "TODO" + assert block_size == 64, "TODO" + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + for i in range(batch_size): + q = q_fp8[i, 0] # (num_heads, head_dim) + q = q.to(torch.float32) + q_scale = weight[i] # (num_heads) + seq_len = int(seq_lens[i].item()) + assert seq_len <= max_seq_len + num_pages = (seq_len + block_size - 1) // block_size + padded_seq_len = num_pages * block_size + pages = page_table[i, :num_pages] # (num_pages,) + kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + kvcache = kvcache_fp8[pages] # (num_pages, block_size * (head_dim + 4)) + SCALE_OFFSET = block_size * head_dim + kvcache_value = kvcache[..., :SCALE_OFFSET].view(dtype=FP8_DTYPE) + kvcache_scale = kvcache[..., SCALE_OFFSET:].view(dtype=torch.float32) + kvcache_value = kvcache_value.to(torch.float32) + kvcache_scale = kvcache_scale.contiguous() + kvcache_value = kvcache_value.view(padded_seq_len, head_dim) + kvcache_scale = kvcache_scale.view(padded_seq_len) + score = F.linear(kvcache_value, q) + score = F.relu(score) + score *= q_scale[None, :] + score = score.sum(dim=1) # (padded_seq_len,) + score *= kvcache_scale + logits[i, :seq_len] = score[:seq_len] + + return logits + + +# def fp8_paged_mqa_logits_torch( +# q_fp8: torch.Tensor, +# kvcache_fp8: torch.Tensor, +# weight: torch.Tensor, +# seq_lens: torch.Tensor, +# page_table: torch.Tensor, +# deep_gemm_metadata: Any, +# max_seq_len: int, +# clean_logits: bool = True, +# ) -> torch.Tensor: +# """ +# Vectorized PyTorch implementation of fp8_paged_mqa_logits. +# Processes all batches in parallel without Python for loops. +# """ +# _ = deep_gemm_metadata +# batch_size, _, num_heads, head_dim = q_fp8.shape +# block_size = kvcache_fp8.shape[1] +# device = q_fp8.device + +# assert head_dim == 128, "TODO" +# assert block_size == 64, "TODO" +# assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) +# assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) +# assert weight.shape == (batch_size, num_heads) +# assert seq_lens.shape == (batch_size,) +# assert page_table.shape[0] == batch_size +# assert clean_logits == False + +# # Prepare q: (batch_size, num_heads, head_dim) +# q = q_fp8[:, 0].to(torch.float32) # (batch_size, num_heads, head_dim) + +# # Calculate number of pages per batch element +# num_pages_per_batch = (seq_lens + block_size - 1) // block_size # (batch_size,) +# max_num_pages = int( +# num_pages_per_batch.max().item() +# ) # Single sync, outside main computation + +# # Padded seq len for each batch +# padded_seq_lens = num_pages_per_batch * block_size # (batch_size,) +# max_padded_seq_len = max_num_pages * block_size + +# # Reshape kvcache for gathering +# # Original: (num_blocks_total, block_size, 1, head_dim + 4) +# # Reshape to: (num_blocks_total, block_size * (head_dim + 4)) +# kvcache_flat = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + +# # Gather pages for all batches: page_table[:, :max_num_pages] +# # Shape: (batch_size, max_num_pages) +# pages = page_table[:, :max_num_pages] + +# # Gather kvcache for all batches +# # Shape: (batch_size, max_num_pages, block_size * (head_dim + 4)) +# gathered_kvcache = kvcache_flat[pages] + +# # Split into values and scales +# SCALE_OFFSET = block_size * head_dim +# # Shape: (batch_size, max_num_pages, block_size * head_dim) +# kvcache_value_flat = gathered_kvcache[..., :SCALE_OFFSET] +# # Shape: (batch_size, max_num_pages, block_size * 4) -> scales are 4 bytes per position +# kvcache_scale_flat = gathered_kvcache[..., SCALE_OFFSET:] + +# # Convert FP8 values to float32 +# kvcache_value_fp8 = kvcache_value_flat.view(dtype=FP8_DTYPE) +# kvcache_value = kvcache_value_fp8.to(torch.float32) +# # Reshape to (batch_size, max_padded_seq_len, head_dim) +# kvcache_value = kvcache_value.view(batch_size, max_padded_seq_len, head_dim) + +# # Convert scales to float32 +# kvcache_scale = kvcache_scale_flat.view(dtype=torch.float32) +# # Reshape to (batch_size, max_padded_seq_len) +# kvcache_scale = kvcache_scale.reshape(batch_size, max_padded_seq_len) + +# # Compute attention scores: kvcache_value @ q^T +# # kvcache_value: (batch_size, max_padded_seq_len, head_dim) +# # q: (batch_size, num_heads, head_dim) +# # score: (batch_size, max_padded_seq_len, num_heads) +# score = torch.bmm(kvcache_value, q.transpose(1, 2)) + +# # Apply ReLU +# score = F.relu(score) + +# # Multiply by weight (q_scale): (batch_size, num_heads) +# # score: (batch_size, max_padded_seq_len, num_heads) +# score = score * weight.unsqueeze(1) + +# # Sum over heads: (batch_size, max_padded_seq_len) +# score = score.sum(dim=2) + +# # Multiply by kvcache_scale: (batch_size, max_padded_seq_len) +# score = score * kvcache_scale + +# # Create output logits with proper masking +# logits = torch.full( +# (batch_size, max_seq_len), +# float("-inf"), # or 0.0 depending on requirements +# dtype=torch.float32, +# device=device, +# ) + +# # Create position indices for masking +# positions = torch.arange(max_seq_len, device=device).unsqueeze( +# 0 +# ) # (1, max_seq_len) +# valid_mask = positions < seq_lens.unsqueeze(1) # (batch_size, max_seq_len) + +# # Copy valid scores to logits +# # We need to handle the case where max_padded_seq_len might differ from max_seq_len +# copy_len = min(max_padded_seq_len, max_seq_len) +# logits[:, :copy_len] = torch.where( +# valid_mask[:, :copy_len], score[:, :copy_len], logits[:, :copy_len] +# ) + +# return logits + + +# Vectorized version (faster but uses more memory) - for AMD/HIP +def topk_transform_512_pytorch_vectorized( + scores: torch.Tensor, + seq_lens: torch.Tensor, + page_tables: torch.Tensor, + out_page_indices: torch.Tensor, + page_size: int, + out_raw_indices: Optional[torch.Tensor] = None, +) -> None: + """ + Vectorized PyTorch fallback for topk_transform_512. + Faster than the loop version but may use more memory. + """ + + TOPK = 512 + batch_size = scores.shape[0] + max_seq_len = scores.shape[1] + device = scores.device + + page_bits = (page_size - 1).bit_length() if page_size > 1 else 0 + page_mask = page_size - 1 + + # Create mask for valid positions based on seq_lens + positions = ( + torch.arange(max_seq_len, device=device).unsqueeze(0).expand(batch_size, -1) + ) + valid_mask = positions < seq_lens.unsqueeze(1) + + # Mask out invalid positions with -inf + masked_scores = scores.clone() + masked_scores[~valid_mask] = float("-inf") + + # Get top-k indices + actual_k = min(TOPK, max_seq_len) + _, raw_indices = torch.topk( + masked_scores, k=actual_k, dim=1, largest=True, sorted=False + ) + raw_indices = raw_indices.to(torch.int32) + + # Pad raw_indices to TOPK size if needed + if actual_k < TOPK: + padding = torch.zeros( + (batch_size, TOPK - actual_k), dtype=torch.int32, device=device + ) + raw_indices = torch.cat([raw_indices, padding], dim=1) + + # Check which indices are valid + batch_indices = ( + torch.arange(batch_size, device=device).unsqueeze(1).expand(-1, TOPK) + ) + gathered_scores = scores[ + batch_indices.flatten(), raw_indices.clamp(min=0).flatten() + ].view(batch_size, TOPK) + + valid_topk = gathered_scores != float("-inf") + if actual_k < TOPK: + pad_mask = torch.arange(TOPK, device=device).unsqueeze(0) >= actual_k + valid_topk = valid_topk & ~pad_mask + + # For short sequences, use sequential indices + needs_sequential = seq_lens <= TOPK + if needs_sequential.any(): + sequential_indices = ( + torch.arange(TOPK, device=device, dtype=torch.int32) + .unsqueeze(0) + .expand(batch_size, -1) + ) + sequential_valid = sequential_indices < seq_lens.unsqueeze(1) + + raw_indices = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), + torch.where( + sequential_valid, + sequential_indices, + torch.tensor(-1, device=device, dtype=torch.int32), + ), + raw_indices, + ) + valid_topk = torch.where( + needs_sequential.unsqueeze(1).expand(-1, TOPK), sequential_valid, valid_topk + ) + + # Transform to page indices + page_idx = raw_indices >> page_bits + offset_in_page = raw_indices & page_mask + + page_idx_clamped = torch.clamp(page_idx, min=0) + physical_pages = torch.gather(page_tables, dim=1, index=page_idx_clamped.long()) + + page_indices = (physical_pages << page_bits) | offset_in_page + page_indices = page_indices.to(torch.int32) + + page_indices = torch.where( + valid_topk, page_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) + + out_page_indices.copy_(page_indices) + + if out_raw_indices is not None: + raw_indices = torch.where( + valid_topk, raw_indices, torch.tensor(-1, device=device, dtype=torch.int32) + ) + out_raw_indices.copy_(raw_indices) + + +@triton.jit +def _fused_scale_kernel( + weight_ptr, # [B, H] + q_scale_ptr, # [B, H, 1] + out_ptr, # [B, H, 1] + numel, # B * H + out_scale, # scalar + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < numel + + w = tl.load(weight_ptr + offs, mask=mask) + qs = tl.load(q_scale_ptr + offs, mask=mask) + + # Compute in fp32 for better numerical stability, then cast back. + acc = w.to(tl.float32) * out_scale * qs.to(tl.float32) + tl.store(out_ptr + offs, acc.to(out_ptr.dtype.element_ty), mask=mask) + + +def fused_scale( + weight: torch.Tensor, + out_scale: float, + q_scale: torch.Tensor, +) -> torch.Tensor: + """ + Triton version of: + weight.unsqueeze(-1) * out_scale * q_scale + + Args: + weight: [B, H], contiguous + q_scale: [B, H, 1], contiguous + out_scale: Python float / scalar + + Returns: + out: [B, H, 1] + """ + assert weight.is_contiguous() and q_scale.is_contiguous() + B, H = weight.shape + numel = B * H + out_dtype = torch.promote_types(weight.dtype, q_scale.dtype) + out = torch.empty((B, H, 1), device=weight.device, dtype=out_dtype) + BLOCK = 1024 + grid = (triton.cdiv(numel, BLOCK),) + _fused_scale_kernel[grid]( + weight, + q_scale, + out, + numel, + out_scale, + BLOCK=BLOCK, + ) + return out + + +class C4IndexerBackend: + def __init__(self): + super().__init__() + self.forward_metadata: DeepseekV4Metadata + self.debug_use_external_c4_sparse_indices: bool = False + + # this method should be type method + # see srt/layers/attention/compressed/compressor.py + + def _forward_prepare_multi_stream( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + positions: torch.Tensor, + forward_batch: ForwardBatch, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + x_for_compressor: Optional[torch.Tensor] = None, + alt_streams: Optional[List[torch.cuda.Stream]] = None, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if TYPE_CHECKING: + assert isinstance(self, CompressorBackend) + + assert alt_streams is not None + assert len(alt_streams) >= 2 + current_stream = torch.cuda.current_stream() + stream_q = alt_streams[0] + stream_weights = alt_streams[1] + + stream_q.wait_stream(current_stream) + stream_weights.wait_stream(current_stream) + + # main stream + self.forward_indexer_compressor( + x=x_for_compressor if (is_nsa_enable_prefill_cp() and x_for_compressor is not None) else x, + forward_batch=forward_batch, + layer_id=c4_indexer.layer_id, + compressor=c4_indexer.compressor, + ) + c4_indexer_kv_cache = token_to_kv_pool.get_index_k_with_scale_buffer( + layer_id=c4_indexer.layer_id, + ) + + # alt stream 0: compute q + with torch.cuda.stream(stream_q): + if q_lora_ready is not None: + stream_q.wait_event(q_lora_ready) + q = c4_indexer.compute_q(q_lora, positions=positions) + q_fp8, q_scale = act_quant(q) + q_scale_ready = stream_q.record_event() + + # alt stream 1: compute weights + with torch.cuda.stream(stream_weights): + weights = c4_indexer.compute_weights(x, skip_scale=True) + stream_weights.wait_event(q_scale_ready) + weights = fused_scale(weights, c4_indexer.weight_scale, q_scale) + + current_stream.wait_stream(stream_q) + current_stream.wait_stream(stream_weights) + + return q_fp8, weights, c4_indexer_kv_cache + + def _forward_prepare_normal( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + positions: torch.Tensor, + forward_batch: ForwardBatch, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + x_for_compressor: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if TYPE_CHECKING: + assert isinstance(self, CompressorBackend) + + q = c4_indexer.compute_q(q_lora, positions=positions) + q_fp8, q_scale = act_quant(q) + weights = c4_indexer.compute_weights(x, skip_scale=True) + weights = fused_scale(weights, c4_indexer.weight_scale, q_scale) + self.forward_indexer_compressor( + x=x_for_compressor if (is_nsa_enable_prefill_cp() and x_for_compressor is not None) else x, + forward_batch=forward_batch, + layer_id=c4_indexer.layer_id, + compressor=c4_indexer.compressor, + ) + c4_indexer_kv_cache = token_to_kv_pool.get_index_k_with_scale_buffer( + layer_id=c4_indexer.layer_id, + ) + return q_fp8, weights, c4_indexer_kv_cache + + def forward_c4_indexer( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + c4_indexer: C4Indexer, + forward_batch: ForwardBatch, + x_for_compressor: Optional[torch.Tensor] = None, + alt_streams: Optional[List[torch.cuda.Stream]] = None, + enable_multi_stream: bool = False, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> None: + if forward_batch.forward_mode.is_idle(): + return + token_to_kv_pool = forward_batch.token_to_kv_pool + + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + assert isinstance(self, CompressorBackend) + + metadata = self.forward_metadata + indexer_metadata = metadata.indexer_metadata + core_metadata = metadata.core_metadata + + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DSV4AttnMetadataRadix, + ) + + assert isinstance(core_metadata, (PagedCoreMetadata, DSV4AttnMetadataRadix)) + assert isinstance(indexer_metadata, PagedIndexerMetadata) + + _x_comp = x_for_compressor if (is_nsa_enable_prefill_cp() and x_for_compressor is not None) else x + if enable_multi_stream: + q_fp8, weights, c4_indexer_kv_cache = self._forward_prepare_multi_stream( + x=x, + q_lora=q_lora, + c4_indexer=c4_indexer, + positions=core_metadata.positions, + forward_batch=forward_batch, + token_to_kv_pool=token_to_kv_pool, + x_for_compressor=_x_comp, + alt_streams=alt_streams, + q_lora_ready=q_lora_ready, + ) + else: + assert q_lora_ready is None + q_fp8, weights, c4_indexer_kv_cache = self._forward_prepare_normal( + x=x, + q_lora=q_lora, + c4_indexer=c4_indexer, + positions=core_metadata.positions, + forward_batch=forward_batch, + token_to_kv_pool=token_to_kv_pool, + x_for_compressor=_x_comp, + ) + + assert len(q_fp8.shape) == 3 + q_fp8 = q_fp8.unsqueeze(1) # the next_n dim is 1 now + assert len(c4_indexer_kv_cache.shape) == 2 + block_kv = 64 + num_heads_kv = 1 + head_dim_with_sf = 132 + + # DeepGEMM#280 does not change test_attention.py for fp8_paged_mqa_logits, thus + c4_indexer_kv_cache = c4_indexer_kv_cache.view( + c4_indexer_kv_cache.shape[0], block_kv, num_heads_kv, head_dim_with_sf + ) + assert len(weights.shape) == 3 + weights = weights.squeeze(2) + # CUDA path: use deep_gemm + if envs.SGLANG_OPT_USE_TILELANG_INDEXER.get(): + from sglang.srt.layers.attention.nsa.tilelang_kernel import ( + tilelang_fp8_paged_mqa_logits as fn, + ) + # elif is_hip(): + elif envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + fn = fp8_paged_mqa_logits_torch + else: + if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1: + from sglang.srt.layers.deep_gemm_wrapper.paged_mqa_logits import ( + fp8_paged_mqa_logits_chunked as fn, + ) + else: + from deep_gemm import fp8_paged_mqa_logits as fn + + logits = fn( + q_fp8, + c4_indexer_kv_cache, + weights, + indexer_metadata.c4_seq_lens, + indexer_metadata.page_table, + indexer_metadata.deep_gemm_metadata, + indexer_metadata.max_seq_len, + False, + ) + + assert indexer_metadata.page_table is core_metadata.page_table + if self.debug_use_external_c4_sparse_indices: + return # skip updating page indices + + indexer_capturer = get_global_indexer_capturer() + capture_enabled = indexer_capturer.is_enabled() + + raw_indices = None + if capture_enabled or forward_batch.hisparse_coordinator is not None: + raw_indices = torch.empty_like(core_metadata.c4_sparse_page_indices) + + if envs.SGLANG_TOPK_TRANSFORM_512_TORCH.get(): + topk_transform_512_pytorch_vectorized( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + raw_indices, + ) + else: + topk_transform_512( + logits, + indexer_metadata.c4_seq_lens, + core_metadata.page_table, + core_metadata.c4_sparse_page_indices, + indexer_metadata.c4_page_size, + raw_indices, + ) + + if forward_batch.hisparse_coordinator is not None: + if forward_batch.forward_mode.is_decode(): + # todo hisparse: to coordinate with kernel signature + core_metadata.c4_sparse_page_indices = ( + forward_batch.hisparse_coordinator.get_front_topk_tokens( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + raw_indices, + ) + ) + else: + core_metadata.c4_sparse_page_indices = token_to_kv_pool.c4_kv_pool.translate_loc_from_compressed_to_hisparse_device( + core_metadata.c4_sparse_page_indices + ) + + if capture_enabled: + compress_layer_id = token_to_kv_pool.layer_mapping[ + c4_indexer.layer_id + ].compress_layer_id + indexer_capturer.capture(compress_layer_id, raw_indices) diff --git a/python/sglang/srt/layers/attention/compressed/metadata.py b/python/sglang/srt/layers/attention/compressed/metadata.py new file mode 100644 index 000000000000..7ab742940177 --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/metadata.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass, field, fields +from typing import Any, TYPE_CHECKING, List, Optional + +import torch + +from sglang.srt.environ import envs +from sglang.srt.utils import is_hip + +if TYPE_CHECKING: + from flash_mla.flash_mla_interface import FlashMLASchedMeta + + +""" +Some comments on the common terms used in DeepSeekV4Backend: + +topk_lengths: + NOTE: TL;DR: topk_lengths == seq_lens + The FlashMLA sparse decode kernel will attend to `k` tokens for each query. + `topk_lengths` indicates how many tokens each query will attend to. + This should be named as `seq_lens`, but we simply follow the naming convention. + +page_table: + The page table indicates which pages each request is assigned to. + Each value in the page table is the page index in the TokenToKVPool. + This page index is irrelevant to the actual `page_size`. + +page_indices: + The real indices used to index into the KV cache. + This can be computed from the `page_table` and `page_size`. + e.g. page_indices[i, j] = page_table[i, j // page_size] * page_size + (j % page_size) + For sparse C4 top-512 attention, the indices will be selected from the C4 page indices. + In implementation, we don't materialize the full C4 `page_indices`, + but calculate them from `page_table` on-the-fly in the attention kernel. + +positions: + The position of the last token for each request. + For compress token, the positions must be times of compress ratio. + For example, for C4, raw_position=11 will trigger a compression, + But the RoPE's position, during compression, must be 8 instead of 11. + +Some other notes: + c4_ / c128_: means "compressed by 4" / "compressed by 128". + c4_page_size: page_size // 4 + c4_seq_lens: seq_lens // 4, but bounded by at least 1, due to flash_mla requirement. + c4_sparse: means "compressed by 4" but only attend to top-512 tokens. + all related length will be clipped to 512. +""" + + +def copy_metadata( + *, + src, + dst, + check_eq_fields: List[str], + copy_fields: List[str], + assign_fields: Optional[List[str]] = None, +): + assign_fields = assign_fields or [] + + for field_name in check_eq_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + assert src_val == dst_val, f"{field_name=} {src_val=} {dst_val=}" + + for field_name in copy_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + assert dst_val is not None, f"{field_name=} {src_val=} {dst_val=}" + dst_val.copy_(src_val) + + for field_name in assign_fields: + setattr(dst, field_name, getattr(src, field_name)) + + provided_fields = check_eq_fields + copy_fields + assign_fields + assert len(provided_fields) == len( + set(provided_fields) + ), f"{provided_fields=} has dup" + all_fields = {f.name for f in fields(src)} + assert set(provided_fields) == all_fields, f"{provided_fields=} {all_fields=}" + + +def create_flashmla_metadata(): + # if is_hip(): + if os.environ.get("SGLANG_HACK_FLASHMLA_BACKEND") == "torch" or is_hip(): + return None + else: + import flash_mla + + return flash_mla.get_mla_metadata()[0] + + +@dataclass +class CoreMetadata: + positions: torch.Tensor # needed for sliding window and others + # NOTE: swa_out_loc only applies to indices that needs to be written + # to the swa_kv_pool. For prefill, we will take a slicing Tensor + # that selects the k/v values that needs to be written. + swa_slice: Optional[torch.Tensor] + swa_out_loc_sliced: torch.Tensor + # NOTE: c4/c128 out_loc will mask the invalid write locations to 0. + # When no compression happens, out_loc will be 0, which is the "padded slot" + c4_out_loc: torch.Tensor + c128_out_loc: torch.Tensor + + def init_swa_slice(self, swa_slice: torch.Tensor): + assert self.swa_slice is None, "can only update once" + self.swa_slice = swa_slice + self.swa_out_loc_sliced = self.swa_out_loc_sliced[swa_slice] + + def copy_(self, other): + raise NotImplementedError + + +@dataclass +class IndexerMetadata: + def copy_(self, other): + raise NotImplementedError + + +@dataclass +class PagedIndexerMetadata(IndexerMetadata): + page_size: int + page_table: torch.Tensor + c4_seq_lens: torch.Tensor + deep_gemm_metadata: Any = field(init=False, repr=False) + + def __post_init__(self): + # if is_hip(): + if envs.SGLANG_FP8_PAGED_MQA_LOGITS_TORCH.get(): + # For HIP/ROCm, we don't need deep_gemm metadata + # Will use aiter's deepgemm_fp8_paged_mqa_logits instead + self.deep_gemm_metadata = None + else: + import deep_gemm + + if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1: + from sglang.srt.layers.deep_gemm_wrapper.paged_mqa_logits import ( + get_paged_mqa_logits_metadata_chunked as get_paged_mqa_logits_metadata, + ) + elif envs.SGLANG_OPT_USE_JIT_INDEXER_METADATA.get(): + from sglang.jit_kernel.deepseek_v4 import get_paged_mqa_logits_metadata + else: + from deep_gemm import get_paged_mqa_logits_metadata + + self.deep_gemm_metadata = get_paged_mqa_logits_metadata( + self.c4_seq_lens.to(torch.int32), + self.c4_page_size, + deep_gemm.get_num_sms(), + ) + + if envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() != -1: + pass + else: + # It is a tensor, thus our CUDA graph replay copy will be easier (just copy it) + assert isinstance(self.deep_gemm_metadata, torch.Tensor) + + assert self.page_size == 256 + + @property + def c4_page_size(self) -> int: + return self.page_size // 4 + + @property + def max_seq_len(self) -> int: + return self.page_table.shape[1] * self.page_size + + def copy_(self, other: "PagedIndexerMetadata"): + if is_hip(): + # HIP/ROCm: don't copy deep_gemm_metadata (it's None) + copy_fields = ["page_table", "c4_seq_lens"] + else: + # CUDA: original behavior + copy_fields = ["page_table", "c4_seq_lens", "deep_gemm_metadata"] + + copy_metadata( + src=other, + dst=self, + check_eq_fields=["page_size"], + copy_fields=copy_fields, + ) + + +@dataclass +class PagedCoreMetadata(CoreMetadata): + page_table: torch.Tensor + # sliding window attention (core) + swa_page_indices: torch.Tensor # at most (sum_qo_len, 128) + swa_topk_lengths: torch.Tensor # clipped to 128 + # C128 dense attention (core) + c128_page_indices: torch.Tensor + c128_topk_lengths_clamp1: torch.Tensor + # C4 sparse attention (core) + c4_topk_lengths_raw: torch.Tensor + c4_topk_lengths_clamp1: torch.Tensor # i.e. c4_seq_lens + c4_sparse_topk: int # must be 512 + c4_sparse_topk_lengths: torch.Tensor = field(init=False) # clipped to 512 + c4_sparse_page_indices: torch.Tensor = field(init=False) # (bs, 512) + # FlashMLA + c1_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c4_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c128_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + + def get_flashmla_metadata(self, compress_ratio: int): + if compress_ratio == 0: + return self.c1_flashmla_metadata + elif compress_ratio == 4: + return self.c4_flashmla_metadata + elif compress_ratio == 128: + return self.c128_flashmla_metadata + else: + raise ValueError(f"invalid {compress_ratio=}") + + def __post_init__(self): + assert self.c4_sparse_topk == 512 + self.c4_sparse_topk_lengths = torch.clamp( + self.c4_topk_lengths_clamp1, max=self.c4_sparse_topk + ) + self.c4_sparse_page_indices = torch.full( + (self.c4_topk_lengths_clamp1.size(0), self.c4_sparse_topk), + -1, + dtype=torch.int32, + device=self.c4_topk_lengths_clamp1.device, + ) + self.c1_flashmla_metadata = create_flashmla_metadata() + self.c4_flashmla_metadata = create_flashmla_metadata() + self.c128_flashmla_metadata = create_flashmla_metadata() + + def copy_(self, other: PagedCoreMetadata) -> None: + copy_metadata( + src=other, + dst=self, + check_eq_fields=["c4_sparse_topk", "swa_slice"], + copy_fields=[ + "positions", + "swa_out_loc_sliced", + "c4_out_loc", + "c128_out_loc", + "page_table", + "swa_page_indices", + "swa_topk_lengths", + "c128_page_indices", + "c128_topk_lengths_clamp1", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c4_sparse_topk_lengths", + "c4_sparse_page_indices", + ], + assign_fields=[ + # For the new API, the metadata has the following lifecycle: + # + # Graph capture warmup forward pass: + # (ignore, we will reset to brand new object after such passes) + # + # Graph capture real-capture forward pass: + # * Layer 0: Set python & tensor objects to metadata + # * Layer >=1: Read them from metadata + # + # Graph replay: + # * Layer 0: The kernels are in "generate metadata" mode + # * Layer >=1: The kernels are in "non-generate metadata" mode + # + # Thus this field can be ignored. + # However, to allow running replay w/o in real cuda graph, we do an assignment. + # (Do we really need that? If no, we can change this field to skip-copy mode) + "c1_flashmla_metadata", + "c4_flashmla_metadata", + "c128_flashmla_metadata", + ], + ) + + +# TODO: implement the ragged metadata + + +@dataclass +class RaggedCoreMetadata(CoreMetadata): + swa_ragged_indices: torch.Tensor + swa_c4_ragged_indices: torch.Tensor + swa_c128_ragged_indices: torch.Tensor + + +@dataclass +class RaggedIndexerMetadata(IndexerMetadata): + c4_k_start: torch.Tensor + c4_k_finish: torch.Tensor + + +@dataclass +class DeepseekV4Metadata: + core_metadata: CoreMetadata + indexer_metadata: IndexerMetadata + debug_seq_lens_expanded: torch.Tensor + + def copy_(self, other: "DeepseekV4Metadata"): + self.core_metadata.copy_(other.core_metadata) + self.indexer_metadata.copy_(other.indexer_metadata) + + +def maybe_copy_inplace(dst, *, src) -> None: + assert type(src) == type(dst) + if dst is not None: + dst.copy_(src) diff --git a/python/sglang/srt/layers/attention/compressed/paged_prefill.py b/python/sglang/srt/layers/attention/compressed/paged_prefill.py new file mode 100644 index 000000000000..29c3c16f37c7 --- /dev/null +++ b/python/sglang/srt/layers/attention/compressed/paged_prefill.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, List, Tuple + +import torch + +from sglang.jit_kernel.deepseek_v4 import tilelang_make_swa_prefill_indices +from sglang.srt.environ import envs +from sglang.srt.layers.attention.nsa import index_buf_accessor_v4 +from sglang.srt.layers.attention.nsa.quant_k_cache_v4 import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.utils import ceil_align + +if TYPE_CHECKING: + from sglang.srt.layers.attention.compressed.metadata import PagedCoreMetadata + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +_HOST_INT32_KWARGS: Dict = dict(dtype=torch.int32, device="cpu", pin_memory=True) + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +def expand_seq_lens( + *, + seq_lens: List[int], + extend_seq_lens: List[int], + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + num_tokens = sum(extend_seq_lens) + seq_lens_expanded = torch.empty(num_tokens, **_HOST_INT32_KWARGS) + expanded_idx_to_unexpanded_idx = torch.empty(num_tokens, **_HOST_INT32_KWARGS) + offset = 0 + for i, (kv_len, qo_len) in enumerate(zip(seq_lens, extend_seq_lens)): + out = seq_lens_expanded[offset : offset + qo_len] + offset += qo_len + torch.arange(kv_len - qo_len + 1, kv_len + 1, out=out) + expanded_idx_to_unexpanded_idx[offset - qo_len : offset].fill_(i) + return ( + seq_lens_expanded.to(device, non_blocking=True), + expanded_idx_to_unexpanded_idx.to(device, non_blocking=True), + ) + + +# NOTE: about the ring buffer layout: +# TODO(dark): add doc + + +def make_swa_ring_buffer_indices( + forward_batch: ForwardBatch, + device: torch.device, + *, + max_seq_len: int, + swa_window_size: int, +) -> torch.Tensor: + SWA_WINDOW = swa_window_size + extend_num_tokens = forward_batch.extend_num_tokens + assert extend_num_tokens is not None + if envs.SGLANG_OPT_USE_TILELANG_SWA_PREPARE.get(): + seq_lens = forward_batch.seq_lens + extend_lens = forward_batch.extend_seq_lens + assert extend_lens is not None + seq_lens_k = seq_lens.to(torch.int32) + seq_lens_q = extend_lens.to(torch.int32) + swa_indices = torch.empty( + (extend_num_tokens, SWA_WINDOW), device=device, dtype=torch.int32 + ) + return tilelang_make_swa_prefill_indices( + seq_lens_k=seq_lens_k, + seq_lens_q=seq_lens_q, + swa_indices=swa_indices, + ) + seq_lens = forward_batch.seq_lens_cpu + extend_lens = forward_batch.extend_seq_lens_cpu + assert seq_lens is not None and extend_lens is not None + batch_size = len(seq_lens) + num_tokens = extend_num_tokens + swa_indices = torch.full((num_tokens, swa_window_size), -1, **_HOST_INT32_KWARGS) + cum_qo_len = 0 + abs_pos_buf = torch.arange(max_seq_len, dtype=torch.int32) + for seq_idx, (kv_len, qo_len) in enumerate(zip(seq_lens.tolist(), extend_lens)): + # already existing KV + old_kv_start = seq_idx * SWA_WINDOW + # newly computed KV + new_kv_start = batch_size * SWA_WINDOW + cum_qo_len + prefix_len = kv_len - qo_len + for curr_seq_qo_idx in range(qo_len): + # layout | prefix_len (cached) | qo_len | + # | 0 ... prefix_len-1 | prefix_len ... kv_len-1 | + # + # Step 1: compute chosen_abs_positions - absolute positions to look at for this specific query token + end_abs_pos = prefix_len + curr_seq_qo_idx + 1 + start_abs_pos = max(end_abs_pos - SWA_WINDOW, 0) + chosen_abs_positions = abs_pos_buf[start_abs_pos:end_abs_pos] + # Step 2: compute swa_indices + # For one abs_pos in chosen_abs_positions, the swa_indices will be: + # 1. abs_pos < prefix_len -> old_kv_start + abs_pos % SWA_WINDOW + # 2. abs_pos >= prefix_len -> new_kv_start + (abs_pos - prefix_len) + torch.where( + chosen_abs_positions < prefix_len, + old_kv_start + chosen_abs_positions % SWA_WINDOW, + new_kv_start + (chosen_abs_positions - prefix_len), + out=swa_indices[ + cum_qo_len + curr_seq_qo_idx, : end_abs_pos - start_abs_pos + ], + ) + cum_qo_len += qo_len + return swa_indices.to(device, non_blocking=True) + + +def prepare_swa_ring_buffer_cache( + swa_k: torch.Tensor, + forward_batch: ForwardBatch, + layer_id: int, + token_to_kv_pool: DeepSeekV4TokenToKVPool, + core_metadata: PagedCoreMetadata, + debug_dump_hook: Any, +) -> Tuple[torch.Tensor, index_buf_accessor_v4.NopeFp8RopeBf16Pack]: + # Quick example: A prefill batch, with: + # * request 0/1: generates 3 token, cache_len = 5 (i.e. seq_len = 5 + 3 = 8) + # * req 2: generate 2 token, cache_len=1000 + # + # Then, the temporary KV Cache has: + # * indices [0, 128) = request 0 window. + # abs_pos=0,1,2,3,4 <-> pool_idx=0,1,2,3,4 + # pool_idx=5...127 contains no data + # * indices [128, 256) = request 1 window. + # abs_pos=0,1,2,3,4 <-> pool_idx=128,...,132 + # pool_idx=128...255 contains no data + # * indices [256, 384) = request 2 window. + # it is a ring buffer, thus abs_pos % 128 = pool_idx_inside_the_block + # abs_pos=0...,781 <-> no valid corresponding pool idx + # abs_pos=872...,895 <-> pool_idx=360,...,383 + # abs_pos=896,...,999 <-> pool_idx=256,...,359 + # * indices [384, 387) = request 0 newly gen 3 kv token + # abs_pos=5,6,7 <-> pool_idx=384,...,386 + # * indices [387, 390) = request 1 newly gen 3 kv token + # abs_pos=5,6,7 <-> pool_idx=387,...,389 + # * indices [390, 392) = request 2 newly gen 2 kv token + # abs_pos=1000,1001 <-> pool_idx=390,391 + + pool_swa_k_cache = token_to_kv_pool.get_swa_key_buffer(layer_id) + num_pool_pages = forward_batch.batch_size + num_newly_gen_tokens, _ = swa_k.shape + + swa_kv_pool = token_to_kv_pool.swa_kv_pool + swa_page_size = swa_kv_pool.page_size + assert swa_page_size == 128 + effective_swa_k_cache = swa_kv_pool.create_buffer( + num_pages=num_pool_pages + ceil_align(num_newly_gen_tokens, swa_page_size), + ) + + # a. SWA data in real kv cache + loc_swa = forward_batch.req_pool_indices + assert loc_swa.shape[0] == forward_batch.batch_size == num_pool_pages + effective_swa_k_cache[:num_pool_pages, :] = pool_swa_k_cache[loc_swa, :].view( + effective_swa_k_cache.dtype + ) + + # b. Newly generated data + swa_k_pack = quant_to_nope_fp8_rope_bf16_pack_triton(swa_k) + offset = num_pool_pages * swa_page_size + loc_newly_gen = torch.arange( + offset, + offset + num_newly_gen_tokens, + device=loc_swa.device, + ) + index_buf_accessor_v4.SetKAndS.execute( + pool=swa_kv_pool, + buf=effective_swa_k_cache, + loc=loc_newly_gen, + nope_fp8_rope_bf16_pack=swa_k_pack, + ) + + if h := debug_dump_hook: + h( + "forward__swa_info", + dict( + loc_swa=loc_swa, + loc_newly_gen=loc_newly_gen, + ), + ) + + return effective_swa_k_cache, swa_k_pack.slice_pack(core_metadata.swa_slice) diff --git a/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py b/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py new file mode 100644 index 000000000000..631fe1a54c75 --- /dev/null +++ b/python/sglang/srt/layers/attention/debug_flash_mla_adapter.py @@ -0,0 +1,190 @@ +from typing import Any, Optional + +import torch + +from sglang.srt.utils import is_hip +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz +FP8_DTYPE = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): + if is_hip(): + # backend == "torch" + import os + + backend = os.environ.get("SGLANG_HACK_FLASHMLA_BACKEND", "kernel") + else: + import flash_mla + + if backend == "comparison": + pack_ref, pack_fast_via_tester = flash_mla_with_kvcache_entrypoint( + backend="torch", **kwargs + ) + pack_fast_via_api = flash_mla_with_kvcache_entrypoint( + backend="kernel", **kwargs + ) + _assert_close(pack_ref=pack_fast_via_tester, pack_fast=pack_fast_via_api) + _assert_close(pack_ref=pack_ref, pack_fast=pack_fast_via_tester) + _assert_close(pack_ref=pack_ref, pack_fast=pack_fast_via_api) + return pack_ref + + if backend == "torch": + return flash_mla_with_kvcache_torch(**kwargs) + + if backend == "kernel": + return flash_mla.flash_mla_with_kvcache(**kwargs) + + raise NotImplementedError + + +def flash_mla_with_kvcache_torch( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: Any, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +): + + from sglang.srt.flashmla_tests import quant as flashmla_quant + from sglang.srt.flashmla_tests.lib import ( + ExtraTestParamForDecode, + KVScope, + TestcaseForDecode, + TestParam, + ) + from sglang.srt.flashmla_tests.ref import ref_sparse_attn_decode + + assert block_table is None + assert cache_seqlens is None + assert is_fp8_kvcache + + b, s_q, h_q, d_qk = q.shape + d_v = head_dim_v + + fp8_layout = flashmla_quant.FP8KVCacheLayout.MODEL1_FP8Sparse + + p = TestParam( + s_q=s_q, + s_kv="unused", + topk="unused", + h_q=h_q, + h_kv=1, + d_qk=d_qk, + d_v=d_v, + decode=ExtraTestParamForDecode( + b=b, + is_varlen="unused", + have_zero_seqlen_k="unused", + extra_s_k="unused", + extra_topk="unused", + extra_block_size="unused", + have_extra_topk_length="unused", + ), + # unused? + seed=-1, + check_correctness=True, + is_all_indices_invalid=False, + num_runs=10, + have_attn_sink=True, + have_topk_length=True, + ) + + blocked_k_quantized = k_cache + blocked_k = flashmla_quant.dequantize_k_cache( + blocked_k_quantized.view(FP8_DTYPE), fp8_layout + ) + # blocked_k_requantized = flashmla_quant.quantize_k_cache(blocked_k, fp8_layout) + # assert torch.testing.assert_allclose(blocked_k_requantized.byte(), blocked_k_quantized.byte()) + kv_scope = KVScope( + t="unused", + cache_seqlens="unused", + block_table="unused", + blocked_k=blocked_k, + blocked_k_quantized=blocked_k_quantized, + abs_indices="unused", + indices_in_kvcache=indices, + topk_length=topk_length, + ) + + extra_kv_scope = None + if extra_k_cache is not None: + extra_blocked_k_quantized = extra_k_cache + extra_blocked_k = flashmla_quant.dequantize_k_cache( + extra_blocked_k_quantized.view(FP8_DTYPE), fp8_layout + ) + # extra_blocked_k_requantized = flashmla_quant.quantize_k_cache(extra_blocked_k, fp8_layout) + # assert torch.testing.assert_allclose(extra_blocked_k_requantized.byte(), extra_blocked_k_quantized.byte()) + extra_kv_scope = KVScope( + t="unused", + cache_seqlens="unused", + block_table="unused", + blocked_k=extra_blocked_k, + blocked_k_quantized=extra_blocked_k_quantized, + abs_indices="unused", + indices_in_kvcache=extra_indices_in_kvcache, + topk_length=extra_topk_length, + ) + + t = TestcaseForDecode( + p="unused", + q=q, + attn_sink=attn_sink, + sm_scale=softmax_scale, + kv_scope=kv_scope, + extra_kv_scope=extra_kv_scope, + ) + # print(f"hi {p=} {t=}") + # print( + # f"hi info " + # f"{get_tensor_info(t.kv_scope.blocked_k)=} " + # f"{get_tensor_info(t.kv_scope.blocked_k_quantized)=} " + # f"{get_tensor_info(t.extra_kv_scope.blocked_k) if t.extra_kv_scope is not None else None=} " + # f"{get_tensor_info(t.extra_kv_scope.blocked_k_quantized) if t.extra_kv_scope is not None else None=} " + # ) + + pack_ref = ref_sparse_attn_decode(p, t) + + # tile_scheduler_metadata, _ = flash_mla.get_mla_metadata() + # pack_fast_via_tester = flashmla_lib.run_flash_mla_decode( + # p, t, tile_scheduler_metadata, num_splits=None + # ) + + # return pack_ref, pack_fast_via_tester + return pack_ref + + +def _assert_close(pack_ref, pack_fast): + import sglang.srt.flashmla_tests.kernelkit as kk + + out_ref, lse_ref = pack_ref + out_fast, lse_fast = pack_fast + + # the copied threshold is too strict, not checked why + # copied from: test_flash_mla_sparse_decoding.py + # is_out_correct = kk.check_is_allclose( + # "out", out_fast, out_ref, abs_tol=1e-3, rel_tol=2.01 / 128, cos_diff_tol=5e-6 + # ) + # is_lse_correct = kk.check_is_allclose( + # "lse", lse_fast, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536 + # ) + + # loosen thresh + is_out_correct = kk.check_is_allclose( + "out", out_fast, out_ref, abs_tol=1e-2, rel_tol=10.0, cos_diff_tol=5e-6 + ) + is_lse_correct = kk.check_is_allclose( + "lse", lse_fast, lse_ref, abs_tol=1e-6, rel_tol=8.01 / 65536 + ) + + assert is_out_correct and is_lse_correct, f"{is_out_correct=} {is_lse_correct=}" diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend.py b/python/sglang/srt/layers/attention/deepseek_v4_backend.py new file mode 100644 index 000000000000..3df43d464ac6 --- /dev/null +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend.py @@ -0,0 +1,591 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, TypeVar + +import torch +import torch.nn.functional as F + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.debug_flash_mla_adapter import ( + flash_mla_with_kvcache_entrypoint, +) +from sglang.srt.layers.attention.nsa.quant_k_cache_v4 import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.spec_info import SpecInput +from sglang.srt.utils import ceil_align + +# components +from .compressed import paged_prefill +from .compressed.compressor import CompressorBackend +from .compressed.indexer import C4IndexerBackend +from .compressed.metadata import ( + DeepseekV4Metadata, + PagedCoreMetadata, + PagedIndexerMetadata, + create_flashmla_metadata, +) + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool + from sglang.srt.model_executor.model_runner import ModelRunner + + +SWA_WINDOW = 128 +C4_TOPK = 512 +PAGE_INDEX_ALIGNED_SIZE = 64 + + +_HOST_INT32_KWARGS = {"dtype": torch.int32, "pin_memory": True} + + +@dataclass +class _DecodeCudaGraphSharedData: + pass # TODO fields + + +T = TypeVar("T", bound=Optional[torch.Tensor]) + + +def _pad_last_dim(x: T, multiples_of: int = PAGE_INDEX_ALIGNED_SIZE) -> T: + if x is None: + return None # type: ignore + curr_size = x.shape[-1] + target_size = ceil_align(curr_size, multiples_of) + return F.pad(x, pad=(0, target_size - curr_size), mode="constant", value=-1) + + +class DeepseekV4Backend(AttentionBackend, C4IndexerBackend, CompressorBackend): + def __init__( + self, + model_runner: ModelRunner, + ): + super().__init__() + self.device = torch.device(model_runner.device) # type: ignore + head_dim = model_runner.model_config.head_dim + assert head_dim == 512 + self.softmax_scale: float = head_dim**-0.5 + self.head_dim_v: int = model_runner.model_config.v_head_dim + self.cuda_int32_kwargs = {"device": self.device, "dtype": torch.int32} + self.host_int32_kwargs = _HOST_INT32_KWARGS + self.debug_dump_hook: Optional[Callable] = None + self.swa_page_size = 128 + assert model_runner.page_size is not None + assert model_runner.req_to_token_pool is not None + self.page_size = model_runner.page_size + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.max_seq_len_for_capture = self.req_to_token.shape[1] + assert self.page_size == 256, "the system hardcodes page_size=256" + + #### Public API #### + + def init_forward_metadata(self, forward_batch: ForwardBatch): + + req_pool_indices = forward_batch.req_pool_indices + seq_lens = forward_batch.seq_lens.to(torch.int32) + batch_size = forward_batch.batch_size + seq_lens_cpu = forward_batch.seq_lens_cpu + assert forward_batch.req_to_token_pool.req_to_token is self.req_to_token + + assert self.swa_page_size % SWA_WINDOW == 0 and self.page_size % 128 == 0 + assert seq_lens_cpu is not None + max_seq_len = int(seq_lens_cpu.max().item()) + + if forward_batch.forward_mode.is_decode_or_idle(): + metadata = self._compute_decode_metadata( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=forward_batch.out_cache_loc, + ) + elif forward_batch.forward_mode.is_prefill(): + metadata = self._compute_prefill_metadata( + max_seq_len=max_seq_len, + forward_batch=forward_batch, + ) + else: + raise NotImplementedError(f"unsupported mode {forward_batch.forward_mode=}") + + # set metadata + self.forward_metadata = metadata + if h := self.debug_dump_hook: + h("init_forward_metadata_output", metadata) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + self.decode_cuda_graph_shared_data = _DecodeCudaGraphSharedData() + self.decode_cuda_graph_metadata_of_bs: Dict[int, DeepseekV4Metadata] = {} + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ): + assert req_pool_indices.size(0) == bs + assert seq_lens.size(0) == bs + + if forward_mode.is_decode_or_idle(): + # NOTE: we should use `self.decode_cuda_graph_shared_data` to avoid allocating + # a pack of tensors per cuda graph, but that is the NEXT step instead of current step. + # For example, we may write: + # + # metadata = compute_decode_metadata() + # use_shared_tensors(metadata, self.decode_cuda_graph_shared_data) + # + # def use_shared_tensors(): + # for field_name in ...: + # getattr(shared_data, field_name).copy_(getattr(metadata, field_name)[..maybe_some_slicing..]) + # setattr(metadata, field_name, getattr(shared_data, field_name)) + + metadata = self._compute_decode_metadata( + max_seq_len=self.max_seq_len_for_capture, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + # Dummy value + out_cache_loc=torch.zeros_like(seq_lens), + ) + + self.decode_cuda_graph_metadata_of_bs[bs] = metadata + self.forward_metadata = metadata + else: + raise NotImplementedError(f"unsupported mode {forward_mode=}") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + actual_forward_mode: Optional[ForwardMode] = None, + ): + # We observe error that len(out_cache_loc)=0 while len(seq_lens)>0. + # We only support DP attention, thus when IDLE, we will not execute attention backend, + # thus it is safe to delete it. + if actual_forward_mode == ForwardMode.IDLE: + if hasattr(self, "forward_metadata"): + del self.forward_metadata # avoid misuse + return + + assert seq_lens_cpu is not None and out_cache_loc is not None + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + req_pool_indices = req_pool_indices[:bs] + + if forward_mode.is_decode_or_idle(): + # Future optimization: use real max seq len + actual_max_seq_len = seq_lens_cpu.max().item() + + chosen_max_seq_len = self.max_seq_len_for_capture + assert actual_max_seq_len <= chosen_max_seq_len + + assert len(out_cache_loc.shape) == 1, f"{out_cache_loc.shape=}" + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, bs - len(out_cache_loc)), + mode="constant", + value=0, + ) + + temp_metadata = self._compute_decode_metadata( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc_padded, + ) + + # Future optimization: may not need to `copy` all things, + # But only copy partially such as `page_table[:, :max_seq_len]` + chosen_metadata = self.decode_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + else: + raise NotImplementedError(f"unsupported mode {forward_mode=}") + + def get_cuda_graph_seq_len_fill_value(self): + # FlashMLA, NSA backend, etc, use "1" + return 1 + + # TODO improve naming + def on_after_cuda_graph_warmup_pass(self): + metadata = self.forward_metadata + if isinstance(metadata.core_metadata, PagedCoreMetadata): + metadata.core_metadata.c1_flashmla_metadata = create_flashmla_metadata() + metadata.core_metadata.c4_flashmla_metadata = create_flashmla_metadata() + metadata.core_metadata.c128_flashmla_metadata = create_flashmla_metadata() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + *, + compress_ratio: Literal[0, 4, 128], + attn_sink: Optional[torch.Tensor] = None, + **_, + ) -> torch.Tensor: + + # NOTE: here set-kv only applies to swa kv + + assert k is v, "DeepseekV4 shares k and v" + swa_k = k + + layer_id = layer.layer_id + metadata = self.forward_metadata + core_metadata = metadata.core_metadata + token_to_kv_pool = forward_batch.token_to_kv_pool + if TYPE_CHECKING: + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + # This sanity check is to avoid, e.g., in CUDA graph capturing, we may accidentally + # run forward passes multiple times with one init_forward_metadata. + # If that happens, the real capturing pass will record that layer 0 do not have any meta init operations + # which is wrong. + + assert isinstance(core_metadata, PagedCoreMetadata), "TODO: support ragged" + # ------- 1. SWA attention k cache ------- + if forward_batch.forward_mode.is_prefill(): + # prefill is complex: concat kv and rearrange + swa_k_cache, swa_k_pack_sliced = ( + paged_prefill.prepare_swa_ring_buffer_cache( + swa_k, + forward_batch, + layer_id, + token_to_kv_pool, + core_metadata, + debug_dump_hook=self.debug_dump_hook, + ) + ) + else: + # decode is trivial: no slicing, no rearrangement + swa_k_cache = token_to_kv_pool.get_swa_key_buffer(layer_id) + swa_k_pack_sliced = quant_to_nope_fp8_rope_bf16_pack_triton(swa_k) + + if save_kv_cache: + token_to_kv_pool.set_swa_key_buffer( + layer_id=layer_id, + loc=core_metadata.swa_out_loc_sliced, + cache_nope_fp8_rope_bf16_pack=swa_k_pack_sliced, + ) + + # ------- 2. Full (C4/C128) attention k cache ------- + extra_k_cache, extra_indices, extra_topk_lengths = None, None, None + if compress_ratio == 4: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_metadata.c4_sparse_page_indices + extra_topk_lengths = core_metadata.c4_sparse_topk_lengths + elif compress_ratio == 128: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_metadata.c128_page_indices + extra_topk_lengths = core_metadata.c128_topk_lengths_clamp1 + + # ------- Call attention core ------- + swa_window_size = token_to_kv_pool.swa_window_size + assert swa_k_cache.ndim == 2 + # view b/c flashmla expect dim=4 + # reference: FlashMLA/tests/test_flash_mla_sparse_prefill.py + k_cache_total_dim = token_to_kv_pool.swa_kv_pool.kv_cache_total_dim + swa_k_cache = swa_k_cache[:, : swa_window_size * k_cache_total_dim].view( + swa_k_cache.shape[0], swa_window_size, 1, k_cache_total_dim + ) + + if extra_k_cache is not None: + page_sizes = { + 4: token_to_kv_pool.page_size // 4, + 128: token_to_kv_pool.page_size // 128, + } + extra_k_cache = extra_k_cache[ + :, : page_sizes[compress_ratio] * k_cache_total_dim + ].view( + extra_k_cache.shape[0], + page_sizes[compress_ratio], + 1, + k_cache_total_dim, + ) + + swa_page_indices = core_metadata.swa_page_indices + + # unsqueeze to adapt decode kernel + if q.ndim == 3: + q = q.unsqueeze(1) + if swa_page_indices.ndim == 2: + swa_page_indices = swa_page_indices.unsqueeze(1) + if extra_indices is not None and extra_indices.ndim == 2: + extra_indices = extra_indices.unsqueeze(1) + + assert attn_sink is not None + + flashmla_metadata = core_metadata.get_flashmla_metadata(compress_ratio) + + # compute-sanitizer observe issue if this is not enforced + assert ( + swa_page_indices.shape[-1] % 64 == 0 + ), f"{swa_page_indices.shape[-1]=} is not aligned to 64" + if extra_indices is not None: + assert ( + extra_indices.shape[-1] % 64 == 0 + ), f"{extra_indices.shape[-1]=} is not aligned to 64" + + input_dict = dict( + q=q, + k_cache=swa_k_cache, + head_dim_v=self.head_dim_v, + block_table=None, + cache_seqlens=None, + tile_scheduler_metadata=flashmla_metadata, + softmax_scale=self.softmax_scale, + is_fp8_kvcache=True, + indices=swa_page_indices, + topk_length=core_metadata.swa_topk_lengths, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_lengths, + ) + + backend = os.environ.get("SGLANG_HACK_FLASHMLA_BACKEND", "kernel") + o = flash_mla_with_kvcache_entrypoint(**input_dict, backend=backend)[0] + o = o.squeeze(1) + + return o + + #### Helper functions #### + + def _compute_prefill_metadata( + self, + *, + max_seq_len: int, + forward_batch: ForwardBatch, + extend_seq_lens_cpu: Optional[List[int]] = None, + ) -> DeepseekV4Metadata: + seq_lens_cpu = forward_batch.seq_lens_cpu + extend_seq_lens_cpu = extend_seq_lens_cpu or forward_batch.extend_seq_lens_cpu + assert seq_lens_cpu is not None and extend_seq_lens_cpu is not None + # NOTE: expanded follow a `causal` mask pattern + seq_lens_expanded, idx_mapping = paged_prefill.expand_seq_lens( + seq_lens=seq_lens_cpu.tolist(), + extend_seq_lens=extend_seq_lens_cpu, + device=self.device, + ) + core_metadata = self._make_paged_core_metadata( + req_to_token=self.req_to_token, + req_pool_indices=forward_batch.req_pool_indices[idx_mapping], + seq_lens=seq_lens_expanded, + max_seq_len=max_seq_len, + out_loc=forward_batch.out_cache_loc, + is_prefill=True, + forward_batch=forward_batch, + ) + # NOTE: `raw` does not follow a `causal` mask pattern + seq_lens_raw_expanded = forward_batch.seq_lens[idx_mapping] + should_store_swa = (seq_lens_raw_expanded - seq_lens_expanded) < SWA_WINDOW + swa_slice = torch.nonzero(should_store_swa, as_tuple=False).squeeze(1) + core_metadata.init_swa_slice(swa_slice) + indexer_metadata = self._make_indexer_metadata(core_metadata) + return DeepseekV4Metadata(core_metadata, indexer_metadata, seq_lens_expanded) + + def _compute_decode_metadata( + self, + *, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + ) -> DeepseekV4Metadata: + assert ( + req_pool_indices.shape[0] == seq_lens.shape[0] == out_cache_loc.shape[0] + ), f"{req_pool_indices.shape=} {seq_lens.shape=} {out_cache_loc.shape=}" + core_metadata = self._make_paged_core_metadata( + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + out_loc=out_cache_loc, + forward_batch=None, # not prefill + ) + indexer_metadata = self._make_indexer_metadata(core_metadata) + return DeepseekV4Metadata(core_metadata, indexer_metadata, seq_lens) + + def _make_indexer_metadata(self, core_metadata: PagedCoreMetadata): + # TODO: handle the expanded seqlens for MTP here + return PagedIndexerMetadata( + page_size=self.page_size, + page_table=core_metadata.page_table, + # NOTE should use `raw` instead of `clamp1` + c4_seq_lens=core_metadata.c4_topk_lengths_raw, + ) + + def _make_paged_compress_tensors( + self, + *, + page_table: torch.Tensor, + page_size: int, + seq_lens: torch.Tensor, + out_loc: torch.Tensor, + compress_ratio: Literal[4, 128], + ) -> Dict[str, torch.Tensor]: + # NOTE(dark): c_ prefix means "compressed" + assert page_table.dim() == 2 + assert out_loc.shape == seq_lens.shape, f"{out_loc.shape=} {seq_lens.shape=}" + + # e.g. seq_lens = [4n - 1, 4n, 4n + 1, 4n + 2] + # raw_out_loc = [4X + 2, 4X + 3, 4Y, 4Y + 1] + # raw_positions = [4n - 2, 4n - 1, 4n, 4n + 1] (i.e. seq_lens - 1) + # then we have: + # c4_seq_lens = [n - 1 , n , n , n ] (i.e. seq_lens // 4) + # c4_out_loc = [0 , X , 0 , 0 ] (i.e. out_loc // 4) + # NOTE: 0 means "any" in this example + should_compress = seq_lens % compress_ratio == 0 + c_page_size = page_size // compress_ratio + c_seq_lens_raw = seq_lens // compress_ratio + c_out_loc = torch.where(should_compress, out_loc // compress_ratio, 0) + c_seq_lens_clamp1 = torch.clamp(c_seq_lens_raw, min=1) + + # NOTE(dark): c4 does not need page indices + if compress_ratio == 4: + return { + "c_out_loc": c_out_loc, + "c_seq_lens_raw": c_seq_lens_raw, + "c_seq_lens_clamp1": c_seq_lens_clamp1, + } + + max_pages = page_table.size(1) + c_max_seq_len = c_page_size * max_pages + # [bs, max_pages] -> [bs, max_pages, c_page_size] -> [bs, c_max_seq_len] + c_offsets = torch.arange(c_max_seq_len, **self.cuda_int32_kwargs) + c_page_indices = ( + (page_table.unsqueeze(2) * c_page_size + c_offsets[:c_page_size]) + .to(torch.int32) + .contiguous() + .view(-1, c_max_seq_len) + ) + # TODO(dark): whether this is a must + # As far as I know, only the padded 0 -> 1 must be filled with -1 + # Should other positions also be masked? + mask = c_offsets.unsqueeze(0) >= c_seq_lens_raw.unsqueeze(1) + # NOTE: mask out the extra positions to -1 + c_page_indices.masked_fill_(mask, -1) + return { + "c_out_loc": c_out_loc, + "c_seq_lens_raw": c_seq_lens_raw, + "c_seq_lens_clamp1": c_seq_lens_clamp1, + "c_page_indices": c_page_indices, + } + + def _make_paged_core_metadata( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + out_loc: torch.Tensor, + # extra args for prefill + is_prefill: bool = False, + forward_batch: Optional[ForwardBatch] = None, + skip_compressor: bool = False, + ) -> PagedCoreMetadata: + assert self.swa_page_size == SWA_WINDOW # TODO(dark): relax this + + # -------------------- START compute SWA metadata -------------------- + swa_pages = req_pool_indices.to(torch.int32) + if is_prefill: + assert forward_batch is not None + swa_page_indices = paged_prefill.make_swa_ring_buffer_indices( + forward_batch=forward_batch, + device=self.device, + max_seq_len=max_seq_len, + swa_window_size=SWA_WINDOW, + ) + else: + # NOTE: for decode, we directly index into the ring buffer pool + # the "page_mapping" for SWA is the req_pool_indices themselves + offsets = torch.arange(SWA_WINDOW, **self.cuda_int32_kwargs) + swa_page_indices = swa_pages.unsqueeze(1) * self.swa_page_size + offsets + # if seq_len < 128, mask out the extra positions to -1 + mask = offsets.unsqueeze(0) >= seq_lens.unsqueeze(1) + swa_page_indices.masked_fill_(mask, -1) + + positions = seq_lens - 1 + swa_topk_lengths = torch.clamp(seq_lens, max=SWA_WINDOW) + swa_out_loc = swa_pages * self.swa_page_size + positions % self.swa_page_size + + # -------------------- END compute SWA metadata -------------------- + + if not skip_compressor: + page_table = req_to_token[req_pool_indices, : max_seq_len : self.page_size] + page_table = page_table.to(torch.int32) // self.page_size + c4_data = self._make_paged_compress_tensors( + page_table=page_table, + page_size=self.page_size, + seq_lens=seq_lens, + out_loc=out_loc, + compress_ratio=4, + ) + c128_data = self._make_paged_compress_tensors( + page_table=page_table, + page_size=self.page_size, + seq_lens=seq_lens, + out_loc=out_loc, + compress_ratio=128, + ) + c128_page_indices = c128_data["c_page_indices"] + swa_page_indices = _pad_last_dim( + swa_page_indices, multiples_of=PAGE_INDEX_ALIGNED_SIZE + ) + c128_page_indices = _pad_last_dim( + c128_page_indices, multiples_of=PAGE_INDEX_ALIGNED_SIZE + ) + else: + # TODO: For draft decode/draft extend + c4_data = { + "c_out_loc": None, + "c_seq_lens_raw": None, + "c_seq_lens_clamp1": None, + "c_page_indices": None, + } + c128_data = { + "c_out_loc": None, + "c_seq_lens_raw": None, + "c_seq_lens_clamp1": None, + "c_page_indices": None, + } + c128_page_indices = c128_data["c_page_indices"] + + return PagedCoreMetadata( + positions=positions, + page_table=page_table, + swa_page_indices=swa_page_indices, + swa_topk_lengths=swa_topk_lengths, + c4_out_loc=c4_data["c_out_loc"], + c4_topk_lengths_raw=c4_data["c_seq_lens_raw"], + c4_topk_lengths_clamp1=c4_data["c_seq_lens_clamp1"], + c128_out_loc=c128_data["c_out_loc"], + c128_page_indices=c128_page_indices, + c128_topk_lengths_clamp1=c128_data["c_seq_lens_clamp1"], + c4_sparse_topk=C4_TOPK, + swa_slice=None, + swa_out_loc_sliced=swa_out_loc, + ) + + #### Test-only API #### + + def extract_metadata(self, forward_batch: ForwardBatch) -> DeepseekV4Metadata: + # NOTE: in the future we may put metadata in the forward_batch itself + # this function is used for tests. Don't delete it. + return self.forward_metadata diff --git a/python/sglang/srt/layers/attention/deepseek_v4_backend_radix.py b/python/sglang/srt/layers/attention/deepseek_v4_backend_radix.py new file mode 100644 index 000000000000..e8a366fb00f1 --- /dev/null +++ b/python/sglang/srt/layers/attention/deepseek_v4_backend_radix.py @@ -0,0 +1,1317 @@ +""" +Some comments on the common terms used in DeepSeekV4Backend: + +topk_lengths: + NOTE: TL;DR: topk_lengths == seq_lens + The FlashMLA sparse decode kernel will attend to `k` tokens for each query. + `topk_lengths` indicates how many tokens each query will attend to. + This should be named as `seq_lens`, but we simply follow the naming convention. + +page_table: + The page table indicates which pages each request is assigned to. + Each value in the page table is the page index in the TokenToKVPool. + This page index is irrelevant to the actual `page_size`. + +page_indices: + The real indices used to index into the KV cache. + This can be computed from the `page_table` and `page_size`. + e.g. page_indices[i, j] = page_table[i, j // page_size] * page_size + (j % page_size) + For sparse C4 top-512 attention, the indices will be selected from the C4 page indices. + In implementation, we don't materialize the full C4 `page_indices`, + but calculate them from `page_table` on-the-fly in the attention kernel. + +positions: + The position of the last token for each request. + For compress token, the positions must be times of compress ratio. + For example, for C4, raw_position=11 will trigger a compression, + But the RoPE's position, during compression, must be 8 instead of 11. + +Some other notes: + c4_ / c128_: means "compressed by 4" / "compressed by 128". + c4_page_size: page_size // 4 + c4_seq_lens: seq_lens // 4, but bounded by at least 1, due to flash_mla requirement. + c4_sparse: means "compressed by 4" but only attend to top-512 tokens. + all related length will be clipped to 512. +""" + +from __future__ import annotations + +import dataclasses +import functools +import warnings +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from sglang.srt.environ import envs +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.compressed.compressor import ( + CompressorBackend, + FusedCompressMetadata, + create_paged_compressor_data, +) +from sglang.srt.layers.attention.compressed.indexer import C4IndexerBackend +from sglang.srt.layers.attention.compressed.metadata import ( + PagedIndexerMetadata, + maybe_copy_inplace, +) +from sglang.srt.layers.attention.debug_flash_mla_adapter import ( + flash_mla_with_kvcache_entrypoint, +) +from sglang.srt.layers.attention.deepseek_v4_backend import _pad_last_dim +from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_round_robin_split +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size +from sglang.srt.layers.attention.nsa.quant_k_cache_v4 import ( + quant_to_nope_fp8_rope_bf16_pack_triton, +) +from sglang.srt.layers.attention.triton_ops.compressed_metadata import ( + init_compressed_metadata as _init_compressed_metadata_triton, +) +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.spec_info import SpecInput + +if TYPE_CHECKING: + from flash_mla.flash_mla_interface import FlashMLASchedMeta + + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.model_runner import ModelRunner + +SWA_WINDOW = 128 +C4_TOPK = 512 +PAGE_INDEX_ALIGNED_SIZE = 64 + + +def _copy_metadata( + src, + dst, + check_eq_fields: List[str], + copy_fields: List[str], + assign_fields: Optional[List[str]] = None, +): + assign_fields = assign_fields or [] + + for field_name in check_eq_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + assert src_val == dst_val, f"{field_name=} {src_val=} {dst_val=}" + + for field_name in copy_fields: + src_val = getattr(src, field_name) + dst_val = getattr(dst, field_name) + # Skip if both src and dst are None (e.g., compress fields when need_compress=False) + if src_val is None and dst_val is None: + continue + assert dst_val is not None, f"{field_name=} {src_val=} {dst_val=}" + if hasattr(dst_val, "copy_"): + dst_val.copy_(src_val) + else: + warnings.warn( + f"{field_name=} {type(dst_val)=} does not have copy_, use setattr" + ) + setattr(dst, field_name, src_val) + + for field_name in assign_fields: + setattr(dst, field_name, getattr(src, field_name)) + + provided_fields = check_eq_fields + copy_fields + assign_fields + provided_fields_unique = set(provided_fields) + assert len(provided_fields) == len( + provided_fields_unique + ), f"{provided_fields=} has dup" + all_fields = {f.name for f in dataclasses.fields(src)} + provided_fields = set(provided_fields) + assert ( + provided_fields == all_fields + ), f"{provided_fields - all_fields=}, {all_fields - provided_fields=}" + + +def _create_flashmla_metadata(): + import flash_mla + + return flash_mla.get_mla_metadata()[0] + + +def _create_dummy_paged_compress_data(compress_ratio: int): + return None + + +@dataclass +class DSV4AttnMetadataRadix: + page_size: int + page_table: torch.Tensor + raw_out_loc: torch.Tensor + cuda_int32_kwargs: dict + + # to calculate compressed metadata + seq_lens_casual: torch.Tensor + positions_casual: torch.Tensor # positions expanded causally + + # sliding window attention (core) + swa_page_indices: torch.Tensor # at most (sum_qo_len, 128) + swa_topk_lengths: torch.Tensor # clipped to 128 + + # NOTE: c4/c128 out_loc will mask the invalid write locations to 0. + # When no compression happens, out_loc will be 0, which is the "padded slot" + c4_sparse_topk: int # must be 512 + c4_out_loc: Optional[torch.Tensor] = None + c4_positions: Optional[torch.Tensor] = None + c4_topk_lengths_raw: Optional[torch.Tensor] = None + c4_topk_lengths_clamp1: Optional[torch.Tensor] = None # i.e. c4_seq_lens + c4_sparse_topk_lengths: torch.Tensor = field(init=False) # clipped to 512 + c4_sparse_page_indices: torch.Tensor = field(init=False) # (bs, 512) + + # C128 dense attention (core) + c128_out_loc: Optional[torch.Tensor] = None + c128_positions: Optional[torch.Tensor] = None + c128_page_indices: Optional[torch.Tensor] = None + c128_topk_lengths_clamp1: Optional[torch.Tensor] = None + + # FlashMLA + c1_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c4_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + c128_flashmla_metadata: FlashMLASchedMeta = field(init=False, repr=False) + + @property + def positions(self) -> torch.Tensor: + return self.positions_casual + + def get_flashmla_metadata(self, compress_ratio: Literal[0, 4, 128]): + if compress_ratio == 0: + return self.c1_flashmla_metadata + elif compress_ratio == 4: + return self.c4_flashmla_metadata + elif compress_ratio == 128: + return self.c128_flashmla_metadata + else: + raise ValueError(f"invalid {compress_ratio=}") + + def copy_(self, other: DSV4AttnMetadataRadix) -> None: + _copy_metadata( + src=other, + dst=self, + check_eq_fields=[ + "c4_sparse_topk", + "page_size", + "cuda_int32_kwargs", + ], + copy_fields=[ + "raw_out_loc", + "seq_lens_casual", + "positions_casual", + "c4_positions", + "c128_positions", + "c4_out_loc", + "c128_out_loc", + "page_table", + "swa_page_indices", + "swa_topk_lengths", + "c128_page_indices", + "c128_topk_lengths_clamp1", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c4_sparse_topk_lengths", + "c4_sparse_page_indices", + ], + assign_fields=[ + # For the new API, the metadata has the following lifecycle: + # + # Graph capture warmup forward pass: + # (ignore, we will reset to brand new object after such passes) + # + # Graph capture real-capture forward pass: + # * Layer 0: Set python & tensor objects to metadata + # * Layer >=1: Read them from metadata + # + # Graph replay: + # * Layer 0: The kernels are in "generate metadata" mode + # * Layer >=1: The kernels are in "non-generate metadata" mode + # + # Thus this field can be ignored. + # However, to allow running replay w/o in real cuda graph, we do an assignment. + # (Do we really need that? If no, we can change this field to skip-copy mode) + "c1_flashmla_metadata", + "c4_flashmla_metadata", + "c128_flashmla_metadata", + ], + ) + + def init_compressed_metadata(self): + """ + Initialize compressed metadata for both C4 and C128 using a single fused Triton kernel. + + NOTE: 0 means "any" in this example + e.g. seq_lens = [4n - 1, 4n, 4n + 1, 4n + 2] + raw_out_loc = [4X + 2, 4X + 3, 4Y, 4Y + 1] + raw_positions = [4n - 2, 4n - 1, 4n, 4n + 1] (i.e. seq_lens - 1) + then we have: + c4_seq_lens = [n - 1 , n , n , n ] (i.e. seq_lens // 4) + c4_positions = [0 , 4n - 4, 0 , 0 ] (i.e. positions // 4 * 4) + c4_out_loc = [0 , X , 0 , 0 ] (i.e. out_loc // 4) + """ + assert self.page_table.dim() == 2 + assert ( + self.raw_out_loc.shape == self.seq_lens_casual.shape + ), f"{self.raw_out_loc.shape=}, {self.seq_lens_casual.shape=}" + + # Compute both C4 and C128 metadata in a single kernel launch + ( + self.c4_out_loc, + self.c4_positions, + self.c4_topk_lengths_raw, + self.c4_topk_lengths_clamp1, + self.c128_out_loc, + self.c128_positions, + self.c128_topk_lengths_clamp1, + self.c128_page_indices, + ) = _init_compressed_metadata_triton( + self.seq_lens_casual, + self.positions_casual, + self.raw_out_loc, + self.page_table, + self.page_size, + compute_page_indices=True, + ) + + self.c128_page_indices = _pad_last_dim(self.c128_page_indices) + self.swa_page_indices = _pad_last_dim(self.swa_page_indices) + + _CP_REINDEX_FIELDS = [ + "seq_lens_casual", + "positions_casual", + "swa_page_indices", + "swa_topk_lengths", + "page_table", + "c4_positions", + "c4_topk_lengths_raw", + "c4_topk_lengths_clamp1", + "c128_positions", + "c128_page_indices", + "c128_topk_lengths_clamp1", + ] + + def apply_cp_reindex(self) -> None: + cp_rank = get_attention_tp_rank() + cp_size = get_attention_tp_size() + idx = slice(cp_rank, None, cp_size) + for field_name in self._CP_REINDEX_FIELDS: + val = getattr(self, field_name, None) + assert isinstance(val, torch.Tensor), f"CP reindex: {field_name} is {type(val)}, expected Tensor" + setattr(self, field_name, val[idx].contiguous()) + + def init_flashmla_related(self): + assert self.c4_sparse_topk == 512 + assert self.c4_topk_lengths_clamp1 is not None + self.c4_sparse_topk_lengths = torch.clamp( + self.c4_topk_lengths_clamp1, max=self.c4_sparse_topk + ) + self.c4_sparse_page_indices = torch.full( + (self.c4_topk_lengths_clamp1.size(0), self.c4_sparse_topk), + -1, + dtype=torch.int32, + device=self.c4_topk_lengths_clamp1.device, + ) + self.c4_sparse_page_indices = _pad_last_dim(self.c4_sparse_page_indices) + self.c1_flashmla_metadata = _create_flashmla_metadata() + self.c4_flashmla_metadata = _create_flashmla_metadata() + self.c128_flashmla_metadata = _create_flashmla_metadata() + + +@dataclass +class DSV4MetadataRadix: + core_attn_metadata: DSV4AttnMetadataRadix + indexer_metadata: Optional[PagedIndexerMetadata] + + c4_compress_metadata: Optional[FusedCompressMetadata] = None + c128_compress_metadata: Optional[FusedCompressMetadata] = None + + @property + def core_metadata(self) -> DSV4AttnMetadataRadix: + return self.core_attn_metadata + + def copy_(self, other: DSV4MetadataRadix): + self.core_attn_metadata.copy_(other.core_attn_metadata) + maybe_copy_inplace(self.indexer_metadata, src=other.indexer_metadata) + maybe_copy_inplace(self.c4_compress_metadata, src=other.c4_compress_metadata) + maybe_copy_inplace( + self.c128_compress_metadata, src=other.c128_compress_metadata + ) + + +@dataclass +class DSV4MetadataSimplified: + req_pool_indices: torch.Tensor + seq_lens: torch.Tensor + out_cache_loc: torch.Tensor + + # Constant tensor for CUDA graph related + extend_seq_lens: Optional[torch.Tensor] = None + real_metadata: Optional[DSV4MetadataRadix] = None + + def copy_(self, other: DSV4MetadataSimplified): + self.req_pool_indices.copy_(other.req_pool_indices) + self.seq_lens.copy_(other.seq_lens) + self.out_cache_loc.copy_(other.out_cache_loc) + + # constant buffer + self.extend_seq_lens = other.extend_seq_lens + + +@dataclass +class _DecodeCudaGraphSharedData: + pass # TODO fields + + +class DeepseekV4BackendRadix(AttentionBackend, C4IndexerBackend, CompressorBackend): + def __init__( + self, + model_runner: ModelRunner, + skip_prefill: bool = False, + speculative_step_id=0, + topk=0, + speculative_num_steps=0, + ): + super().__init__() + self.device = torch.device(model_runner.device) # type: ignore + head_dim = model_runner.model_config.head_dim + assert head_dim == 512 + self.softmax_scale: float = head_dim**-0.5 + self.head_dim_v: int = model_runner.model_config.v_head_dim + self.cuda_int32_kwargs = {"device": self.device, "dtype": torch.int32} + self.debug_dump_hook: Optional[Callable] = None + self.swa_page_size = 128 + assert model_runner.page_size is not None + assert model_runner.req_to_token_pool is not None + self.page_size = model_runner.page_size + assert self.page_size == 256, "the system hardcodes page_size=256" + + # Init Pools + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.token_to_kv_pool: DeepSeekV4TokenToKVPool = model_runner.token_to_kv_pool # type: ignore + self.MAX_SEQ_LEN_FOR_CAPTURE = self.req_to_token.shape[1] + + assert isinstance(self.token_to_kv_pool, DeepSeekV4TokenToKVPool) + + # Speculative Decoding + self.topk = model_runner.server_args.speculative_eagle_topk or 0 + assert self.topk in [0, 1], "MTP Topk > 1 not supported for DeepSeek V4" + self.mtp_enabled = self.topk > 0 + self.speculative_num_steps = speculative_num_steps + self.speculative_num_draft_tokens: int = ( # type: ignore + model_runner.server_args.speculative_num_draft_tokens + ) + self.speculative_step_id = speculative_step_id + self.forward_metadata: Union[DSV4MetadataRadix, DSV4MetadataSimplified] = None + + def _move_to_device(self, x: List[int]) -> torch.Tensor: + # NOTE(dark): need to avoid sync + pin_tensor = torch.tensor(x, dtype=torch.int32, pin_memory=True) + return pin_tensor.to(self.device, non_blocking=True) + + #### Public API #### + + def init_forward_metadata_indexer(self, core_attn_metadata: DSV4AttnMetadataRadix): + return PagedIndexerMetadata( + page_size=self.page_size, + page_table=core_attn_metadata.page_table, + # NOTE should use `raw` instead of `clamp1` + c4_seq_lens=core_attn_metadata.c4_topk_lengths_raw, + ) + + def init_forward_metadata_decode( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + ) -> DSV4MetadataRadix: + assert ( + req_pool_indices.shape[0] == seq_lens.shape[0] == out_cache_loc.shape[0] + ), f"{req_pool_indices.shape=} {seq_lens.shape=} {out_cache_loc.shape=}" + + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices, + seq_lens_casual=seq_lens, + max_seq_len=max_seq_len, + out_loc=out_cache_loc, + need_compress=True, + ) + + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + + create = functools.partial( + create_paged_compressor_data, + is_prefill=False, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) + + if not envs.SGLANG_OPT_USE_FUSED_PAGED_COMPRESS.get(): + create = _create_dummy_paged_compress_data + + return DSV4MetadataRadix( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_prefill( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: List[int], + out_cache_loc: torch.Tensor, + num_tokens: int, + extend_seq_lens: torch.Tensor, + extend_seq_lens_cpu: List[int], + need_compress: bool = True, + use_prefill_cuda_graph: bool = False, + ) -> DSV4MetadataRadix: + seq_lens_casual, req_pool_indices_repeated = self.expand_prefill_casually( + num_tokens=num_tokens, + seq_lens=seq_lens_cpu, + extend_seq_lens=extend_seq_lens_cpu, + req_pool_indices=req_pool_indices, + padded_num_tokens=out_cache_loc.shape[0], + ) + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices_repeated, + seq_lens_casual=seq_lens_casual, + max_seq_len=max_seq_len, + out_loc=out_cache_loc, + need_compress=need_compress, + is_prefill=True, + ) + indexer_metadata = ( + self.init_forward_metadata_indexer(core_attn_metadata) + if need_compress + else None + ) + if not (envs.SGLANG_OPT_USE_FUSED_PAGED_COMPRESS.get() and need_compress): + create = _create_dummy_paged_compress_data + else: + create = functools.partial( + create_paged_compressor_data, + is_prefill=True, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + extend_lens=extend_seq_lens, + extend_lens_cpu=extend_seq_lens_cpu, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + return DSV4MetadataRadix( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_target_verify( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> Union[DSV4MetadataRadix, DSV4MetadataSimplified]: + if envs.SGLANG_ADVANCED_CUDA_GRAPH_CAPTURE.get(): + assert out_cache_loc is not None + # FIXME: Constant tensor + if not hasattr(self, "extend_seq_lens_buffer"): + self.extend_seq_lens_buffer = torch.tensor( + [self.speculative_num_draft_tokens] * 1025, device=self.device + ) + extend_seq_lens = self.extend_seq_lens_buffer[: len(seq_lens)] + + return DSV4MetadataSimplified( + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + extend_seq_lens=extend_seq_lens, + ) + else: + seq_lens_cpu = seq_lens.tolist() + return self.init_forward_metadata_target_verify_old( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def init_forward_metadata_target_verify_old( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[List[int]] = None, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> DSV4MetadataRadix: + batch_size = len(seq_lens) + seq_lens = seq_lens + self.speculative_num_draft_tokens + seq_lens_cpu = [x + self.speculative_num_draft_tokens for x in seq_lens_cpu] + extend_seq_lens_cpu = [self.speculative_num_draft_tokens] * batch_size + extend_seq_lens = self._move_to_device(extend_seq_lens_cpu) + num_tokens = self.speculative_num_draft_tokens * batch_size + if out_cache_loc is None: # NOTE: for CUDA graph related + out_cache_loc = seq_lens.new_zeros(num_tokens) + return self.init_forward_metadata_prefill( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + num_tokens=num_tokens, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=True, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def make_forward_metadata_from_simplified( + self, simplified_metadata: DSV4MetadataSimplified + ) -> DSV4MetadataRadix: + # Extract the real metadata from the simplified metadata + req_pool_indices = simplified_metadata.req_pool_indices + seq_lens = simplified_metadata.seq_lens + out_cache_loc = simplified_metadata.out_cache_loc + + bs, num_draft_tokens = len(seq_lens), self.speculative_num_draft_tokens + seq_lens = seq_lens + self.speculative_num_draft_tokens + extend_seq_lens = simplified_metadata.extend_seq_lens + + seq_lens_casual, req_pool_indices_repeated = ( + self.expend_extend_with_same_length( + bs, num_draft_tokens, seq_lens, req_pool_indices + ) + ) + core_attn_metadata = self.make_core_attn_metadata( + req_to_token=self.req_to_token, + req_pool_indices_repeated=req_pool_indices_repeated, + seq_lens_casual=seq_lens_casual, + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + out_loc=out_cache_loc, + need_compress=True, + ) + indexer_metadata = self.init_forward_metadata_indexer(core_attn_metadata) + create = functools.partial( + create_paged_compressor_data, + is_prefill=True, + token_to_kv_pool=self.token_to_kv_pool, + req_to_token=self.req_to_token, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + extend_lens=extend_seq_lens, + seq_lens_cpu=None, + extend_lens_cpu=None, + use_prefill_cuda_graph=True, + num_q_tokens=num_draft_tokens, + ) + return DSV4MetadataRadix( + core_attn_metadata, + indexer_metadata, + c4_compress_metadata=create(compress_ratio=4), + c128_compress_metadata=create(compress_ratio=128), + ) + + def init_forward_metadata_draft_extend( + self, + max_seq_len: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: List[int], + num_tokens_per_bs: int, + out_cache_loc: Optional[torch.Tensor] = None, + use_prefill_cuda_graph: bool = False, + ) -> DSV4MetadataRadix: + batch_size = len(seq_lens) + extend_seq_lens_cpu = [num_tokens_per_bs] * batch_size + extend_seq_lens = self._move_to_device(extend_seq_lens_cpu) + num_tokens = num_tokens_per_bs * batch_size + if out_cache_loc is None: + out_cache_loc = seq_lens.new_zeros(num_tokens) + return self.init_forward_metadata_prefill( + seq_lens=seq_lens, + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + num_tokens=num_tokens, + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=False, + use_prefill_cuda_graph=use_prefill_cuda_graph, + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch) -> None: + if self.mtp_enabled and forward_batch.forward_mode.is_idle(): + return + + req_pool_indices = forward_batch.req_pool_indices + seq_lens = forward_batch.seq_lens.to(torch.int32) + seq_lens_cpu = forward_batch.seq_lens_cpu + assert forward_batch.req_to_token_pool.req_to_token is self.req_to_token + + assert self.swa_page_size % SWA_WINDOW == 0 and self.page_size % 128 == 0 + assert seq_lens_cpu is not None + max_seq_len = int(seq_lens_cpu.max().item()) + + if forward_batch.forward_mode.is_decode_or_idle(): + metadata = self.init_forward_metadata_decode( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=forward_batch.out_cache_loc, + ) + elif forward_batch.forward_mode.is_target_verify(): + metadata = self.init_forward_metadata_target_verify( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=forward_batch.out_cache_loc, + ) + elif forward_batch.forward_mode.is_prefill(include_draft_extend_v2=True): + extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu + extend_seq_lens = forward_batch.extend_seq_lens + assert ( + seq_lens is not None + and seq_lens_cpu is not None + and extend_seq_lens is not None + and extend_seq_lens_cpu is not None + ) + is_draft = forward_batch.forward_mode.is_draft_extend(include_v2=True) + metadata = self.init_forward_metadata_prefill( + max_seq_len=max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu.tolist(), + out_cache_loc=forward_batch.out_cache_loc, + num_tokens=sum(extend_seq_lens_cpu), + extend_seq_lens=extend_seq_lens, + extend_seq_lens_cpu=extend_seq_lens_cpu, + need_compress=not is_draft, # NOTE: draft model is swa only + ) + else: + raise NotImplementedError(f"unsupported mode {forward_batch.forward_mode=}") + + # set metadata + self.forward_metadata = metadata + if h := self.debug_dump_hook: + h("init_forward_metadata_output", metadata) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + self.decode_cuda_graph_shared_data = _DecodeCudaGraphSharedData() + self.decode_cuda_graph_metadata_of_bs: Dict[int, DSV4MetadataRadix] = {} + self.target_verify_cuda_graph_metadata_of_bs: Dict[ + int, Union[DSV4MetadataRadix, DSV4MetadataSimplified] + ] = {} + self.draft_extend_cuda_graph_metadata_of_bs: Dict[int, DSV4MetadataRadix] = {} + self.draft_extend_num_tokens_per_bs = ( + max_num_tokens // max_bs if max_bs > 0 else 1 + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + ) -> None: + assert req_pool_indices.size(0) == bs + assert seq_lens.size(0) == bs + + if forward_mode.is_decode_or_idle(): + # NOTE: we should use `self.decode_cuda_graph_shared_data` to avoid allocating + # a pack of tensors per cuda graph, but that is the NEXT step instead of current step. + # For example, we may write: + # + # metadata = compute_decode_metadata() + # use_shared_tensors(metadata, self.decode_cuda_graph_shared_data) + # + # def use_shared_tensors(): + # for field_name in ...: + # getattr(shared_data, field_name).copy_(getattr(metadata, field_name)[..maybe_some_slicing..]) + # setattr(metadata, field_name, getattr(shared_data, field_name)) + + metadata = self.init_forward_metadata_decode( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=torch.zeros_like(seq_lens), # Dummy value + ) + + self.decode_cuda_graph_metadata_of_bs[bs] = metadata + self.forward_metadata = metadata + elif forward_mode.is_target_verify(): + out_cache_loc = torch.zeros(num_tokens, **self.cuda_int32_kwargs) + metadata = self.init_forward_metadata_target_verify( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + use_prefill_cuda_graph=True, + ) + self.target_verify_cuda_graph_metadata_of_bs[bs] = metadata + self.forward_metadata = metadata + + # Track the current simplified metadata for resetting after warmup + self._current_capture_simplified = ( + metadata if isinstance(metadata, DSV4MetadataSimplified) else None + ) + elif forward_mode.is_draft_extend(include_v2=True): + num_tokens_per_bs = num_tokens // bs + metadata = self.init_forward_metadata_draft_extend( + max_seq_len=self.MAX_SEQ_LEN_FOR_CAPTURE, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens.tolist(), + num_tokens_per_bs=num_tokens_per_bs, + use_prefill_cuda_graph=True, + ) + self.draft_extend_cuda_graph_metadata_of_bs[bs] = metadata + self.forward_metadata = metadata + else: + raise NotImplementedError(f"{forward_mode=} not supported yet") + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInput], + seq_lens_cpu: Optional[torch.Tensor], + out_cache_loc: Optional[torch.Tensor] = None, + actual_forward_mode: Optional[ForwardMode] = None, + ) -> None: + # We observe error that len(out_cache_loc)=0 while len(seq_lens)>0. + # We only support DP attention, thus when IDLE, we will not execute attention backend, + # thus it is safe to delete it. + if actual_forward_mode == ForwardMode.IDLE: + if hasattr(self, "forward_metadata"): + del self.forward_metadata # avoid misuse + return + + assert seq_lens_cpu is not None + seq_lens = seq_lens[:bs] + seq_lens_cpu = seq_lens_cpu[:bs] + req_pool_indices = req_pool_indices[:bs] + + if forward_mode.is_decode_or_idle(): + assert out_cache_loc is not None + # Future optimization: use real max seq len + actual_max_seq_len = seq_lens_cpu.max().item() + + chosen_max_seq_len = self.MAX_SEQ_LEN_FOR_CAPTURE + assert actual_max_seq_len <= chosen_max_seq_len + + assert len(out_cache_loc.shape) == 1, f"{out_cache_loc.shape=}" + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, bs - len(out_cache_loc)), + mode="constant", + value=0, + ) + + temp_metadata = self.init_forward_metadata_decode( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc_padded, + ) + + # Future optimization: may not need to `copy` all things, + # But only copy partially such as `page_table[:, :max_seq_len]` + chosen_metadata = self.decode_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_target_verify(): + assert out_cache_loc is not None + # Future optimization: use real max seq len + actual_max_seq_len = seq_lens_cpu.max().item() + chosen_max_seq_len = self.MAX_SEQ_LEN_FOR_CAPTURE + assert actual_max_seq_len <= chosen_max_seq_len + # NOTE: extend length remains the same during target verify + num_tokens = self.speculative_num_draft_tokens * bs + out_cache_loc_padded = torch.nn.functional.pad( + out_cache_loc, + pad=(0, num_tokens - len(out_cache_loc)), + mode="constant", + value=0, + ) + temp_metadata = self.init_forward_metadata_target_verify( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc_padded, + use_prefill_cuda_graph=True, + ) + chosen_metadata = self.target_verify_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_draft_extend(include_v2=True): + actual_max_seq_len = seq_lens_cpu.max().item() + chosen_max_seq_len = self.MAX_SEQ_LEN_FOR_CAPTURE + assert actual_max_seq_len <= chosen_max_seq_len + num_tokens_per_bs = self.draft_extend_num_tokens_per_bs + # NOTE: draft extend doesn't need out_cache_loc since need_compress=False + temp_metadata = self.init_forward_metadata_draft_extend( + max_seq_len=chosen_max_seq_len, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu.tolist(), + num_tokens_per_bs=num_tokens_per_bs, + use_prefill_cuda_graph=True, + ) + chosen_metadata = self.draft_extend_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + else: + raise NotImplementedError + + def replay_cuda_graph_metadata_from( + self, + bs: int, + temp_metadata: DSV4MetadataRadix, + forward_mode: ForwardMode, + ) -> None: + """Copy pre-computed metadata to this backend's cuda graph metadata storage. + + This method is used to avoid redundant computation when multiple backends + need the same metadata (e.g., in speculative decoding with multiple steps). + """ + if forward_mode.is_decode_or_idle(): + chosen_metadata = self.decode_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_target_verify(): + chosen_metadata = self.target_verify_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + elif forward_mode.is_draft_extend(include_v2=True): + chosen_metadata = self.draft_extend_cuda_graph_metadata_of_bs[bs] + chosen_metadata.copy_(temp_metadata) + self.forward_metadata = chosen_metadata + else: + raise NotImplementedError + + def get_cuda_graph_seq_len_fill_value(self): + # FlashMLA, NSA backend, etc, use "1" + return 1 + + # TODO improve naming + def on_after_cuda_graph_warmup_pass(self): + metadata: DSV4MetadataRadix = self.forward_metadata + if isinstance(metadata.core_attn_metadata, DSV4AttnMetadataRadix): + metadata.core_attn_metadata.c1_flashmla_metadata = ( + _create_flashmla_metadata() + ) + metadata.core_attn_metadata.c4_flashmla_metadata = ( + _create_flashmla_metadata() + ) + metadata.core_attn_metadata.c128_flashmla_metadata = ( + _create_flashmla_metadata() + ) + + # For advanced CUDA graph capture, reset forward_metadata back to + # the current batch size's DSV4MetadataSimplified so that the next + # pass (including the actual capture pass) re-executes the + # simplified→real derivation, ensuring those GPU ops are recorded + # into the CUDA graph. + current_simplified = getattr(self, "_current_capture_simplified", None) + if current_simplified is not None: + self.forward_metadata = current_simplified + + def store_cache( + self, layer_id: int, swa_k: torch.Tensor, forward_batch: ForwardBatch + ) -> None: + raw_loc = forward_batch.out_cache_loc + if envs.SGLANG_OPT_USE_FUSED_STORE_CACHE.get(): + self.token_to_kv_pool.set_swa_key_buffer_radix_fused( + layer_id=layer_id, + raw_loc=raw_loc, + cache_k=swa_k, + ) + else: + swa_k_pack = quant_to_nope_fp8_rope_bf16_pack_triton(swa_k) + self.token_to_kv_pool.set_swa_key_buffer_radix( + layer_id=layer_id, + raw_loc=raw_loc, + cache_nope_fp8_rope_bf16_pack=swa_k_pack, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + compress_ratio: Literal[0, 4, 128], + save_kv_cache: bool = True, + attn_sink: Optional[torch.Tensor] = None, + **_, + ) -> torch.Tensor: + if isinstance(self.forward_metadata, DSV4MetadataSimplified): + real_metadata = self.make_forward_metadata_from_simplified( + simplified_metadata=self.forward_metadata, + ) + self.forward_metadata = real_metadata + + if self.mtp_enabled and forward_batch.forward_mode.is_idle(): + return q.new_empty(q.shape[0], q.shape[1], layer.v_head_dim) + + # NOTE: here set-kv only applies to swa kv + + assert k is v, "DeepseekV4 shares k and v" + swa_k = k + + layer_id = layer.layer_id + metadata = self.forward_metadata + core_attn_metadata = metadata.core_attn_metadata + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + # This sanity check is to avoid, e.g., in CUDA graph capturing, we may accidentally + # run forward passes multiple times with one init_forward_metadata. + # If that happens, the real capturing pass will record that layer 0 do not have any meta init operations + # which is wrong. + + if isinstance(core_attn_metadata, DSV4AttnMetadataRadix): + # ------- 1. SWA attention k cache ------- + if save_kv_cache: + self.store_cache(layer_id, swa_k, forward_batch) + swa_k_cache = token_to_kv_pool.get_swa_key_buffer_radix(layer_id) + + # ------- 2. Full (C4/C128) attention k cache ------- + extra_k_cache, extra_indices, extra_topk_lengths = None, None, None + if compress_ratio == 4: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_attn_metadata.c4_sparse_page_indices + extra_topk_lengths = core_attn_metadata.c4_sparse_topk_lengths + elif compress_ratio == 128: + extra_k_cache = token_to_kv_pool.get_extra_key_buffer(layer_id) + extra_indices = core_attn_metadata.c128_page_indices + extra_topk_lengths = core_attn_metadata.c128_topk_lengths_clamp1 + + # ------- Call attention core ------- + swa_window_size = token_to_kv_pool.swa_window_size + assert swa_k_cache.ndim == 2 + # view b/c flashmla expect dim=4 + # reference: FlashMLA/tests/test_flash_mla_sparse_prefill.py + k_cache_total_dim = token_to_kv_pool.swa_kv_pool.kv_cache_total_dim + swa_k_cache = swa_k_cache[:, : swa_window_size * k_cache_total_dim].view( + swa_k_cache.shape[0], swa_window_size, 1, k_cache_total_dim + ) + + if extra_k_cache is not None: + page_sizes = { + 4: token_to_kv_pool.page_size // 4, + 128: token_to_kv_pool.page_size // 128, + } + extra_k_cache = extra_k_cache[ + :, : page_sizes[compress_ratio] * k_cache_total_dim + ].view( + extra_k_cache.shape[0], + page_sizes[compress_ratio], + 1, + k_cache_total_dim, + ) + swa_page_indices = core_attn_metadata.swa_page_indices + swa_topk_lengths = core_attn_metadata.swa_topk_lengths + + if self.mtp_enabled: + if swa_page_indices.shape[0] != q.shape[0]: + swa_page_indices = _pad_tensor_to_size( + swa_page_indices, q.shape[0], value=0 + ) + + if swa_topk_lengths.shape[0] != q.shape[0]: + swa_topk_lengths = _pad_tensor_to_size( + swa_topk_lengths, q.shape[0], value=1 + ) + + # unsqueeze to adapt decode kernel + if q.ndim == 3: + q = q.unsqueeze(1) + if swa_page_indices.ndim == 2: + swa_page_indices = swa_page_indices.unsqueeze(1) + if extra_indices is not None and extra_indices.ndim == 2: + extra_indices = extra_indices.unsqueeze(1) + + assert attn_sink is not None + + flashmla_metadata = core_attn_metadata.get_flashmla_metadata(compress_ratio) + + # compute-sanitizer observe issue if this is not enforced + assert ( + swa_page_indices.shape[-1] % 64 == 0 + ), f"{swa_page_indices.shape=}'s last dimension is not aligned to 64" + if extra_indices is not None: + assert ( + extra_indices.shape[-1] % 64 == 0 + ), f"{extra_indices.shape=}'s last dimension is not aligned to 64" + + input_dict = dict( + q=q, + k_cache=swa_k_cache, + head_dim_v=self.head_dim_v, + block_table=None, + cache_seqlens=None, + tile_scheduler_metadata=flashmla_metadata, + softmax_scale=self.softmax_scale, + is_fp8_kvcache=True, + indices=swa_page_indices, + topk_length=swa_topk_lengths, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices, + extra_topk_length=extra_topk_lengths, + ) + + backend = envs.SGLANG_HACK_FLASHMLA_BACKEND.get() + o = flash_mla_with_kvcache_entrypoint(**input_dict, backend=backend)[0] + + o = o.squeeze(1) + return o + + raise NotImplementedError("ragged attention") + + #### Helper functions #### + + def expand_prefill_casually( + self, + num_tokens: int, + seq_lens: List[int], + extend_seq_lens: List[int], + req_pool_indices: torch.Tensor, + padded_num_tokens: Optional[int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE: expanded follow a `causal` mask pattern + seq_lens_casual = torch.empty(num_tokens, **self.cuda_int32_kwargs) + idx_to_req_repeated = torch.empty(num_tokens, **self.cuda_int32_kwargs) + offset = 0 + for i, (kv_len, qo_len) in enumerate(zip(seq_lens, extend_seq_lens)): + out = seq_lens_casual[offset : offset + qo_len] + offset += qo_len + torch.arange(kv_len - qo_len + 1, kv_len + 1, out=out) + idx_to_req_repeated[offset - qo_len : offset].fill_(i) + + assert offset == num_tokens + req_pool_indices_repeated = req_pool_indices[idx_to_req_repeated] + + # Padding is generic (out_cache_loc may be ceil_align'd beyond num_tokens). + # CP always needs it; non-CP can opt in via SGLANG_DSV4_FIX_ATTN_PADDING. + _need_pad = is_nsa_prefill_cp_round_robin_split() or envs.SGLANG_DSV4_FIX_ATTN_PADDING.get() + if _need_pad and padded_num_tokens is not None and padded_num_tokens > num_tokens: + pad_size = padded_num_tokens - num_tokens + seq_lens_casual = torch.nn.functional.pad( + # TODO: is pad value 1 ok? + seq_lens_casual, (0, pad_size), value=1 + ) + req_pool_indices_repeated = torch.nn.functional.pad( + req_pool_indices_repeated, (0, pad_size), value=req_pool_indices_repeated[-1].item() + ) + + return seq_lens_casual, req_pool_indices_repeated + + def expend_extend_with_same_length( + self, + bs: int, + qo_len: int, + seq_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + ): + seq_lens_casual = seq_lens[:, None] + torch.arange( + -qo_len + 1, 1, **self.cuda_int32_kwargs + ) + seq_lens_casual = seq_lens_casual.flatten() + idx_to_req_repeated = torch.arange( + bs, **self.cuda_int32_kwargs + ).repeat_interleave(qo_len) + req_pool_indices_repeated = req_pool_indices[idx_to_req_repeated] + return seq_lens_casual, req_pool_indices_repeated + + def make_core_attn_metadata( + self, + req_to_token: torch.Tensor, + req_pool_indices_repeated: torch.Tensor, + seq_lens_casual: torch.Tensor, + max_seq_len: int, + out_loc: torch.Tensor, + need_compress: bool = True, + is_prefill: bool = False, + ) -> DSV4AttnMetadataRadix: + # NOTE: the full attn page size is 256 and SWA page size is 128, + # which is OK in current SWA radix tree design + assert self.swa_page_size == SWA_WINDOW + + # -------------------- START compute SWA metadata -------------------- + swa_page_indices = self.get_swa_page_indices( + seq_lens_casual=seq_lens_casual, + req_pool_indices_repeated=req_pool_indices_repeated, + ) + + swa_page_indices = _pad_last_dim( + swa_page_indices, multiples_of=PAGE_INDEX_ALIGNED_SIZE + ) + + raw_positions = seq_lens_casual - 1 + swa_topk_lengths = torch.clamp(seq_lens_casual, max=SWA_WINDOW) + + # -------------------- END compute SWA metadata -------------------- + page_table = req_to_token[ + req_pool_indices_repeated, : max_seq_len : self.page_size + ] + page_table = (page_table // self.page_size).to(torch.int32) + + core_attn_metadata = DSV4AttnMetadataRadix( + page_size=self.page_size, + raw_out_loc=out_loc, + seq_lens_casual=seq_lens_casual, + cuda_int32_kwargs=self.cuda_int32_kwargs, + positions_casual=raw_positions, + page_table=page_table, + swa_page_indices=swa_page_indices, + swa_topk_lengths=swa_topk_lengths, + c4_sparse_topk=C4_TOPK, + ) + + if need_compress: + core_attn_metadata.init_compressed_metadata() + if is_prefill and is_nsa_prefill_cp_round_robin_split(): + core_attn_metadata.apply_cp_reindex() + core_attn_metadata.init_flashmla_related() + else: + # Draft model doesn't include c4/c128 compressors + core_attn_metadata.c4_sparse_topk_lengths = None + core_attn_metadata.c4_sparse_page_indices = None + core_attn_metadata.c1_flashmla_metadata = _create_flashmla_metadata() + core_attn_metadata.c4_flashmla_metadata = None + core_attn_metadata.c128_flashmla_metadata = None + return core_attn_metadata + + def get_swa_page_indices( + self, + seq_lens_casual: torch.Tensor, + req_pool_indices_repeated: torch.Tensor, + ) -> torch.Tensor: + pos_causal = seq_lens_casual - 1 + num_qo_tokens = seq_lens_casual.size(0) + offsets = pos_causal.unsqueeze(1) - torch.arange( + SWA_WINDOW, **self.cuda_int32_kwargs + ).unsqueeze(0) + invalid_offset_mask = offsets < 0 + offsets.masked_fill_(invalid_offset_mask, 0) + raw_indices = self.req_to_token[req_pool_indices_repeated[:, None], offsets] + assert raw_indices.shape == (num_qo_tokens, SWA_WINDOW) + raw_indices.masked_fill_(invalid_offset_mask, -1) + swa_indices = self.token_to_kv_pool.translate_loc_from_full_to_swa(raw_indices) + return swa_indices + + #### Test-only API #### + def extract_metadata(self, forward_batch: ForwardBatch): + # NOTE: in the future we may put metadata in the forward_batch itself + # this function is used for tests. Don't delete it. + return self.forward_metadata + + +class DeepseekV4MultiStepBackend(DeepseekV4BackendRadix): + def __init__( + self, model_runner: ModelRunner, topk: int, speculative_num_steps: int + ): + super().__init__(model_runner) + self.model_runner = model_runner + self.topk = topk + self.speculative_num_steps = speculative_num_steps + self.attn_backends: List[DeepseekV4BackendRadix] = [] + for i in range(self.speculative_num_steps): + self.attn_backends.append( + DeepseekV4BackendRadix( + model_runner, + speculative_step_id=i, + topk=self.topk, + speculative_num_steps=self.speculative_num_steps, + ) + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + if self.speculative_num_steps == 1: + return + + # Compute metadata only once using the first backend + self.attn_backends[0].init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=forward_batch.req_pool_indices, + seq_lens=forward_batch.seq_lens, + seq_lens_sum=forward_batch.seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, + out_cache_loc=forward_batch.out_cache_loc, + ) + temp_metadata = self.attn_backends[0].forward_metadata + + # Copy to other backends without recomputing + for i in range(1, self.speculative_num_steps - 1): + self.attn_backends[i].replay_cuda_graph_metadata_from( + bs=bs, + temp_metadata=temp_metadata, + forward_mode=ForwardMode.DECODE, + ) + + +def _pad_tensor_to_size(tensor: torch.Tensor, size: int, *, value: int = 0): + if value == 0: + return torch.cat( + [tensor, tensor.new_zeros(size - tensor.shape[0], *tensor.shape[1:])], + dim=0, + ) + else: + return torch.cat( + [ + tensor, + tensor.new_full((size - tensor.shape[0], *tensor.shape[1:]), value), + ], + dim=0, + ) diff --git a/python/sglang/srt/layers/attention/indexer_topk_capturer.py b/python/sglang/srt/layers/attention/indexer_topk_capturer.py new file mode 100644 index 000000000000..05b086302e13 --- /dev/null +++ b/python/sglang/srt/layers/attention/indexer_topk_capturer.py @@ -0,0 +1,121 @@ +import logging +from typing import TYPE_CHECKING, Optional + +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.layers.topk_capturer_base import ( + _GB, + _MB, + BaseTopkCapturer, + BaseTopkCapturerNoop, +) + +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig + +logger = logging.getLogger(__name__) + +INDEX_TOPK = 512 + + +def _count_indexer_layers(model_config: "ModelConfig") -> int: + # TODO very hacky now + compress_ratios = getattr(model_config.hf_text_config, "compress_ratios", None) + if compress_ratios is None: + return 0 + return sum(1 for r in compress_ratios if r == 4) + + +class IndexerTopkCapturer(BaseTopkCapturer): + def __init__( + self, + model_config: "ModelConfig", + num_tokens: int, + max_running_requests: int, + device: str, + ): + from sglang.srt.server_args import get_global_server_args + + self.num_indexer_layers = _count_indexer_layers(model_config) + self.index_topk = getattr(model_config.hf_text_config, "index_topk", INDEX_TOPK) + + if self.num_indexer_layers == 0: + logger.warning("No indexer layers found, IndexerTopkCapturer disabled") + self._enabled = False + return + + self._enabled = True + + server_args = get_global_server_args() + max_batch_size = max( + server_args.chunked_prefill_size * server_args.dp_size, + max_running_requests, + ) + + attn_tp_size = get_attention_tp_size() + assert attn_tp_size == 1, "IndexerTopkCapturer now only supports DP attention" + + super().__init__( + num_tokens=num_tokens, + max_batch_size=max_batch_size, + num_layers=self.num_indexer_layers, + topk_size=self.index_topk, + device=device, + ) + + self._log_allocation() + + def _log_allocation(self): + host_size_gb = self.host_cache.get_buffer_size_bytes() / _GB + device_size_mb = self.device_cache.get_buffer_size_bytes() / _MB + logger.info( + f"IndexerTopkCapturer allocated: " + f"num_indexer_layers={self.num_indexer_layers}, index_topk={self.index_topk}, " + f"host_cache={host_size_gb:.2f}GB, device_cache={device_size_mb:.2f}MB" + ) + + def _sync_to_host(self, forward_batch, can_run_graph, cuda_graph_batch): + # b/c DP attention, we will not use a global buffer and gather it (like MoE), + # and each rank should directly write to host + num_tokens = forward_batch.out_cache_loc.shape[0] + out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() + self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ + :num_tokens, :, : self.topk_size + ].cpu() + + def is_enabled(self) -> bool: + return self._enabled + + +class IndexerTopkCapturerNoop(BaseTopkCapturerNoop): + pass + + +_global_indexer_capturer: Optional[IndexerTopkCapturer] = IndexerTopkCapturerNoop() + + +def get_global_indexer_capturer(): + return _global_indexer_capturer + + +def set_global_indexer_capturer(capturer): + global _global_indexer_capturer + _global_indexer_capturer = capturer + + +def create_indexer_capturer( + enable: bool, + model_config: "ModelConfig", + num_tokens: int, + max_running_requests: int, + device: str, +): + if enable: + capturer = IndexerTopkCapturer( + model_config=model_config, + num_tokens=num_tokens, + max_running_requests=max_running_requests, + device=device, + ) + if capturer.is_enabled(): + return capturer + return IndexerTopkCapturerNoop() diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py index 1cdf65b91c29..fb40881a122e 100644 --- a/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor.py @@ -316,6 +316,9 @@ def vanilla(cls, pool, buf, loc, index_k, index_k_scale): @classmethod def triton(cls, pool, buf, loc, index_k, index_k_scale): + # Observe int32 loc and the kernel assert int64. may optimize later + loc = loc.to(torch.int64) + _set_k_and_s_triton( buf=buf, loc=loc, @@ -354,14 +357,16 @@ def _set_k_and_s_triton( f"index_k_scale must be 1D or 2D, got shape {index_k_scale.shape}" ) if _is_hip: - assert buf_numel_per_page == 1 * (128 + 4) + # assert buf_numel_per_page == 1 * (128 + 4) + pass else: assert buf_numel_per_page == 64 * (128 + 4) assert num_tokens_to_write == num_tokens_to_write_ == num_tokens_to_write__ assert index_head_dim == 128 assert scale_dim == 1 if _is_hip: - assert page_size == 1 + # assert page_size == 1 + pass else: assert page_size == 64 diff --git a/python/sglang/srt/layers/attention/nsa/index_buf_accessor_v4.py b/python/sglang/srt/layers/attention/nsa/index_buf_accessor_v4.py new file mode 100644 index 000000000000..e03a3b8cc37f --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/index_buf_accessor_v4.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + +# May move this dataclass to somewhere else in the future (low priority) +@dataclass +class NopeFp8RopeBf16Pack: + k_nope_fp8: torch.Tensor + k_rope_bf16: torch.Tensor + scale_k_nope_ue8m0: torch.Tensor + + def __post_init__(self): + assert self.k_nope_fp8.shape[-1] == 448 + assert self.k_rope_bf16.shape[-1] == 64 + assert self.scale_k_nope_ue8m0.shape[-1] == 7 + + def slice_pack(self, _slice: Any) -> NopeFp8RopeBf16Pack: + return NopeFp8RopeBf16Pack( + k_nope_fp8=self.k_nope_fp8[_slice], + k_rope_bf16=self.k_rope_bf16[_slice], + scale_k_nope_ue8m0=self.scale_k_nope_ue8m0[_slice], + ) + + +# TODO seems no need to have this and can remove +# class GetKAndS: +# # may also make a torch version for comparison if needed +# @classmethod +# def triton(cls, pool, buf, loc) -> NopeFp8RopeBf16Pack: +# return TODO + + +class SetKAndS: + @classmethod + def execute(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + # Use cls.torch for accuracy baseline + cls.triton(pool, buf, loc, nope_fp8_rope_bf16_pack) + + @classmethod + def torch(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + _set_k_and_s_torch(buf, loc, nope_fp8_rope_bf16_pack, pool.page_size) + + @classmethod + def triton(cls, pool, buf, loc, nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack): + _set_k_and_s_triton(buf, loc, nope_fp8_rope_bf16_pack, pool.page_size) + + +def _set_k_and_s_triton( + buf: torch.Tensor, + loc: torch.Tensor, + nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + page_size: int, +): + """ + :param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8 + :param loc: (num_tokens_to_write,), int, element := the token index to write to + :param index_k: (num_tokens_to_write, 128 elem), fp8 + :param index_k_scale: (num_tokens_to_write, 1 elem), fp32 + :return: + """ + num_pages, buf_numel_per_page = buf.shape + (num_tokens_to_write,) = loc.shape + + k_nope, k_rope, scale_k_nope = ( + nope_fp8_rope_bf16_pack.k_nope_fp8, + nope_fp8_rope_bf16_pack.k_rope_bf16, + nope_fp8_rope_bf16_pack.scale_k_nope_ue8m0, + ) + + num_tokens_to_write_nope, nope_dim = k_nope.shape + num_tokens_to_write_rope, rope_dim = k_rope.shape + num_tokens_to_write_scale, scale_dim = scale_k_nope.shape + + assert ( + num_tokens_to_write + == num_tokens_to_write_nope + == num_tokens_to_write_rope + == num_tokens_to_write_scale + ) + + assert buf.dtype == torch.uint8 + assert loc.dtype in [torch.int64, torch.int32], f"{loc.dtype=}" # can be int32 + + assert k_nope.dtype == fp8_dtype + assert k_rope.dtype == torch.bfloat16 + assert scale_k_nope.dtype == torch.uint8, f"{scale_k_nope.dtype=}" + + assert buf.is_contiguous() + assert loc.is_contiguous() + assert k_nope.is_contiguous() + assert k_rope.is_contiguous() + assert scale_k_nope.is_contiguous() + + buf_fp8 = buf.view(fp8_dtype) + buf_bf16 = buf.view(torch.bfloat16) + buf_uint8 = buf.view(torch.uint8) + + nope_rope_bytes = nope_dim + rope_dim * 2 + s_offset_nbytes_in_page = page_size * (nope_dim + rope_dim * 2) + + _set_k_and_s_triton_kernel[(num_tokens_to_write,)]( + buf_fp8, + buf_bf16, + buf_uint8, + loc, + k_nope, + k_rope, + scale_k_nope, + k_nope.stride(0), + k_rope.stride(0), + scale_k_nope.stride(0), + PAGE_SIZE=page_size, + BUF_NUMEL_PER_PAGE=buf_numel_per_page, + NUM_NOPE_ELEMS_PER_TOKEN=nope_dim, + NUM_ROPE_ELEMS_PER_TOKEN=rope_dim, + NUM_SCALE_ELEMS_PER_TOKEN=scale_dim, + NUM_NOPE_ROPE_BYTES_PER_TOKEN=nope_rope_bytes, + PADDED_SCALE_ELEMS_PER_TOKEN=scale_dim + 1, # 1B pad + S_OFFSET_NBYTES_IN_PAGE=s_offset_nbytes_in_page, + BLOCK_NOPE=512, + BLOCK_ROPE=64, + BLOCK_SCALE=8, + ) + + +@triton.jit +def _set_k_and_s_triton_kernel( + buf_fp8_ptr, + buf_bf16_ptr, + buf_uint8_ptr, + loc_ptr, + k_nope_ptr, + k_rope_ptr, + scale_k_nope_ptr, + k_nope_ptr_stride_0, + k_rope_ptr_stride_0, + scale_k_nope_ptr_stride_0, + PAGE_SIZE: tl.constexpr, + BUF_NUMEL_PER_PAGE: tl.constexpr, + NUM_NOPE_ELEMS_PER_TOKEN: tl.constexpr, + NUM_ROPE_ELEMS_PER_TOKEN: tl.constexpr, + NUM_NOPE_ROPE_BYTES_PER_TOKEN: tl.constexpr, + NUM_SCALE_ELEMS_PER_TOKEN: tl.constexpr, + PADDED_SCALE_ELEMS_PER_TOKEN: tl.constexpr, + S_OFFSET_NBYTES_IN_PAGE: tl.constexpr, + BLOCK_NOPE: tl.constexpr, + BLOCK_ROPE: tl.constexpr, + BLOCK_SCALE: tl.constexpr, +): + token_id = tl.program_id(0) + loc = tl.load(loc_ptr + token_id) + + nope_range = tl.arange(0, BLOCK_NOPE) + nope_mask = nope_range < NUM_NOPE_ELEMS_PER_TOKEN + in_k_nope_offsets = token_id * k_nope_ptr_stride_0 + nope_range + k_nope = tl.load(k_nope_ptr + in_k_nope_offsets, mask=nope_mask, other=0.0) + + rope_range = tl.arange(0, BLOCK_ROPE) + in_k_rope_offsets = token_id * k_rope_ptr_stride_0 + rope_range + k_rope = tl.load(k_rope_ptr + in_k_rope_offsets) + + scale_range = tl.arange(0, BLOCK_SCALE) + scale_mask = scale_range < NUM_SCALE_ELEMS_PER_TOKEN + in_scale_k_offsets = token_id * scale_k_nope_ptr_stride_0 + scale_range + k_scale = tl.load(scale_k_nope_ptr + in_scale_k_offsets, mask=scale_mask, other=0) + + loc_page_index = loc // PAGE_SIZE + loc_token_offset_in_page = loc % PAGE_SIZE + + out_k_nope_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE + + loc_token_offset_in_page * NUM_NOPE_ROPE_BYTES_PER_TOKEN + + nope_range + ) + + out_k_rope_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE // 2 + + loc_token_offset_in_page * (NUM_NOPE_ROPE_BYTES_PER_TOKEN // 2) + + NUM_NOPE_ELEMS_PER_TOKEN // 2 + + rope_range + ) + + out_s_offsets = ( + loc_page_index * BUF_NUMEL_PER_PAGE + + S_OFFSET_NBYTES_IN_PAGE + + loc_token_offset_in_page * PADDED_SCALE_ELEMS_PER_TOKEN + + scale_range + ) + + tl.store(buf_fp8_ptr + out_k_nope_offsets, k_nope, mask=nope_mask) + tl.store(buf_bf16_ptr + out_k_rope_offsets, k_rope) + tl.store(buf_uint8_ptr + out_s_offsets, k_scale, mask=scale_mask) + + +def _set_k_and_s_torch( + buf: torch.Tensor, + loc: torch.Tensor, + nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + page_size: int, +): + """ + :param buf: (num_pages, page_size 64 * (128B data + 4B scale)), uint8 + :param loc: (num_tokens_to_write,), int, element := the token index to write to + :param index_k: (num_tokens_to_write, 128 elem), fp8 + :param index_k_scale: (num_tokens_to_write, 1 elem), fp32 + :return: + """ + num_pages, buf_numel_per_page = buf.shape + (num_tokens_to_write,) = loc.shape + + k_nope, k_rope, scale_k_nope = ( + nope_fp8_rope_bf16_pack.k_nope_fp8, + nope_fp8_rope_bf16_pack.k_rope_bf16, + nope_fp8_rope_bf16_pack.scale_k_nope_ue8m0, + ) + + num_tokens_to_write_nope, nope_dim = k_nope.shape + num_tokens_to_write_rope, rope_dim = k_rope.shape + num_tokens_to_write_scale, scale_dim = scale_k_nope.shape + + assert ( + num_tokens_to_write + == num_tokens_to_write_nope + == num_tokens_to_write_rope + == num_tokens_to_write_scale + ), f"{num_tokens_to_write=} {num_tokens_to_write_nope=} {num_tokens_to_write_rope=} {num_tokens_to_write_scale=}" + + assert buf.dtype == torch.uint8 + assert loc.dtype in [ + torch.int64, + torch.int32, + ], f"{loc.dtype=}" # can be int32 or int64 + + assert k_nope.dtype == fp8_dtype + assert k_rope.dtype == torch.bfloat16 + assert scale_k_nope.dtype == torch.uint8 + + assert buf.is_contiguous() + assert loc.is_contiguous() + assert k_nope.is_contiguous() + assert k_rope.is_contiguous() + assert scale_k_nope.is_contiguous() + + buf_fp8 = buf.view(fp8_dtype).flatten() + buf_bf16 = buf.view(torch.bfloat16).flatten() + buf_scale = buf.view(torch.uint8).flatten() + + loc_page_index = loc // page_size + loc_token_offset_in_page = loc % page_size + + s_offset_nbytes_in_page = page_size * (nope_dim + rope_dim * 2) + + nope_offset = loc_page_index * buf_numel_per_page + loc_token_offset_in_page * ( + nope_dim + rope_dim * 2 + ) + + rope_offset = ( + loc_page_index * buf_numel_per_page // 2 + + (loc_token_offset_in_page * (nope_dim + rope_dim * 2) + nope_dim) // 2 + ) + + s_offset = ( + loc_page_index * buf_numel_per_page + + s_offset_nbytes_in_page + + loc_token_offset_in_page * (scale_dim + 1) # +1 for padding byte + ) + + for i in range(num_tokens_to_write): + buf_fp8[nope_offset[i] : nope_offset[i] + nope_dim] = k_nope[i] + buf_bf16[rope_offset[i] : rope_offset[i] + rope_dim] = k_rope[i] + buf_scale[s_offset[i] : s_offset[i] + scale_dim] = scale_k_nope[i] diff --git a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py index d17523b41955..4dc79b0f5033 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_indexer.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_indexer.py @@ -46,6 +46,9 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import NSATokenToKVPool +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn DUAL_STREAM_TOKEN_THRESHOLD = 1024 if _is_cuda else 0 @@ -112,7 +115,7 @@ def topk_transform( def rotate_activation(x: torch.Tensor) -> torch.Tensor: - assert x.dtype == torch.bfloat16 + # assert x.dtype == torch.bfloat16 # from sgl_kernel import hadamard_transform if _is_hip: from fast_hadamard_transform import hadamard_transform @@ -738,7 +741,7 @@ def _get_topk_ragged_with_cp( actual_seq_q_list.append(actual_seq_q) batch_idx_list.append(batch_idx) - k_fp8 = torch.cat(k_fp8_list, dim=0).view(torch.float8_e4m3fn) + k_fp8 = torch.cat(k_fp8_list, dim=0).view(fp8_dtype) k_scale = torch.cat(k_scale_list, dim=0).view(torch.float32).squeeze(-1) kv_fp8 = (k_fp8, k_scale) ks = torch.cat(ks_list, dim=0) @@ -779,7 +782,7 @@ def _get_topk_ragged_with_cp( block_tables[0], ) - k_fp8 = k_fp8.view(torch.float8_e4m3fn) + k_fp8 = k_fp8.view(fp8_dtype) k_scale = k_scale.view(torch.float32).squeeze(-1) kv_fp8 = (k_fp8, k_scale) ks = torch.full((actual_seq_q,), offset, dtype=torch.int32, device="cuda") @@ -872,7 +875,7 @@ def forward_indexer( block_tables[i], ) - k_fp8 = k_fp8.view(torch.float8_e4m3fn).unsqueeze(0).contiguous() + k_fp8 = k_fp8.view(fp8_dtype).unsqueeze(0).contiguous() k_scale = k_scale.view(torch.float32).squeeze(-1).unsqueeze(0).contiguous() index_score = fp8_index( diff --git a/python/sglang/srt/layers/attention/nsa/quant_k_cache.py b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py index 5454071b897d..fa9c9ba5dda7 100644 --- a/python/sglang/srt/layers/attention/nsa/quant_k_cache.py +++ b/python/sglang/srt/layers/attention/nsa/quant_k_cache.py @@ -2,6 +2,10 @@ import triton import triton.language as tl +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + def quantize_k_cache(cache_k): return _quantize_k_cache_fast_wrapped(cache_k) @@ -75,7 +79,7 @@ def _quantize_k_cache_ref( result = torch.empty( (num_blocks, block_size, dv + num_tiles * 4 + input_elem_size * (d - dv)), - dtype=torch.float8_e4m3fn, + dtype=fp8_dtype, device=input_k_cache.device, ) result_k_nope_part = result[..., :dv] @@ -100,7 +104,7 @@ def _quantize_k_cache_ref( ..., tile_idx * tile_size : (tile_idx + 1) * tile_size ].float() / cur_scale_factors_inv.float() - ).to(torch.float8_e4m3fn) + ).to(fp8_dtype) result_k_nope_part[..., tile_idx * tile_size : (tile_idx + 1) * tile_size] = ( cur_quantized_nope ) @@ -152,7 +156,7 @@ def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128): output = torch.empty( (num_tokens, dim_nope + num_tiles * 4 + k_rope.element_size() * dim_rope), - dtype=torch.float8_e4m3fn, + dtype=fp8_dtype, device=k_nope.device, ) output_nope_q = output[..., :dim_nope] @@ -180,8 +184,8 @@ def _quantize_k_cache_fast(k_nope, k_rope, group_size: int = 128): GROUP_SIZE=group_size, DIM_NOPE=dim_nope, DIM_ROPE=dim_rope, - FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, - FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + FP8_MIN=torch.finfo(fp8_dtype).min, + FP8_MAX=torch.finfo(fp8_dtype).max, ) return output @@ -232,7 +236,7 @@ def _quantize_k_cache_fast_separate(k_nope, k_rope, group_size: int = 128): # Create typed views for the kernel to write into # Fixed byte layout for nope_part: [nope_fp8 (dim_nope bytes) | scales_fp32 (num_tiles*4 bytes)] # Fixed byte layout for rope_part: [rope_bf16 (dim_rope*2 bytes)] - nope_q_view = nope_part_u8[:, :dim_nope].view(torch.float8_e4m3fn) + nope_q_view = nope_part_u8[:, :dim_nope].view(fp8_dtype) nope_s_view = nope_part_u8[:, dim_nope:].view(torch.float32) rope_view = rope_part_u8.view(torch.bfloat16) @@ -256,8 +260,8 @@ def _quantize_k_cache_fast_separate(k_nope, k_rope, group_size: int = 128): GROUP_SIZE=group_size, DIM_NOPE=dim_nope, DIM_ROPE=dim_rope, - FP8_MIN=torch.finfo(torch.float8_e4m3fn).min, - FP8_MAX=torch.finfo(torch.float8_e4m3fn).max, + FP8_MIN=torch.finfo(fp8_dtype).min, + FP8_MAX=torch.finfo(fp8_dtype).max, ) # Add middle dimension for compatibility with set_mla_kv_buffer_triton diff --git a/python/sglang/srt/layers/attention/nsa/quant_k_cache_v4.py b/python/sglang/srt/layers/attention/nsa/quant_k_cache_v4.py new file mode 100644 index 000000000000..f5d1c9b92d70 --- /dev/null +++ b/python/sglang/srt/layers/attention/nsa/quant_k_cache_v4.py @@ -0,0 +1,179 @@ +import torch +import triton +import triton.language as tl + +from sglang.srt.layers.attention.nsa.index_buf_accessor_v4 import NopeFp8RopeBf16Pack +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn +from sglang.srt.utils import is_hip + + +@triton.jit +def _quant_k_cache_fused_kernel( + k_bf16_ptr, + k_nope_fp8_ptr, + k_rope_bf16_ptr, + scale_k_nope_uint8_ptr, + k_bf16_stride_0, + k_nope_fp8_stride_0, + k_rope_bf16_stride_0, + scale_stride_0, + DIM_NOPE: tl.constexpr, + DIM_ROPE: tl.constexpr, + TILE_SIZE: tl.constexpr, + NUM_TILES: tl.constexpr, + FP8_MIN: tl.constexpr, + FP8_MAX: tl.constexpr, + EPS: tl.constexpr, +): + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + if tile_id == NUM_TILES: + # copy rope part (last 64 dims) + rope_range = tl.arange(0, TILE_SIZE) + rope_mask = rope_range < DIM_ROPE + + # load k_bf16[token_id, 448:512] + in_rope_offsets = token_id * k_bf16_stride_0 + DIM_NOPE + rope_range + rope_data = tl.load(k_bf16_ptr + in_rope_offsets, mask=rope_mask, other=0.0) + + # store to k_rope_bf16[token_id, :] + out_rope_offsets = token_id * k_rope_bf16_stride_0 + rope_range + tl.store(k_rope_bf16_ptr + out_rope_offsets, rope_data, mask=rope_mask) + else: + # do nope quantization (tile_id < NUM_TILES) + tile_range = tl.arange(0, TILE_SIZE) + + # load k_bf16[token_id, tile_id*64:(tile_id+1)*64] + in_tile_offsets = token_id * k_bf16_stride_0 + tile_id * TILE_SIZE + tile_range + x_bf16 = tl.load(k_bf16_ptr + in_tile_offsets) + x_fp32 = x_bf16.to(tl.float32) + + abs_x = tl.abs(x_fp32) + max_abs = tl.max(abs_x) + max_abs_clamped = tl.maximum(max_abs, EPS) + scale = max_abs_clamped / FP8_MAX + + # cast scale to ue8m0 format + # log2_scale = t_hip/l.log2(scale) + log2_scale = tl.log2(scale) + # if is_hip(): + # ceil_log2 = tl.math.ceil(log2_scale+1e-5) + # else: + # ceil_log2 = tl.math.ceil(log2_scale) + # ceil_log2 = tl.math.ceil(log2_scale+1e-5) + ceil_log2 = tl.math.ceil(log2_scale) + scale_pow2_fp32 = tl.exp2(ceil_log2) + scale_inv = 1.0 / scale_pow2_fp32 + x_scaled = x_fp32 * scale_inv + x_fp8 = tl.clamp(x_scaled, FP8_MIN, FP8_MAX).to(k_nope_fp8_ptr.dtype.element_ty) + + out_fp8_offsets = ( + token_id * k_nope_fp8_stride_0 + tile_id * TILE_SIZE + tile_range + ) + tl.store(k_nope_fp8_ptr + out_fp8_offsets, x_fp8) + + exponent = ceil_log2.to(tl.int32) + scale_uint8 = (exponent + 127).to(tl.uint8) + + out_scale_offset = token_id * scale_stride_0 + tile_id + tl.store(scale_k_nope_uint8_ptr + out_scale_offset, scale_uint8) + + +def quant_to_nope_fp8_rope_bf16_pack_triton( + k_bf16: torch.Tensor, +) -> NopeFp8RopeBf16Pack: + """ + Quantize nope part (0:448) to fp8 and keep rope part (448:512) still in bf16. + Scaling factor is in ue8m0 format and stored as uint8. + """ + assert k_bf16.dtype == torch.bfloat16 + num_tokens, hidden_dim = k_bf16.shape + assert hidden_dim == 512 + dim_nope = 448 + dim_rope = 64 + tile_size = 64 + num_tiles = dim_nope // tile_size # 7 tiles + + k_bf16 = k_bf16.contiguous() + + k_nope_fp8 = torch.empty( + (num_tokens, dim_nope), dtype=fp8_dtype, device=k_bf16.device + ) + k_rope_bf16 = torch.empty( + (num_tokens, dim_rope), dtype=torch.bfloat16, device=k_bf16.device + ) + scale_k_nope_ue8m0 = torch.empty( + (num_tokens, num_tiles), dtype=torch.uint8, device=k_bf16.device + ) + + fp8_dtype_info = torch.finfo(fp8_dtype) + + # additional block to handle rope copy + grid = (num_tokens, num_tiles + 1) + _quant_k_cache_fused_kernel[grid]( + k_bf16, + k_nope_fp8, + k_rope_bf16, + scale_k_nope_ue8m0, + k_bf16.stride(0), + k_nope_fp8.stride(0), + k_rope_bf16.stride(0), + scale_k_nope_ue8m0.stride(0), + DIM_NOPE=dim_nope, + DIM_ROPE=dim_rope, + TILE_SIZE=tile_size, + NUM_TILES=num_tiles, + FP8_MIN=fp8_dtype_info.min, + FP8_MAX=fp8_dtype_info.max, + EPS=1e-8, + ) + + return NopeFp8RopeBf16Pack( + k_nope_fp8=k_nope_fp8, + k_rope_bf16=k_rope_bf16, + scale_k_nope_ue8m0=scale_k_nope_ue8m0, + ) + + +# Torch implementation for accuracy baseline +def quant_to_nope_fp8_rope_bf16_pack(k_bf16: torch.Tensor) -> NopeFp8RopeBf16Pack: + assert k_bf16.dtype == torch.bfloat16 + _num_tokens, hidden_dim = k_bf16.shape + assert hidden_dim == 512 + dim_nope = 448 + dim_rope = 64 + + k_nope_bf16, k_rope_bf16 = k_bf16.split([dim_nope, dim_rope], dim=-1) + + tile_size = 64 + num_tiles = dim_nope // tile_size + + # FIXME: Check here later + x = k_nope_bf16.contiguous().view(-1, num_tiles, tile_size) + scale = x.abs().amax(dim=-1).float() / 448.0 + scale_pow2_fp32 = _cast_scale_inv_to_ue8m0(scale, out_dtype=torch.float32) + scale_k_nope_ue8m0 = scale_pow2_fp32.to(torch.float8_e8m0fnu) + k_nope_fp8 = (x.float() / scale_pow2_fp32.unsqueeze(-1)).to(fp8_dtype) + k_nope_fp8 = k_nope_fp8.view(-1, tile_size * num_tiles) + # ue8m0 is float8_e4m3fn, but can be also viewed as uint8 integer + scale_k_nope_ue8m0 = scale_k_nope_ue8m0.view(torch.uint8) + + return NopeFp8RopeBf16Pack( + k_nope_fp8=k_nope_fp8, + k_rope_bf16=k_rope_bf16.contiguous(), + scale_k_nope_ue8m0=scale_k_nope_ue8m0, + ) + + +def _cast_scale_inv_to_ue8m0( + scales_inv: torch.Tensor, out_dtype=torch.float32 +) -> torch.Tensor: + # if is_hip(): + # log2_val = torch.clamp_min(scales_inv, 1e-4).log2() + # return torch.pow(2, (log2_val + 1e-5).ceil()).to(out_dtype) + # else: + # return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype) + return torch.pow(2, torch.clamp_min(scales_inv, 1e-4).log2().ceil()).to(out_dtype) diff --git a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py index 1088bd3d171b..66ff709c7ef4 100644 --- a/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py +++ b/python/sglang/srt/layers/attention/nsa/tilelang_kernel.py @@ -1,4 +1,5 @@ -from typing import Optional, Tuple +import functools +from typing import Any, Optional, Tuple import tilelang import tilelang.language as T @@ -12,20 +13,26 @@ pass_configs = { tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + # tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, } -# TL_DISABLE_FAST_MATH has deprecated in v0.1.7.post1 tilelang -if hasattr(tilelang.PassConfigKey, "TL_DISABLE_FAST_MATH"): - pass_configs[tilelang.PassConfigKey.TL_DISABLE_FAST_MATH] = True -elif hasattr(tilelang.PassConfigKey, "TL_ENABLE_FAST_MATH"): - pass_configs[tilelang.PassConfigKey.TL_ENABLE_FAST_MATH] = False _is_hip = is_hip() _is_gfx95_supported = is_gfx95_supported() _is_fp8_fnuz = is_fp8_fnuz() -BF16 = "bfloat16" + FP8 = "float8_e4m3fnuz" if _is_fp8_fnuz else "float8_e4m3" +BF16 = "bfloat16" +if _is_fp8_fnuz: + FP8 = "float8_e4m3fnuz" + FP8_ = torch.float8_e4m3fnuz +else: + FP8 = "float8_e4m3" + FP8_ = torch.float8_e4m3fn FP32 = "float32" +INT32 = "int32" + +_is_hip = is_hip() def fast_log2_ceil(x): @@ -49,8 +56,11 @@ def act_quant_kernel( N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False ): M = T.symbolic("M") + # fp8_min = -448.0 fp8_min = -224.0 if _is_fp8_fnuz else -448.0 + # fp8_max = 448.0 fp8_max = 224.0 if _is_fp8_fnuz else 448.0 + # fp8_max_inv = 1 / fp8_max fp8_max_inv = 1 / fp8_max num_stages = 0 if round_scale else 2 blk_m = 32 @@ -115,6 +125,7 @@ def act_quant( x.size(-1) % block_size == 0 ), f"Last dimension size must be divisible by block_size (block_size={block_size})" N = x.size(-1) + # y = torch.empty_like(x, dtype=FP8_) if _is_fp8_fnuz: y = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) else: @@ -807,3 +818,112 @@ def tilelang_sparse_fwd( num_heads, d_v, tail_dim, topk, sm_scale=sm_scale ) return kernel(q.unsqueeze(0), kv.unsqueeze(0), indices.unsqueeze(0)) # type: ignore + + +@functools.cache +def fp8_paged_mqa_logits_kernel( + head_dim: int = 128, + num_heads: int = 64, + block_size: int = 64, + clear_accum: bool = True, +) -> Any: + N = T.symbolic("batch_size") + L = T.symbolic("max_table_length") + S = T.symbolic("max_seq_len") + C = T.symbolic("num_blocks") + B = block_size + D = head_dim + H = num_heads + d_0, d_1 = T.dynamic("d_0, d_1") + + assert D % 4 == 0 + assert H % 4 == 0 + assert D == 128 + + @tilelang.jit + def fp8_paged_mqa_logits( + q: T.Tensor[(N, H, D), FP8], + kvcache: T.StridedTensor[(C, B, D), (d_0, D, 1), FP8], + kvcache_scale: T.StridedTensor[(C, B), (d_1, 1), FP32], + weight: T.Tensor[(N, H), FP32], + seq_lens: T.Tensor[(N,), INT32], + page_table: T.Tensor[(N, L), INT32], + o: T.Tensor[(N, S), FP32], + ) -> None: + _ = N, L, S, C, D, H, B, d_0, d_1 + with T.Kernel(N) as bx: + seq_len = seq_lens[bx] + q_smem = T.alloc_shared((H, D), FP8) + q_s_frag = T.alloc_fragment((H,), FP32) + T.copy(q[bx, 0, 0], q_smem) + T.copy(weight[bx, 0], q_s_frag) + + for i in T.Pipelined(T.ceildiv(seq_len, B), num_stages=2): + page = page_table[bx, i] + k_smem = T.alloc_shared((B, D), FP8) + k_s_frag = T.alloc_fragment((B,), FP32) + # first B * D FP8 are cache; last 4 * D are D FP32 scales + T.copy(kvcache[page, 0, 0], k_smem) + T.copy(kvcache_scale[page, 0], k_s_frag) + + # shape: [B, H] + logits = T.alloc_fragment((B, H), FP32) + if not clear_accum: + T.fill(logits, 0.0) + T.gemm( + k_smem, + q_smem, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=clear_accum, + ) + + # post processing + for h, j in T.Parallel(H, B): + logits[j, h] = T.max(logits[j, h], 0.0) * q_s_frag[h] + logits_sum = T.alloc_fragment((B,), FP32) + T.reduce_sum(logits, logits_sum, dim=1) + for j in T.Parallel(B): + logits_sum[j] *= k_s_frag[j] + T.copy(logits_sum, o[bx, i * B]) + + return fp8_paged_mqa_logits + + +def tilelang_fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kvcache_fp8: torch.Tensor, + weight: torch.Tensor, + seq_lens: torch.Tensor, + page_table: torch.Tensor, + deep_gemm_metadata: Any, + max_seq_len: int, + clean_logits: bool = True, +) -> torch.Tensor: + _ = deep_gemm_metadata + batch_size, _, num_heads, head_dim = q_fp8.shape + block_size = kvcache_fp8.shape[1] + assert head_dim == 128, "TODO" + assert block_size == 64, "TODO" + assert q_fp8.shape == (batch_size, 1, num_heads, head_dim) + assert kvcache_fp8.shape[1:] == (block_size, 1, head_dim + 4) + assert weight.shape == (batch_size, num_heads) + assert seq_lens.shape == (batch_size,) + assert page_table.shape[0] == batch_size + assert clean_logits == False + + logits = page_table.new_empty((batch_size, max_seq_len), dtype=torch.float32) + kernel = fp8_paged_mqa_logits_kernel( + head_dim=head_dim, + num_heads=num_heads, + block_size=block_size, + clear_accum=clean_logits, + ) + q_fp8 = q_fp8.view(batch_size, num_heads, head_dim) + kvcache_fp8 = kvcache_fp8.view(-1, block_size * (head_dim + 4)) + kvcache = kvcache_fp8[..., : block_size * head_dim].view(dtype=torch.float8_e4m3fn) + kvcache = kvcache.view(-1, block_size, head_dim) + kvcache_scale = kvcache_fp8[..., block_size * head_dim :].view(dtype=torch.float32) + kernel(q_fp8, kvcache, kvcache_scale, weight, seq_lens, page_table, logits) + return logits diff --git a/python/sglang/srt/layers/attention/nsa/triton_kernel.py b/python/sglang/srt/layers/attention/nsa/triton_kernel.py index 9d970b83a96a..65c52082a5c0 100644 --- a/python/sglang/srt/layers/attention/nsa/triton_kernel.py +++ b/python/sglang/srt/layers/attention/nsa/triton_kernel.py @@ -5,6 +5,16 @@ import triton.language as tl +def _is_hip() -> bool: + """Check if running on AMD ROCm/HIP.""" + return hasattr(torch.version, "hip") and torch.version.hip is not None + + +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + + # Triton implementation @triton.jit def _act_quant_kernel( @@ -109,7 +119,7 @@ def act_quant( M = x_flat.size(0) # Allocate output tensors - y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + y = torch.empty_like(x, dtype=fp8_dtype) y_flat = y.view(-1, N) s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32) s_flat = s.view(-1, N // block_size) @@ -120,6 +130,14 @@ def act_quant( grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, block_size)) round_scale = scale_fmt is not None + # Determine num_stages based on round_scale and platform + # Note: num_stages=0 causes TritonAMDGPUStreamPipeline to fail on HIP + # Use num_stages=1 on HIP when round_scale is True + if round_scale: + num_stages = 1 if _is_hip() else 0 + else: + num_stages = 2 + _act_quant_kernel[grid]( x_flat, y_flat, @@ -130,7 +148,8 @@ def act_quant( round_scale=round_scale, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_stages=0 if round_scale else 2, + # num_stages=0 if round_scale else 2, + num_stages=num_stages, ) return y, s diff --git a/python/sglang/srt/layers/attention/triton_ops/compressed_metadata.py b/python/sglang/srt/layers/attention/triton_ops/compressed_metadata.py new file mode 100644 index 000000000000..d3c3275ea098 --- /dev/null +++ b/python/sglang/srt/layers/attention/triton_ops/compressed_metadata.py @@ -0,0 +1,464 @@ +""" +Triton kernels for fused compressed metadata initialization. + +These kernels replace the fragmented tensor operations in DSV4AttnMetadataRadix.init_compressed_metadata, +reducing kernel launch overhead from 10+ launches to 1. + +Set environment variable SGLANG_USE_TORCH_COMPRESS_METADATA=1 to use the original PyTorch implementation +instead of the Triton kernel (useful for debugging or compatibility). +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from sglang.srt.environ import envs + +# Environment variable to control implementation dispatch +# Set to "1" to use PyTorch implementation, otherwise use Triton +_USE_TRITON_IMPL = envs.SGLANG_OPT_USE_TRITON_CA_METADATA.get() + + +# ============================================================================= +# Triton Implementation +# ============================================================================= + + +@triton.jit +def _init_compressed_attn_metadata_kernel( + # Inputs + seq_lens_ptr, + positions_ptr, + raw_out_loc_ptr, + page_table_ptr, # Only used when COMPUTE_PAGE_INDICES=True + # Outputs (C4) + c4_out_loc_ptr, + c4_positions_ptr, + c4_seq_lens_raw_ptr, + c4_seq_lens_clamp1_ptr, + # Outputs (C128) + c128_out_loc_ptr, + c128_positions_ptr, + c128_seq_lens_clamp1_ptr, + c128_page_indices_ptr, # Only used when COMPUTE_PAGE_INDICES=True + # Meta + bs, + max_pages, + page_size: tl.constexpr, + c128_max_seq_len: tl.constexpr, + c128_page_size: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + COMPUTE_PAGE_INDICES: tl.constexpr, +): + """ + Unified fused kernel for both C4 and C128 compressed metadata computation. + + This kernel computes metadata for both compression ratios (4 and 128) in a single launch. + + For ratio=4: + should_compress = seq_lens % 4 == 0 + c4_out_loc = (raw_out_loc // 4) if should_compress else 0 + c4_positions = (positions // 4) * 4 + c4_seq_lens_raw = seq_lens // 4 + c4_seq_lens_clamp1 = max(c4_seq_lens_raw, 1) + + For ratio=128: + should_compress = seq_lens % 128 == 0 + c128_out_loc = (raw_out_loc // 128) if should_compress else 0 + c128_positions = (positions // 128) * 128 + c128_seq_lens_raw = seq_lens // 128 + c128_seq_lens_clamp1 = max(c128_seq_lens_raw, 1) + c128_page_indices[pos] = page_table[pos // c_page_size] * c_page_size + (pos % c_page_size) + if pos < c128_seq_lens_raw else -1 + """ + batch_id = tl.program_id(0) + if batch_id >= bs: + return + + # Load inputs for this batch element + seq_len = tl.load(seq_lens_ptr + batch_id) + position = tl.load(positions_ptr + batch_id) + raw_out_loc = tl.load(raw_out_loc_ptr + batch_id) + + # ========== C4 Metadata Computation ========== + # Compute compressed metadata for ratio=4 + c4_should_compress = (seq_len % 4) == 0 + c4_out_loc = tl.where(c4_should_compress, raw_out_loc // 4, 0) + # Use bit masking for efficiency: positions & ~3 == (positions // 4) * 4 + c4_positions = position & (~3) + c4_seq_lens_raw = seq_len // 4 + c4_seq_lens_clamp1 = tl.maximum(c4_seq_lens_raw, 1) + + # Store C4 outputs + tl.store(c4_out_loc_ptr + batch_id, c4_out_loc) + tl.store(c4_positions_ptr + batch_id, c4_positions) + tl.store(c4_seq_lens_raw_ptr + batch_id, c4_seq_lens_raw) + tl.store(c4_seq_lens_clamp1_ptr + batch_id, c4_seq_lens_clamp1) + + # ========== C128 Metadata Computation ========== + # Compute compressed metadata for ratio=128 + c128_should_compress = (seq_len % 128) == 0 + c128_out_loc = tl.where(c128_should_compress, raw_out_loc // 128, 0) + # Use bit masking: positions & ~127 == (positions // 128) * 128 + c128_positions = position & (~127) + c128_seq_lens_raw = seq_len // 128 + c128_seq_lens_clamp1 = tl.maximum(c128_seq_lens_raw, 1) + + # Store C128 scalar outputs + tl.store(c128_out_loc_ptr + batch_id, c128_out_loc) + tl.store(c128_positions_ptr + batch_id, c128_positions) + tl.store(c128_seq_lens_clamp1_ptr + batch_id, c128_seq_lens_clamp1) + + # ========== C128 Page Indices Computation (conditional) ========== + if COMPUTE_PAGE_INDICES: + # Compute page_indices for this batch element + # Process in blocks for efficiency + page_indices_base = batch_id * c128_max_seq_len + for block_start in range(0, c128_max_seq_len, BLOCK_SIZE): + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < c128_max_seq_len + + # Compute page index and offset within page + page_idx = offsets // c128_page_size + offset_in_page = offsets % c128_page_size + + # Load page table entries (with bounds check) + page_mask = mask & (page_idx < max_pages) + page_table_vals = tl.load( + page_table_ptr + batch_id * max_pages + page_idx, + mask=page_mask, + other=0, + ) + + # Compute c_page_indices = page_table[page_idx] * c_page_size + offset_in_page + c_page_indices_vals = page_table_vals * c128_page_size + offset_in_page + + # Mask out positions >= c128_seq_lens_raw with -1 + valid_mask = offsets < c128_seq_lens_raw + c_page_indices_vals = tl.where(valid_mask, c_page_indices_vals, -1) + + # Store page indices + tl.store( + c128_page_indices_ptr + page_indices_base + offsets, + c_page_indices_vals, + mask=mask, + ) + + +def _init_compressed_attn_metadata_triton( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: Optional[torch.Tensor] = None, + page_size: int = 0, + compute_page_indices: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, # C4 outputs + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], # C128 outputs +]: + """Triton implementation of compressed metadata computation.""" + bs = seq_lens.shape[0] + device = seq_lens.device + + # Allocate C4 output tensors + c4_out_loc = torch.empty(bs, dtype=torch.int32, device=device) + c4_positions = torch.empty(bs, dtype=torch.int32, device=device) + c4_seq_lens_raw = torch.empty(bs, dtype=torch.int32, device=device) + c4_seq_lens_clamp1 = torch.empty(bs, dtype=torch.int32, device=device) + + # Allocate C128 output tensors + c128_out_loc = torch.empty(bs, dtype=torch.int32, device=device) + c128_positions = torch.empty(bs, dtype=torch.int32, device=device) + c128_seq_lens_clamp1 = torch.empty(bs, dtype=torch.int32, device=device) + + # Handle page indices computation + if compute_page_indices: + assert ( + page_table is not None + ), "page_table required when compute_page_indices=True" + assert page_size > 0, "page_size required when compute_page_indices=True" + max_pages = page_table.shape[1] + c128_page_size = page_size // 128 + c128_max_seq_len = c128_page_size * max_pages + c128_page_indices = torch.empty( + bs, c128_max_seq_len, dtype=torch.int32, device=device + ) + BLOCK_SIZE = triton.next_power_of_2(max(c128_page_size, 64)) + else: + max_pages = 0 + c128_page_size = 1 + c128_max_seq_len = 0 + c128_page_indices = None + BLOCK_SIZE = 64 + # Create dummy page_table pointer if not provided + if page_table is None: + page_table = torch.empty(0, dtype=torch.int32, device=device) + + # Launch unified kernel + grid = (bs,) + _init_compressed_attn_metadata_kernel[grid]( + # Inputs + seq_lens, + positions, + raw_out_loc, + page_table, + # C4 outputs + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + # C128 outputs + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + ( + c128_page_indices + if c128_page_indices is not None + else torch.empty(0, dtype=torch.int32, device=device) + ), + # Meta + bs, + max_pages, + page_size if page_size > 0 else 128, # Default to avoid division by zero + c128_max_seq_len, + c128_page_size, + BLOCK_SIZE, + compute_page_indices, + ) + + return ( + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + c128_page_indices, + ) + + +# ============================================================================= +# PyTorch Reference Implementation (Original) +# ============================================================================= + + +def _init_compressed_attn_metadata_torch( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: Optional[torch.Tensor] = None, + page_size: int = 0, + compute_page_indices: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, # C4 outputs + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], # C128 outputs +]: + """ + Original PyTorch implementation of compressed metadata computation. + + This is the reference implementation that the Triton kernel replaces. + It's kept for debugging and compatibility purposes. + """ + cuda_int32_kwargs = {"device": seq_lens.device, "dtype": torch.int32} + + # ========== C4 Metadata Computation ========== + c4_should_compress = seq_lens % 4 == 0 + c4_seq_lens_raw = seq_lens // 4 + c4_positions = positions // 4 * 4 + c4_out_loc = torch.where(c4_should_compress, raw_out_loc // 4, 0) + c4_seq_lens_clamp1 = torch.clamp(c4_seq_lens_raw, min=1) + + # ========== C128 Metadata Computation ========== + c128_should_compress = seq_lens % 128 == 0 + c128_seq_lens_raw = seq_lens // 128 + c128_positions = positions // 128 * 128 + c128_out_loc = torch.where(c128_should_compress, raw_out_loc // 128, 0) + c128_seq_lens_clamp1 = torch.clamp(c128_seq_lens_raw, min=1) + + # ========== C128 Page Indices Computation ========== + if compute_page_indices: + assert ( + page_table is not None + ), "page_table required when compute_page_indices=True" + assert page_size > 0, "page_size required when compute_page_indices=True" + + c128_page_size = page_size // 128 + max_pages = page_table.size(1) + c128_max_seq_len = c128_page_size * max_pages + + # [bs, max_pages] -> [bs, max_pages, c_page_size] -> [bs, c_max_seq_len] + c_offsets = torch.arange(c128_max_seq_len, **cuda_int32_kwargs) + c128_page_indices = ( + (page_table.unsqueeze(2) * c128_page_size + c_offsets[:c128_page_size]) + .to(torch.int32) + .contiguous() + .view(-1, c128_max_seq_len) + ) + # Mask out positions >= c128_seq_lens_raw with -1 + mask = c_offsets.unsqueeze(0) >= c128_seq_lens_raw.unsqueeze(1) + c128_page_indices.masked_fill_(mask, -1) + else: + c128_page_indices = None + + return ( + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + c128_page_indices, + ) + + +# ============================================================================= +# Public API (dispatches based on environment variable) +# ============================================================================= + + +def init_compressed_metadata( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: Optional[torch.Tensor] = None, + page_size: int = 0, + compute_page_indices: bool = True, +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, # C4 outputs + torch.Tensor, + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], # C128 outputs +]: + """ + Unified function for compressed metadata computation. + + Computes both C4 and C128 metadata. Uses Triton kernel by default for + better performance. Set SGLANG_OPT_USE_TRITON_COMPRESS_METADATA=1 to use + the original PyTorch implementation instead. + + Args: + seq_lens: [bs] int32, sequence lengths + positions: [bs] int32, positions (expanded causally) + raw_out_loc: [bs] int32, raw output locations + page_table: [bs, max_pages] int32, page table (required if compute_page_indices=True) + page_size: int, page size (required if compute_page_indices=True) + compute_page_indices: bool, whether to compute c128_page_indices + + Returns: + Tuple of: + - c4_out_loc: [bs] int32 + - c4_positions: [bs] int32 + - c4_topk_lengths_raw: [bs] int32 + - c4_topk_lengths_clamp1: [bs] int32 + - c128_out_loc: [bs] int32 + - c128_positions: [bs] int32 + - c128_topk_lengths_clamp1: [bs] int32 + - c128_page_indices: [bs, c_max_seq_len] int32 or None + """ + if not _USE_TRITON_IMPL: + return _init_compressed_attn_metadata_torch( + seq_lens, + positions, + raw_out_loc, + page_table, + page_size, + compute_page_indices, + ) + else: + return _init_compressed_attn_metadata_triton( + seq_lens, + positions, + raw_out_loc, + page_table, + page_size, + compute_page_indices, + ) + + +# ============================================================================= +# Backward Compatibility Wrappers +# ============================================================================= + + +def init_c4_metadata( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Backward compatible wrapper for C4-only metadata computation. + + Note: This still computes C128 metadata internally but discards it. + For better performance, use init_compressed_metadata() directly. + """ + ( + c4_out_loc, + c4_positions, + c4_seq_lens_raw, + c4_seq_lens_clamp1, + _, + _, + _, + _, + ) = init_compressed_metadata( + seq_lens, + positions, + raw_out_loc, + page_table=None, + page_size=0, + compute_page_indices=False, + ) + return c4_out_loc, c4_positions, c4_seq_lens_raw, c4_seq_lens_clamp1 + + +def init_c128_metadata( + seq_lens: torch.Tensor, + positions: torch.Tensor, + raw_out_loc: torch.Tensor, + page_table: torch.Tensor, + page_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Backward compatible wrapper for C128-only metadata computation. + + Note: This still computes C4 metadata internally but discards it. + For better performance, use init_compressed_metadata() directly. + """ + ( + _, + _, + _, + _, + c128_out_loc, + c128_positions, + c128_seq_lens_clamp1, + c128_page_indices, + ) = init_compressed_metadata( + seq_lens, + positions, + raw_out_loc, + page_table=page_table, + page_size=page_size, + compute_page_indices=True, + ) + return c128_out_loc, c128_positions, c128_seq_lens_clamp1, c128_page_indices diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index de8f7983f360..998989caac70 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -12,6 +12,10 @@ import triton import triton.language as tl +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz + +fp8_dtype = torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn + from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph from sglang.srt.layers.attention.flashinfer_mla_backend import ( FlashInferMLAAttnBackend, @@ -197,7 +201,7 @@ def unpad_draft_extend_output_kernel( def _quantize_fp8_qkv(q, k, v, layer): - q = q.to(torch.float8_e4m3fn) + q = q.to(fp8_dtype) k_scale = getattr(layer, "k_scale_float", None) if k_scale is None: @@ -209,7 +213,7 @@ def _quantize_fp8_qkv(q, k, v, layer): ) k = k_2d.reshape(k.shape) else: - k = k.to(torch.float8_e4m3fn) + k = k.to(fp8_dtype) v_scale = getattr(layer, "v_scale_float", None) if v_scale is None: @@ -221,7 +225,7 @@ def _quantize_fp8_qkv(q, k, v, layer): ) v = v_2d.reshape(v.shape) else: - v = v.to(torch.float8_e4m3fn) + v = v.to(fp8_dtype) return q, k, v, k_scale, v_scale @@ -702,7 +706,7 @@ def quantize_and_rope_for_fp8( - k_nope_out: [seq_len, num_heads, kv_lora_rank], dtype=torch.float8_e4m3fn - k_rope_out: [seq_len, num_heads, qk_rope_head_dim], dtype=torch.float8_e4m3fn """ - attn_dtype = torch.float8_e4m3fn + attn_dtype = fp8_dtype q_len, num_heads = q_rope.shape[0], q_rope.shape[1] # Allocate output tensors with FP8 dtype @@ -840,7 +844,7 @@ def forward_decode( ) -> torch.Tensor: """Run forward for decode using TRTLLM MLA kernel.""" merge_query = q_rope is not None - if self.data_type == torch.float8_e4m3fn: + if self.data_type == fp8_dtype: # For FP8 path, we quantize the query and rope parts and merge them into a single tensor # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend assert all( @@ -965,7 +969,7 @@ def forward_extend( # TODO refactor to avoid code duplication merge_query = q_rope is not None if ( - self.data_type == torch.float8_e4m3fn + self.data_type == fp8_dtype ) and forward_batch.forward_mode.is_target_verify(): # For FP8 path, we quantize the query and rope parts and merge them into a single tensor # Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend @@ -1130,7 +1134,7 @@ def forward_extend( v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim) q_scale = k_scale = v_scale = 1.0 - if self.data_type == torch.float8_e4m3fn: + if self.data_type == fp8_dtype: q, k, v, k_scale, v_scale = _quantize_fp8_qkv(q, k, v, layer) common_trtllm_args = { diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py index 88d0a959b156..5d5544be45b0 100644 --- a/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py +++ b/python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py @@ -4,6 +4,7 @@ import torch +from sglang.srt.environ import envs from sglang.srt.layers.deep_gemm_wrapper import compile_utils from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401 DEEPGEMM_BLACKWELL, @@ -39,6 +40,13 @@ def grouped_gemm_nt_f8f8bf16_masked( _sanity_check_input(lhs) _sanity_check_input(rhs) + if envs.SGLANG_HACK_SKIP_FP4_FP8_GEMM.get(): + out.zero_() + return + + lhs = _ensure_cuda(lhs) + rhs = _ensure_cuda(rhs) + with compile_utils.deep_gemm_execution_hook( expected_m, n, k, num_groups, kernel_type ): @@ -46,12 +54,20 @@ def grouped_gemm_nt_f8f8bf16_masked( overlap_args.num_sms if overlap_args is not None else None ): + fp4_kwargs = ( + dict(recipe_a=(1, 128), recipe_b=(1, 32)) + if envs.SGLANG_DSV4_MODE.get() == "2604" + and envs.SGLANG_DSV4_FP4_EXPERTS.get() + else {} + ) + return deep_gemm.fp8_m_grouped_gemm_nt_masked( lhs, rhs, out, masked_m, expected_m, + **fp4_kwargs, **( dict( enable_overlap=True, @@ -64,6 +80,15 @@ def grouped_gemm_nt_f8f8bf16_masked( ) +def _ensure_cuda( + pair: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[torch.Tensor, torch.Tensor]: + return ( + pair[0].cuda() if not pair[0].is_cuda else pair[0], + pair[1].cuda() if not pair[1].is_cuda else pair[1], + ) + + def grouped_gemm_nt_f8f8bf16_contig( lhs: Tuple[torch.Tensor, torch.Tensor], rhs: Tuple[torch.Tensor, torch.Tensor], @@ -74,11 +99,25 @@ def grouped_gemm_nt_f8f8bf16_contig( num_groups, n, _ = rhs[0].shape kernel_type = compile_utils.DeepGemmKernelType.GROUPED_GEMM_NT_F8F8BF16_CONTIG + if m == 0: + return + _sanity_check_input(lhs) _sanity_check_input(rhs) + if envs.SGLANG_HACK_SKIP_FP4_FP8_GEMM.get(): + out.zero_() + return + fp4_kwargs = ( + dict(recipe_a=(1, 128), recipe_b=(1, 32)) + if envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get() + else {} + ) + with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type): - deep_gemm.m_grouped_fp8_gemm_nt_contiguous(lhs, rhs, out, m_indices) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous( + lhs, rhs, out, m_indices, **fp4_kwargs + ) def gemm_nt_f8f8bf16( diff --git a/python/sglang/srt/layers/deep_gemm_wrapper/paged_mqa_logits.py b/python/sglang/srt/layers/deep_gemm_wrapper/paged_mqa_logits.py new file mode 100644 index 000000000000..a313f991ac7b --- /dev/null +++ b/python/sglang/srt/layers/deep_gemm_wrapper/paged_mqa_logits.py @@ -0,0 +1,77 @@ +import deep_gemm + +from dataclasses import dataclass +from typing import List, Union + +import torch + +from sglang.srt.environ import envs + +@dataclass +class _PagedMqaLogitsMetadataChunk: + start: int + end: int + schedule_meta: torch.Tensor + + +@dataclass +class _PagedMqaLogitsMetadata: + chunks: List[_PagedMqaLogitsMetadataChunk] + + def copy_(self, other: "_PagedMqaLogitsMetadata"): + raise Exception("Not expect to be copied") + + +def get_paged_mqa_logits_metadata_chunked( + context_lens: torch.Tensor, + block_kv: int, + num_sms: int, +) -> Union[_PagedMqaLogitsMetadata, torch.Tensor]: + chunk_size = envs.SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE.get() + batch_size = context_lens.shape[0] + + if batch_size <= chunk_size: + return deep_gemm.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms) + + chunks: List[_PagedMqaLogitsMetadataChunk] = [] + for start in range(0, batch_size, chunk_size): + end = min(start + chunk_size, batch_size) + schedule_meta = deep_gemm.get_paged_mqa_logits_metadata( + context_lens[start:end], block_kv, num_sms, + ) + chunks.append(_PagedMqaLogitsMetadataChunk(start=start, end=end, schedule_meta=schedule_meta)) + + return _PagedMqaLogitsMetadata(chunks=chunks) + + +def fp8_paged_mqa_logits_chunked( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_table: torch.Tensor, + schedule_meta: Union[_PagedMqaLogitsMetadata, torch.Tensor], + max_context_len: int, + clean_logits: bool, +) -> torch.Tensor: + if not isinstance(schedule_meta, _PagedMqaLogitsMetadata): + return deep_gemm.fp8_paged_mqa_logits( + q, kv_cache, weights, context_lens, block_table, + schedule_meta, max_context_len, clean_logits, + ) + + all_logits = [] + for chunk_meta in schedule_meta.chunks: + chunk_logits = deep_gemm.fp8_paged_mqa_logits( + q[chunk_meta.start:chunk_meta.end], + kv_cache, + weights[chunk_meta.start:chunk_meta.end], + context_lens[chunk_meta.start:chunk_meta.end], + block_table[chunk_meta.start:chunk_meta.end], + chunk_meta.schedule_meta, + max_context_len, + clean_logits, + ) + all_logits.append(chunk_logits) + + return torch.cat(all_logits, dim=0) diff --git a/python/sglang/srt/layers/deepseek_v4_rope.py b/python/sglang/srt/layers/deepseek_v4_rope.py new file mode 100644 index 000000000000..88a057befe29 --- /dev/null +++ b/python/sglang/srt/layers/deepseek_v4_rope.py @@ -0,0 +1,241 @@ +import math +from functools import lru_cache +from typing import Optional + +import tilelang +import torch +import triton +import triton.language as tl + +from sglang.srt.utils.common import maybe_torch_compile + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + + +@lru_cache(2) +def precompute_freqs_cis( + dim, seqlen, original_seq_len, base, factor, beta_fast, beta_slow +) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (ModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + + def find_correction_dim(num_rotations, dim, base, max_seq_len): + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min, max, dim): + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + if original_seq_len > 0: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + t = torch.arange(seqlen) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +@maybe_torch_compile +def apply_rotary_emb( + x: torch.Tensor, freqs_cis: torch.Tensor, inverse: bool = False +) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Adopted from DeepSeek's reference implementation, but adapted to sglang input formats: + - 2D: x [bs, rope_dim], freqs_cis [bs, rope_dim // 2] + - 3D: x [bs, n_heads, rope_dim], freqs_cis [bs, rope_dim // 2] + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + y = x + x = torch.view_as_complex(x.float().unflatten(-1, (-1, 2))) + if inverse: + freqs_cis = freqs_cis.conj() + if x.ndim == 3: + # x: [bs, n_heads, rope_dim // 2], freqs_cis: [bs, rope_dim // 2] + # -> reshape freqs_cis to [bs, 1, rope_dim // 2] to broadcast over n_heads + freqs_cis = freqs_cis.unsqueeze(1) + # For 2D case should directly match: x [bs, rope_dim // 2], freqs_cis [bs, rope_dim // 2] + x = torch.view_as_real(x * freqs_cis).flatten(-2) + y.copy_(x) + return y + + +@triton.jit +def apply_rotary_emb_triton_kernel( + x_ptr, + freqs_ptr, + positions_ptr, + rope_dim, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_freq_pos, + stride_freq_dim, + USE_POS: tl.constexpr, + IS_INVERSE: tl.constexpr, + IS_3D: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + pid_dim = tl.program_id(2) + + # Get position: from tensor or directly use pid_batch + if USE_POS: + position = tl.load(positions_ptr + pid_batch) + else: + position = pid_batch + + if IS_3D: + # [bs, n_heads, rope_dim] + base_offset = pid_batch * stride_x_batch + pid_head * stride_x_head + else: + # [bs, rope_dim] + base_offset = pid_batch * stride_x_batch + + offs_pair = pid_dim * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs_pair < (rope_dim // 2) + + # real is even, imag is odd + offs_x_real = base_offset + offs_pair * 2 * stride_x_dim + offs_x_imag = base_offset + (offs_pair * 2 + 1) * stride_x_dim + + x_real = tl.load(x_ptr + offs_x_real, mask=mask, other=0.0).to(tl.float32) + x_imag = tl.load(x_ptr + offs_x_imag, mask=mask, other=0.0).to(tl.float32) + + offs_freq_real = position * stride_freq_pos + offs_pair * 2 * stride_freq_dim + offs_freq_imag = position * stride_freq_pos + (offs_pair * 2 + 1) * stride_freq_dim + + freq_real = tl.load(freqs_ptr + offs_freq_real, mask=mask, other=0.0) + freq_imag = tl.load(freqs_ptr + offs_freq_imag, mask=mask, other=0.0) + + if IS_INVERSE: + # (a + bi) * (c - di) = (ac + bd) + (bc - ad)i + out_real = x_real * freq_real + x_imag * freq_imag + out_imag = x_imag * freq_real - x_real * freq_imag + else: + # (a + bi) * (c + di) = (ac - bd) + (ad + bc)i + out_real = x_real * freq_real - x_imag * freq_imag + out_imag = x_real * freq_imag + x_imag * freq_real + + tl.store(x_ptr + offs_x_real, out_real, mask=mask) + tl.store(x_ptr + offs_x_imag, out_imag, mask=mask) + + +def apply_rotary_emb_triton( + x: torch.Tensor, + freqs_cis: torch.Tensor, + positions: Optional[torch.Tensor] = None, + inverse: bool = False, +) -> torch.Tensor: + """ + Args: + x: 2d [bs, rope_dim] or 3d [bs, n_heads, rope_dim] + freqs_cis: + - If positions is None: [bs, rope_dim // 2] (already indexed) + - If positions is not None: [max_seqlen, rope_dim // 2] (full table) + positions: Optional[bs], if provided will index into freqs_cis + inverse: bool, if True, apply inverse rotation (conjugate) + Returns: + x with rotary embeddings applied (inplace) + """ + is_3d = x.ndim == 3 + + if is_3d: + batch_size, n_heads, rope_dim = x.shape + else: + batch_size, rope_dim = x.shape + n_heads = 1 + + freqs_real = torch.view_as_real(freqs_cis).flatten(-2) + + BLOCK_SIZE = 128 + + num_blocks_dim = triton.cdiv(rope_dim // 2, BLOCK_SIZE) + grid = (batch_size, n_heads if is_3d else 1, num_blocks_dim) + + if positions is not None: + # use positions to index into freqs_cis + assert positions.shape == ( + batch_size, + ), f"positions shape {positions.shape} != ({batch_size},)" + + apply_rotary_emb_triton_kernel[grid]( + x, + freqs_real, + positions, + rope_dim, + x.stride(0), + x.stride(1) if is_3d else 0, + x.stride(-1), + freqs_real.stride(0), + freqs_real.stride(1), + USE_POS=True, + IS_INVERSE=inverse, + IS_3D=is_3d, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + # freqs_cis already indexed, use pid_batch as position + assert ( + freqs_real.shape[0] == batch_size + ), f"freqs_cis batch size {freqs_real.shape[0]} != x batch size {batch_size}" + + apply_rotary_emb_triton_kernel[grid]( + x, + freqs_real, + None, + rope_dim, + x.stride(0), + x.stride(1) if is_3d else 0, + x.stride(-1), + freqs_real.stride(0), + freqs_real.stride(1), + USE_POS=False, + IS_INVERSE=inverse, + IS_3D=is_3d, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return x diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 7fde05894b59..7e5730c1583d 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -97,7 +97,11 @@ def __init__( self.variance_size_override = ( None if var_hidden_size == hidden_size else var_hidden_size ) - if _use_aiter: + # if _use_aiter: + # self._forward_method = self.forward_aiter + if get_bool_env_var("SGLANG_USE_NATIVE_LAYERNORM"): + self._forward_method = self.forward_native + elif _use_aiter: self._forward_method = self.forward_aiter def forward_cuda( @@ -154,11 +158,28 @@ def forward_aiter( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if _is_hip: + ### amd dsk + # Handle empty tensor case + if x.shape[0] == 0: + if residual is not None: + return x, residual + return x + + # # DEBUG: skip rms_norm,return tensor + # if residual is not None: + # # return same shape 的 tensor + # # output = torch.empty_like(x) + # output = torch.zeros_like(x) + # # residual_out = torch.empty_like(x) + # residual_out = torch.zeros_like(x) + # return output, residual_out + # # return torch.empty_like(x) + # return torch.zeros_like(x) + # else: if residual is not None: residual_out = torch.empty_like(x) output = torch.empty_like(x) - if post_residual_addition is not None: - residual = residual + post_residual_addition fused_add_rms_norm( output, x, diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 39919eedd72c..cd492e7fab80 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -252,7 +252,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.dtype == loaded_weight.dtype ), "init para dtype and loaded weight dtype should be the same" - assert param.size() == loaded_weight.size() + assert ( + param.size() == loaded_weight.size() + ), f"{param.shape=} {param.dtype=} {loaded_weight.shape=} {loaded_weight.dtype=}" param.data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py new file mode 100644 index 000000000000..08526a982d63 --- /dev/null +++ b/python/sglang/srt/layers/mhc.py @@ -0,0 +1,686 @@ +import functools +import math +from typing import Tuple + +import tilelang +import tilelang.language as T +import torch + +from sglang.jit_kernel.utils import is_arch_support_pdl +from sglang.srt.layers.attention.nsa.utils import is_nsa_prefill_cp_round_robin_split + +tilelang.set_log_level("WARNING") + +pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, +} + +FP8 = "float8_e4m3" +BF16 = "bfloat16" +FP32 = "float32" +INT32 = "int32" + + +@tilelang.jit(pass_configs=pass_configs) +def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float): + n = T.symbolic("n") + mix_hc = (2 + hc) * hc + threads = 64 + + ENABLE_PDL = is_arch_support_pdl() + + @T.prim_func + def hc_split_sinkhorn_kernel_( + mixes: T.Tensor[(n, mix_hc), FP32], + hc_scale: T.Tensor[(3,), T.float32], + hc_base: T.Tensor[(mix_hc,), T.float32], + pre: T.Tensor[(n, hc), FP32], + post: T.Tensor[(n, hc), FP32], + comb: T.Tensor[(n, hc, hc), FP32], + ): + with T.Kernel(n, threads=threads) as i: + if ENABLE_PDL: + T.pdl_sync() + + mixes_shared = T.alloc_shared(mix_hc, FP32) + comb_frag = T.alloc_fragment((hc, hc), FP32) + T.copy(mixes[i, :], mixes_shared) + + for j in T.Parallel(hc): + pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps + for j in T.Parallel(hc): + post[i, j] = 2 * T.sigmoid( + mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc] + ) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = ( + mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + + hc_base[j * hc + k + hc * 2] + ) + + row_sum = T.alloc_fragment(hc, FP32) + col_sum = T.alloc_fragment(hc, FP32) + + # comb = comb.softmax(-1) + eps + row_max = T.alloc_fragment(hc, FP32) + T.reduce_max(comb_frag, row_max, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j]) + T.reduce_sum(comb_frag, row_sum, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(comb_frag, col_sum, dim=0) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) + + for _ in T.serial(sinkhorn_iters - 1): + # comb = comb / (comb.sum(-1) + eps) + T.reduce_sum(comb_frag, row_sum, dim=1) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps) + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(comb_frag, col_sum, dim=0) + for j, k in T.Parallel(hc, hc): + comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) + + T.copy(comb_frag, comb[i, :, :]) + if ENABLE_PDL: + T.pdl_trigger() + + return hc_split_sinkhorn_kernel_ + + +def hc_split_sinkhorn( + mixes: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + hc_mult: int = 4, + sinkhorn_iters: int = 20, + eps: float = 1e-6, +): + b, s, _ = mixes.size() + pre = mixes.new_empty(b, s, hc_mult) + post = mixes.new_empty(b, s, hc_mult) + comb = mixes.new_empty(b, s, hc_mult, hc_mult) + kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps) + kernel( + mixes.view(-1, (2 + hc_mult) * hc_mult), + hc_scale, + hc_base, + pre.view(-1, hc_mult), + post.view(-1, hc_mult), + comb.view(-1, hc_mult, hc_mult), + ) + return pre, post, comb + + +# Adapted from https://github.com/tile-ai/tilelang/blob/5fe8b84313083d0a4035849c9282f06586c93d58/examples/deepseek_mhc/example_mhc_pre.py +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 16, + hc_mult: int = 4, +): + """Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block.""" + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + hidden_block = math.gcd(512, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] + hc_scale: T.Tensor[[3], T.float32] + hc_base: T.Tensor[[hc_mult3], T.float32] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] + # outputs + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(num_tokens, threads=96) as i: + ################################################################## + # _pre_norm_fn_fwd_norm + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + + if ENABLE_PDL: + T.pdl_sync() + + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + ################################################################## + # _pre_split_mixes_fwd (post & comb) + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = ( + T.sigmoid( + mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult] + ) + * hc_post_mult_value + ) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = ( + mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + + hc_base[j * hc_mult + k + hc_mult * 2] + ) + + ################################################################## + # _sinkhorn_fwd + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + # comb = comb.softmax(-1) + eps + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + # comb = comb / (comb.sum(-1) + eps) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + # save comb_mix to global memory + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + ################################################################## + # _pre_split_mixes_fwd (pre) + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + ################################################################### + # _pre_apply_mix_fwd + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + xs = T.alloc_shared((hc_mult, hidden_block), T.float32) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + + if ENABLE_PDL: + T.pdl_trigger() + + +# Adapted from https://github.com/tile-ai/tilelang/blob/5fe8b84313083d0a4035849c9282f06586c93d58/examples/deepseek_mhc/example_mhc_pre.py +@tilelang.jit +def mhc_pre_gemm_sqrsum_tilelang( + x, + fn, + out, + sqrsum, + hc_mult3: int, + hc_hidden_size: int, + token_block: int = 32, + hidden_block: int = 256, +) -> tilelang.JITKernel: + """Not highly optimized TileLang implementation of fused gemm and sqrsum in mHC pre block.""" + assert hc_mult3 <= 32 # should be 24 usually + num_tokens = T.dynamic("num_tokens") + assert hc_hidden_size % hidden_block == 0 + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) + fn: T.Tensor((hc_mult3, hc_hidden_size), T.float32) + out: T.Tensor((num_tokens, hc_mult3), T.float32) + sqrsum: T.Tensor((num_tokens), T.float32) + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(T.ceildiv(num_tokens, token_block)) as px: + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sqrsum_part = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sqrsum_part) + if ENABLE_PDL: + T.pdl_sync() + for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2): + x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout( + {x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)} + ) + + T.copy(x[px * token_block, pz * hidden_block], x_smem_16) + T.copy(fn[0, pz * hidden_block], fn_smem) + + x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem_16, x_frag_16) + x_frag = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_frag_16, x_frag) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j] + + # should be TF32 gemm + T.gemm( + x_frag, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + sqrsum_l = T.alloc_fragment(token_block, T.float32) + T.reduce_sum(sqrsum_part, sqrsum_l) + for i in T.Parallel(token_block): + sqrsum[px * token_block + i] = sqrsum_l[i] + for i, j in T.Parallel(token_block, 32): + if j < hc_mult3: + out[px * token_block + i, j] = out_frag[i, j] + if ENABLE_PDL: + T.pdl_trigger() + + +@functools.cache +def mhc_pre_gemm_sqrsum_splitk_kernel( + hc_mult3: int, + hc_hidden_size: int, + split_k: int, + token_block: int = 32, + hidden_block: int = 256, + threads: int = 128, +) -> Tuple[tilelang.JITKernel, tilelang.JITKernel]: + assert hc_mult3 <= 32 + assert hc_hidden_size % hidden_block == 0 + assert hc_hidden_size % split_k == 0 + split_size = hc_hidden_size // split_k + assert split_size % hidden_block == 0 + + num_tokens = T.dynamic("num_tokens") + + ENABLE_PDL = is_arch_support_pdl() + + @tilelang.jit + def mhc_pre_gemm_sqrsum_splitk_stage_0( + x: T.Tensor[(num_tokens, hc_hidden_size), T.bfloat16], + fn: T.Tensor[(hc_mult3, hc_hidden_size), T.float32], + out_partial: T.Tensor[(split_k, num_tokens, 32), T.float32], + sqrsum_partial: T.Tensor[(split_k, num_tokens), T.float32], + ): + with T.Kernel(T.ceildiv(num_tokens, token_block), split_k, threads=threads) as ( + px, + bz, + ): + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sq_part4 = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sq_part4) + + k_base = bz * split_size + + if ENABLE_PDL: + T.pdl_sync() + + for pz in T.Pipelined(split_size // hidden_block, num_stages=2): + x_smem = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout( + {x_smem: tilelang.layout.make_swizzled_layout(x_smem)} + ) + + T.copy(x[px * token_block, k_base + pz * hidden_block], x_smem) + T.copy(fn[0, k_base + pz * hidden_block], fn_smem) + + x_f16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem, x_f16) + x_f = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_f16, x_f) + + # partial sqrsum for this tile + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + v = x_f[i, jj * 4 + j] + sq_part4[i, j] += v * v + + # partial GEMM accumulate + T.gemm( + x_f, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + wg_wait=0, + clear_accum=False, + ) + + # reduce 4 lanes -> scalar per token + sq_l = T.alloc_fragment((token_block,), T.float32) + T.reduce_sum(sq_part4, sq_l) + + # write to workspace (NO atomic; each (px,bz) unique) + for i in T.Parallel(token_block): + t = px * token_block + i + if t < num_tokens: + sqrsum_partial[bz, t] = sq_l[i] + + for i, j in T.Parallel(token_block, 32): + t = px * token_block + i + if t < num_tokens: + out_partial[bz, t, j] = out_frag[i, j] + + if ENABLE_PDL: + T.pdl_trigger() + + @tilelang.jit + def mhc_pre_gemm_sqrsum_splitk_stage_1( + out_partial: T.Tensor[(split_k, num_tokens, 32), T.float32], + sqrsum_partial: T.Tensor[(split_k, num_tokens), T.float32], + out: T.Tensor[(num_tokens, hc_mult3), T.float32], + sqrsum: T.Tensor[(num_tokens,), T.float32], + ): + warps_per_cta = threads // 32 + num_reduce = T.ceildiv(split_k, 32) + with T.Kernel(T.ceildiv(num_tokens, warps_per_cta), threads=threads) as (px,): + tx = T.get_thread_binding() + warp = tx // 32 + lane = tx % 32 + t = px * warps_per_cta + warp + s = T.alloc_local((1,), T.float32) + acc = T.alloc_local((1,), T.float32) + s[0] = 0 + acc[0] = 0 + if ENABLE_PDL: + T.pdl_sync() + + if t < num_tokens: + for r in T.serial(num_reduce): + bz = r * 32 + lane + s[0] += T.if_then_else(bz < split_k, sqrsum_partial[bz, t], 0.0) + sqrsum[t] = T.warp_reduce_sum(s[0]) + if lane < hc_mult3: + for bz in T.serial(split_k): + acc[0] += out_partial[bz, t, lane] + out[t, lane] = acc[0] + + if ENABLE_PDL: + T.pdl_trigger() + + return ( # type: ignore + mhc_pre_gemm_sqrsum_splitk_stage_0, + mhc_pre_gemm_sqrsum_splitk_stage_1, + ) + + +# Adapted from https://github.com/tile-ai/tilelang/blob/5fe8b84313083d0a4035849c9282f06586c93d58/examples/deepseek_mhc/example_mhc_pre.py +def mhc_pre( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, + n_splits_pre: int = 32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for mHC pre block. + + Args: + residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16 + fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 + hc_scale: shape (3,), dtype torch.float32 + hc_base: shape (hc_mult3,), dtype torch.float32 + rms_eps: RMS normalization epsilon + hc_pre_eps: pre-mix epsilon + hc_sinkhorn_eps: sinkhorn epsilon + hc_post_mult_value: post-mix multiplier value + sinkhorn_repeat: number of sinkhorn iterations + n_splits: split-k factor; TileLang version of mhc_pre_gemm_sqrsum doesn't support this + + Returns: + post_mix: shape (..., hc_mult), dtype torch.float32 + comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32 + layer_input: shape (..., hidden_size), dtype torch.bfloat16 + """ + + # Validate shapes + assert residual.dtype == torch.bfloat16 + assert fn.dtype == torch.float32 + assert hc_scale.dtype == torch.float32 + assert hc_base.dtype == torch.float32 + + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + hc_hidden_size = hc_mult * hidden_size + assert fn.shape[0] == hc_mult3 + assert fn.shape[1] == hc_hidden_size + assert hc_scale.shape == (3,) + assert hc_base.shape == (hc_mult3,) + + outer_shape = residual.shape[:-2] + + residual_flat = residual.view(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + fn_flat = fn + + post_mix = torch.empty( + num_tokens, hc_mult, dtype=torch.float32, device=residual.device + ) + comb_mix = torch.empty( + num_tokens, hc_mult2, dtype=torch.float32, device=residual.device + ) + layer_input = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device + ) + + gemm_out_mul = torch.empty( + n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device + ) + gemm_out_sqrsum = torch.empty( + n_splits, num_tokens, dtype=torch.float32, device=residual.device + ) + + if num_tokens <= 2048: + assert n_splits == 1 + kernel_0, kernel_1 = mhc_pre_gemm_sqrsum_splitk_kernel( + hc_mult3, + hc_hidden_size, + split_k=n_splits_pre, + token_block=32, + ) + partial_out = gemm_out_mul.new_empty(n_splits_pre, num_tokens, 32) + partial_sqrsum = gemm_out_sqrsum.new_empty(n_splits_pre, num_tokens) + kernel_0( + residual_flat.view(num_tokens, hc_hidden_size), + fn_flat, + partial_out, + partial_sqrsum, + ) + kernel_1( + partial_out, + partial_sqrsum, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + ) + del partial_out, partial_sqrsum + else: + assert ( + n_splits == 1 + ), "The simple TileLang version gemm_sqrsum doesn't support split-k" + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + + return post_mix, comb_mix, layer_input + + +# Adapted from https://github.com/tile-ai/tilelang/blob/5fe8b84313083d0a4035849c9282f06586c93d58/examples/deepseek_mhc/example_mhc_post.py +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, + }, +) +def mhc_post_tilelang( + a, b, c, d, x, hc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024 +) -> tilelang.JITKernel: + # rename for shorter code + n = T.dynamic("num_tokens") + h = hidden + + h_blk = math.gcd(hidden, h_blk) + a: T.Tensor((n, hc, hc), T.float32) + b: T.Tensor((n, hc, h), T.bfloat16) + c: T.Tensor((n, hc), T.float32) + d: T.Tensor((n, h), T.bfloat16) + x: T.Tensor((n, hc, h), T.bfloat16) + + ENABLE_PDL = is_arch_support_pdl() + with T.Kernel(n, threads=n_thr) as i_n: + if ENABLE_PDL: + T.pdl_sync() + + x_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + b_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + d_shared = T.alloc_shared(h_blk, T.bfloat16) + + x_local = T.alloc_fragment((hc, h_blk), T.float32) + b_local = T.alloc_fragment((hc, h_blk), T.float32) + d_local = T.alloc_fragment(h_blk, T.float32) + + a_local = T.alloc_fragment((hc, hc), T.float32) + c_local = T.alloc_fragment(hc, T.float32) + T.copy(a[i_n, 0, 0], a_local) + T.copy(c[i_n, 0], c_local) + + for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2): + T.copy(b[i_n, 0, i0_h * h_blk], b_shared) + T.copy(d[i_n, i0_h * h_blk], d_shared) + + T.copy(b_shared, b_local) + T.copy(d_shared, d_local) + for i_hco, i1_h in T.Parallel(hc, h_blk): + x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h] + for i_hci in T.serial(hc): + x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] + T.copy(x_local, x_shared) + + T.copy(x_shared, x[i_n, 0, i0_h * h_blk]) + + if ENABLE_PDL: + T.pdl_trigger() + + +# Adapted from https://github.com/tile-ai/tilelang/blob/5fe8b84313083d0a4035849c9282f06586c93d58/examples/deepseek_mhc/example_mhc_post.py +def mhc_post( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + if is_nsa_prefill_cp_round_robin_split(): + x = x.contiguous() + residual = residual.contiguous() + post_layer_mix = post_layer_mix.contiguous() + comb_res_mix = comb_res_mix.contiguous() + out = torch.empty_like(residual) + mhc_post_tilelang( + comb_res_mix, + residual, + post_layer_mix.squeeze(-1), + x, + out, + residual.shape[-2], + residual.shape[-1], + ) + return out diff --git a/python/sglang/srt/layers/moe/deepseek_v4_topk.py b/python/sglang/srt/layers/moe/deepseek_v4_topk.py new file mode 100644 index 000000000000..41f7d848a4bf --- /dev/null +++ b/python/sglang/srt/layers/moe/deepseek_v4_topk.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import logging +from typing import Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.environ import envs +from sglang.srt.eplb.expert_location_dispatch import ( + ExpertLocationDispatchInfo, + topk_ids_logical_to_physical, +) +from sglang.srt.utils import ( + cpu_has_amx_support, + get_bool_env_var, + get_compiler_backend, + is_cpu, + is_cuda, + is_hip, + is_npu, +) + +logger = logging.getLogger(__name__) +_is_cuda = is_cuda() +_is_hip = is_hip() +_is_cpu = is_cpu() +_is_cpu_amx_available = cpu_has_amx_support() +_is_npu = is_npu() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + + +from sglang.srt.layers.moe.topk import StandardTopKOutput, _mask_topk_ids_padded_region + + +class HashTopK(nn.Module): + def __init__( + self, + topk, + num_experts, + num_fused_shared_experts, + vocab_size, + scoring_func="sqrtsoftplus", + routed_scaling_factor=1.5, + apply_routed_scaling_factor_on_output=False, + ): + super().__init__() + self.num_experts = num_experts + self.topk = topk + self.routed_scaling_factor = routed_scaling_factor + self.num_fused_shared_experts = num_fused_shared_experts + self.score_func = scoring_func + self.tid2eid = nn.Parameter( + torch.empty(vocab_size, topk - num_fused_shared_experts, dtype=torch.int32), + requires_grad=False, + ) + + if get_bool_env_var("SGLANG_HACK_TID2EID_INIT_ZERO"): + print("hack: tid2eid init to zero") + nn.init.constant_(self.tid2eid, 0) + + assert not apply_routed_scaling_factor_on_output, "not implemented" + + def empty_topk_output(self, device: torch.device): + topk = self.topk - self.num_fused_shared_experts + topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device) + topk_ids = torch.full((0, topk), -1, dtype=torch.int32, device=device) + router_logits = torch.empty((0, topk), dtype=torch.float32, device=device) + return StandardTopKOutput(topk_weights, topk_ids, router_logits) + + def _forward_torch( + self, router_logits: torch.Tensor, input_ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.score_func == "softmax": + scores = router_logits.softmax(dim=-1) + elif self.score_func == "sigmoid": + scores = router_logits.sigmoid() + else: + scores = torch.nn.functional.softplus(router_logits).sqrt() + + num_token = scores.shape[0] + + topk_ids = torch.zeros( + (num_token, self.topk), dtype=torch.int32, device=scores.device + ) + topk_weights = torch.zeros( + (num_token, self.topk), dtype=scores.dtype, device=scores.device + ) + + if self.num_fused_shared_experts == 1: + # Hash MoE: get routed expert IDs and weights + topk_ids[:, :-1] = self.tid2eid[input_ids] + topk_weights[:, :-1] = scores.gather(1, topk_ids[:, :-1]) + + if self.score_func != "softmax": + topk_weights[:, :-1] /= topk_weights[:, :-1].sum(dim=-1, keepdim=True) + + # reference: biased_grouped_topk_impl in topk.py + topk_ids[:, -1] = torch.randint( + low=self.num_experts, + high=self.num_experts + self.num_fused_shared_experts, + size=(num_token,), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + + # don't apply routed scaling factor here + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / self.routed_scaling_factor + ) + else: + topk_ids[:, :] = self.tid2eid[input_ids] + topk_weights[:, :] = scores.gather(1, topk_ids[:, :]) + if self.score_func != "softmax": + topk_weights[:, :] /= topk_weights[:, :].sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + input_ids: torch.Tensor, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ): + assert ( + input_ids.shape[0] == hidden_states.shape[0] == router_logits.shape[0] + ), f"{input_ids.shape=} {hidden_states.shape=} {router_logits.shape=}" + + if envs.SGLANG_OPT_USE_FUSED_HASH_TOPK.get(): + from sglang.jit_kernel.deepseek_v4 import hash_topk + + topk_weights, topk_ids = hash_topk( + router_logits=router_logits, + input_ids=input_ids, + tid2eid=self.tid2eid, + num_fused_shared_experts=self.num_fused_shared_experts, + routed_scaling_factor=self.routed_scaling_factor, + scoring_func=self.score_func, + ) + else: + topk_weights, topk_ids = self._forward_torch(router_logits, input_ids) + + if is_hip(): + topk_weights = topk_weights.to(torch.float32) + + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + topk_output = StandardTopKOutput( + topk_weights=topk_weights, topk_ids=topk_ids, router_logits=router_logits + ) + return topk_output + + +@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu) +def biased_topk_impl( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + if scoring_func == "sigmoid": + scores = gating_output.sigmoid() + elif scoring_func == "sqrtsoftplus": + scores = torch.nn.functional.softplus(gating_output).sqrt() + + num_token = scores.shape[0] + num_experts = scores.shape[1] + + scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) + _, topk_ids = torch.topk( + scores_for_choice, + k=topk, + dim=-1, + sorted=(True if num_fused_shared_experts > 0 else False), + ) + topk_weights = scores.gather(1, topk_ids) + + if num_fused_shared_experts: + topk_ids[:, -1] = torch.randint( + low=num_experts, + high=num_experts + num_fused_shared_experts, + size=(topk_ids.size(0),), + dtype=topk_ids.dtype, + device=topk_ids.device, + ) + if routed_scaling_factor is not None: + topk_weights[:, -1] = ( + topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor + ) + + if renormalize: + topk_weights_sum = ( + topk_weights.sum(dim=-1, keepdim=True) + if num_fused_shared_experts == 0 + else topk_weights[:, :-1].sum(dim=-1, keepdim=True) + ) + topk_weights = topk_weights / topk_weights_sum + if apply_routed_scaling_factor_on_output: + topk_weights *= routed_scaling_factor + + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids + + +def biased_topk_jit_kernel_impl( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + scoring_func: str = "sigmoid", + num_fused_shared_experts: int = 0, + routed_scaling_factor: Optional[float] = None, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + apply_routed_scaling_factor_on_output: Optional[bool] = False, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + from sglang.jit_kernel.moe_fused_gate import moe_fused_gate + + topk_weights, topk_ids = moe_fused_gate( + gating_output, + correction_bias, + topk=topk, + scoring_func=scoring_func, + num_fused_shared_experts=num_fused_shared_experts, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32) + topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info) + _mask_topk_ids_padded_region(topk_ids, num_token_non_padded) + return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..eacde3f6b8fb --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..b60f7dc039df --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_5_1/E=257,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index a1885fade143..2cf0a5670517 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -13,6 +13,8 @@ import torch.nn.functional as F import triton.language as tl +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker +from sglang.srt.environ import envs from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.utils import ( cpu_has_amx_support, @@ -87,6 +89,7 @@ def inplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ) -> None: fused_experts_impl( hidden_states, @@ -117,6 +120,7 @@ def inplace_fused_experts( gemm1_alpha, gemm1_limit, filter_expert, + swiglu_limit=swiglu_limit, ) @@ -149,6 +153,7 @@ def outplace_fused_experts( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -179,6 +184,7 @@ def outplace_fused_experts( gemm1_alpha=gemm1_alpha, gemm1_limit=gemm1_limit, filter_expert=filter_expert, + swiglu_limit=swiglu_limit, ) @@ -237,6 +243,7 @@ def fused_experts( moe_runner_config.gemm1_alpha, moe_runner_config.gemm1_clamp_limit, filter_expert, + moe_runner_config.swiglu_limit, ) return hidden_states else: @@ -268,6 +275,7 @@ def fused_experts( gemm1_alpha=moe_runner_config.gemm1_alpha, gemm1_limit=moe_runner_config.gemm1_clamp_limit, filter_expert=filter_expert, + swiglu_limit=moe_runner_config.swiglu_limit, ) @@ -319,6 +327,7 @@ def fused_experts_impl( gemm1_alpha: Optional[float] = None, gemm1_limit: Optional[float] = None, filter_expert: bool = True, + swiglu_limit: Optional[float] = None, ): padded_size = padding_size if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: @@ -469,6 +478,7 @@ def fused_experts_impl( filter_expert=filter_expert, ) + # Activation function with multiplication if activation == "silu" and is_gated: if gemm1_alpha is not None: @@ -478,23 +488,51 @@ def fused_experts_impl( gemm1_alpha, gemm1_limit, ) - elif _is_cuda or _is_hip: - if not filter_expert: - silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" + assert is_2604b == (swiglu_limit is not None), \ + f"swiglu_limit must be non-None iff submode=2604B " \ + f"(got submode={envs.SGLANG_DSV4_2604_SUBMODE.get()!r}, swiglu_limit={swiglu_limit!r})" + + swiglu_limit_for_triton: Optional[float] = None + if is_2604b: + assert swiglu_limit == 10 + assert intermediate_cache1.shape == (total_tokens, N) + assert (_is_cuda or _is_hip), \ + "DSV4 2604 submode 2604B only supports CUDA/HIP downstream" + + if envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get(): + # Fusion path passes the limit into act_and_mul_triton, which + # only runs on the filter_expert=True branch below. + assert filter_expert, \ + "SGLANG_OPT_SWIGLU_CLAMP_FUSION requires filter_expert=True (downstream must be act_and_mul_triton)" + swiglu_limit_for_triton = swiglu_limit + else: + # In-place clamp works with either downstream kernel because the + # clamped values are already written back to intermediate_cache1. + half = N // 2 + intermediate_cache1[:, :half].clamp_(max=swiglu_limit) + intermediate_cache1[:, half:].clamp_(min=-swiglu_limit, max=swiglu_limit) + deepseek_v4_moe_code_path_checker.observed += 1 + + if _is_cuda or _is_hip: + if not filter_expert: + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + else: + act_and_mul_triton( + intermediate_cache1.view(-1, N), + intermediate_cache2, + config, + topk_ids, + expert_ids, + down_moe_use_tma, + activation, + swiglu_limit=swiglu_limit_for_triton, + ) else: - act_and_mul_triton( - intermediate_cache1.view(-1, N), - intermediate_cache2, - config, - topk_ids, - expert_ids, - down_moe_use_tma, - activation, + vllm_ops.silu_and_mul( + intermediate_cache2, intermediate_cache1.view(-1, N) ) - else: - vllm_ops.silu_and_mul( - intermediate_cache2, intermediate_cache1.view(-1, N) - ) elif activation == "gelu" and is_gated: assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" assert gemm1_limit is None, "gemm1_limit is not supported for gelu" @@ -525,6 +563,7 @@ def fused_experts_impl( else: raise ValueError(f"Unsupported activation: {activation=}, with {is_gated=}") + invoke_fused_moe_kernel( intermediate_cache2, w2, @@ -557,6 +596,7 @@ def fused_experts_impl( filter_expert=filter_expert, ) + if routed_scaling_factor is None: routed_scaling_factor = 1.0 @@ -587,7 +627,8 @@ def fused_experts_impl( ) elif _is_hip: - if _use_aiter: + _force_triton = envs.SGLANG_FORCE_TRITON_MOE_FP8.get() + if _use_aiter and not _force_triton: moe_sum( intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx], diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py index 230b64057ab4..5fa6249d5345 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py @@ -27,6 +27,7 @@ is_hip, is_sm90_supported, ) +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker try: from triton.tools.tensor_descriptor import TensorDescriptor @@ -871,6 +872,8 @@ def act_and_mul_kernel( expert_step: tl.constexpr, BLOCK_SIZE: tl.constexpr, ACTIVATION_TYPE: tl.constexpr, + SWIGLU_LIMIT: tl.constexpr = 0.0, + HAS_SWIGLU_LIMIT: tl.constexpr = False, ): """ Unified activation and multiply kernel that handles both sorted and unsorted routing, @@ -899,6 +902,10 @@ def act_and_mul_kernel( gate_output = tl.load(gate_output_ptr + offset, mask=mask) up_output = tl.load(up_output_ptr + offset, mask=mask) + if HAS_SWIGLU_LIMIT: + gate_output = tl.minimum(gate_output, SWIGLU_LIMIT) + up_output = tl.maximum(tl.minimum(up_output, SWIGLU_LIMIT), -SWIGLU_LIMIT) + gate_output_activated = _apply_activation(gate_output, ACTIVATION_TYPE) gate_output_activated = gate_output_activated.to(InDtype) @@ -915,6 +922,7 @@ def act_and_mul_triton( expert_ids: Optional[torch.Tensor] = None, down_moe_use_tma: bool = False, activation: str = "silu", + swiglu_limit: Optional[float] = None, ) -> None: """ Args: @@ -925,11 +933,16 @@ def act_and_mul_triton( expert_ids: Expert IDs for sorted routing (used when down_moe_use_tma=True) down_moe_use_tma: Whether to use sorted routing layout activation: Activation type ("silu" or "gelu") + swiglu_limit: if not None, clamp gate to [-inf, L] and up to [-L, L] before activation + (compiles a separate kernel variant via tl.constexpr). """ grid = (down_input.shape[0],) hidden_size = gateup_output.shape[1] expert_ids_row = topk_ids.view(-1) if not down_moe_use_tma else expert_ids expert_step = 1 if not down_moe_use_tma else config["BLOCK_SIZE_M"] + has_swiglu_limit = swiglu_limit is not None + if has_swiglu_limit: + deepseek_v4_moe_code_path_checker.observed += 1 act_and_mul_kernel[grid]( gateup_output, down_input, @@ -938,6 +951,8 @@ def act_and_mul_triton( expert_step, BLOCK_SIZE=512, ACTIVATION_TYPE=activation, + SWIGLU_LIMIT=float(swiglu_limit) if has_swiglu_limit else 0.0, + HAS_SWIGLU_LIMIT=has_swiglu_limit, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 019843ae0365..1647473469e0 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -181,6 +181,7 @@ def __init__( routed_scaling_factor: Optional[float] = None, gemm1_alpha: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None, + swiglu_limit: Optional[float] = None, use_weight_loader_fused: bool = False, with_bias=False, routing_method_type: Optional[RoutingMethodType] = None, @@ -255,6 +256,7 @@ def __init__( routed_scaling_factor=routed_scaling_factor, gemm1_alpha=gemm1_alpha, gemm1_clamp_limit=gemm1_clamp_limit, + swiglu_limit=swiglu_limit, is_gated=is_gated, routing_method_type=routing_method_type, ) diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py index 12dd2ba6a237..088fbfbef33d 100644 --- a/python/sglang/srt/layers/moe/moe_runner/base.py +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -48,6 +48,7 @@ class MoeRunnerConfig: routed_scaling_factor: Optional[float] = None gemm1_alpha: Optional[float] = None gemm1_clamp_limit: Optional[float] = None + swiglu_limit: Optional[float] = None @dataclass diff --git a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py index 7fa8193fb328..8c54cb66aed2 100644 --- a/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py +++ b/python/sglang/srt/layers/moe/moe_runner/deep_gemm.py @@ -1,10 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Tuple +import einops import torch +from sglang.jit_kernel.deepseek_v4 import silu_and_mul_masked_post_quant +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker +from sglang.srt.environ import envs from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.moe_runner.base import ( MoeQuantInfo, @@ -106,6 +110,7 @@ def __init__(self, config: MoeRunnerConfig): super().__init__(config) assert self.config.activation == "silu" assert self.config.is_gated + self.swiglu_limit = self.config.swiglu_limit def run( self, @@ -146,6 +151,7 @@ def _run_contiguous_gemm( K = hidden_states_shape[1] scale_block_size = 128 + # TODO: this can be fp4 indeed for new model, we should rename it w13_weight_fp8 = ( quant_info.w13_weight, quant_info.w13_scale, @@ -169,6 +175,10 @@ def _run_contiguous_gemm( dispose_tensor(hidden_states) dispose_tensor(hidden_states_scale) + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + gateup_output = _apply_swiglu_limit(gateup_output, swiglu_limit=self.swiglu_limit) + deepseek_v4_moe_code_path_checker.observed += 1 + down_input = torch.empty( ( all_tokens, @@ -213,12 +223,6 @@ def _run_masked_gemm( running_state: dict, ) -> torch.Tensor: from sglang.srt.layers import deep_gemm_wrapper - from sglang.srt.layers.moe.ep_moe.kernels import ( - silu_and_mul_masked_post_quant_fwd, - ) - from sglang.srt.layers.quantization.fp8_kernel import ( - sglang_per_token_group_quant_8bit, - ) hidden_states = runner_input.hidden_states hidden_states_scale = runner_input.hidden_states_scale @@ -262,47 +266,33 @@ def _run_masked_gemm( dispose_tensor(hidden_states) dispose_tensor(hidden_states_scale) + is_2604b = envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" + assert is_2604b == (self.swiglu_limit is not None), \ + f"swiglu_limit must be non-None iff submode=2604B (got submode={envs.SGLANG_DSV4_2604_SUBMODE.get()!r}, swiglu_limit={self.swiglu_limit!r})" + + swiglu_limit_arg: Optional[float] = None + if is_2604b: + assert not _MASKED_GEMM_FAST_ACT, \ + "DSV4 2604 submode 2604B does not support SGLANG_MASKED_GEMM_FAST_ACT" + assert envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get(), \ + "DSV4 2604 submode 2604B requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + + if envs.SGLANG_OPT_SWIGLU_CLAMP_FUSION.get(): + swiglu_limit_arg = self.swiglu_limit + else: + gateup_output = einops.rearrange(gateup_output, 'grp tok hidden -> (grp tok) hidden') + gateup_output = _apply_swiglu_limit(gateup_output, swiglu_limit=self.swiglu_limit) + gateup_output = einops.rearrange(gateup_output, '(grp tok) hidden -> grp tok hidden', grp=num_groups) + deepseek_v4_moe_code_path_checker.observed += 1 + # Act - scale_block_size = 128 - if _MASKED_GEMM_FAST_ACT: - down_input, down_input_scale = sglang_per_token_group_quant_8bit( - x=gateup_output, - dst_dtype=torch.float8_e4m3fn, - group_size=scale_block_size, - masked_m=masked_m, - column_major_scales=True, - scale_tma_aligned=True, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - fuse_silu_and_mul=True, - enable_v2=True, - ) - else: - down_input = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2, - ), - device=hidden_states_device, - dtype=torch.float8_e4m3fn, - ) - down_input_scale = torch.empty( - ( - gateup_output.shape[0], - gateup_output.shape[1], - gateup_output.shape[2] // 2 // scale_block_size, - ), - device=hidden_states_device, - dtype=torch.float32, - ) - silu_and_mul_masked_post_quant_fwd( - gateup_output, - down_input, - down_input_scale, - scale_block_size, - masked_m, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - ) + down_input, down_input_scale = _varlen_deep_gemm_silu_mul_quant( + gateup_output, + masked_m, + group_size=128, + topk=self.config.top_k, + swiglu_limit=swiglu_limit_arg, + ) del gateup_output # GroupGemm-1 @@ -604,3 +594,101 @@ def post_permute_deep_gemm_to_deepep_normal( topk_ids=running_state["topk_ids"], topk_weights=running_state["topk_weights"], ) + + +def _varlen_deep_gemm_silu_mul_quant( + gateup_output: torch.Tensor, + masked_m: Optional[torch.Tensor], + group_size: int, + topk: int, + swiglu_limit: Optional[float] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd + from sglang.srt.layers.quantization.fp8_kernel import ( + sglang_per_token_group_quant_8bit, + ) + + if _MASKED_GEMM_FAST_ACT: + assert swiglu_limit is None, \ + "swiglu_limit (DSV4 2604 submode 2604B) is not supported together with SGLANG_MASKED_GEMM_FAST_ACT" + return sglang_per_token_group_quant_8bit( + x=gateup_output, + dst_dtype=torch.float8_e4m3fn, + group_size=group_size, + masked_m=masked_m, + column_major_scales=True, + scale_tma_aligned=True, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + fuse_silu_and_mul=True, + enable_v2=True, + ) + + assert masked_m is not None + hidden_states_device = gateup_output.device + E, N, D_2 = gateup_output.shape + D = D_2 // 2 + del D_2 + G = D // group_size + down_input = torch.empty( + (E, N, D), + device=hidden_states_device, + dtype=torch.float8_e4m3fn, + ) + + if envs.SGLANG_OPT_USE_JIT_EP_ACTIVATION.get(): + assert N % 4 == 0 and G % 4 == 0 + down_input_scale = torch.empty( + (E, G // 4, N), + device=hidden_states_device, + dtype=torch.int32, + ) + silu_and_mul_masked_post_quant( + gateup_output, + down_input, + down_input_scale, + group_size, + masked_m, + scale_ue8m0=True, + topk=topk, + transposed=True, + swiglu_limit=swiglu_limit, + ) + down_input_scale = down_input_scale.transpose(-1, -2) + else: + assert swiglu_limit is None, \ + "swiglu_limit (DSV4 2604 submode 2604B) requires SGLANG_OPT_USE_JIT_EP_ACTIVATION=True" + down_input_scale = torch.empty( + (E, N, G), + device=hidden_states_device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + gateup_output, + down_input, + down_input_scale, + group_size, + masked_m, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ) + return down_input, down_input_scale + + +# TODO: it the weight non-interleaved? +# TODO: also, is the weight gate first, up later? +def _apply_swiglu_limit(gateup_output: torch.Tensor, swiglu_limit: float) -> torch.Tensor: + assert swiglu_limit == 10 + + num_tokens, hidden_size_x2 = gateup_output.shape + assert hidden_size_x2 == 2048 * 2 + assert gateup_output.dtype == torch.bfloat16 + + gate, up = torch.chunk(gateup_output, chunks=2, dim=-1) + assert gate.shape == (num_tokens, hidden_size_x2 // 2) + assert up.shape == (num_tokens, hidden_size_x2 // 2) + + up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) + gate = torch.clamp(gate, max=swiglu_limit) + + out = torch.cat([gate, up], dim=-1) + assert out.shape == (num_tokens, hidden_size_x2) + return out diff --git a/python/sglang/srt/layers/moe/routed_experts_capturer.py b/python/sglang/srt/layers/moe/routed_experts_capturer.py index 00bd68755587..37b1505f6261 100644 --- a/python/sglang/srt/layers/moe/routed_experts_capturer.py +++ b/python/sglang/srt/layers/moe/routed_experts_capturer.py @@ -8,10 +8,13 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.layers.dp_attention import ( + attn_tp_all_gather_into_tensor, get_attention_dp_rank, + get_attention_tp_size, get_dp_local_info, is_dp_attention_enabled, ) +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import get_global_server_args @@ -181,6 +184,17 @@ def __init__( device=device, ) + if get_moe_a2a_backend().is_deepep(): + attn_tp_size = get_attention_tp_size() if is_dp_attention_enabled() else 1 + self.gather_buffer = torch.empty( + ( + self.device_cache.buffer.shape[0] * attn_tp_size, + self.device_cache.buffer.shape[2], + ), + dtype=torch.int32, + device=device, + ) + def _sync_fwd_experts_buffer_DtoH( self, forward_batch: ForwardBatch, @@ -206,6 +220,12 @@ def _sync_fwd_experts_buffer_DtoH( ].cpu() def capture(self, layer_id: int, topk_ids: torch.Tensor): + if get_moe_a2a_backend().is_deepep(): + local_topk_ids = topk_ids + topk_ids = self.gather_buffer[ + : local_topk_ids.size(0) * get_attention_tp_size() + ] + attn_tp_all_gather_into_tensor(topk_ids, local_topk_ids) self.device_cache.capture_fwd_routed_experts(layer_id, topk_ids) def get_routed_experts( @@ -281,7 +301,7 @@ def set_global_experts_capturer(capturer: RoutedExpertsCapturer): def extract_routed_experts_from_meta_info(data): # To solve the performance issue, we return the experts_ids in base64 # We left this function for user to change it back to normal int32 - # See detokenizer_manager::_extract_routed_experts + # See detokenizer_manager::_extract_topk_base64 routed_experts_base64 = data["meta_info"].get("routed_experts", None) routed_experts = np.frombuffer( pybase64.b64decode(routed_experts_base64.encode("utf-8")), dtype=np.int32 diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 419786c2f06e..21e10e6e4253 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -39,6 +39,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.environ import envs from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( @@ -996,17 +997,39 @@ def select_experts( ) elif custom_routing_function is None: assert not apply_routed_scaling_factor_on_output, "Not implemented" - # Qwen3MOE uses fused_topk - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=num_routed_topk if _use_aiter else top_k, - renormalize=renormalize, - correction_bias=correction_bias, - num_token_non_padded=num_token_non_padded, - expert_location_dispatch_info=expert_location_dispatch_info, - scoring_func=scoring_func, - ) + if scoring_func == "sqrtsoftplus": + if envs.SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK.get(): + from sglang.srt.layers.moe.deepseek_v4_topk import ( + biased_topk_jit_kernel_impl as biased_topk_impl, + ) + else: + from sglang.srt.layers.moe.deepseek_v4_topk import biased_topk_impl + + topk_weights, topk_ids = biased_topk_impl( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=num_routed_topk if _use_aiter else top_k, + renormalize=renormalize, + scoring_func=scoring_func, + num_fused_shared_experts=num_fused_shared_experts, + routed_scaling_factor=routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + else: + # Qwen3MOE uses fused_topk + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=num_routed_topk if _use_aiter else top_k, + renormalize=renormalize, + correction_bias=correction_bias, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + scoring_func=scoring_func, + ) else: assert ( num_token_non_padded is None diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index 565f5b9fd202..fe8cd56318cd 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -33,6 +33,7 @@ def _dtype_rank(dtype: torch.dtype) -> Optional[int]: torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, ): return 0 if dtype in (torch.float16, torch.bfloat16): @@ -69,6 +70,9 @@ def copy_with_check(target: torch.Tensor, loaded_weight: torch.Tensor): raise ValueError( f"Downcasting not allowed: {target.dtype=}, {loaded_weight.dtype=}" ) + # safety extra check + if loaded_rank == torch.float8_e8m0fnu: + assert target_rank in {torch.float8_e8m0fnu, torch.float32} target.copy_(loaded_weight) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 573f69a3c4e9..18708a05a357 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -92,6 +92,14 @@ _use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") and _is_hip _use_aiter = envs.SGLANG_USE_AITER.get() and _is_hip + +def _use_aiter_moe() -> bool: + # with SGLANG_FORCE_TRITON_MOE_FP8=1, FP8 MoE skips + # aiter (no weight shuffle, no aiter dispatch) while the rest of the aiter paths + # keep running. Must be checked at call time because the flag is a runtime toggle. + return _use_aiter and not envs.SGLANG_FORCE_TRITON_MOE_FP8.get() + + if _use_aiter or _use_hip_int4: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -373,6 +381,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: input_scale=None, ) layer.input_scale = None + elif _is_cpu: assert ( _is_cpu_amx_available @@ -403,14 +412,17 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: ) and (not layer.weight_scale_inv.format_ue8m0) ): + requant_weight_ue8m0_inplace( layer.weight, layer.weight_scale_inv, self.quant_config.weight_block_size, ) layer.weight_scale_inv.format_ue8m0 = True + weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data + layer.weight.data = weight.data layer.weight_scale_inv.data = weight_scale.data @@ -601,6 +613,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + self.is_fp4_expert = ( + envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get() + ) if get_moe_runner_backend().is_cutlass(): assert ( cutlass_fp8_supported() @@ -660,7 +675,28 @@ def create_weights( ) # WEIGHTS - if _is_hip and _use_hip_int4: + if self.is_fp4_expert: + # FP4 E2M1 packed as uint8: 2 FP4 values per byte, K dim halved + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + # The provided .safetensors uses `int8` + dtype=torch.int8, + ), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + elif _is_hip and _use_hip_int4: # INT4 MoE weight - INT32 packed w13_weight = torch.nn.Parameter( torch.empty( @@ -707,7 +743,34 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - if self.block_quant: + if self.is_fp4_expert: + # TODO: temp for the ckpt + assert hidden_size == 4096 + assert intermediate_size_per_partition == 2048 + + # FP4: per-row on N, per-32 on K + fp4_block_k = 32 + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + elif self.block_quant: w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, @@ -834,7 +897,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: w2_weight_scale, requires_grad=False ) layer.w2_input_scale = None - if _use_aiter: + if _use_aiter_moe(): # add this section for MI300 # Pre-shuffle weights layer.w13_weight.data = shuffle_weight( @@ -843,7 +906,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: layer.w2_weight.data = shuffle_weight( layer.w2_weight.contiguous(), (16, 16) ) - elif _use_aiter: + elif _use_aiter_moe(): # Pre-shuffle weights layer.w13_weight.data = shuffle_weight( layer.w13_weight.contiguous(), (16, 16) @@ -858,6 +921,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) else: # For fp8 moe run with deepgemm, the expert weights and scales need be requantized to ue8m0 + from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE from sglang.srt.model_loader.utils import ( should_deepgemm_weight_requant_ue8m0, @@ -866,8 +930,44 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: # Check if MoE will actually use DeepGEMM runner will_use_deepgemm = self.is_deepgemm_moe_runner_backend_enabled() + if self.is_fp4_expert: + layer.w13_weight.data = layer.w13_weight.data.view(torch.uint8) + layer.w2_weight.data = layer.w2_weight.data.view(torch.uint8) + + # Pre-convert FP4 weight scales from FP32 to UE8M0 packed INT32 at init time, + # eliminating the runtime transpose_and_pack_fp32_into_ue8m0 kernel in deep_gemm. + # Only on Blackwell (SM100+) where DEEPGEMM_SCALE_UE8M0 is True. + if ( + envs.SGLANG_OPT_DEEPGEMM_SCALE_CONVERT_AT_INIT.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and will_use_deepgemm + ): + from deep_gemm import transform_sf_into_required_layout + + for scale_param, weight_param in [ + (layer.w13_weight_scale_inv, layer.w13_weight), + (layer.w2_weight_scale_inv, layer.w2_weight), + ]: + num_experts, n, _ = scale_param.data.shape + # FP4 weight already viewed as uint8: shape (E, N, K/2), so K = shape[2] * 2 + k = weight_param.shape[2] * 2 + scale_param.data = transform_sf_into_required_layout( + scale_param.data, + mn=n, + k=k, + recipe=None, + recipe_ab=(1, 32), + num_groups=num_experts, + is_sfa=False, + disable_ue8m0_cast=False, + ) + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True + if ( - should_deepgemm_weight_requant_ue8m0( + not self.is_fp4_expert + and should_deepgemm_weight_requant_ue8m0( weight_block_size=getattr( self.quant_config, "weight_block_size", None ), @@ -1067,7 +1167,7 @@ def process_weights_hip_scale_padding(self, layer: Module): padding_size, # Avoid circular import ) - if _use_aiter: + if _use_aiter_moe(): layer.w13_weight = torch.nn.Parameter( shuffle_weight(layer.w13_weight.data, (16, 16)), requires_grad=False, @@ -1378,8 +1478,16 @@ def maybe_apply_hip_fused_experts( ), ) - if _use_aiter: + if _use_aiter_moe(): assert not no_combine, f"{no_combine=} is not supported." + # Keep deepseek_v4.py's per-layer "observed == 1" sanity check in sync + # when the aiter dispatch bypasses the Triton/deep_gemm clamp logic. + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + from sglang.srt.debug_utils.deepseek_v4_debug_utils import ( + deepseek_v4_moe_code_path_checker, + ) + + deepseek_v4_moe_code_path_checker.observed += 1 if self.block_quant: return fused_moe( x, diff --git a/python/sglang/srt/layers/topk_capturer_base.py b/python/sglang/srt/layers/topk_capturer_base.py new file mode 100644 index 000000000000..da354451d662 --- /dev/null +++ b/python/sglang/srt/layers/topk_capturer_base.py @@ -0,0 +1,147 @@ +import logging +from typing import Optional + +import torch + +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + +logger = logging.getLogger(__name__) + +_GB = 1024 * 1024 * 1024 +_MB = 1024 * 1024 + + +def get_tensor_size_bytes(t: torch.Tensor): + import numpy as np + + return int(np.prod(t.shape)) * t.dtype.itemsize + + +class BaseDeviceCache: + def __init__( + self, max_batch_size: int, num_layers: int, topk_size: int, device: str + ): + self.buffer = torch.zeros( + (max_batch_size, num_layers, topk_size), + dtype=torch.int32, + device=device, + ) + self.num_layers = num_layers + self.topk_size = topk_size + + def capture(self, layer_id: int, topk_indices: torch.Tensor): + batch = topk_indices.shape[0] + topk_dim = min(topk_indices.shape[1], self.topk_size) + self.buffer[:batch, layer_id, :topk_dim] = topk_indices[:, :topk_dim] + + def get_buffer_size_bytes(self): + return get_tensor_size_bytes(self.buffer) + + +class BaseHostCache: + def __init__(self, num_tokens: int, num_layers: int, topk_size: int): + self.buffer = torch.zeros( + (num_tokens, num_layers, topk_size), + dtype=torch.int32, + device="cpu", + pin_memory=True, + ) + self.num_tokens = num_tokens + self.num_layers = num_layers + self.topk_size = topk_size + + def get_buffer_size_bytes(self): + return get_tensor_size_bytes(self.buffer) + + +class BaseTopkCapturer: + def __init__( + self, + num_tokens: int, + max_batch_size: int, + num_layers: int, + topk_size: int, + device: str, + ): + self.num_layers = num_layers + self.topk_size = topk_size + + self.host_cache = BaseHostCache(num_tokens, num_layers, topk_size) + self.device_cache = BaseDeviceCache( + max_batch_size, num_layers, topk_size, device + ) + + def capture(self, layer_id: int, topk_indices: torch.Tensor): + self.device_cache.capture(layer_id, topk_indices) + + def _sync_to_host( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ): + from sglang.srt.layers.dp_attention import ( + get_attention_dp_rank, + get_dp_local_info, + is_dp_attention_enabled, + ) + + if is_dp_attention_enabled(): + local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) + if can_run_graph: + local_start_pos = get_attention_dp_rank() * cuda_graph_batch + local_end_pos = local_start_pos + local_num_tokens + else: + local_end_pos = local_start_pos + local_num_tokens + else: + local_start_pos = 0 + local_end_pos = forward_batch.out_cache_loc.shape[0] + + out_cache_loc_cpu = forward_batch.out_cache_loc.cpu() + self.host_cache.buffer[out_cache_loc_cpu] = self.device_cache.buffer[ + local_start_pos:local_end_pos, :, : self.topk_size + ].cpu() + + def get_topk( + self, + req_pool_idx: int, + seqlen: int, + req_to_token_pool: ReqToTokenPool, + ) -> torch.Tensor: + cache_pool_idx = ( + req_to_token_pool.req_to_token[req_pool_idx][: seqlen - 1].cpu().clone() + ) + return self.host_cache.buffer[cache_pool_idx] + + def on_forward_end( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ): + self._sync_to_host(forward_batch, can_run_graph, cuda_graph_batch) + + def is_enabled(self) -> bool: + return True + + +class BaseTopkCapturerNoop: + def capture(self, layer_id: int, topk_indices: torch.Tensor): + pass + + def get_topk( + self, req_pool_idx: int, seqlen: int, req_to_token_pool: ReqToTokenPool + ): + return None + + def on_forward_end( + self, + forward_batch: ForwardBatch, + can_run_graph: bool, + cuda_graph_batch: Optional[int], + ): + pass + + def is_enabled(self) -> bool: + return False diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index a65b0dd28b2a..17a305e45c01 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -342,6 +342,18 @@ def _decode_batch_token_id_output(self, recv_obj: BatchTokenIDOutput): return output_strs + def _extract_topk_base64(self, data_list) -> List[List[int]]: + if data_list is None: + return None + return [ + ( + pybase64.b64encode(item.numpy().tobytes()).decode("utf-8") + if item is not None + else [] + ) + for item in data_list + ] + def _extract_routed_experts( self, recv_obj: BatchTokenIDOutput ) -> list[str | None] | None: @@ -364,7 +376,8 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): if len(recv_obj.rids) > 0 else [] ) - routed_experts = self._extract_routed_experts(recv_obj) + routed_experts = self._extract_topk_base64(recv_obj.routed_experts) + indexer_topk = self._extract_topk_base64(recv_obj.indexer_topk) return BatchStrOutput( rids=recv_obj.rids, @@ -391,6 +404,7 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): output_token_ids_logprobs_idx=recv_obj.output_token_ids_logprobs_idx, output_token_entropy_val=recv_obj.output_token_entropy_val, output_hidden_states=recv_obj.output_hidden_states, + indexer_topk=indexer_topk, routed_experts=routed_experts, customized_info=recv_obj.customized_info, placeholder_tokens_idx=None, diff --git a/python/sglang/srt/managers/hisparse_coordinator.py b/python/sglang/srt/managers/hisparse_coordinator.py new file mode 100644 index 000000000000..ca88a0b16a4b --- /dev/null +++ b/python/sglang/srt/managers/hisparse_coordinator.py @@ -0,0 +1,387 @@ +# to be combined with the sparse coordinator class and sparse algorithm family + +from typing import List, NamedTuple + +import torch + +from sglang.srt.managers.schedule_batch import Req +from sglang.srt.mem_cache.hisparse_memory_pool import ( + DeepSeekV4SingleKVPoolHost, + HiSparseTokenToKVPoolAllocator, +) +from sglang.srt.utils import get_device_module + +device_module = get_device_module() + +from sglang.jit_kernel.hisparse import load_cache_to_device_buffer_mla +from sglang.srt.mem_cache.memory_pool import ReqToTokenPool + + +class HiSparseAct(NamedTuple): + start_event: device_module.Event + finish_event: device_module.Event + req: Req + + +class HiSparseCoordinator: + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: HiSparseTokenToKVPoolAllocator, + top_k: int, + device_buffer_size: int, + device: str, + tp_group, + ): + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.top_k = top_k + self.device_buffer_size = device_buffer_size + self.device = device + self.compress_ratio = self.token_to_kv_pool_allocator.compress_ratio + + self.mem_pool_device = self.token_to_kv_pool_allocator.hisparse_kvcache + self.mem_pool_host = DeepSeekV4SingleKVPoolHost( + self.mem_pool_device, + self.token_to_kv_pool_allocator.size_full // self.compress_ratio, + 1, + ) + self.item_size_bytes = ( + self.mem_pool_host.kv_cache_total_dim * self.mem_pool_host.dtype.itemsize + ) + + max_num_reqs = req_to_token_pool.size + max_context_len = req_to_token_pool.max_context_len + + # to have an extra page for new tokens + self.padded_buffer_size = ( + self.device_buffer_size + self.mem_pool_device.page_size + ) + + self.req_to_device_buffer = torch.zeros( + (max_num_reqs, self.padded_buffer_size), dtype=torch.int64, device=device + ) + self.req_to_host_pool = torch.zeros( + (max_num_reqs, max_context_len // self.compress_ratio), + dtype=torch.int64, + device=device, + ) + + self.write_staging_stream = device_module.Stream() + self.write_decoding_stream = device_module.Stream() + self.ack_staging_queue: List[HiSparseAct] = [] + self.ack_decoding_queue: List[HiSparseAct] = [] + + self.tp_group = tp_group + self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) + + # initialize data structures for swap-in kernel + layer_num = self.mem_pool_device.layer_num + self.req_device_buffer_tokens = torch.full( + (max_num_reqs, layer_num, self.padded_buffer_size), + -1, + dtype=torch.int32, + device=device, + ) + self.req_device_buffer_token_locs = torch.full( + (max_num_reqs, layer_num, self.padded_buffer_size), + -1, + dtype=torch.int32, + device=device, + ) + self.bitmap = torch.full( + (max_num_reqs, max_context_len // self.compress_ratio), + -1, + dtype=torch.int16, + device=device, + ) + self._lru_init = torch.arange( + self.device_buffer_size, dtype=torch.int16, device=device + ) + self.lru_slots = ( + self._lru_init.view(1, 1, -1) + .repeat(max_num_reqs, layer_num, 1) + .contiguous() + ) + self.transfer_tasks_src = torch.full( + (max_num_reqs * (self.top_k + 1),), + -1, + dtype=torch.int64, + device=device, + ) + self.transfer_tasks_dst = torch.full( + (max_num_reqs * (self.top_k + 1),), + -1, + dtype=torch.int64, + device=device, + ) + + def admit_request_into_staging(self, req: Req) -> None: + req.staging = True + + full_kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.fill_ids) + ].to(dtype=torch.int64, copy=True) + device_indices = ( + self.mem_pool_device.translate_loc_from_full_to_hisparse_device( + full_kv_indices + ) + ) + # req.c4_indices = device_indices + + prefill_len = len(device_indices) + host_indices = self.mem_pool_host.alloc(prefill_len).to(device=self.device) + assert host_indices is not None, "Host mem pool alloc failed" + self.req_to_host_pool[req.req_pool_idx, :prefill_len] = host_indices + + start_event = device_module.Event() + finish_event = device_module.Event() + start_event.record() + with device_module.stream(self.write_staging_stream): + start_event.wait(self.write_staging_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, + host_indices, + device_indices, + ) + finish_event.record() + # NOTE: We must save the host indices and device indices here, + # this is because we need to guarantee that these tensors are + # still alive when the write stream is executing. + if host_indices.is_cuda: + host_indices.record_stream(self.write_staging_stream) + if device_indices.is_cuda: + device_indices.record_stream(self.write_staging_stream) + + self.ack_staging_queue.append(HiSparseAct(start_event, finish_event, req)) + + def alloc_device_buffer(self, req: Req) -> None: + compressed_logical_indices = ( + self.mem_pool_device.translate_loc_from_full_to_compressed( + self.req_to_token_pool.req_to_token[req.req_pool_idx, : req.seqlen] + ) + ) + buffer_indices = self.token_to_kv_pool_allocator.alloc_device_buffer( + compressed_logical_indices, self.padded_buffer_size + ).to(torch.int32) + assert ( + len(buffer_indices) == self.padded_buffer_size + ), "Device buffer alloc failed" + self.req_to_device_buffer[req.req_pool_idx, : self.padded_buffer_size] = ( + buffer_indices + ) + # initialize the token locs for the device buffer + self.req_device_buffer_tokens[ + req.req_pool_idx, :, : self.device_buffer_size + ] = torch.arange(self.device_buffer_size, device=self.device) + self.req_device_buffer_token_locs[ + req.req_pool_idx, :, : self.padded_buffer_size + ] = buffer_indices[: self.padded_buffer_size] + + def testing_backup(self, req): + device_indices = req.c4_indices + host_indices = self.req_to_host_pool[req.req_pool_idx, : len(device_indices)] + + self.mem_pool_host.testing_backup_to_device_all_layer( + self.mem_pool_device, host_indices, device_indices + ) + torch.cuda.current_stream().synchronize() + + def collect_ready_batch(self) -> List[Req]: + ready_batch = None + if len(self.ack_staging_queue) == 0: + return ready_batch + + finish_count = 0 + for _, finish_event, _ in self.ack_staging_queue: + if not finish_event.query(): + break + finish_count += 1 + queue_size = torch.tensor(finish_count, dtype=torch.int, device="cpu") + if self.tp_world_size > 1: + # synchronize TP workers to make sure the same update to scheduler + torch.distributed.all_reduce( + queue_size, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + finish_count = int(queue_size.item()) + while finish_count > 0: + _, _, req = self.ack_staging_queue.pop(0) + # prepare device buffer and update req + self.alloc_device_buffer(req) + req.staging = False + finish_count -= 1 + if len(self.ack_staging_queue) == 0: + ready_batch = req.batch + elif self.ack_staging_queue[0][2].batch != req.batch: + ready_batch = req.batch + # to break the circular reference + req.batch = None + # self.testing_backup(req) + return ready_batch + + def map_last_loc_to_buffer( + self, + out_cache_loc: torch.Tensor, + seq_lens: torch.Tensor, + req_pool_indices: torch.Tensor, + ) -> None: + active_reqs = seq_lens % self.compress_ratio == 0 + new_out_cache_loc = out_cache_loc[active_reqs] + active_req_pool_indices = req_pool_indices[active_reqs] + + # point output locations to the reserved buffer locations + compressed_locs = self.token_to_kv_pool_allocator.get_last_loc_compressed( + new_out_cache_loc + ) + reserved_buffer_loc = self.req_to_device_buffer[ + active_req_pool_indices, self.device_buffer_size + ] + # todo, maybe clear the prior mapping as well + self.mem_pool_device.full_to_hisparse_device_index_mapping[compressed_locs] = ( + reserved_buffer_loc + ) + # proceed only if the backup is finished for new generated tokens + self.wait_for_decode_writes() + + def get_front_topk_tokens( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + raw_indices: torch.Tensor, + ) -> torch.Tensor: + # a dummy selection for testing + num_reqs = req_pool_indices.size(0) + top_k_indices = torch.full( + (num_reqs, self.top_k), -1, dtype=torch.int32, device=self.device + ) + for i in range(num_reqs): + top_n = min( + seq_lens[i] // self.compress_ratio, + self.top_k, + ) + if top_n == 0: + continue + top_k_indices[i, :top_n] = self.req_to_device_buffer[req_pool_indices[i]][ + raw_indices[i, :top_n] + ] + return top_k_indices + + def wait_for_decode_writes(self) -> None: + if len(self.ack_decoding_queue) == 0: + return + _, finish_event, _ = self.ack_decoding_queue.pop(0) + finish_event.synchronize() + + def retract_req(self, req: Req) -> None: + # todo + raise NotImplementedError + + def update_requests_after_decode(self, reqs: List[Req]) -> None: + reqs_to_backup = [r for r in reqs if r.seqlen % self.compress_ratio == 0] + if len(reqs_to_backup) == 0: + return + + req_pool_indices = torch.tensor( + [r.req_pool_idx for r in reqs_to_backup], device=self.device + ) + req_seq_lens = torch.tensor( + [r.seqlen // self.compress_ratio for r in reqs_to_backup], + device=self.device, + ) + buffer_indices = self.req_to_device_buffer[ + req_pool_indices, self.device_buffer_size + ] + + # for short requests, copy the new token from reserved buffer to normal buffer + short_reqs = req_seq_lens <= self.device_buffer_size + if torch.any(short_reqs): + new_token_buffer_indices = self.req_to_device_buffer[ + req_pool_indices[short_reqs], req_seq_lens[short_reqs] - 1 + ] + # todo, need to do the same transfer after prefill as well + self.mem_pool_device.transfer_values_on_device( + buffer_indices[short_reqs], new_token_buffer_indices + ) + + # for all requests, backup the new token to host for future use + host_indices = self.mem_pool_host.alloc(len(buffer_indices)).to( + device=self.device + ) + assert host_indices is not None, "Host mem pool alloc failed" + self.req_to_host_pool[req_pool_indices, req_seq_lens - 1] = host_indices + + start_event = device_module.Event() + finish_event = device_module.Event() + start_event.record() + with device_module.stream(self.write_decoding_stream): + start_event.wait(self.write_decoding_stream) + self.mem_pool_host.backup_from_device_all_layer( + self.mem_pool_device, + host_indices, + buffer_indices.contiguous(), + ) + finish_event.record() + # NOTE: We must save the host indices and device indices here, + # this is because we need to guarantee that these tensors are + # still alive when the write stream is executing. + if host_indices.is_cuda: + host_indices.record_stream(self.write_decoding_stream) + if buffer_indices.is_cuda: + buffer_indices.record_stream(self.write_decoding_stream) + + self.ack_decoding_queue.append(HiSparseAct(start_event, finish_event, None)) + + def request_finished(self, req: Req): + compressed_len = req.seqlen // self.compress_ratio + # release memory + buffer_indices = self.req_to_device_buffer[req.req_pool_idx] + self.token_to_kv_pool_allocator.free_hisparse_indices(buffer_indices) + host_indices = self.req_to_host_pool[req.req_pool_idx, :compressed_len] + self.mem_pool_host.free(host_indices) + # clear req info + self.req_device_buffer_tokens[req.req_pool_idx, :, :] = -1 + self.req_device_buffer_token_locs[req.req_pool_idx, :, :] = -1 + self.req_to_device_buffer[req.req_pool_idx, :] = 0 + self.req_to_host_pool[req.req_pool_idx, :] = 0 + self.lru_slots[req.req_pool_idx].copy_(self._lru_init) + + def swap_in_selected_pages( + self, + req_pool_indices, + top_k_result, + top_k_device_locs, + seq_lens, + layer_id, + ): + """ + Swap in selected top-k pages/tokens from host to device memory. + First step: Using diff kernel to identify the top-k pages/tokens that need to be swapped in. + Second step: Using the io kernel to load the pages/tokens from host to device. + Returns: + Device indices of the selected pages/tokens + """ + block_size = 512 + load_cache_to_device_buffer_mla( + top_k_tokens=top_k_result, + device_buffer_tokens=self.req_device_buffer_tokens, + host_cache_locs=self.req_to_host_pool, + device_buffer_locs=self.req_device_buffer_token_locs, + host_cache=self.mem_pool_host.kv_buffer[layer_id], + device_buffer=self.mem_pool_device.kv_buffer[layer_id], + top_k_device_locs=top_k_device_locs, + diff_map=self.bitmap, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens // self.compress_ratio, + lru_slots=self.lru_slots, + transfer_tasks_src=self.transfer_tasks_src, + transfer_tasks_dst=self.transfer_tasks_dst, + page_size=1, + layer_id=layer_id, + item_size_bytes=self.item_size_bytes, + block_size=block_size, + num_top_k=self.top_k, + hot_buffer_size=self.device_buffer_size, + ) + return top_k_device_locs diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index fad02e0a0112..4e33b2334812 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -202,6 +202,8 @@ class GenerateReqInput(BaseReq, APIServingTimingMixin): return_hidden_states: Union[List[bool], bool] = False # Whether to return captured routed experts return_routed_experts: bool = False + # Whether to return captured indexer topk (layers with indexer) + return_indexer_topk: bool = False # The start location in the prompt for returning routed experts. routed_experts_start_len: int = 0 @@ -639,6 +641,7 @@ def __getitem__(self, i): else self.return_hidden_states ), return_routed_experts=self.return_routed_experts, + return_indexer_topk=self.return_indexer_topk, modalities=self.modalities[i] if self.modalities else None, session_params=self.session_params, lora_path=self.lora_path[i] if self.lora_path is not None else None, @@ -714,6 +717,9 @@ class TokenizedGenerateReqInput(BaseReq): # The start location in the prompt for returning routed experts. routed_experts_start_len: int = 0 + # Whether to return captured indexer topk (layers with indexer) + return_indexer_topk: bool = False + # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None @@ -989,6 +995,9 @@ class BatchTokenIDOutput( # routed_experts[i] is a tensor of shape (token, layer, top_k) for request i routed_experts: List[Optional[torch.Tensor]] + # The indexer topk for each output token (layers with indexer) + indexer_topk: List[torch.Tensor] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. @@ -1077,6 +1086,9 @@ class BatchStrOutput( # routed_experts[i] is a tensor of shape (token, layer, top_k) for request i routed_experts: List[Optional[torch.Tensor]] + # The indexer topk for each output token (layers with indexer) + indexer_topk: List[List[int]] + # The information of placeholder tokens (e.g., image token) # idx is the index of the token in the prompt after expansion. # val is the length of padded tokens after expansion. diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 4f2e3fb19197..b0e66948407d 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -281,6 +281,9 @@ def _handle_output_by_index(output, i): routed_experts=_extract_field_by_index( output, "routed_experts", i, check_length=False ), + indexer_topk=_extract_field_by_index( + output, "indexer_topk", i, check_length=False + ), customized_info=_extract_field_by_index( output, "customized_info", i, check_length=False ), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9681a1d70dcf..320e29bcbce4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -90,6 +90,7 @@ from typing import Any, Dict from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm @@ -511,6 +512,7 @@ def __init__( require_reasoning: bool = False, return_hidden_states: bool = False, return_routed_experts: bool = False, + return_indexer_topk: bool = False, eos_token_ids: Optional[Set[int]] = None, bootstrap_host: Optional[str] = None, bootstrap_port: Optional[int] = None, @@ -716,6 +718,12 @@ def __init__( self.routed_experts: Optional[torch.Tensor] = ( None # cpu tensor: shape (seqlen, topk) ) + + # capture indexer topk (layers with indexer) + self.return_indexer_topk = return_indexer_topk + self.indexer_topk: Optional[torch.Tensor] = ( + None # cpu tensor: shape (seqlen, num_indexer_layers, index_topk) + ) # Customized info self.customized_info: Optional[Dict[str, List[Any]]] = None @@ -779,6 +787,10 @@ def __init__( self.dllm_block_offset = 0 self.dllm_config = dllm_config + # For hisparse + self.staging = False + self.batch = None + @property def seqlen(self) -> int: """Get the current sequence length of the request.""" @@ -1075,6 +1087,7 @@ def reset_for_retract(self): self.prefix_indices = torch.empty((0,), dtype=torch.int64) self.routed_experts = None + self.indexer_topk = None self.last_node = None self.swa_uuid_for_lock = None self.extend_input_len = 0 @@ -1096,6 +1109,9 @@ def reset_for_retract(self): self.kv_committed_len = 0 self.kv_committed_freed = False self.kv_overallocated_freed = False + self.swa_evicted_seqlen = 0 + self.extend_batch_idx = 0 + self.decode_batch_idx = 0 def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator): token_indices = req_to_token_pool.req_to_token[ @@ -1332,6 +1348,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Whether to return captured experts return_routed_experts: bool = False + # Whether to return captured indexer topk (layers with indexer) + return_indexer_topk: bool = False + # Whether this batch is prefill-only (no token generation needed) is_prefill_only: bool = False @@ -1345,6 +1364,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Metrics dp_cooperation_info: Optional[DPCooperationInfo] = None + # HiSparse + hisparse_coordinator: Optional[HiSparseCoordinator] = None + @classmethod def init_new( cls, @@ -1365,7 +1387,7 @@ def init_new( if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator): is_hybrid_swa = True - return cls( + batch = cls( reqs=reqs, req_to_token_pool=req_to_token_pool, token_to_kv_pool_allocator=token_to_kv_pool_allocator, @@ -1380,11 +1402,16 @@ def init_new( spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), return_routed_experts=any(req.return_routed_experts for req in reqs), + return_indexer_topk=any(req.return_indexer_topk for req in reqs), is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, dllm_staging_reqs=dllm_staging_reqs, dllm_config=dllm_config, ) + # FIXME: hack for staging batch + for r in reqs: + r.batch = batch + return batch def batch_size(self): return len(self.reqs) @@ -1957,6 +1984,7 @@ def prepare_for_decode(self): return if self.sampling_info.penalizer_orchestrator.is_required: + # todo hisparse, potential compatibility issue if self.enable_overlap: # TODO: this can be slow, optimize this. delayed_output_ids = torch.tensor( @@ -1988,7 +2016,12 @@ def prepare_for_decode(self): # Allocate memory self.out_cache_loc = alloc_for_decode(self, token_per_req=1) + if self.hisparse_coordinator is not None: + self.hisparse_coordinator.map_last_loc_to_buffer( + self.out_cache_loc, self.seq_lens, self.req_pool_indices + ) + # todo hisparse: be careful about meta data modification # Update req-level memory management fields for req in self.reqs: req.decode_batch_idx += 1 @@ -2298,8 +2331,13 @@ def _evict_swa(self, req: Req, pre_len: int): ), "cache_protected_len must be page aligned" req.swa_evicted_seqlen = max(req.swa_evicted_seqlen, req.cache_protected_len) + # NOTE: the swa_evicted_seqlen is based on pre_len - sliding_window_size, + # not guaranteed there is at least one page out of the swa_evicted_seqlen (page_size > window_size case). + # As radix cache is page-aligned, then there may be no page inserted as non-tomb node. + # Even the full tokens are matched, it is useless when there no swa tokens cached on tree. new_swa_evicted_seqlen = max( - req.swa_evicted_seqlen, pre_len - sliding_window_size + req.swa_evicted_seqlen, + pre_len - sliding_window_size - self.tree_cache.page_size, ) if self.tree_cache.page_size > 1: diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 65ffc7198a0d..4aee1536900c 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -423,6 +423,8 @@ def __init__( ) self.is_hybrid_ssm_cache = self.tree_cache.supports_mamba() + self.rem_swa_token_offset = 0 + self.priority_scheduling_preemption_threshold = ( priority_scheduling_preemption_threshold ) @@ -449,11 +451,9 @@ def _get_running_request_total_token_offset(self, req: Req) -> int: @property def rem_total_tokens(self): if self.is_hybrid_swa: - available_and_evictable = min( + available_and_evictable = ( self.token_to_kv_pool_allocator.full_available_size() - + self.tree_cache.full_evictable_size(), - self.token_to_kv_pool_allocator.swa_available_size() - + self.tree_cache.swa_evictable_size(), + + self.tree_cache.full_evictable_size() ) elif self.is_hybrid_ssm_cache: available_and_evictable = ( @@ -467,14 +467,20 @@ def rem_total_tokens(self): ) return available_and_evictable - self.rem_total_token_offset + @property + def rem_swa_tokens(self): + return ( + self.token_to_kv_pool_allocator.swa_available_size() + + self.tree_cache.swa_evictable_size() + - self.rem_swa_token_offset + ) + @property def cur_rem_tokens(self): if self.is_hybrid_swa: - available_and_evictable = min( + available_and_evictable = ( self.token_to_kv_pool_allocator.full_available_size() - + self.tree_cache.full_evictable_size(), - self.token_to_kv_pool_allocator.swa_available_size() - + self.tree_cache.swa_evictable_size(), + + self.tree_cache.full_evictable_size() ) elif self.is_hybrid_ssm_cache: available_and_evictable = ( @@ -489,11 +495,31 @@ def cur_rem_tokens(self): return available_and_evictable - self.cur_rem_token_offset + def _swa_budget_for_req(self, extend_input_len: int) -> int: + """SWA pool budget per request. Only valid when is_hybrid_swa is True. + + With chunked prefill + overlap scheduler, the peak SWA occupancy is: + chunk N (running, not yet in tree) + sliding window (locked in tree) + + chunk N+1 (new allocation) + Since chunk N and locked tokens are already excluded from + swa_available + swa_evictable, the budget only needs to cover the + chunk N+1 allocation. We floor at sliding_window_size to reserve + room for the decode phase. + """ + if self.rem_chunk_tokens is not None: + alloc = min(extend_input_len, self.rem_chunk_tokens) + else: + alloc = extend_input_len + return max(alloc, self.tree_cache.sliding_window_size) + self.page_size + def ceil_paged_tokens(self, tokens: int) -> int: return -(-tokens // self.page_size) * self.page_size def budget_state(self): - if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: + no_token = self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0 + if not no_token and self.is_hybrid_swa: + no_token = self.rem_swa_tokens <= 0 + if no_token: return AddReqResult.NO_TOKEN if self.rem_input_tokens <= 0: @@ -518,6 +544,9 @@ def _update_prefill_budget( self.cur_rem_token_offset += extend_input_len self.rem_input_tokens -= extend_input_len + if self.is_hybrid_swa: + self.rem_swa_token_offset += self._swa_budget_for_req(extend_input_len) + if self.dllm_config is not None: self.rem_dllm_tokens -= extend_input_len elif self.rem_chunk_tokens is not None: @@ -567,9 +596,15 @@ def add_chunked_req(self, req: Req): _rem_tokens = self._get_dllm_remain_tokens() else: _rem_tokens = min(self.rem_chunk_tokens, int(self.rem_total_tokens)) - # The chunked_req must be added to the list; otherwise, it will cause a memory leak. - # Therefore, in certain cases where _rem_tokens <= 0, it should be replaced with rem_chunk_tokens. + if self.is_hybrid_swa: + # alloc_extend needs extend_num_tokens + page_size per request, + # so reserve one page here to avoid OOM + _rem_tokens = min( + _rem_tokens, int(self.rem_swa_tokens) - self.page_size + ) if _rem_tokens <= 0: + if self.is_hybrid_swa: + return req _rem_tokens = self.rem_chunk_tokens truncated = req.extend_input_len > _rem_tokens @@ -604,11 +639,12 @@ def _lock_node(self, last_node: TreeNode): self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req): - # Early exit if no enough tokens for the input tokens - if self.ceil_paged_tokens(req.extend_input_len) > min( - self.cur_rem_tokens, self.rem_total_tokens - ): + paged_input = self.ceil_paged_tokens(req.extend_input_len) + if paged_input > min(self.cur_rem_tokens, self.rem_total_tokens): return AddReqResult.NO_TOKEN + if self.is_hybrid_swa: + if self._swa_budget_for_req(req.extend_input_len) > self.rem_swa_tokens: + return AddReqResult.NO_TOKEN def add_req_state(r, insert_sort=False): new_token_ratio = ( @@ -705,10 +741,12 @@ def add_one_req( if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): return self.add_one_req_ignore_eos(req) - total_tokens = req.extend_input_len + min( + # FIXME: overestimation of total tokens for the request, as the allocator does the same way + max_new = min( max(req.sampling_params.max_new_tokens - len(req.output_ids), 0), CLIP_MAX_NEW_TOKENS, ) + total_tokens = req.extend_input_len + max_new + self.page_size # adjusting the input_tokens based on host_hit_length and page_size real_input_tokens = req.extend_input_len - req.host_hit_length @@ -718,6 +756,11 @@ def add_one_req( if total_tokens >= self.rem_total_tokens: return AddReqResult.NO_TOKEN + if self.is_hybrid_swa: + swa_needed = self._swa_budget_for_req(req.extend_input_len) + if swa_needed >= self.rem_swa_tokens: + return AddReqResult.NO_TOKEN + if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0: return AddReqResult.OTHER @@ -726,6 +769,11 @@ def add_one_req( if total_tokens >= self.rem_total_tokens: return AddReqResult.NO_TOKEN + if self.is_hybrid_swa: + swa_needed = self._swa_budget_for_req(req.extend_input_len) + if swa_needed >= self.rem_swa_tokens: + return AddReqResult.NO_TOKEN + if req.host_hit_length > 0: new_indices, req.last_node = self.tree_cache.init_load_back( req.last_host_node, req.host_hit_length @@ -788,6 +836,13 @@ def add_one_req( trunc_len // truncation_align_size ) + now_input_len = trunc_len + len(req.prefix_indices) + now_input_len = now_input_len // self.page_size * self.page_size + trunc_len = now_input_len - len(req.prefix_indices) + + if trunc_len <= 0: + return AddReqResult.OTHER + # Chunked prefill req.set_extend_input_len(trunc_len) req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len] diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 18fd50130581..f2d2c29dedb2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.fp4_utils import initialize_fp4_gemm_config from sglang.srt.layers.quantization.fp8_utils import initialize_fp8_gemm_config from sglang.srt.lora.lora_overlap_loader import LoRAOverlapLoader +from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator from sglang.srt.managers.io_struct import ( AbortReq, ActiveRanksOutput, @@ -170,6 +171,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin from sglang.srt.parser.reasoning_parser import ReasoningParser +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( @@ -308,6 +310,8 @@ def __init__( self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hicache_storage = server_args.hicache_storage_backend is not None self.max_recv_per_poll = envs.SGLANG_SCHEDULER_MAX_RECV_PER_POLL.get() + self.enable_hisparse = server_args.enable_hisparse + self.hisparse_coordinator: Optional[HiSparseCoordinator] = None # Distributed rank info self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( @@ -690,6 +694,22 @@ def init_cache_with_memory_pool(self): else: self.tree_cache = RadixCache(params) + if self.enable_hisparse: + # FIXME, hardcode some hisparse config here for now + self.hisparse_coordinator = HiSparseCoordinator( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + top_k=512, + device_buffer_size=1024, + device=self.device, + tp_group=( + self.attn_tp_cpu_group + if self.server_args.enable_dp_attention + else self.tp_cpu_group + ), + ) + self.tp_worker.register_hisparse_coordinator(self.hisparse_coordinator) + if ( server_args.disaggregation_mode == "decode" and server_args.disaggregation_decode_enable_offload_kvcache @@ -1447,6 +1467,7 @@ def handle_generate_request( require_reasoning=recv_req.require_reasoning, return_hidden_states=recv_req.return_hidden_states, return_routed_experts=recv_req.return_routed_experts, + return_indexer_topk=recv_req.return_indexer_topk, eos_token_ids=self.model_config.hf_eos_token_id, bootstrap_host=recv_req.bootstrap_host, bootstrap_port=recv_req.bootstrap_port, @@ -1790,6 +1811,47 @@ def stash_chunked_request(self, req: Req): else: self.req_to_token_pool.free(req.req_pool_idx) + def create_hisparse_ready_batch(self, reqs: List[Req]) -> ScheduleBatch: + batch = ScheduleBatch.init_new( + reqs=reqs, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + tree_cache=self.tree_cache, + model_config=self.model_config, + enable_overlap=self.enable_overlap, + spec_algorithm=self.spec_algorithm, + chunked_req=None, + dllm_config=self.dllm_config, + ) + batch.forward_mode = ForwardMode.DECODE + batch.device = self.device + + seq_lens = [len(r.origin_input_ids) for r in reqs] + + batch.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64) + batch.seq_lens = batch.seq_lens_cpu.to(self.device, non_blocking=True) + batch.orig_seq_lens = torch.tensor( + seq_lens, dtype=torch.int32, device=self.device + ) + batch.seq_lens_sum = sum(seq_lens) + + batch.req_pool_indices = torch.tensor( + [r.req_pool_idx for r in reqs], + dtype=torch.int32, + device=self.device, + ) + + batch.multimodal_inputs = [r.multimodal_inputs for r in reqs] + if batch.return_logprob: + batch.top_logprobs_nums = [r.top_logprobs_num for r in reqs] + batch.token_ids_logprobs = [r.token_ids_logprob for r in reqs] + batch.sampling_info = SamplingBatchInfo.from_schedule_batch( + batch, self.model_config.vocab_size + ) + + batch.hisparse_coordinator = self.hisparse_coordinator + return batch + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: self._abort_on_queued_timeout() if self.dllm_config is not None: @@ -1814,31 +1876,48 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: chunked_req_to_exclude.add(self.chunked_req) self.stash_chunked_request(self.chunked_req) - if self.last_batch and self.last_batch.forward_mode.is_extend(): - if self.last_batch.chunked_req is not None: - # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. - # We need to discard it. - chunked_req_to_exclude.add(self.last_batch.chunked_req) - - if self.last_batch.dllm_staging_reqs.non_empty(): - chunked_req_to_exclude.update(self.last_batch.dllm_staging_reqs) - - # Filter batch - last_bs = self.last_batch.batch_size() - self.last_batch.filter_batch( - chunked_req_to_exclude=list(chunked_req_to_exclude) - ) - if self.last_batch.batch_size() < last_bs: - self.running_batch.batch_is_full = False - - # Merge the new batch into the running batch. - # For prefill-only batch, we can avoid going through decoding step. - if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only: + if self.enable_hisparse: + hisparse_batch = self.hisparse_coordinator.collect_ready_batch() + if hisparse_batch is not None: + if hisparse_batch.chunked_req is not None: + chunked_req_to_exclude.add(hisparse_batch.chunked_req) + hisparse_batch.filter_batch( + chunked_req_to_exclude=list(chunked_req_to_exclude) + ) if self.running_batch.is_empty(): - self.running_batch = self.last_batch + self.running_batch = hisparse_batch else: - # Merge running_batch with prefill batch - self.running_batch.merge_batch(self.last_batch) + self.running_batch.merge_batch(hisparse_batch) + self.running_batch.hisparse_coordinator = self.hisparse_coordinator + else: + if self.last_batch and self.last_batch.forward_mode.is_extend(): + if self.last_batch.chunked_req is not None: + # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. + # We need to discard it. + chunked_req_to_exclude.add(self.last_batch.chunked_req) + + if self.last_batch.dllm_staging_reqs.non_empty(): + chunked_req_to_exclude.update(self.last_batch.dllm_staging_reqs) + + # Filter batch + last_bs = self.last_batch.batch_size() + self.last_batch.filter_batch( + chunked_req_to_exclude=list(chunked_req_to_exclude) + ) + if self.last_batch.batch_size() < last_bs: + self.running_batch.batch_is_full = False + + # Merge the new batch into the running batch. + # For prefill-only batch, we can avoid going through decoding step. + if ( + not self.last_batch.is_empty() + and not self.last_batch.is_prefill_only + ): + if self.running_batch.is_empty(): + self.running_batch = self.last_batch + else: + # Merge running_batch with prefill batch + self.running_batch.merge_batch(self.last_batch) new_batch = self.get_new_batch_prefill() @@ -1876,8 +1955,7 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: def get_num_allocatable_reqs(self, running_bs): res = get_global_server_args().pp_max_micro_batch_size - running_bs - if self.pp_size > 1: - res = min(res, self.req_to_token_pool.available_size()) + res = min(res, self.req_to_token_pool.available_size()) return res def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: @@ -2153,6 +2231,7 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: if (kv_full_retract_flag := not batch.check_decode_mem()) or ( TEST_RETRACT and self.forward_ct % TEST_RETRACT_INTERVAL == 0 ): + # todo hisparse: retract for hisparse if no sufficient memory as well, this includes no sufficient device or host memory left for C4 old_available_tokens = self.token_to_kv_pool_allocator.available_size() old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode( diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index c4728b714b57..4e89d845db72 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -8,6 +8,9 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.environ import envs +from sglang.srt.layers.attention.indexer_topk_capturer import ( + get_global_indexer_capturer, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.routed_experts_capturer import get_global_experts_capturer from sglang.srt.managers.io_struct import ( @@ -71,6 +74,14 @@ def maybe_collect_routed_experts(self: Scheduler, req: Req): req_to_token_pool=self.req_to_token_pool, ) + def maybe_collect_indexer_topk(self: Scheduler, req: Req): + """Collect indexer topk for a finished request (layers with indexer).""" + req.indexer_topk = get_global_indexer_capturer().get_topk( + req_pool_idx=req.req_pool_idx, + seqlen=req.seqlen, + req_to_token_pool=self.req_to_token_pool, + ) + def maybe_collect_customized_info( self: Scheduler, i: int, req: Req, logits_output: LogitsProcessorOutput ): @@ -137,11 +148,14 @@ def process_batch_result_prefill( if req.finished(): self.maybe_collect_routed_experts(req) + self.maybe_collect_indexer_topk(req) release_kv_cache(req, self.tree_cache) req.time_stats.completion_time = time.perf_counter() elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match self.tree_cache.cache_unfinished_req(req) + if self.enable_hisparse: + self.hisparse_coordinator.admit_request_into_staging(req) self.maybe_collect_customized_info(i, req, logits_output) @@ -413,6 +427,7 @@ def process_batch_result_decode( if req.finished(): self.maybe_collect_routed_experts(req) + self.maybe_collect_indexer_topk(req) if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes @@ -420,6 +435,8 @@ def process_batch_result_decode( release_kv_cache(req, self.tree_cache) else: release_kv_cache(req, self.tree_cache) + if self.enable_hisparse: + self.hisparse_coordinator.request_finished(req) req.time_stats.completion_time = time.perf_counter() @@ -479,6 +496,11 @@ def process_batch_result_decode( batch, num_accepted_tokens=result.num_accepted_tokens ) + if self.enable_hisparse: + self.hisparse_coordinator.update_requests_after_decode( + [req for req in batch.reqs if not req.finished()] + ) + def _mamba_prefix_cache_update( self, req: Req, batch: ScheduleBatch, result: GenerationBatchResult, i: int ) -> None: @@ -854,6 +876,7 @@ def stream_output_generation( retraction_counts = [] output_hidden_states = None load = self.get_load() + indexer_topk = None routed_experts = None customized_info = {} @@ -1054,7 +1077,10 @@ def stream_output_generation( if routed_experts is None: routed_experts = [] routed_experts.append(req.routed_experts) - + if req.return_indexer_topk: + if indexer_topk is None: + indexer_topk = [] + indexer_topk.append(req.indexer_topk) if req.customized_info is not None: for k, v in req.customized_info.items(): if k not in customized_info: @@ -1109,6 +1135,7 @@ def stream_output_generation( output_token_entropy_val=None, output_hidden_states=output_hidden_states, routed_experts=routed_experts, + indexer_topk=indexer_topk, customized_info=customized_info, placeholder_tokens_idx=None, placeholder_tokens_val=None, diff --git a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py index 484a949f5b23..c2a76f0eda98 100644 --- a/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py +++ b/python/sglang/srt/managers/scheduler_runtime_checker_mixin.py @@ -68,6 +68,15 @@ def _get_swa_token_info(self: Scheduler): swa_num_used = self.swa_tokens_per_layer - ( swa_available_size + swa_evictable_size ) + # if swa_num_used != 0: + # print( + # f"[DEBUG-{get_tp_group().rank}] {self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}" + # ) + if swa_num_used < 0: + # print( + # f"[WRONG-{get_tp_group().rank}] {self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}" + # ) + raise ValueError(f"swa_num_used < 0") full_token_usage = full_num_used / self.full_tokens_per_layer swa_token_usage = swa_num_used / self.swa_tokens_per_layer return ( @@ -176,7 +185,131 @@ def _get_batch_uncached_size(self: Scheduler, batch: ScheduleBatch) -> int: return ret + def _get_batch_swa_uncached_sizes( + self: Scheduler, batch: ScheduleBatch + ) -> tuple[int, int]: + """ + Get uncached sizes for both full and SWA pools. + + Returns: + (full_uncached_size, swa_uncached_size) + + For Full pool: uncached = allocated - cache_protected_len + For SWA pool: uncached = allocated - max(cache_protected_len, swa_evicted_seqlen) + + Note: swa_evicted_seqlen is NOT always >= cache_protected_len. + In some cases (e.g., first extend batch with overlap, or decode batch where + decode_batch_idx % sliding_window_size != 1), _evict_swa() is not called, + leaving swa_evicted_seqlen at its old value while cache_protected_len may + have increased. + + When swa_evicted_seqlen < cache_protected_len: + - Tokens in [0, cache_protected_len) are in radix tree (protected) + - Tokens in [cache_protected_len, kv_allocated_len) are uncached for both pools + """ + full_uncached = 0 + swa_uncached = 0 + for req in batch.reqs: + assert req.kv_committed_freed == req.kv_overallocated_freed + if req.kv_committed_freed: + continue + + allocated_len = req.kv_allocated_len + if self.page_size > 1: + allocated_len = ceil_align(allocated_len, self.page_size) + assert req.cache_protected_len % self.page_size == 0 + # Note: swa_evicted_seqlen may not be page aligned if _evict_swa() was not called + + # Full: uncached = allocated - cache_protected_len + full_uncached += allocated_len - req.cache_protected_len + # SWA: uncached = allocated - max(cache_protected_len, swa_evicted_seqlen) + # Use max() because swa_evicted_seqlen is not always >= cache_protected_len + swa_uncached += allocated_len - max( + req.cache_protected_len, req.swa_evicted_seqlen + ) + + return full_uncached, swa_uncached + + def self_check_swa_during_busy(self: Scheduler): + """ + Check SWA memory invariant during busy periods. + + Invariant for each pool: + total_size = available + evictable + protected + uncached + + For SWA pool, tombstone nodes' tokens are in 'available' (freed via free_swa), + not in 'evictable'. + """ + current_batch: ScheduleBatch = self.last_batch + + if current_batch is None: + return + + spec_topk = self.server_args.speculative_eagle_topk or 1 + if spec_topk > 1: + warnings.warn( + "Runtime memory check (busy) is not supported when speculation topk > 1." + ) + return + + # Get pool info + ( + _, + _, + _, + _, + full_available, + full_evictable, + swa_available, + swa_evictable, + ) = self._get_swa_token_info() + + full_protected = self.tree_cache.full_protected_size() + swa_protected = self.tree_cache.swa_protected_size() + + # Calculate uncached for current batch + full_uncached, swa_uncached = self._get_batch_swa_uncached_sizes(current_batch) + + # Also add running_batch if it exists and is different from current_batch. + # This is needed because during overlap scheduling, running_batch may have + # extend requests whose SWA tokens were freed via _evict_swa(), affecting + # swa_available but not counted in current_batch's uncached. + if ( + self.running_batch is not None + and self.running_batch is not current_batch + and not self.running_batch.is_empty() + ): + f_unc, s_unc = self._get_batch_swa_uncached_sizes(self.running_batch) + full_uncached += f_unc + swa_uncached += s_unc + + # Verify invariants + full_total = full_available + full_evictable + full_protected + full_uncached + swa_total = swa_available + swa_evictable + swa_protected + swa_uncached + + if envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.get() > 1: + log_msg = ( + f"[SWA Mem Check (BUSY)] " + f"full: ({full_available=} + {full_evictable=} + {full_protected=} + {full_uncached=}) = {full_total=} " + f"swa: ({swa_available=} + {swa_evictable=} + {swa_protected=} + {swa_uncached=}) = {swa_total=}" + ) + logger.info(log_msg) + + assert full_total == self.full_tokens_per_layer, ( + f"Full Pool Mem Leak Detected! {full_total=} vs {self.full_tokens_per_layer=}, " + f"{full_available=}, {full_evictable=}, {full_protected=}, {full_uncached=}" + ) + assert swa_total == self.swa_tokens_per_layer, ( + f"SWA Pool Mem Leak Detected! {swa_total=} vs {self.swa_tokens_per_layer=}, " + f"{swa_available=}, {swa_evictable=}, {swa_protected=}, {swa_uncached=}" + ) + def self_check_during_busy(self: Scheduler): + # Dispatch to SWA checker for hybrid SWA mode + if self.is_hybrid_swa: + self.self_check_swa_during_busy() + return + current_batch: ScheduleBatch = self.last_batch if current_batch is None: @@ -218,7 +351,15 @@ def _check_req_pool(self: Scheduler): else: req_total_size = self.req_to_token_pool.size - if len(self.req_to_token_pool.free_slots) != req_total_size: + # Non-PD ReqToTokenPool skips slot 0 (for SWA/KV State padding), + # so expect free_slots == total_size - 1. + # PD decode's HybridReqToTokenPool does not skip slot 0, + # so expect free_slots == total_size. + if self.disaggregation_mode == DisaggregationMode.DECODE: + expected_free = req_total_size + else: + expected_free = req_total_size - 1 + if len(self.req_to_token_pool.free_slots) != expected_free: msg = ( "req_to_token_pool memory leak detected!" f"available_size={len(self.req_to_token_pool.free_slots)}, " @@ -239,6 +380,8 @@ def check_memory(self: Scheduler): else: memory_leak, token_msg = self._check_radix_cache_memory() + # todo hisparse, check memory leak for hisparse layers + if memory_leak: msg = "token_to_kv_pool_allocator memory leak detected! " f"{token_msg}" raise_error_or_warn( diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index ae6211887b44..778057df4e60 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -937,6 +937,7 @@ def _create_tokenized_object( require_reasoning=obj.require_reasoning, return_hidden_states=obj.return_hidden_states, return_routed_experts=obj.return_routed_experts, + return_indexer_topk=obj.return_indexer_topk, data_parallel_rank=obj.data_parallel_rank, priority=obj.priority, extra_key=obj.extra_key, @@ -1536,6 +1537,8 @@ def _handle_batch_output( meta_info["hidden_states"] = recv_obj.output_hidden_states[i] if getattr(recv_obj, "routed_experts", None): meta_info["routed_experts"] = recv_obj.routed_experts[i] + if getattr(recv_obj, "indexer_topk", None): + meta_info["indexer_topk"] = recv_obj.indexer_topk[i] if getattr(recv_obj, "customized_info", None): for k, v in recv_obj.customized_info.items(): meta_info[k] = v[i] diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 37416ba8b5af..c7f1dfcabf5e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -386,6 +386,9 @@ def set_hicache_consumer(self, consumer_index: int): if self.hicache_layer_transfer_counter is not None: self.hicache_layer_transfer_counter.set_consumer(consumer_index) + def register_hisparse_coordinator(self, coordinator): + self.model_runner.hisparse_coordinator = coordinator + def get_worker_info(self): return ( self.max_total_num_tokens, diff --git a/python/sglang/srt/mem_cache/allocator.py b/python/sglang/srt/mem_cache/allocator.py index eaf29628bf8e..714abb1ed0d4 100644 --- a/python/sglang/srt/mem_cache/allocator.py +++ b/python/sglang/srt/mem_cache/allocator.py @@ -55,6 +55,10 @@ def __init__( self.is_not_in_free_group = True self.free_group = [] + @property + def size_full(self): + return self.size + def debug_print(self) -> str: return "" diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index a4377989b4ba..97619cb39a4d 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -16,7 +16,6 @@ MatchPrefixParams, MatchResult, ) -from sglang.srt.mem_cache.swa_memory_pool import SWATokenToKVPoolAllocator if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req @@ -99,7 +98,8 @@ class SWAChunkCache(ChunkCache): """ChunkCache with support for sliding window attention.""" def __init__(self, params: CacheInitParams): - assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) + # hisparse would override this to use HiSparseTokenToKVPoolAllocator + # assert isinstance(params.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator) super().__init__(params) self.sliding_window_size = params.sliding_window_size diff --git a/python/sglang/srt/mem_cache/compress_state.py b/python/sglang/srt/mem_cache/compress_state.py new file mode 100644 index 000000000000..d4ebb7d34cc5 --- /dev/null +++ b/python/sglang/srt/mem_cache/compress_state.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import dataclasses + +import torch + +from sglang.srt.environ import envs + + +@dataclasses.dataclass +class KVAndScoreOld: + kv: torch.Tensor + score: torch.Tensor + + def __post_init__(self): + assert self.kv.shape == self.score.shape + + @staticmethod + def empty_like(new_shape, old: KVAndScoreOld) -> KVAndScoreOld: + return KVAndScoreOld( + kv=torch.empty(*new_shape, dtype=old.kv.dtype, device=old.kv.device), + score=torch.empty( + *new_shape, dtype=old.score.dtype, device=old.score.device + ), + ) + + @property + def shape(self): + return self.kv.shape + + def __getitem__(self, index) -> KVAndScoreOld: + return KVAndScoreOld(kv=self.kv[index], score=self.score[index]) + + def __setitem__(self, index, value: KVAndScore): + self.kv[index] = value.kv + self.score[index] = value.score + + def clear(self): + self.kv.zero_() + self.score.fill_(float("-inf")) + + def view(self, *args): + return KVAndScoreOld( + kv=self.kv.view(*args), + score=self.score.view(*args), + ) + + def clone(self) -> KVAndScoreOld: + return KVAndScoreOld(kv=self.kv.clone(), score=self.score.clone()) + + +@dataclasses.dataclass +class KVAndScore: + # [..., 2 * d], don't directly construct this class + kv_score: torch.Tensor + + @property + def kv(self) -> torch.Tensor: + return self.kv_score[..., : self._item_size] + + @property + def score(self) -> torch.Tensor: + return self.kv_score[..., self._item_size :] + + @property + def shape(self): + return self.kv_score.shape + + def __post_init__(self): + self._item_size = self.kv_score.shape[-1] // 2 + + @staticmethod + def from_kv_score(*, kv: torch.Tensor, score: torch.Tensor) -> KVAndScore: + assert kv.shape == score.shape + return KVAndScore(torch.cat([kv, score], dim=-1)) + + def new_empty(self, new_shape) -> KVAndScore: + assert new_shape[-1] == self._item_size + new_shape = list(new_shape) + new_shape[-1] = 2 * self._item_size + return KVAndScore(self.kv_score.new_empty(new_shape, requires_grad=False)) + + def __getitem__(self, index) -> KVAndScore: + return KVAndScore(self.kv_score[index]) + + def __setitem__(self, index, value: KVAndScore): + self.kv_score[index] = value.kv_score + + def clear(self): + self.kv.zero_() + self.score.fill_(float("-inf")) + + def view(self, *args): + args = list(args) + if isinstance(args[-1], int) and args[-1] != -1: + args[-1] = 2 * self._item_size + return KVAndScore(self.kv_score.view(*args)) + + def clone(self) -> KVAndScore: + return KVAndScore(self.kv_score.clone()) + + @staticmethod + def cat(tensors: list[KVAndScore], dim: int) -> KVAndScore: + assert dim != -1, "Concatenation along last dim is not supported." + assert len(tensors) > 0, "At least one tensor is required for concatenation." + item_size = tensors[0]._item_size + for v in tensors: + assert ( + v._item_size == item_size + ), "All tensors must have the same item size." + + return KVAndScore(torch.cat([v.kv_score for v in tensors], dim=dim)) + + +class DeepSeekV4CompressState: + def __init__( + self, + max_num_reqs: int, + ratio: int, + overlap: bool, + head_dim: int, + device: str, + dtype: torch.dtype, + ): + self.max_num_reqs = max_num_reqs + self.ratio = ratio + self.overlap = overlap + self.head_dim = head_dim + self.device = device + self.dtype = dtype + coff = 1 + self.overlap + + state_shape = (max_num_reqs, ratio * coff, 2 * head_dim * coff) + self.kv_score_state = torch.empty(state_shape, dtype=dtype, device=device) + + def get_state(self) -> KVAndScore: + if envs.SGLANG_OPT_USE_OLD_COMPRESSOR.get(): + half_dim = self.head_dim * (1 + self.overlap) + return KVAndScoreOld( + self.kv_score_state[..., :half_dim], + self.kv_score_state[..., half_dim:], + ) + return KVAndScore(self.kv_score_state) + + +class CompressStatePool: + def __init__( + self, + size: int, + swa_page_size: int, + ring_size: int, + overlap: bool, + head_dim: int, + dtype: torch.dtype, + device: str, + enable_memory_saver: bool, + ratio: int, + ): + self.swa_page_size = swa_page_size + self.ring_size = ring_size + self.enable_memory_saver = enable_memory_saver + + # NOTE: page(ring) 0 is always to store dummy data, + # and make -1 location as clean state for handling edge cases when compressing + self._size = size + self.ring_size + 1 + # NOTE(dark): fused compressor need to ceil_align size to ratio + self._size = (self._size + ratio - 1) // ratio * ratio + + self.kv_score_buffer = KVAndScore( + torch.empty( + (self._size, 2 * (1 + overlap) * head_dim), dtype=dtype, device=device + ) + ) + self.kv_score_buffer[-1].clear() + + def translate_from_swa_loc_to_state_loc( + self, swa_loc: torch.Tensor + ) -> torch.Tensor: + swa_pages = swa_loc // self.swa_page_size + state_loc = swa_pages * self.ring_size + (swa_loc % self.ring_size) + # NOTE: -1 means padding location, map it to -1 in state loc as well + state_loc = torch.where(swa_loc < 0, -1, state_loc) + return state_loc + + def get_state_by_state_loc(self, state_loc: torch.Tensor) -> KVAndScore: + return self.kv_score_buffer[state_loc] + + def set_state_by_state_loc(self, state_loc: torch.Tensor, value: KVAndScore): + self.kv_score_buffer[state_loc] = value + self.kv_score_buffer[-1].clear() # keep -1 location as clean state diff --git a/python/sglang/srt/mem_cache/deepseekv4_memory_pool.py b/python/sglang/srt/mem_cache/deepseekv4_memory_pool.py new file mode 100644 index 000000000000..bf9c03380759 --- /dev/null +++ b/python/sglang/srt/mem_cache/deepseekv4_memory_pool.py @@ -0,0 +1,860 @@ +from __future__ import annotations + +import logging +from contextlib import nullcontext +from typing import List, Literal, NamedTuple, Optional, Tuple, Union + +import torch +from sgl_kernel.kvcacheio import transfer_kv_all_layer_mla + +from sglang.jit_kernel.deepseek_v4 import fused_store_cache +from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE +from sglang.srt.environ import envs +from sglang.srt.layers.attention.nsa import index_buf_accessor, index_buf_accessor_v4 +from sglang.srt.layers.attention.nsa.index_buf_accessor_v4 import NopeFp8RopeBf16Pack +from sglang.srt.mem_cache.compress_state import ( + CompressStatePool, + DeepSeekV4CompressState, + KVAndScore, +) +from sglang.srt.mem_cache.memory_pool import KVCache +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ceil_div + +logger = logging.getLogger(__name__) + + +def get_compress_state_ring_size( + compress_ratio: int, is_speculative: bool = False +) -> int: + """Get ring size for given compression ratio. + + This is the single source of truth for ring size calculation. + All other code should call this function instead of duplicating the logic. + + Args: + compress_ratio: Compression ratio (4 or 128) + is_speculative: Whether speculative decoding is enabled + + Returns: + Ring size for the given compression ratio + """ + assert compress_ratio in [4, 128], f"Unsupported {compress_ratio = }" + if is_speculative: + return 8 if compress_ratio == 4 else 128 + else: + return 16 if compress_ratio == 4 else 256 + + +class DeepSeekV4SingleKVPool(KVCache): + # FIXME: rename to something like PartialRoPEKVPool and combine with NSA KVPool + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + is_swa_pool: Optional[bool] = False, + ): + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + + self.scale_pad = 1 + self.quantize_block_size = 64 + self.rope_storage_dtype = torch.bfloat16 + self.k_with_scale_buffer_dtype = torch.int8 + self.is_swa_pool = is_swa_pool + self._create_buffers() + + @property + def page_size(self): + if self.is_swa_pool: + assert ( + envs.SGLANG_OPT_DPSK_V4_RADIX.get() + and (self._page_size == 256) + or not envs.SGLANG_OPT_DPSK_V4_RADIX.get() + and (self._page_size == 128) + ), "SWA KV pool page size not correct!" + + return self._page_size + + @page_size.setter + def page_size(self, value: int): + self._page_size = value + + def _create_buffers(self): + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.kv_buffer = [ + self.create_buffer( + num_pages=(self.size + self.page_size + 1) // self.page_size, + ) + for _ in range(self.layer_num) + ] + + def get_bytes_per_token(self) -> int: + # The padded slot 0 is used for writing dummy outputs from padded tokens. + # Layout: + # shape: (num_pages, page_size * (nope_dim 448 + rope_dim 128 * 2) + + # page_size * (nope_dim / quant_block_size + scale_pad) * fp32_nbytes 4) + # data: for page i, + # * buf[i, :page_size * head_dim] for fp8 data + # * buf[i, page_size * head_dim:].view(float32) for scale + # + # Raw description from FlashMLA flash_mla_with_kvcache: + # head_dim should be 512 while head_dim_v is also 512. + # In FP8+sparse mode, every block can be divided into two parts. + # The first parts stores NoPE0, RoPE0, NoPE1, RoPE1, ... + # while the second part stores scale factors: 7xue8m0, 1Bpad, 7xue8m0, 1Bpad, ... + dim_per_token = ( + self.qk_nope_head_dim + + self.qk_rope_head_dim * self.rope_storage_dtype.itemsize + + self.qk_nope_head_dim // self.quantize_block_size + + self.scale_pad + ) + return dim_per_token + + def create_buffer(self, *, num_pages: int): + bytes_per_token = self.get_bytes_per_token() + self.kv_cache_total_dim = bytes_per_token + bytes_per_page_non_padded = self.page_size * bytes_per_token + self.bytes_per_page_padded = ceil_div(bytes_per_page_non_padded, 576) * 576 + + assert bytes_per_token == 448 + 64 * 2 + 8 + assert self.store_dtype == torch.uint8 + + return torch.zeros( + num_pages, + self.bytes_per_page_padded, + dtype=self.store_dtype, + device=self.device, + ) + + def set_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ): + index_buf_accessor_v4.SetKAndS.execute( + pool=self, + buf=self.kv_buffer[layer_id], + loc=loc, + nope_fp8_rope_bf16_pack=cache_nope_fp8_rope_bf16_pack, + ) + + def set_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + return fused_store_cache( + input=cache_k, + cache=self.kv_buffer[layer_id], + indices=loc, + page_size=self.page_size, + type="flashmla", + ) + + def get_key_buffer(self, layer_id: int): + if self.store_dtype != self.dtype: + return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) + + return self.kv_buffer[layer_id] + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError("Use get_key_buffer instead.") + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError("Use get_key_buffer instead.") + + +class HiSparseC4DevicePool(DeepSeekV4SingleKVPool): + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: int | None = None, + end_layer: int | None = None, + ): + super().__init__( + size, + page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.kv_buffer], + dtype=torch.uint64, + device=self.device, + ) + self.compress_ratio = 4 + + def register_mapping(self, full_to_hisparse_device_index_mapping: torch.Tensor): + self.full_to_hisparse_device_index_mapping = ( + full_to_hisparse_device_index_mapping + ) + + def translate_loc_from_full_to_compressed(self, full_indices: torch.Tensor): + mask = (full_indices + 1) % self.compress_ratio == 0 + compressed_indices = full_indices[mask] // self.compress_ratio + return compressed_indices + + def translate_loc_from_compressed_to_hisparse_device( + self, compressed_indices: torch.Tensor + ): + return self.full_to_hisparse_device_index_mapping[compressed_indices].to( + torch.int32 + ) + + def _translate_loc_from_compressed_to_hisparse_device( + self, compressed_indices: torch.Tensor + ): + return self.full_to_hisparse_device_index_mapping[compressed_indices] + + def translate_loc_from_full_to_hisparse_device(self, full_indices: torch.Tensor): + return self._translate_loc_from_compressed_to_hisparse_device( + self.translate_loc_from_full_to_compressed(full_indices) + ) + + def set_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack, + ): + loc = self.translate_loc_from_compressed_to_hisparse_device(loc) + super().set_key_buffer(layer_id, loc, cache_nope_fp8_rope_bf16_pack) + + def transfer_values_on_device(self, dst_indices, src_indices): + # FIXME, page padding to be handled in the custom op + transfer_kv_all_layer_mla( + src_layers=self.data_ptrs, + dst_layers=self.data_ptrs, + src_indices=src_indices, + dst_indices=dst_indices, + item_size=self.kv_cache_total_dim, + num_layers=self.layer_num, + ) + + def get_cpu_copy(self, indices): + raise NotImplementedError("HiSparseC4DevicePool does not support get_cpu_copy") + + def load_cpu_copy(self, kv_cache_cpu, indices): + raise NotImplementedError("HiSparseC4DevicePool does not support load_cpu_copy") + + +class DeepSeekV4IndexerPool(KVCache): + quant_block_size = 128 + index_k_with_scale_buffer_dtype = torch.uint8 + + def __init__( + self, + size: int, + page_size: int, + dtype: torch.dtype, + index_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + ): + super().__init__( + size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + self.index_head_dim = index_head_dim + + self._create_buffer() + + def _create_buffer(self): + num_scales_per_token = self.index_head_dim // self.quant_block_size + # NOTE: weight in fp8, and scale in fp32 + page_bytes = self.page_size * self.index_head_dim + page_bytes += self.page_size * num_scales_per_token * 4 + with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): + with ( + torch.cuda.use_mem_pool(self.custom_mem_pool) + if self.custom_mem_pool + else nullcontext() + ): + self.index_k_with_scale_buffer = [ + torch.zeros( + (self.size + self.page_size + 1) // self.page_size, + page_bytes, + dtype=self.index_k_with_scale_buffer_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: + return self.index_k_with_scale_buffer[layer_id] + + # copied from NSATokenToKVPool, theoretically can be directly reused + def get_index_k_scale_buffer( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Fused method to get both index K and scale data in a single call using Triton. + More efficient than calling get_index_k_continuous and get_index_k_scale_continuous separately. + + :param layer_id: Layer index + :param seq_len: Sequence length + :param page_indices: Page indices tensor + :return: tuple of (k_fp8, k_scale) where + k_fp8: (seq_len, index_head_dim), uint8 + k_scale: (seq_len, 4), uint8 + """ + buf = self.index_k_with_scale_buffer[layer_id] + return index_buf_accessor.GetKAndS.execute( + self, buf, seq_len=seq_len, page_indices=page_indices + ) + + def set_index_k_scale_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + ) -> None: + buf = self.index_k_with_scale_buffer[layer_id - self.start_layer] + index_buf_accessor.SetKAndS.execute( + pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale + ) + + def set_index_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + return fused_store_cache( + input=cache_k, + cache=self.index_k_with_scale_buffer[layer_id - self.start_layer], + indices=loc, + page_size=self.page_size, + type="indexer", + ) + + +class DeepSeekV4LayerItem(NamedTuple): + compress_ratio: Literal[0, 4, 128] + compress_layer_id: int + compress_kv_pool: Optional[DeepSeekV4SingleKVPool] = None + + +class DeepSeekV4TokenToKVPool(KVCache): + + def __init__( + self, + max_num_reqs: int, + swa_size: int, + c4_size: int, + c128_size: int, + c4_state_pool_size: int, + c128_state_pool_size: int, + page_size: int, + swa_page_size: int, + dtype: torch.dtype, + state_dtype: torch.dtype, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + indexer_head_dim: int, + layer_num: int, + device: str, + enable_memory_saver: bool, + compression_ratios: List[int], + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, + enable_hisparse: bool = False, + ): + super().__init__( + swa_size, + page_size, + dtype, + layer_num, + device, + enable_memory_saver, + start_layer, + end_layer, + ) + + logger.info( + "Initialize DeepSeekV4TokenToKVPool with " + f"{max_num_reqs=} {swa_size=} {c4_size=} {c128_size=} " + f"{c4_state_pool_size=} {c128_state_pool_size=}" + ) + + self.max_num_reqs = max_num_reqs + self.c4_size = c4_size + self.c128_size = c128_size + self.c4_state_pool_size = c4_state_pool_size + self.c128_state_pool_size = c128_state_pool_size + self.state_dtype = state_dtype + self.compression_ratios = compression_ratios + + assert page_size % swa_page_size == 0 + + self.swa_size = swa_size + self.swa_window_size = swa_page_size + self.swa_page_size = swa_page_size + self.scale_pad = 1 + + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.indexer_head_dim = indexer_head_dim + + c4_layer_num = sum(1 for r in compression_ratios if r == 4) + c128_layer_num = sum(1 for r in compression_ratios if r == 128) + c4_page_size = page_size // 4 + c128_page_size = page_size // 128 + self.swa_kv_pool = DeepSeekV4SingleKVPool( + swa_size, + swa_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + layer_num, + device, + enable_memory_saver, + is_swa_pool=True, + ) + + c4_kv_pool_type = DeepSeekV4SingleKVPool + if enable_hisparse: + c4_kv_pool_type = HiSparseC4DevicePool + self.c4_kv_pool = c4_kv_pool_type( + c4_size, + c4_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + c4_layer_num, + device, + enable_memory_saver, + ) + + self.c128_kv_pool = DeepSeekV4SingleKVPool( + c128_size, + c128_page_size, + dtype, + qk_nope_head_dim, + qk_rope_head_dim, + c128_layer_num, + device, + enable_memory_saver, + ) + + self.c4_indexer_kv_pool = DeepSeekV4IndexerPool( + c4_size, + c4_page_size, + dtype, # indexer kv: fp8 + fp32 scale + indexer_head_dim, + c4_layer_num, + device, + enable_memory_saver, + ) + + self._init_compressed_layer_mapping() + + if envs.SGLANG_OPT_DPSK_V4_RADIX.get(): + self._init_paged_compress_states() + else: + self._init_compress_states() + + self._should_cache_swa = envs.SGLANG_OPT_CACHE_SWA_TRANSLATION.get() + + def register_mapping(self, full_to_swa_index_mapping: torch.Tensor): + self.full_to_swa_index_mapping = full_to_swa_index_mapping + + def get_ring_size(self, compress_ratio: int) -> int: + server_args = get_global_server_args() + is_speculative = server_args.speculative_algorithm is not None + return get_compress_state_ring_size(compress_ratio, is_speculative) + + def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor): + assert self.full_to_swa_index_mapping is not None + + # Note: kv_indices could have -1 values (from alloc_extend), which will be mapped to -1 + # since the last item of full_to_swa_index_mapping is -1. + return self.full_to_swa_index_mapping[kv_indices].to(torch.int32) + + def get_contiguous_buf_infos(self) -> Tuple[List[int], List[int], List[int]]: + """Channel 1: C4 KV + C4 indexer + C128 KV (source page indices).""" + data_ptrs: List[int] = [] + data_lens: List[int] = [] + item_lens: List[int] = [] + + for bufs in [ + self.c4_kv_pool.kv_buffer, + self.c4_indexer_kv_pool.index_k_with_scale_buffer, + self.c128_kv_pool.kv_buffer, + ]: + for buf in bufs: + assert buf.ndim == 2, f"expected 2D buffer, got {buf.ndim}D" + data_ptrs.append(buf.data_ptr()) + data_lens.append(buf.nbytes) + item_lens.append(buf[0].nbytes) + + return data_ptrs, data_lens, item_lens + + def get_state_buf_infos(self) -> Tuple[List[int], List[int], List[int]]: + """Channel 2: SWA KV + compress states + indexer compress states (SWA page indices). + + Compress state ring buffers are bundled per SWA page: + item_lens = ring_size * per_slot_bytes, so one SWA page index + copies the entire ring region for that page. + """ + data_ptrs: List[int] = [] + data_lens: List[int] = [] + item_lens: List[int] = [] + + for buf in self.swa_kv_pool.kv_buffer: + assert buf.ndim == 2, f"expected 2D buffer, got {buf.ndim}D" + data_ptrs.append(buf.data_ptr()) + data_lens.append(buf.nbytes) + item_lens.append(buf[0].nbytes) + + for pools in [ + self.compress_state_pools, + self.indexer_compress_state_pools, + ]: + for pool in pools: + if pool is None: + continue + t = pool.kv_score_buffer.kv_score + assert t.ndim == 2, f"expected 2D buffer, got {t.ndim}D" + data_ptrs.append(t.data_ptr()) + data_lens.append(t.nbytes) + item_lens.append(t[0].nbytes * pool.ring_size) + + return data_ptrs, data_lens, item_lens + + def _init_paged_compress_states(self): + # Use pre-calculated pool sizes from memory profiler + c4_state_pool_size = self.c4_state_pool_size + c128_state_pool_size = self.c128_state_pool_size + self.compress_state_pools: List[CompressStatePool] = [] + self.indexer_compress_state_pools: List[CompressStatePool] = [] + + for ratio in self.compression_ratios: + overlap = ratio == 4 + compress_state_pool = indexer_compress_state_pool = None + size = c4_state_pool_size if ratio == 4 else c128_state_pool_size + ring_size = self.get_ring_size(ratio) if ratio != 0 else 0 + + # NOTE: c1 layer has no compress state + if ratio != 0: + compress_state_pool = CompressStatePool( + size=size, + swa_page_size=self.swa_page_size, + ring_size=ring_size, + overlap=overlap, + head_dim=self.qk_nope_head_dim + self.qk_rope_head_dim, + dtype=self.state_dtype, + device=self.device, + enable_memory_saver=False, + ratio=ratio, + ) + + if ratio == 4: + indexer_compress_state_pool = CompressStatePool( + size=size, + swa_page_size=self.swa_page_size, + ring_size=ring_size, + overlap=overlap, + head_dim=self.indexer_head_dim, + device=self.device, + dtype=self.state_dtype, + enable_memory_saver=False, + ratio=ratio, + ) + + self.compress_state_pools.append(compress_state_pool) + self.indexer_compress_state_pools.append(indexer_compress_state_pool) + + def _init_compressed_layer_mapping(self): + c1_cnt, c4_cnt, c128_cnt = 0, 0, 0 + self.layer_mapping: List[DeepSeekV4LayerItem] = [] + + for ratio in self.compression_ratios: + if ratio == 0: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=0, + compress_layer_id=c1_cnt, + ) + ) + c1_cnt += 1 + elif ratio == 4: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=4, + compress_layer_id=c4_cnt, + compress_kv_pool=self.c4_kv_pool, + ) + ) + c4_cnt += 1 + elif ratio == 128: + self.layer_mapping.append( + DeepSeekV4LayerItem( + compress_ratio=128, + compress_layer_id=c128_cnt, + compress_kv_pool=self.c128_kv_pool, + ) + ) + c128_cnt += 1 + else: + raise ValueError(f"Unsupported compression ratio: {ratio}") + + def _init_compress_states(self): + self.compress_states: List[Optional[DeepSeekV4CompressState]] = [] + self.indexer_compress_states: List[Optional[DeepSeekV4CompressState]] = [] + for ratio in self.compression_ratios: + overlap = ratio == 4 + attn_kv_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + compress_state = indexer_compress_state = None + # NOTE: c1 layer has no compress state + if ratio != 0: + compress_state = DeepSeekV4CompressState( + max_num_reqs=self.max_num_reqs, + ratio=ratio, + overlap=overlap, + head_dim=attn_kv_head_dim, + device=self.device, + dtype=self.state_dtype, + ) + # NOTE: only c4 needs indexer + if ratio == 4: + indexer_compress_state = DeepSeekV4CompressState( + max_num_reqs=self.max_num_reqs, + ratio=ratio, + overlap=overlap, + head_dim=self.indexer_head_dim, + device=self.device, + dtype=self.state_dtype, + ) + self.compress_states.append(compress_state) + self.indexer_compress_states.append(indexer_compress_state) + + def get_attention_compress_states(self, layer_id: int) -> KVAndScore: + if envs.SGLANG_OPT_DPSK_V4_RADIX.get(): + compress_state_pool = self.compress_state_pools[layer_id] + assert ( + compress_state_pool is not None + ), "Only c4/c128 layers have attention states." + return compress_state_pool + else: + compress_state = self.compress_states[layer_id] + assert ( + compress_state is not None + ), "Only c4/c128 layers have attention states." + return compress_state.get_state() + + def get_indexer_compress_states( + self, layer_id: int + ) -> Union[KVAndScore, CompressStatePool]: + if envs.SGLANG_OPT_DPSK_V4_RADIX.get(): + indexer_compress_state_pool = self.indexer_compress_state_pools[layer_id] + assert ( + indexer_compress_state_pool is not None + ), "Only c4 layers have indexer states." + return indexer_compress_state_pool + else: + compress_state = self.indexer_compress_states[layer_id] + assert compress_state is not None, "Only c4 layers have indexer states." + return compress_state.get_state() + + def get_swa_key_buffer(self, layer_id: int) -> torch.Tensor: + return self.swa_kv_pool.get_key_buffer(layer_id) + + # TODO seems no need to have this and can remove + # def get_swa_key_buffer_by_loc( + # self, layer_id: int, loc: torch.Tensor + # ) -> torch.Tensor: + # return self.swa_kv_pool.get_key_buffer_by_loc(layer_id, loc) + + def set_swa_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + self.swa_kv_pool.set_key_buffer(layer_id, loc, cache_nope_fp8_rope_bf16_pack) + + def get_extra_key_buffer(self, layer_id: int) -> torch.Tensor | None: + # c4/c128 -> extra_cache_k + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + return compress_kv_pool.get_key_buffer(compress_layer_id) + + def set_extra_key_buffer( + self, + layer_id: int, + loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + compress_kv_pool.set_key_buffer( + compress_layer_id, loc, cache_nope_fp8_rope_bf16_pack + ) + + def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.get_index_k_with_scale_buffer(compress_layer_id) + + def get_index_k_scale_buffer( + self, + layer_id: int, + seq_len: int, + page_indices: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.get_index_k_scale_buffer( + compress_layer_id, seq_len, page_indices + ) + + def set_index_k_scale_buffer( + self, + layer_id: int, + loc: torch.Tensor, + index_k: torch.Tensor, + index_k_scale: torch.Tensor, + ) -> None: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + self.c4_indexer_kv_pool.set_index_k_scale_buffer( + compress_layer_id, loc, index_k, index_k_scale + ) + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + raise NotImplementedError() + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError() + + def set_kv_buffer(self, *args, **kwargs) -> None: + raise NotImplementedError() + + # ---- APIs for radix cache compatible branch ---- + def set_swa_key_buffer_radix( + self, + layer_id: int, + raw_loc: torch.Tensor, + cache_nope_fp8_rope_bf16_pack: NopeFp8RopeBf16Pack, + ) -> None: + swa_loc = self.translate_loc_from_full_to_swa(raw_loc) + self.swa_kv_pool.set_key_buffer( + layer_id, swa_loc, cache_nope_fp8_rope_bf16_pack + ) + + def get_swa_key_buffer_radix(self, layer_id: int) -> torch.Tensor: + return self.swa_kv_pool.get_key_buffer(layer_id) + + # --- Fused APIs of setting key buffers ---- + def set_swa_key_buffer_radix_fused( + self, + layer_id: int, + raw_loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + if self._should_cache_swa: + if layer_id == 0: + self.cached_loc = self.translate_loc_from_full_to_swa(raw_loc) + swa_loc = self.cached_loc + else: + swa_loc = self.translate_loc_from_full_to_swa(raw_loc) + return self.swa_kv_pool.set_key_buffer_fused(layer_id, swa_loc, cache_k) + + def set_extra_key_buffer_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + _, compress_layer_id, compress_kv_pool = self.layer_mapping[layer_id] + assert compress_kv_pool is not None + return compress_kv_pool.set_key_buffer_fused(compress_layer_id, loc, cache_k) + + def set_index_k_fused( + self, + layer_id: int, + loc: torch.Tensor, + cache_k: torch.Tensor, + ) -> None: + compress_ratio, compress_layer_id, _ = self.layer_mapping[layer_id] + assert compress_ratio == 4, f"only c4 has indexer, got {compress_ratio = }" + return self.c4_indexer_kv_pool.set_index_fused(compress_layer_id, loc, cache_k) + + # final branch: + # - DeepSeekV4TokenToKVPool + # - c4 / c128 / c4_indexer + # - swa_kv_pool: shape (num_pages, pages * bytes_per_page), where num_pages = max_num_reqs + # - PagedTokenToKVAllocator diff --git a/python/sglang/srt/mem_cache/hisparse_memory_pool.py b/python/sglang/srt/mem_cache/hisparse_memory_pool.py new file mode 100644 index 000000000000..286b0ffc4f88 --- /dev/null +++ b/python/sglang/srt/mem_cache/hisparse_memory_pool.py @@ -0,0 +1,372 @@ +# mapping on device memory, host memory and memory allocator + +import weakref +from typing import Optional + +import torch +from sgl_kernel.kvcacheio import transfer_kv_all_layer_mla + +from sglang.srt.mem_cache.allocator import ( + BaseTokenToKVPoolAllocator, + PagedTokenToKVPoolAllocator, +) +from sglang.srt.mem_cache.deepseekv4_memory_pool import ( + DeepSeekV4TokenToKVPool, + HiSparseC4DevicePool, +) + + +class DeepSeekV4SingleKVPoolHost: + # simplified host KV pool for hisparse C4 device pool + + def __init__( + self, + device_pool: HiSparseC4DevicePool, + host_size: int, + page_size: int, + pin_memory: bool = True, + device: str = "cpu", + ): + + assert host_size > 0, "Host size must be specified and greater than 0" + # use page size 1 for simplicity + assert page_size == 1, "Host page size must be 1 for DeepSeekV4SingleKVPoolHost" + + self.device_pool = device_pool + self.size = host_size + self.page_size = page_size + self.num_pages = (self.size + self.page_size - 1) // self.page_size + self.pin_memory = pin_memory + self.device = device + + self.dtype = device_pool.store_dtype + self.layer_num = device_pool.layer_num + self.kv_cache_total_dim = device_pool.kv_cache_total_dim + + self.kv_buffer = self.init_kv_buffer() + self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)] + self.data_ptrs = torch.tensor( + [x.data_ptr() for x in self.data_refs], + dtype=torch.uint64, + device=self.device_pool.device, + ) + self.clear() + + def clear(self): + self.free_slots = torch.arange( + 1, self.num_pages + 1, dtype=torch.int64, device="cpu" + ) + + def init_kv_buffer(self): + dims = (self.layer_num, self.size + self.page_size, self.kv_cache_total_dim) + host_pool = torch.empty(dims, dtype=self.dtype, device=self.device) + assert self.pin_memory, "DeepSeekV4SingleKVPoolHost requires pin_memory=True" + if self.pin_memory: + torch.cuda.cudart().cudaHostRegister( + host_pool.data_ptr(), host_pool.numel() * host_pool.element_size(), 0 + ) + return host_pool + + def backup_from_device_all_layer(self, device_pool, host_indices, device_indices): + # todo: direct io backend + # FIXME, page padding to be handled in the custom op + transfer_kv_all_layer_mla( + src_layers=device_pool.data_ptrs, + dst_layers=self.data_ptrs, + src_indices=device_indices, + dst_indices=host_indices, + item_size=self.kv_cache_total_dim, + num_layers=self.layer_num, + ) + + def testing_backup_to_device_all_layer( + self, device_pool, host_indices, device_indices + ): + transfer_kv_all_layer_mla( + src_layers=self.data_ptrs, + dst_layers=device_pool.data_ptrs, + src_indices=host_indices, + dst_indices=device_indices, + item_size=self.kv_cache_total_dim, + num_layers=self.layer_num, + ) + + def available_size(self): + return len(self.free_slots) + + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + if need_size > self.available_size(): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + return select_index + + def free(self, indices: torch.Tensor) -> int: + self.free_slots = torch.cat([self.free_slots, indices.cpu()]) + return len(indices) + + +class HiSparseTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator): + def __init__( + self, + logical_attn_allocator: BaseTokenToKVPoolAllocator, + ): + assert isinstance(logical_attn_allocator._kvcache, DeepSeekV4TokenToKVPool) + assert isinstance( + logical_attn_allocator._kvcache.c4_kv_pool, HiSparseC4DevicePool + ) + self.compress_ratio = 4 + + self.hisparse_kvcache = logical_attn_allocator._kvcache.c4_kv_pool + self._size_full = logical_attn_allocator.size_full + self._size_hisparse = self.hisparse_kvcache.size + + self.dtype = self.hisparse_kvcache.dtype + self.device = self.hisparse_kvcache.device + self.page_size = self.hisparse_kvcache.page_size + + self.logical_attn_allocator = logical_attn_allocator + self.hisparse_attn_allocator = PagedTokenToKVPoolAllocator( + self._size_hisparse, + self.page_size, + self.dtype, + self.device, + self.hisparse_kvcache, + logical_attn_allocator.need_sort, + ) + + self.full_to_hisparse_device_index_mapping = torch.cat( + [ + torch.zeros( + self._size_hisparse + self.page_size, + dtype=torch.int64, + device=self.device, + ), + torch.tensor([-1], dtype=torch.int64, device=self.device), + ] + ) + + self.need_sort = logical_attn_allocator.need_sort + self.free_pages = None + self.release_pages = None + self.is_not_in_free_group = True + self.free_group = [] + self.clear() + + self.hisparse_kvcache.register_mapping( + weakref.proxy(self.full_to_hisparse_device_index_mapping) + ) + + @property + def size_full(self) -> int: + return self._size_full + + def full_available_size(self): + return self.logical_attn_allocator.full_available_size() + + def swa_available_size(self): + return self.logical_attn_allocator.swa_available_size() + + def free_swa(self, free_indices: torch.Tensor): + self.logical_attn_allocator.free_swa(free_indices) + + def available_size(self) -> int: + return min( + self.logical_attn_allocator.available_size(), + self.hisparse_attn_allocator.available_size() * self.compress_ratio, + ) + + def alloc(self, need_size: int): + raise NotImplementedError( + "Page size = 1 is not supported in HiSparse allocator" + ) + + def alloc_device_buffer(self, allocated_indices, need_size: int): + assert need_size % self.page_size == 0 + # clear original reference and isolate the buffer from outside addressing, allocate new buffer if needed + hisparse_indices = self.full_to_hisparse_device_index_mapping[allocated_indices] + self.full_to_hisparse_device_index_mapping[allocated_indices] = 0 + if len(hisparse_indices) >= need_size: + buffer_indices = hisparse_indices[:need_size] + self.free_hisparse_indices(hisparse_indices[need_size:]) + else: + # page alignment, claiming the residual space for an incomplete page + page_residual_length = len(hisparse_indices) % self.page_size + if page_residual_length != 0: + hisparse_indices = torch.cat( + [ + hisparse_indices, + torch.arange( + hisparse_indices[-1] + 1, + hisparse_indices[-1] + + self.page_size + - page_residual_length + + 1, + device=self.device, + ), + ] + ) + extra_indices = self.hisparse_attn_allocator.alloc( + need_size - len(hisparse_indices) + ) + assert ( + extra_indices is not None + ), "Hisparse allocation failed in alloc_device_buffer" + buffer_indices = torch.cat([hisparse_indices, extra_indices]) + return buffer_indices + + def free_hisparse_indices(self, buffer_indices: torch.Tensor): + # disable free group mechanism for device buffer free + self.hisparse_attn_allocator.is_not_in_free_group = True + self.hisparse_attn_allocator.free(buffer_indices[buffer_indices > 0]) + + def get_last_loc_compressed(self, last_locs: torch.Tensor): + return (last_locs - 3) // self.compress_ratio + + def get_last_loc_hisparse_device(self, last_locs: torch.Tensor): + hisparse_last_locs = ( + self.hisparse_kvcache._translate_loc_from_compressed_to_hisparse_device( + self.get_last_loc_compressed(last_locs) + ) + ) + return hisparse_last_locs + + def alloc_extend( + self, + prefix_lens: torch.Tensor, + prefix_lens_cpu: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, # last_loc for full layers + extend_num_tokens: int, + ): + assert self.page_size > 1 + num_tokens = extend_num_tokens + len(seq_lens) * self.page_size + + if num_tokens > self.available_size(): + return None + + logical_indices = self.logical_attn_allocator.alloc_extend( + prefix_lens, + prefix_lens_cpu, + seq_lens, + seq_lens_cpu, + last_loc, + extend_num_tokens, + ) + compressed_logical_indices = ( + self.hisparse_kvcache.translate_loc_from_full_to_compressed(logical_indices) + ) + + hisparse_last_loc = self.get_last_loc_hisparse_device(last_loc) + hisparse_indices = self.hisparse_attn_allocator.alloc_extend( + prefix_lens // self.compress_ratio, + prefix_lens_cpu // self.compress_ratio, + seq_lens // self.compress_ratio, + seq_lens_cpu // self.compress_ratio, + hisparse_last_loc, + len(compressed_logical_indices), + ) + + assert logical_indices is not None, "Logical allocation failed in alloc_extend" + assert ( + hisparse_indices is not None + ), "Hisparse allocation failed in alloc_extend" + + self.full_to_hisparse_device_index_mapping[compressed_logical_indices] = ( + hisparse_indices + ) + return logical_indices + + def alloc_decode( + self, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, # last_loc for full layers + ): + logical_indices = self.logical_attn_allocator.alloc_decode( + seq_lens, seq_lens_cpu, last_loc + ) + + return logical_indices + + def alloc_decode_regular( + self, + seq_lens: torch.Tensor, + seq_lens_cpu: torch.Tensor, + last_loc: torch.Tensor, # last_loc for full layers + ): + logical_indices = self.logical_attn_allocator.alloc_decode( + seq_lens, seq_lens_cpu, last_loc + ) + + hisparse_last_loc = self.get_last_loc_hisparse_device(last_loc) + active_reqs = seq_lens % self.compress_ratio == 0 + hisparse_indices = self.hisparse_attn_allocator.alloc_decode( + seq_lens[active_reqs] // self.compress_ratio, + seq_lens_cpu[active_reqs.cpu()] // self.compress_ratio, + hisparse_last_loc[active_reqs], + ) + + if logical_indices is None or hisparse_indices is None: + return None + + compressed_logical_indices = ( + self.hisparse_kvcache.translate_loc_from_full_to_compressed(logical_indices) + ) + + assert len(compressed_logical_indices) == len( + hisparse_indices + ), "Mismatch in allocated indices length in alloc_decode" + self.full_to_hisparse_device_index_mapping[compressed_logical_indices] = ( + hisparse_indices + ) + + return logical_indices + + def free_compressed(self, compressed_indices: torch.Tensor): + hisparse_indices = ( + self.hisparse_kvcache.translate_loc_from_compressed_to_hisparse_device( + compressed_indices + ) + ) + hisparse_indices = hisparse_indices[hisparse_indices > 0] + self.free_hisparse_indices(hisparse_indices) + self.full_to_hisparse_device_index_mapping[compressed_indices] = 0 + + def free_hisparse(self, free_indices: torch.Tensor): + compressed_indices = ( + self.hisparse_kvcache.translate_loc_from_full_to_compressed(free_indices) + ) + self.free_compressed(compressed_indices) + + def clear(self): + self.logical_attn_allocator.clear() + self.hisparse_attn_allocator.clear() + + # Note: the last item is -1, we don't clear it, see the comment in __init__ + self.full_to_hisparse_device_index_mapping[:-1].fill_(0) + self.is_not_in_free_group = True + self.free_group = [] + + def free(self, free_index: torch.Tensor): + if free_index.numel() == 0: + return + + # NOTE: the API is not idempotent. + if self.is_not_in_free_group: + self.logical_attn_allocator.free(free_index) + # free activities will be associated with device buffers + # self.free_hisparse(free_index) + else: + self.free_group.append(free_index) + assert ( + self.logical_attn_allocator.available_size() + <= self.logical_attn_allocator.size + ) + assert ( + self.hisparse_attn_allocator.available_size() + <= self.hisparse_attn_allocator.size + ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 0afbb15fd7e8..48bf7be6f6f3 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -137,7 +137,10 @@ def __init__( (size, max_context_len), dtype=torch.int32, device=device ) - self.free_slots = list(range(size)) + # NOTE: must also change the one in `clear` + # temporarily skip index 0 because SWA cache and KV State don't have padding + # self.free_slots = list(range(size)) # Old code + self.free_slots = list(range(1, size)) def write(self, indices, values): self.req_to_token[indices] = values @@ -161,7 +164,9 @@ def free(self, free_index: Union[int, List[int]]): self.free_slots.extend(free_index) def clear(self): - self.free_slots = list(range(self.size)) + # temporarily skip index 0 because SWA cache and KV State don't have padding + # self.free_slots = list(range(self.size)) + self.free_slots = list(range(1, self.size)) class MambaPool: diff --git a/python/sglang/srt/mem_cache/swa_memory_pool.py b/python/sglang/srt/mem_cache/swa_memory_pool.py index 0faf201cbd48..4d118b745bba 100644 --- a/python/sglang/srt/mem_cache/swa_memory_pool.py +++ b/python/sglang/srt/mem_cache/swa_memory_pool.py @@ -1,15 +1,17 @@ import logging import weakref -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import torch +from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, PagedTokenToKVPoolAllocator, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool from sglang.srt.mem_cache.utils import maybe_init_custom_mem_pool @@ -230,29 +232,33 @@ def __init__( page_size: int, dtype: torch.dtype, device: str, - kvcache: SWAKVPool, + kvcache: Union[SWAKVPool, DeepSeekV4TokenToKVPool], need_sort: bool, ): - assert isinstance(kvcache, SWAKVPool) + assert isinstance(kvcache, (SWAKVPool, DeepSeekV4TokenToKVPool)) self._size_full = size self._size_swa = size_swa self.dtype = dtype self.device = device self.page_size = page_size + # FIXME: kv cache should not be passed to allocator + full_kv_pool = getattr(kvcache, "full_kv_pool", None) + swa_kv_pool = getattr(kvcache, "swa_kv_pool", None) + if page_size == 1: self.full_attn_allocator = TokenToKVPoolAllocator( size, dtype, device, - kvcache.full_kv_pool, + full_kv_pool, need_sort, ) self.swa_attn_allocator = TokenToKVPoolAllocator( size_swa, dtype, device, - kvcache.swa_kv_pool, + swa_kv_pool, need_sort, ) else: @@ -261,7 +267,7 @@ def __init__( page_size, dtype, device, - kvcache.full_kv_pool, + full_kv_pool, need_sort, ) self.swa_attn_allocator = PagedTokenToKVPoolAllocator( @@ -269,7 +275,7 @@ def __init__( page_size, dtype, device, - kvcache.swa_kv_pool, + swa_kv_pool, need_sort, ) # Note: append one more item of value -1 in the end so -1 maps to -1. @@ -293,6 +299,7 @@ def __init__( self.free_group = [] self.clear() + # FIXME: the mapping should be maintained by the allocator? self._kvcache = kvcache self._kvcache.register_mapping(weakref.proxy(self.full_to_swa_index_mapping)) @@ -361,6 +368,9 @@ def alloc_extend( ): assert self.page_size > 1 num_tokens = extend_num_tokens + len(seq_lens) * self.page_size + msg = f"[ALLOC-EXTEND-{get_tp_group().rank}] {num_tokens=}, {extend_num_tokens=}, {len(seq_lens)=}, {self.page_size=}" + msg += f", {self.full_attn_allocator.available_size()=}, {self.swa_attn_allocator.available_size()=}" + # print(msg) if num_tokens > self.full_attn_allocator.available_size(): return None if num_tokens > self.swa_attn_allocator.available_size(): @@ -430,6 +440,7 @@ def free(self, free_index: torch.Tensor): assert self.swa_attn_allocator.available_size() <= self.swa_attn_allocator.size def free_swa(self, free_index: torch.Tensor): + # print(f"[FREE-SWA-{get_tp_group().rank}] {free_index=}", flush=True) swa_indices = self.full_to_swa_index_mapping[free_index] swa_indices = swa_indices[swa_indices > 0] self.swa_attn_allocator.free(swa_indices) diff --git a/python/sglang/srt/mem_cache/swa_radix_cache.py b/python/sglang/srt/mem_cache/swa_radix_cache.py index 4b07b841f2aa..547830496b42 100644 --- a/python/sglang/srt/mem_cache/swa_radix_cache.py +++ b/python/sglang/srt/mem_cache/swa_radix_cache.py @@ -28,6 +28,7 @@ import torch from numpy import float64 +from sglang.srt.environ import envs from sglang.srt.mem_cache.base_prefix_cache import ( BasePrefixCache, EvictParams, @@ -592,6 +593,10 @@ def cache_unfinished_req(self, req: Req, chunked=False) -> None: # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later if self.page_size != 1: + # req.prefix_indices = new_indices + # # req.kv_allocated_len = len(new_indices) + # if len(kv_indices[len(new_indices) :]) > 0: + # self.token_to_kv_pool_allocator.free(kv_indices[len(new_indices) :]) req.prefix_indices = torch.cat( [new_indices, kv_indices[len(new_indices) :]] ) @@ -846,9 +851,13 @@ def _match_prefix_helper( match_len_since_tombstone = float("inf") best_value_len = 0 best_last_node = node + enable_compact = envs.SGLANG_OPT_SWA_RADIX_CACHE_COMPACT.get() while len(key) > 0 and child_key in node.children.keys(): child = node.children[child_key] + if enable_compact: + self._compact_single_child_chain(child) + if child.swa_tombstone: # update best_value_len and best_last_node if needed if match_len_since_tombstone >= self.sliding_window_size: @@ -897,6 +906,37 @@ def _match_prefix_helper( return value[:best_value_len], best_last_node + def _compact_single_child_chain(self, node: TreeNode) -> None: + """Path compression: merge consecutive single-child internal nodes + with the same tombstone and lock ref state into one node. + Leaf nodes are never absorbed because req.last_node may point to them.""" + while len(node.children) == 1: + child = next(iter(node.children.values())) + # Never absorb a leaf since req.last_node may reference it + if len(child.children) == 0: + break + if ( + child.swa_tombstone != node.swa_tombstone + or child.full_lock_ref != node.full_lock_ref + or child.swa_lock_ref != node.swa_lock_ref + ): + break + + node.key = RadixKey( + node.key.token_ids + child.key.token_ids, node.key.extra_key + ) + node.value = torch.cat([node.value, child.value]) + node.children = child.children + for grandchild in node.children.values(): + grandchild.parent = node + + if child.swa_uuid is not None: + node.swa_uuid = child.swa_uuid + + self.full_lru_list.remove_node(child) + if not child.swa_tombstone: + self.swa_lru_list.remove_node(child) + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: # new_node -> child new_node = TreeNode() @@ -1018,6 +1058,32 @@ def _insert_helper( child_key = self.get_child_key_fn(key) if len(key): + logger.debug( + f"Has Additional Node: len(key)={len(key)}, total_prefix_length={total_prefix_length}, swa_evicted_seqlen={swa_evicted_seqlen}, len(value)={len(value)}" + ) + + # All lengths are page aligned. + # Layout: |--- total_prefix_length ---|--- len(key) ---| + # ^ ^ ^ + # 0 total_prefix_length total_length + # + # Cases based on swa_evicted_seqlen position: + # 1. swa_evicted_seqlen <= total_prefix_length: + # Already handled in the previous loop. All of len(key) inserted as non-tombstone. + # 2. total_prefix_length < swa_evicted_seqlen < total_length: + # - [total_prefix_length, swa_evicted_seqlen): insert as tombstone node + # - [swa_evicted_seqlen, total_length): insert as non-tombstone node + # 3. swa_evicted_seqlen == total_length: + # Handled above (early return): no new node created since leaf cannot be tombstone. + + if swa_evicted_seqlen == total_prefix_length + len(key): + # If page size > window size, swa_evicted_seqlen may == total_prefix_length + len(key), since window in the partial left page + # In this case, we don't need to add new node for the remaining key and value (leaf should not be tombstone) + # Better solution: when evict swa during decoding, keep additional non-evicted page and then it can be inserted in the tree + self.token_to_kv_pool_allocator.free(value) + return total_prefix_length + + # Insert tombstone nodes before the swa evicted seqlen if ( swa_evicted_seqlen > total_prefix_length and swa_evicted_seqlen < total_prefix_length + len(key) @@ -1032,6 +1098,7 @@ def _insert_helper( key = key[swa_tombstone_len:] value = value[swa_tombstone_len:] + # Insert non-tombstone nodes after the swa evicted seqlen self._add_new_node(node, key, value, swa_tombstone=False) return total_prefix_length diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b0b2ede6dbde..748d18aac64e 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -726,11 +726,13 @@ def run_once(): self.device_module.synchronize() self.model_runner.tp_group.barrier() run_once() + attn_backend.on_after_cuda_graph_warmup_pass() if get_global_graph_memory_pool() is None: set_global_graph_memory_pool(self.device_module.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) + out = self._capture_graph( graph, get_global_graph_memory_pool(), stream, run_once ) @@ -832,6 +834,8 @@ def replay_prepare( self.capture_forward_mode, forward_batch.spec_info, seq_lens_cpu=seq_lens_cpu, + out_cache_loc=forward_batch.out_cache_loc, + actual_forward_mode=forward_batch.forward_mode, ) # Store fields @@ -860,6 +864,7 @@ def replay( else: graph_key = self.bs self.graphs[graph_key].replay() + output = self.output_buffers[graph_key] if isinstance(output, LogitsProcessorOutput): diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index efd9f07d3d1b..229f0c61692d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -61,6 +61,7 @@ if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput + from sglang.srt.managers.hisparse_coordinator import HiSparseCoordinator from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner @@ -98,8 +99,8 @@ class ForwardMode(IntEnum): # Used in diffusion LLM inference DLLM_EXTEND = auto() - def is_prefill(self): - return self.is_extend() + def is_prefill(self, include_draft_extend_v2: bool = False): + return self.is_extend(include_draft_extend_v2=include_draft_extend_v2) def is_extend(self, include_draft_extend_v2: bool = False): return ( @@ -375,6 +376,9 @@ class ForwardBatch(ForwardBatchDeepSeekMHAMixin): # For hidden states before normal return_hidden_states_before_norm: bool = False + # For hisparse + hisparse_coordinator: Optional[HiSparseCoordinator] = None + @classmethod def init_new( cls, diff --git a/python/sglang/srt/model_executor/input_buffers.py b/python/sglang/srt/model_executor/input_buffers.py index f4468a70c634..f354bd9b1d03 100644 --- a/python/sglang/srt/model_executor/input_buffers.py +++ b/python/sglang/srt/model_executor/input_buffers.py @@ -144,6 +144,10 @@ def populate_from_forward_batch( enable_num_token_non_padded_flag: bool, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Optional[torch.Tensor]: + # We need invalid reqs to have idx=0, b/c SWA cache and KV State use this object + # Future optimization: may put inside `if bs != raw_bs` + self.req_pool_indices.zero_() + if bs != raw_bs: self.seq_lens.fill_(seq_len_fill_value) self.out_cache_loc.zero_() diff --git a/python/sglang/srt/model_executor/memory_profiler.py b/python/sglang/srt/model_executor/memory_profiler.py new file mode 100644 index 000000000000..0b5f9c0013af --- /dev/null +++ b/python/sglang/srt/model_executor/memory_profiler.py @@ -0,0 +1,228 @@ +"""Memory profiler for DeepSeekV4 and other models.""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed.parallel_state import get_world_group +from sglang.srt.mem_cache.deepseekv4_memory_pool import get_compress_state_ring_size + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + +logger = logging.getLogger(__name__) + + +@dataclass +class DSv4PoolSizes: + """Pool sizes for DeepSeekV4 memory allocation.""" + + full_max_total_num_tokens: int + swa_max_total_num_tokens: int + c4_max_total_num_tokens: int + c128_max_total_num_tokens: int + c4_state_pool_size: int + c128_state_pool_size: int + + +class DSv4MemoryCalculator: + """Calculate pool sizes for DeepSeekV4 memory allocation. + + Memory pools for DSv4: + - SWA KV pool: size=F*R, layers=num_layers_total + - C4 KV pool: size=F/4, layers=num_layers_ca4 + - C128 KV pool: size=F/128, layers=num_layers_ca128 + - C4 Indexer pool: size=F/4, layers=num_layers_ca4 + - C4 State pool (paged): size=F*R/swa_page_size*c4_ring_size, layers=num_layers_ca4 + - C128 State pool (paged): size=F*R/swa_page_size*c128_ring_size, layers=num_layers_ca128 + - C4 Indexer State pool: size=F*R/swa_page_size*c4_ring_size, layers=num_layers_ca4 + + Where F = full_token, R = swa_ratio + Ring sizes: c4_ring_size=16 (or 8 for speculative), c128_ring_size=256 (or 128 for speculative) + """ + + def __init__( + self, + model_config: ModelConfig, + page_size: int, + swa_ratio: float, + is_speculative: bool = False, + ): + self.qk_nope_head_dim = model_config.qk_nope_head_dim + self.qk_rope_head_dim = model_config.qk_rope_head_dim + self.indexer_head_dim = model_config.index_head_dim + self.compression_ratios = model_config.compress_ratios + # NOTE: Hardcored swa page size from swa window size + self.swa_page_size = model_config.window_size + self.page_size = page_size + self.swa_ratio = swa_ratio + self.is_speculative = is_speculative + + # Get ring sizes based on speculative mode + self.c4_ring_size = get_compress_state_ring_size(4, self.is_speculative) + self.c128_ring_size = get_compress_state_ring_size(128, self.is_speculative) + + # Count layers by compression type + self.num_layers_total = len(self.compression_ratios) + self.num_layers_ca4 = sum(1 for r in self.compression_ratios if r == 4) + self.num_layers_ca128 = sum(1 for r in self.compression_ratios if r == 128) + + # Bytes per full token + self.bytes_per_full_token = self.get_bytes_per_full_token() + + def get_bytes_per_full_token(self) -> float: + """Calculate total memory bytes per full_token. + + Returns: + Total memory bytes per full_token (across all pools and layers) + """ + # KV pool bytes per token (fp8 nope + bf16 rope + scale) + # Layout: nope_fp8 (448) + rope_bf16 (64*2) + scale (8) + kv_bytes = self.qk_nope_head_dim + self.qk_rope_head_dim * 2 + 8 + + # Indexer bytes per token (fp8 + fp32 scale) + # Layout: index_k (512) + scale (512/128*4) + quant_block_size = 128 + indexer_bytes = ( + self.indexer_head_dim + self.indexer_head_dim // quant_block_size * 4 + ) + + # State bytes per token (float32) + # KVAndScore layout: (size, 2 * (1 + overlap) * head_dim) + attn_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + state_dtype_size = 4 # float32 + c4_state_bytes = 2 * 2 * attn_head_dim * state_dtype_size # overlap=True + c128_state_bytes = 2 * 1 * attn_head_dim * state_dtype_size # overlap=False + c4_indexer_state_bytes = 2 * 2 * self.indexer_head_dim * state_dtype_size + + # State paged expansion: state size = num_swa_pages * ring_size + # where num_swa_pages = swa_tokens / swa_page_size + c4_state_ratio = self.c4_ring_size / self.swa_page_size + c128_state_ratio = self.c128_ring_size / self.swa_page_size + + # Calculate total bytes per full_token + bytes_per_full_token = ( + # SWA KV pool: size = full_token * swa_ratio + self.swa_ratio * kv_bytes * self.num_layers_total + # C4 KV pool: size = full_token / 4 + + 1 / 4 * kv_bytes * self.num_layers_ca4 + # C128 KV pool: size = full_token / 128 + + 1 / 128 * kv_bytes * self.num_layers_ca128 + # C4 indexer pool: size = full_token / 4 + + 1 / 4 * indexer_bytes * self.num_layers_ca4 + # C4 compress state pool (paged): size = num_swa_pages * c4_ring_size + + self.swa_ratio * c4_state_ratio * c4_state_bytes * self.num_layers_ca4 + # C128 compress state pool (paged): size = num_swa_pages * c128_ring_size + + self.swa_ratio + * c128_state_ratio + * c128_state_bytes + * self.num_layers_ca128 + # C4 indexer compress state pool (paged): size = num_swa_pages * c4_ring_size + + self.swa_ratio + * c4_state_ratio + * c4_indexer_state_bytes + * self.num_layers_ca4 + ) + + return bytes_per_full_token + + def calculate_pool_sizes(self, available_bytes: int) -> DSv4PoolSizes: + """Calculate pool sizes based on available memory. + + Args: + available_bytes: Available memory bytes for KV cache + + Returns: + DSv4PoolSizes containing all pool sizes + """ + full_token = int(available_bytes / self.bytes_per_full_token) + + # Align to page_size + full_token = full_token // self.page_size * self.page_size + + # Calculate each pool's size + swa_tokens = int(full_token * self.swa_ratio) // self.page_size * self.page_size + + pool_sizes = DSv4PoolSizes( + full_max_total_num_tokens=full_token, + swa_max_total_num_tokens=swa_tokens, + c4_max_total_num_tokens=full_token // 4, + c128_max_total_num_tokens=full_token // 128, + c4_state_pool_size=swa_tokens // self.swa_page_size * self.c4_ring_size, + c128_state_pool_size=swa_tokens // self.swa_page_size * self.c128_ring_size, + ) + + logger.info( + f"DSv4 memory calculation: " + f"bytes_per_full_token={self.bytes_per_full_token:.2f}, " + f"available_bytes={available_bytes / (1 << 30):.2f} GB, " + f"full_token={full_token}" + ) + + return pool_sizes + + def get_pool_sizes_by_profiling(self, mr: ModelRunner) -> DSv4PoolSizes: + # Profile available memory bytes directly (standalone path for DSv4) + available_bytes = profile_available_bytes( + device=mr.device, + gpu_id=mr.gpu_id, + total_gpu_memory=mr.total_gpu_memory, + mem_fraction_static=mr.mem_fraction_static, + distributed=get_world_group().world_size > 1, + cpu_group=get_world_group().cpu_group, + ) + + if self.is_speculative: + draft_layers = 1 + target_layers = self.num_layers_total + target_ratio = target_layers / (target_layers + draft_layers) + available_bytes = int(available_bytes * target_ratio) + + return self.calculate_pool_sizes(available_bytes) + + def get_pool_sizes_by_configuration(self, max_total_tokens: int) -> DSv4PoolSizes: + available_bytes = max_total_tokens * self.bytes_per_full_token + return self.calculate_pool_sizes(available_bytes) + + +def profile_available_bytes( + device: str, + gpu_id: int, + total_gpu_memory: float, + mem_fraction_static: float, + distributed: bool = False, + cpu_group=None, +) -> int: + """Profile available memory bytes for KV cache. + + Args: + device: Device type (cuda, etc.) + gpu_id: GPU ID + total_gpu_memory: Total GPU memory in GB + mem_fraction_static: Static memory fraction + distributed: Whether running in distributed mode + cpu_group: CPU group for distributed + + Returns: + Available memory bytes for KV cache + """ + from sglang.srt.utils.common import get_available_gpu_memory + + available_gpu_memory = get_available_gpu_memory( + device, gpu_id, distributed=distributed, cpu_group=cpu_group + ) + rest_memory = available_gpu_memory - total_gpu_memory * (1 - mem_fraction_static) + + available_bytes = int(rest_memory * (1 << 30)) + + logger.info( + f"Memory profiling: available_gpu_memory={available_gpu_memory:.2f} GB, " + f"total_gpu_memory={total_gpu_memory:.2f} GB, " + f"mem_fraction_static={mem_fraction_static:.2f}, " + f"rest_memory={rest_memory:.2f} GB" + ) + + return available_bytes diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 275f5164ee02..2d59e133a84d 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -84,6 +84,11 @@ ATTENTION_BACKENDS, attn_backend_wrapper, ) +from sglang.srt.layers.attention.indexer_topk_capturer import ( + create_indexer_capturer, + get_global_indexer_capturer, + set_global_indexer_capturer, +) from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.dp_attention import ( @@ -310,12 +315,15 @@ def __init__( self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.is_hybrid_swa = model_config.is_hybrid_swa - self.is_hybrid_swa_compress = model_config.is_hybrid_swa_compress + self.is_hybrid_swa_compress = getattr( + model_config, "is_hybrid_swa_compress", False + ) self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA self.attention_chunk_size = model_config.attention_chunk_size self.forward_pass_id = 0 self.init_new_workspace = False self.draft_model_idx = draft_model_idx + self.enable_hisparse = server_args.enable_hisparse self.remote_instance_transfer_engine = None self.remote_instance_transfer_engine_session_id = "" @@ -366,7 +374,7 @@ def __init__( self.init_threads_binding() # Get memory before model loading - min_per_gpu_memory = self.init_torch_distributed() + self.total_gpu_memory = self.init_torch_distributed() # Init forward stream for overlap schedule self.forward_stream = torch.get_device_module(self.device).Stream() @@ -387,7 +395,7 @@ def __init__( deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args) # Initialize the model runner - self.initialize(min_per_gpu_memory) + self.initialize() self.check_quantized_moe_compatibility() # Temporary cached values @@ -404,6 +412,9 @@ def __init__( self._model_update_group = {} self._weights_send_group = {} + # For hisparse + self.hisparse_coordinator = None + def init_mindspore_runner(self): # Init the mindspore runner # for now, there is only some communication initialization work @@ -418,7 +429,7 @@ def init_mindspore_runner(self): port=self.dist_port, ) - def initialize(self, min_per_gpu_memory: float): + def initialize(self): server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( @@ -493,6 +504,8 @@ def initialize(self, min_per_gpu_memory: float): self.end_layer = getattr(self.model, "end_layer", model_num_layers) self.num_effective_layers = self.end_layer - self.start_layer + self.adjust_hybrid_swa_layers_for_pp() + # For LoopCoder models, each loop has its own layer_id, so we need to multiply by loop_num loop_num = getattr(self.model_config.hf_config, "loop_num", 1) if loop_num > 1: @@ -507,23 +520,6 @@ def initialize(self, min_per_gpu_memory: float): ) ), "PP is not compatible with MTP models." - # Consider PP, so use start_layer and end_layer. - full_attention_layer_ids = [ - layer_idx - for layer_idx in range(self.start_layer, self.end_layer + 1) - if hasattr(self.model_config, "full_attention_layer_ids") - and layer_idx in self.model_config.full_attention_layer_ids - ] - swa_attention_layer_ids = [ - layer_idx - for layer_idx in range(self.start_layer, self.end_layer + 1) - if hasattr(self.model_config, "swa_attention_layer_ids") - and layer_idx in self.model_config.swa_attention_layer_ids - ] - # Update back to model_config. - self.model_config.swa_attention_layer_ids = swa_attention_layer_ids - self.model_config.full_attention_layer_ids = full_attention_layer_ids - # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) # In layered loading, torchao may have been applied @@ -559,7 +555,7 @@ def initialize(self, min_per_gpu_memory: float): self.configure_kv_cache_dtype() # Init memory pool and attention backends - self.init_memory_pool(min_per_gpu_memory) + self.init_memory_pool() # Init max running requests self.max_running_requests = min( @@ -575,6 +571,9 @@ def initialize(self, min_per_gpu_memory: float): # Init routed experts capturer self.init_routed_experts_capturer() + # Init indexer topk capturer + self.init_indexer_capturer() + if self.device == "cuda": self.init_cublas() self.init_attention_backend() @@ -601,6 +600,30 @@ def initialize(self, min_per_gpu_memory: float): self.prealloc_symmetric_memory_pool() + def adjust_hybrid_swa_layers_for_pp(self): + if not self.is_hybrid_swa: + return + + if self.model_config.is_swa_with_compressed_attention: + return + + # Consider PP, so use start_layer and end_layer. + full_attention_layer_ids = [ + layer_idx + for layer_idx in range(self.start_layer, self.end_layer + 1) + if hasattr(self.model_config, "full_attention_layer_ids") + and layer_idx in self.model_config.full_attention_layer_ids + ] + swa_attention_layer_ids = [ + layer_idx + for layer_idx in range(self.start_layer, self.end_layer + 1) + if hasattr(self.model_config, "swa_attention_layer_ids") + and layer_idx in self.model_config.swa_attention_layer_ids + ] + # Update back to model_config. + self.model_config.swa_attention_layer_ids = swa_attention_layer_ids + self.model_config.full_attention_layer_ids = full_attention_layer_ids + def init_routed_experts_capturer(self): if not self.server_args.disable_shared_experts_fusion and hasattr( self.model, "num_fused_shared_experts" @@ -620,6 +643,17 @@ def init_routed_experts_capturer(self): ) ) + def init_indexer_capturer(self): + set_global_indexer_capturer( + create_indexer_capturer( + enable=get_global_server_args().enable_return_indexer_topk, + model_config=self.model_config, + num_tokens=self.max_total_num_tokens + self.page_size, + max_running_requests=self.max_running_requests, + device=self.device, + ) + ) + def remote_instance_init_transfer_engine(self): try: from mooncake.engine import TransferEngine @@ -1510,7 +1544,7 @@ def mamba2_config(self): def max_token_pool_size(self): """Return the max token pool size considering hybrid swa settings.""" if self.is_hybrid_swa: - return min(self.swa_max_total_num_tokens, self.max_total_num_tokens) + return self.full_max_total_num_tokens else: return self.max_total_num_tokens @@ -2285,6 +2319,13 @@ def forward( cuda_graph_batch=getattr(self.graph_runner, "bs", None), ) + # Copy cached indexer topk buffers back to CPU cache + get_global_indexer_capturer().on_forward_end( + forward_batch=forward_batch, + can_run_graph=output.can_run_graph, + cuda_graph_batch=getattr(self.graph_runner, "bs", None), + ) + if self.eplb_manager is not None: self.eplb_manager.on_forward_pass_end() @@ -2334,6 +2375,7 @@ def _forward_raw( server_args=self.server_args, ) + forward_batch.hisparse_coordinator = self.hisparse_coordinator if forward_batch.forward_mode.is_decode(): ret = self.forward_decode( forward_batch, @@ -2411,6 +2453,24 @@ def sample( else forward_batch.seq_lens - 1 ), ) + + # no need for this, just look at warmup request output (by hacking the flags there to return logprob) + # if get_bool_env_var("SGLANG_HACK_PRINT_NEXT_TOKEN_LOGPROBS"): + # next_token_logprobs = logits_output.next_token_logprobs + # print(f"hi {next_token_logprobs=}", flush=True) + # # NOTE: need `return_logprob` to make it not-none + # if next_token_logprobs is None: + # top_logprobs, top_ids = None, None + # else: + # top_logprobs, top_ids = torch.topk(next_token_logprobs, k=5, dim=-1) + # top_logprobs, top_ids = top_logprobs.tolist(), top_ids.tolist() + # print( + # f"[{torch.distributed.get_rank()}] " + # f"{top_ids=} " + # f"{top_logprobs=} " + # f"{get_tensor_info(next_token_logprobs)=}" + # ) + return next_token_ids def compute_logprobs_only( diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 6cfa91e87c9a..804dcfb67b81 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -5,13 +5,23 @@ import torch -from sglang.srt.configs.model_config import get_nsa_index_head_dim, is_deepseek_nsa +from sglang.srt.configs.model_config import ( + get_nsa_index_head_dim, + is_deepseek_compressed, + is_deepseek_nsa, +) from sglang.srt.distributed.parallel_state import get_world_group +from sglang.srt.environ import envs from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.mem_cache.allocator import ( PagedTokenToKVPoolAllocator, TokenToKVPoolAllocator, ) +from sglang.srt.mem_cache.deepseekv4_memory_pool import ( + DeepSeekV4IndexerPool, + DeepSeekV4TokenToKVPool, +) +from sglang.srt.mem_cache.hisparse_memory_pool import HiSparseTokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ( DoubleSparseTokenToKVPool, HybridLinearKVPool, @@ -46,9 +56,33 @@ class ModelRunnerKVCacheMixin: def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int: kv_size = torch._utils._element_size(self.kv_cache_dtype) - if self.use_mla_backend: + if is_deepseek_compressed(self.model_config.hf_config): + # TODO: more accurate compute, now we assume every token has indexer for simplicity + assert kv_size == 1, kv_size # uint8 + + cell_size = ( + ( + self.model_config.qk_nope_head_dim + + self.model_config.qk_rope_head_dim * 2 + ) + * num_layers + * kv_size + ) + index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config) + indexer_size_per_token = ( + index_head_dim + + index_head_dim // DeepSeekV4IndexerPool.quant_block_size * 4 + ) + element_size = torch._utils._element_size( + DeepSeekV4IndexerPool.index_k_with_scale_buffer_dtype + ) + cell_size += indexer_size_per_token * num_layers * element_size + elif self.use_mla_backend: cell_size = ( - (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) + ( + self.model_config.qk_nope_head_dim + + self.model_config.qk_rope_head_dim + ) * num_layers * kv_size ) @@ -108,7 +142,9 @@ def get_cell_size_per_token(self: ModelRunner, num_layers: int) -> int: ) return cell_size - def profile_max_num_token(self: ModelRunner, total_gpu_memory: int): + def profile_max_num_token(self: ModelRunner): + # FIXME: As the cache types contains various dtypes and layout, + # consider use num_bytes directly instead of estimating via fixed dtype and cell size. available_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, @@ -135,7 +171,7 @@ def profile_max_num_token(self: ModelRunner, total_gpu_memory: int): cell_size = self.get_cell_size_per_token(num_layers) - rest_memory = available_gpu_memory - total_gpu_memory * ( + rest_memory = available_gpu_memory - self.total_gpu_memory * ( 1 - self.mem_fraction_static ) if self.mambaish_config is not None: @@ -206,57 +242,145 @@ def handle_max_mamba_cache(self: ModelRunner, total_rest_memory): def set_num_tokens_hybrid_swa(self: ModelRunner): page_size = self.server_args.page_size - assert self.sliding_window_size is not None and self.sliding_window_size > 0 full_layers_num = len(self.model_config.full_attention_layer_ids) swa_layers_num = len(self.model_config.swa_attention_layer_ids) - assert swa_layers_num > 0, "Hybrid SWA model must have at least one SWA layer" - def align_page_size(x: int) -> int: + def align_to_page(x: int) -> int: return (x // page_size) * page_size + # Special case: all layers are SWA (no full attention layers) + # FIXME: maybe remove this special case as it can also be handled by the general case if full_layers_num == 0: - # all layers are SWA - self.swa_max_total_num_tokens = align_page_size(self.max_total_num_tokens) + self.swa_max_total_num_tokens = align_to_page(self.max_total_num_tokens) self.full_max_total_num_tokens = 0 self.max_total_num_tokens = self.swa_max_total_num_tokens logger.info( - f"Use sliding window memory pool (all SWA). swa_layer_tokens={self.swa_max_total_num_tokens}" + f"Use sliding window memory pool (all SWA). " + f"swa_layer_tokens={self.swa_max_total_num_tokens}" ) return - # Algorithm: - # Existing max_total_num_tokens is per layer and assume all layers have the same number of tokens. - # - Find total # of tokens available across layers. - # - Calculate full_max_total_num_tokens and swa_max_total_num_tokens based on the given swa_full_tokens_ratio. + # Memory allocation for hybrid SWA models: + # + # Given: + # - ratio = swa_full_tokens_ratio (SWA tokens per layer / full tokens per layer) + # - total_tokens_across_layers = max_total_num_tokens * num_layers + # + # Let full_tokens = tokens per full attention layer, then: + # - swa_tokens = ratio * full_tokens + # - total_tokens_across_layers = full_tokens * full_layers_num + swa_tokens * swa_layers_num + # = full_tokens * (full_layers_num + ratio * swa_layers_num) + # + # Solving for full_tokens: + # full_tokens = total_tokens_across_layers / (full_layers_num + ratio * swa_layers_num) + total_tokens = self.max_total_num_tokens * self.model_config.num_hidden_layers - swa_full_tokens_ratio = self.server_args.swa_full_tokens_ratio + ratio = self.server_args.swa_full_tokens_ratio + denominator = full_layers_num + ratio * swa_layers_num + assert denominator > 0, ( + f"Invalid denominator={denominator}: " + f"ratio={ratio}, swa_layers={swa_layers_num}, full_layers={full_layers_num}" + ) - # Solve the equations: - # 1. swa_max_total_num_tokens * swa_layers_num + full_max_total_num_tokens * full_layers_num == total_tokens - # 2. full_max_total_num_tokens * swa_full_tokens_ratio == swa_max_total_num_tokens - denominator = swa_full_tokens_ratio * swa_layers_num + full_layers_num - assert ( - denominator > 0 - ), f"Invalid denominator={denominator} for swa_full_tokens_ratio={swa_full_tokens_ratio} and swa_layers_num={swa_layers_num} and full_layers_num={full_layers_num}" self.full_max_total_num_tokens = int(total_tokens / denominator) - self.swa_max_total_num_tokens = int( - self.full_max_total_num_tokens * swa_full_tokens_ratio - ) + self.swa_max_total_num_tokens = int(self.full_max_total_num_tokens * ratio) - self.full_max_total_num_tokens = align_page_size(self.full_max_total_num_tokens) - self.swa_max_total_num_tokens = align_page_size(self.swa_max_total_num_tokens) + # Align to page boundaries + self.full_max_total_num_tokens = align_to_page(self.full_max_total_num_tokens) + self.swa_max_total_num_tokens = align_to_page(self.swa_max_total_num_tokens) self.max_total_num_tokens = self.full_max_total_num_tokens logger.info( - f"Use sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}" + f"Use sliding window memory pool. " + f"full_layer_tokens={self.full_max_total_num_tokens}, " + f"swa_layer_tokens={self.swa_max_total_num_tokens}" + ) + + def set_num_tokens_hybrid_swa_compress(self: ModelRunner): + """Set memory pool sizes for DSv4 compressed attention. + + This is a standalone code path for DSv4 that directly profiles available bytes + and calculates all pool sizes considering: + - SWA KV pool + - C4/C128 KV pools + - C4 Indexer pool + - C4/C128 compress state pools (paged) + - C4 indexer compress state pool + """ + from sglang.srt.model_executor.memory_profiler import DSv4MemoryCalculator + + self.state_dtype = torch.float32 + logger.info(f"DSv4 compressed attention: kv_cache_dtype={self.kv_cache_dtype}") + logger.info(f"DSv4 compressed attention: state_dtype={self.state_dtype}") + + page_size = self.server_args.page_size + assert ( + page_size % 128 == 0 + ), "page_size must be multiple of 128 for compressed attention" + + if not self.spec_algorithm.is_none() and self.is_draft_worker: + config = getattr(self.server_args, "_draft_pool_config", None) + assert ( + config is not None + ), "Draft worker requires target's pool config but _draft_pool_config is not set." + self.full_max_total_num_tokens = config["full_max_total_num_tokens"] + self.swa_max_total_num_tokens = config["swa_max_total_num_tokens"] + self.c4_max_total_num_tokens = 0 + self.c128_max_total_num_tokens = 0 + self.c4_state_pool_size = 0 + self.c128_state_pool_size = 0 + + logger.info( + f"DSv4 pool sizes (DRAFT): using TARGET's pool sizes - " + f"full={self.full_max_total_num_tokens}, " + f"swa={self.swa_max_total_num_tokens}" + ) + return + + # Calculate all pool sizes using DSv4MemoryCalculator + calculator = DSv4MemoryCalculator( + model_config=self.model_config, + page_size=page_size, + swa_ratio=self.server_args.swa_full_tokens_ratio, + is_speculative=self.server_args.speculative_algorithm is not None, + ) + + pool_sizes = calculator.get_pool_sizes_by_profiling(self) + if pool_sizes.full_max_total_num_tokens > self.max_total_num_tokens: + pool_sizes = calculator.get_pool_sizes_by_configuration( + max_total_tokens=self.max_total_num_tokens + ) + + # Set pool sizes to self + self.full_max_total_num_tokens = pool_sizes.full_max_total_num_tokens + self.swa_max_total_num_tokens = pool_sizes.swa_max_total_num_tokens + self.c4_max_total_num_tokens = pool_sizes.c4_max_total_num_tokens + self.c128_max_total_num_tokens = pool_sizes.c128_max_total_num_tokens + self.c4_state_pool_size = pool_sizes.c4_state_pool_size + self.c128_state_pool_size = pool_sizes.c128_state_pool_size + + if not self.spec_algorithm.is_none() and not self.is_draft_worker: + self.server_args._draft_pool_config = { + "full_max_total_num_tokens": self.full_max_total_num_tokens, + "swa_max_total_num_tokens": self.swa_max_total_num_tokens, + } + + logger.info( + f"DSv4 pool sizes: " + f"full={self.full_max_total_num_tokens}, " + f"swa={self.swa_max_total_num_tokens}, " + f"c4={self.c4_max_total_num_tokens}, " + f"c128={self.c128_max_total_num_tokens}, " + f"c4_state={self.c4_state_pool_size}, " + f"c128_state={self.c128_state_pool_size}" ) - def init_memory_pool(self: ModelRunner, total_gpu_memory: int): + def init_memory_pool(self: ModelRunner): max_num_reqs = self.server_args.max_running_requests - max_total_tokens = self.server_args.max_total_tokens - self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + max_total_tokens_configured = self.server_args.max_total_tokens + self.max_total_num_tokens = self.profile_max_num_token() if max_num_reqs is None: max_num_reqs = min( @@ -292,14 +416,16 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): max_num_reqs, self.server_args.max_running_requests // self.dp_size ) - if max_total_tokens is not None: - if max_total_tokens > self.max_total_num_tokens: + if max_total_tokens_configured is not None: + if max_total_tokens_configured > self.max_total_num_tokens: logging.warning( - f"max_total_tokens={max_total_tokens} is larger than the profiled value " + f"max_total_tokens={max_total_tokens_configured} is larger than the profiled value " f"{self.max_total_num_tokens}. " f"Use the profiled value instead." ) - self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens) + self.max_total_num_tokens = min( + self.max_total_num_tokens, max_total_tokens_configured + ) self.max_total_num_tokens = ( self.max_total_num_tokens @@ -322,7 +448,11 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): # create token size for hybrid cache if self.is_hybrid_swa: - self.set_num_tokens_hybrid_swa() + assert self.sliding_window_size is not None and self.sliding_window_size > 0 + if self.model_config.is_swa_with_compressed_attention: + self.set_num_tokens_hybrid_swa_compress() + else: + self.set_num_tokens_hybrid_swa() if not self.spec_algorithm.is_none() and not self.is_draft_worker: # Draft worker should use SWA adjusted max_total_num_tokens for cache size, otherwise it may cause oob in kv cache store @@ -399,7 +529,42 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): # Initialize token_to_kv_pool is_nsa_model = is_deepseek_nsa(self.model_config.hf_config) - if self.server_args.attention_backend == "ascend": + is_v4_model = is_deepseek_compressed(self.model_config.hf_config) + if is_v4_model: + + if envs.SGLANG_OPT_DPSK_V4_RADIX.get(): + swa_page_size = self.page_size + assert swa_page_size == 256, "In paged swa mode, page_size must be 256." + else: + swa_page_size = self.model_config.window_size + assert ( + swa_page_size == 128 + ), "In ring buffer swa mode, page_size must be 128." + + self.token_to_kv_pool = DeepSeekV4TokenToKVPool( + max_num_reqs=self.server_args.max_running_requests, + swa_size=self.swa_max_total_num_tokens, + c4_size=self.c4_max_total_num_tokens, + c128_size=self.c128_max_total_num_tokens, + c4_state_pool_size=self.c4_state_pool_size, + c128_state_pool_size=self.c128_state_pool_size, + page_size=self.page_size, + swa_page_size=swa_page_size, + dtype=self.kv_cache_dtype, + state_dtype=self.state_dtype, + qk_nope_head_dim=self.model_config.qk_nope_head_dim, + qk_rope_head_dim=self.model_config.qk_rope_head_dim, + indexer_head_dim=self.model_config.index_head_dim, + layer_num=self.num_effective_layers, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + compression_ratios=self.model_config.compress_ratios, + start_layer=self.start_layer, + end_layer=self.end_layer, + enable_hisparse=self.enable_hisparse, + ) + + elif self.server_args.attention_backend == "ascend": if self.use_mla_backend: from sglang.srt.hardware_backend.npu.memory_pool_npu import ( NPUMLATokenToKVPool, @@ -638,6 +803,10 @@ def init_memory_pool(self: ModelRunner, total_gpu_memory: int): kvcache=self.token_to_kv_pool, need_sort=need_sort, ) + if self.enable_hisparse: + self.token_to_kv_pool_allocator = HiSparseTokenToKVPoolAllocator( + self.token_to_kv_pool_allocator + ) else: assert self.is_draft_worker diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index cde44eb93d61..2209af5419dc 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -89,6 +89,7 @@ get_moe_runner_backend, should_use_flashinfer_cutlass_moe_fp4_allgather, ) +from sglang.srt.layers.moe.deepseek_v4_topk import HashTopK from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.kt_ep_wrapper import KTEPWrapperMethod @@ -291,13 +292,16 @@ def __init__( quant_config, prefix: str = "", is_nextn: bool = False, + is_hash_moe: bool = False, + is_deepseek_v4: bool = False, ): super().__init__() self.is_nextn = is_nextn self.weight = nn.Parameter( torch.empty((config.n_routed_experts, config.hidden_size)) ) - if config.topk_method == "noaux_tc": + + if config.topk_method == "noaux_tc" and not is_hash_moe: correction_bias_dtype = ( torch.bfloat16 if quant_config is not None @@ -321,18 +325,23 @@ def forward( forward_batch: ForwardBatch = None, ): if use_intel_amx_backend(self): - return torch.ops.sgl_kernel.weight_packed_linear( + logits = torch.ops.sgl_kernel.weight_packed_linear( hidden_states, self.weight, None, # bias True, # is_vnni ) + return logits if get_global_server_args().enable_deterministic_inference: - return F.linear(hidden_states, self.weight, None) - - if forward_batch is not None and nsa_use_prefill_cp(forward_batch): logits = F.linear(hidden_states, self.weight, None) + return logits + + # downstream do not support bf16 output; and for safety let's use the code path same as non-CP firstly + # if forward_batch is not None and nsa_use_prefill_cp(forward_batch): + # logits = F.linear(hidden_states, self.weight, None) + if False: + pass else: # NOTE: For some unknown reason, router_gemm seems degrade accept length. if ( @@ -352,7 +361,10 @@ def forward( hidden_states, self.weight, gemm_output_zero_allocator ) else: - logits = F.linear(hidden_states, self.weight, None) + # reference implementation convert both to float before compute + from sglang.jit_kernel.deepseek_v4 import linear_bf16_fp32 + + logits = linear_bf16_fp32(hidden_states, self.weight) return logits @@ -367,6 +379,7 @@ def __init__( prefix: str = "", alt_stream: Optional[torch.cuda.Stream] = None, is_nextn: bool = False, + is_deepseek_v4: bool = False, ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -383,6 +396,13 @@ def __init__( self.alt_stream = alt_stream self.is_nextn = is_nextn + # Special case: For DeepSeek V4 MTP layer, it does not use hash moe + if envs.SGLANG_DSV4_MODE.get() == "2604": + n_hash_layers = config.num_hash_layers + else: + n_hash_layers = getattr(config, "n_hash_layers", 0) + self.is_hash = layer_id < n_hash_layers and not (is_deepseek_v4 and is_nextn) + if self.tp_size > config.n_routed_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " @@ -400,6 +420,8 @@ def __init__( quant_config=quant_config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn, + is_hash_moe=self.is_hash, + is_deepseek_v4=is_deepseek_v4, ) # scaling factor for fused shared experts on AMD-platform. @@ -424,31 +446,51 @@ def __init__( routing_method_type=getattr( config, "routing_method_type", RoutingMethodType.DeepSeekV3 ), + swiglu_limit=getattr(config, "swiglu_limit", None), prefix=add_prefix("experts", prefix), ) - self.topk = TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, - layer_id=self.layer_id, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - quant_config=quant_config, - routed_scaling_factor=self.routed_scaling_factor, - apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, - fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor, - # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized - # and requires the output format to be standard (except trtllm). We use quant_config to determine the output format. - output_format=( - TopKOutputFormat.STANDARD - if (quant_config is None) - and (not get_moe_runner_backend().is_flashinfer_trtllm()) - else None - ), - ) + self.use_grouped_topk = config.n_group > config.topk_group + # Remove this b/c it seems both field always exists, and config object cannot be `get` + # if config.get("topk_group", None) and config.get("n_group", None): + # self.use_grouped_topk = config.n_group > config.topk_group + # else: + # self.use_grouped_topk = False + + if self.is_hash and not (is_nextn and is_deepseek_v4): + self.topk = HashTopK( + topk=config.num_experts_per_tok + self.num_fused_shared_experts, + num_experts=config.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, + vocab_size=config.vocab_size, + scoring_func=config.scoring_func, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, + ) + else: + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + layer_id=self.layer_id, + renormalize=config.norm_topk_prob, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + scoring_func=config.scoring_func, + correction_bias=self.gate.e_score_correction_bias, + quant_config=quant_config, + routed_scaling_factor=self.routed_scaling_factor, + apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, + fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor, + # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized + # and requires the output format to be standard (except trtllm). We use quant_config to determine the output format. + output_format=( + TopKOutputFormat.STANDARD + if (quant_config is None) + and (not get_moe_runner_backend().is_flashinfer_trtllm()) + else None + ), + ) self.shared_experts_is_int8 = False self.shared_experts_is_fp8 = False @@ -535,6 +577,9 @@ def __init__( ) self._fuse_shared_experts_inside_sbo = SboFlags.fuse_shared_experts_inside_sbo() + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert hasattr(self, "shared_experts") + def get_moe_weights(self): return [ x.data @@ -552,12 +597,16 @@ def forward( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: if not self._enable_a2a_moe: from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if ( - self.alt_stream is not None + # NOTE temporarily disable dual stream + False + and self.alt_stream is not None and self.num_fused_shared_experts == 0 and hidden_states.shape[0] > 0 and get_is_capture_mode() @@ -574,9 +623,13 @@ def forward( should_allreduce_fusion, use_reduce_scatter, gemm_output_zero_allocator, + input_ids, + input_ids_global=input_ids_global, ) else: - return self.forward_deepep(hidden_states, forward_batch) + return self.forward_deepep( + hidden_states, forward_batch, input_ids_global=input_ids_global + ) def forward_normal_dual_stream( self, @@ -617,6 +670,8 @@ def forward_normal( should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, gemm_output_zero_allocator: BumpAllocator = None, + input_ids: Optional[torch.Tensor] = None, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -632,7 +687,8 @@ def forward_normal( ) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states, gemm_output_zero_allocator) - topk_output = self.topk(hidden_states, router_logits) + topk_kwargs = {"input_ids": input_ids_global} if self.is_hash else {} + topk_output = self.topk(hidden_states, router_logits, **topk_kwargs) else: shared_output = None topk_output = self.topk.empty_topk_output(hidden_states.device) @@ -671,6 +727,12 @@ def _post_combine_hook( hidden_states, topk_output, ) + # Apply routed_scaling_factor BEFORE dumping moe_expert_output so the + # dump semantics match the CUDA path (where scaling is applied inside + # fused_experts_impl via moe_sum_reduce). Without this, AMD dumps were + # pre-scaling while CUDA dumps were post-scaling, producing a spurious + # ~7.7% rel_diff that masked real differences. + _skip_scaling = envs.SGLANG_FORCE_TRITON_MOE_FP8.get() if ( not _is_cuda and not _use_aiter @@ -678,6 +740,14 @@ def _post_combine_hook( ): # fused in biased_grouped_topk so we can skip here final_hidden_states *= self.routed_scaling_factor + elif ( + _use_aiter + and not self.experts.should_fuse_routed_scaling_factor_in_topk + and not _skip_scaling + ): + final_hidden_states *= self.routed_scaling_factor + # For _skip_scaling: scaling was already applied inside + # fused_experts_impl's moe_sum_reduce_triton (CUDA-parity path). if shared_output is not None: final_hidden_states += shared_output if ( @@ -751,6 +821,7 @@ def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch, + input_ids_global: Optional[torch.Tensor] = None, ) -> torch.Tensor: shared_output = None sbo_enabled_flag = self._fuse_shared_experts_inside_sbo and not self.is_nextn @@ -773,6 +844,7 @@ def forward_deepep( shared_event = self.alt_stream.record_event() else: shared_output = self._forward_shared_experts(hidden_states) + topk_kwargs = {"input_ids": input_ids_global} if self.is_hash else {} topk_output = self.topk( hidden_states, router_logits, @@ -780,6 +852,7 @@ def forward_deepep( expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, ), + **topk_kwargs, ) else: topk_output = self.topk.empty_topk_output(hidden_states.device) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py new file mode 100644 index 000000000000..4bb5cea84b1c --- /dev/null +++ b/python/sglang/srt/models/deepseek_v4.py @@ -0,0 +1,2805 @@ +from __future__ import annotations + +import concurrent.futures +import logging +import os +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl + +import sglang.srt.models.deepseek_v2 as deepseek_v2 +from sglang.jit_kernel.deepseek_v4 import fused_rope, linear_bf16_fp32 +from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config +from sglang.srt.debug_utils.deepseek_v4_debug_utils import deepseek_v4_moe_code_path_checker +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size +from sglang.srt.environ import envs +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.attention.nsa.nsa_indexer import rotate_activation +from sglang.srt.layers.attention.nsa.utils import ( + can_cp_split, + cp_all_gather_rerange_output, + cp_split_and_rebuild_data, + cp_split_and_rebuild_position, + is_nsa_enable_prefill_cp, + nsa_use_prefill_cp, + prepare_input_dp_with_cp_dsa, +) +from sglang.srt.layers.communicator import LayerScatterModes, get_attn_tp_context +from sglang.srt.layers.deepseek_v4_rope import apply_rotary_emb_triton +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + dp_gather_partial, + dp_scatter, + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, + get_global_dp_buffer, + get_local_dp_buffer, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8 +from sglang.srt.layers.rotary_embedding import get_rope_wrapper +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding +from sglang.srt.mem_cache.compress_state import ( + CompressStatePool, + KVAndScore, + KVAndScoreOld, +) +from sglang.srt.mem_cache.deepseekv4_memory_pool import DeepSeekV4TokenToKVPool +from sglang.srt.mem_cache.memory_pool import RadixAttention +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.model_loader.utils import maybe_executor_submit, should_async_load +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.dbrx import ReplicatedLinear +from sglang.srt.models.deepseek_v2 import ParallelLMHead, _is_cuda, _is_hip, _is_npu +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ( + BumpAllocator, + LazyValue, + add_prefix, + get_bool_env_var, + log_info_on_rank0, + make_layers, + maybe_torch_compile, +) + +logger = logging.getLogger(__name__) + +from sglang.srt.environ import envs + +MOE_BIT_WISE_EQUAL_MODE = False +ATTN_BIT_WISE_EQUAL_MODE = False +COMPRESSOR_BIT_WISE_EQUAL_MODE = False +_FP8_WO_A_GEMM = envs.SGLANG_OPT_FP8_WO_A_GEMM.get() + + +if TYPE_CHECKING: + from sglang.srt.layers.attention.deepseek_v4_backend import DeepseekV4Backend + from sglang.srt.layers.quantization import QuantizationConfig + from sglang.srt.layers.rotary_embedding import RotaryEmbedding + from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + PPProxyTensors, + ) + + +class DeepseekRefRMSNorm(nn.Module): + """ + Root Mean Square Layer Normalization (RMSNorm). + + Args: + dim (int): Dimension of the input tensor. + eps (float): Epsilon value for numerical stability. Defaults to 1e-6. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + # rmsnorm in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for convenient. + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + """ + Forward pass for RMSNorm. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor with the same shape as input. + """ + out = rms_normalize_triton(x, self.eps, self.weight) + return out + + +@maybe_torch_compile +def rms_normalize(x: torch.Tensor, eps: float) -> torch.Tensor: + x *= torch.rsqrt(x.square().mean(-1, keepdim=True) + eps) + return x + + +@triton.jit +def _rms_normalize_kernel( + x_ptr, + weight_ptr, + eps, + stride_row, + dim, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, +): + pid = tl.program_id(0) + + offs = tl.arange(0, BLOCK_SIZE) + mask = offs < dim + + base = pid * stride_row + x = tl.load(x_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + + # x / sqrt(mean(x^2) + eps) + mean_sq = tl.sum(x * x, axis=0) / dim + rms_inv = tl.rsqrt(mean_sq + eps) + out = x * rms_inv + + if HAS_WEIGHT: + weight = tl.load(weight_ptr + offs, mask=mask, other=0.0) + out = out * weight + + tl.store(x_ptr + base + offs, out, mask=mask) + + +def rms_normalize_triton( + x: torch.Tensor, eps: float, weight: torch.Tensor = None +) -> torch.Tensor: + """RMS normalize with optional weight. + + Args: + x: Input tensor of shape (..., dim), normalizes over last dimension + eps: Epsilon for numerical stability + weight: Optional weight tensor of shape (dim,) + """ + dim = x.shape[-1] + x_flat = x.view(-1, dim) + num_rows = x_flat.shape[0] + + BLOCK_SIZE = triton.next_power_of_2(dim) + grid = (num_rows,) + + _rms_normalize_kernel[grid]( + x_flat, + weight, + eps, + x_flat.stride(0), + dim, + BLOCK_SIZE=BLOCK_SIZE, + HAS_WEIGHT=(weight is not None), + ) + return x + + +class Compressor(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + is_in_indexer: bool, + rotary_emb: RotaryEmbedding, + freqs_cis: torch.Tensor, # TODO: remove it after using rotary embedding + compress_ratio: Literal[0, 4, 128], + head_dim: int, + rotate: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.layer_id = layer_id + self.is_in_indexer = is_in_indexer + self.dim = config.hidden_size + self.head_dim = head_dim + self.rope_head_dim = getattr(config, "qk_rope_head_dim", 64) + self.nope_head_dim = head_dim - self.rope_head_dim + assert compress_ratio != 0, "compress_ratio should not be 0" + self.ratio = compress_ratio + self.overlap = self.ratio == 4 + self.rotate = rotate + self.coff = coff = 1 + self.overlap + + self.ape = nn.Parameter( + torch.empty(self.ratio, coff * self.head_dim, dtype=torch.float32) + ) + # fuse wkv and wgate into wkv_gate, merge the last dim + wkv_gate_dtype = torch.bfloat16 + self.wkv_gate = ReplicatedLinear( + self.dim, + 2 * coff * self.head_dim, + bias=False, + quant_config=None, + prefix=add_prefix("wkv_gate", prefix), + params_dtype=wkv_gate_dtype, + ) + # self.norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.norm = DeepseekRefRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.rotary_emb = rotary_emb + self.freqs_cis = freqs_cis + + self.ape_converted = False + + @cached_property + def use_fused_compress(self) -> bool: + if ( + envs.SGLANG_OPT_USE_FUSED_PAGED_COMPRESS.get() + and envs.SGLANG_OPT_DPSK_V4_RADIX.get() + ): + return True + return ( + envs.SGLANG_OPT_USE_FUSED_COMPRESS.get() + and not envs.SGLANG_OPT_DPSK_V4_RADIX.get() + ) + + def apply_ape_hotfix(self): + assert not self.ape_converted + self.ape_converted = True + + # ========== copied from the hotfix in "260119-updated" of ref code ========== + is_model_2604 = envs.SGLANG_DSV4_MODE.get() == "2604" + if ( + self.overlap + and not is_model_2604 + and get_bool_env_var("SGLANG_ENABLE_APE_HOTFIX", "1") + ): + # NOTE: We reorder the parameters here to match the layout of the provided checkpoint. + # This is only required for compatibility with this checkpoint; the official version + # does not need this reordering. + ape = torch.chunk(self.ape.data, 2, dim=-1) + if self.use_fused_compress: + ape = torch.cat([ape[1], ape[0]], dim=0) + else: + ape = torch.cat([ape[1], ape[0]], dim=-1) + self.ape.data.copy_(ape.view(self.ratio, -1)) + # ============================================================================ + + def _get_states(self, forward_batch: ForwardBatch) -> KVAndScore: + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + if self.is_in_indexer: + return token_to_kv_pool.get_indexer_compress_states(self.layer_id) + else: + return token_to_kv_pool.get_attention_compress_states(self.layer_id) + + def _get_state_pool(self, forward_batch: ForwardBatch) -> CompressStatePool: + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + if self.is_in_indexer: + ret = token_to_kv_pool.get_indexer_compress_states(self.layer_id) + else: + ret = token_to_kv_pool.get_attention_compress_states(self.layer_id) + + assert isinstance(ret, CompressStatePool) + + return ret + + def overlap_transform(self, tensor: torch.Tensor, fill_value: Any) -> torch.Tensor: + # tensor: [block_num, r, 2 * d] + assert tensor.dim() == 3 + assert tensor.shape[1:] == (self.ratio, 2 * self.head_dim) + + s, r, d = tensor.size(0), self.ratio, self.head_dim + new_tensor = tensor.new_full((s, 2 * r, d), fill_value) + new_tensor[:, r:] = tensor[:, :, d:] + new_tensor[1:, :r] = tensor[:-1, :, :d] + return new_tensor + + def overlap_transform_decode(self, tensor: torch.Tensor) -> torch.Tensor: + # NOTE: the default value has been initialized when creating the states + # tensor: [bs, 2 * r, 2 * d] + assert tensor.dim() == 3 + assert tensor.shape[1:] == (2 * self.ratio, 2 * self.head_dim) + r, d = self.ratio, self.head_dim + ret = torch.cat((tensor[:, :r, :d], tensor[:, r:, d:]), dim=1) + return ret + + @staticmethod + def compute_state_len(seq_len: int, ratio: int): + """Tailing length for the valid states in kv cache. + When overlap is enabled, there is always an extra block: [extra block, compressing part] + """ + return seq_len % ratio + (ratio == 4) * ratio + + @staticmethod + def compute_state_len_indices(seq_len: int, ratio: int): + state_len = seq_len % ratio + (ratio == 4) * ratio + # NOTE: -1 here means invalid position + return torch.arange(seq_len - state_len, seq_len).clamp(min=-1) + + def print_tensor(self, y: torch.Tensor, name: str): + enable = int(os.environ.get("SGLANG_ENABLE_PRINT_TENSOR", 0)) + if enable: + print(f"[sgl] {name}: shape={y.shape}, dtype={y.dtype}, device={y.device}") + print(f"{y.flatten()[:10]}...{y.flatten()[-10:]}") + + def compress_extend_paged( + self, + kv_and_scores: KVAndScore, + forward_batch: ForwardBatch, + ): + backend = forward_batch.attn_backend + if TYPE_CHECKING: + assert isinstance(backend, DeepseekV4Backend) + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + + # extract some info + state_pool = self._get_state_pool(forward_batch) + prefix_lens = forward_batch.extend_prefix_lens_cpu + extend_lens = forward_batch.extend_seq_lens_cpu + req_pool_indices = forward_batch.req_pool_indices + req_to_token = forward_batch.req_to_token_pool.req_to_token + assert not self.forward_mode.is_target_verify() + + assert extend_lens is not None and prefix_lens is not None + device = kv_and_scores.kv.device + + # Deliberately fill w/ huge values, s.t. when misuse and access the unfilled values, + # we have higher probability to see something very weird + assert kv_and_scores.kv.shape[-1] == self.head_dim * self.coff + compressed_kv_output = torch.full( + (kv_and_scores.kv.size(0), self.head_dim), + fill_value=10000.0, + dtype=kv_and_scores.kv.dtype, + device=device, + ) + + bs = forward_batch.batch_size + pt = 0 + for i in range(bs): + kv_and_score = kv_and_scores[pt : pt + extend_lens[i]] + pre_state_indices = self.compute_state_len_indices( + seq_len=prefix_lens[i], ratio=self.ratio + ).to(device) + raw_loc = torch.where( + pre_state_indices < 0, + -1, + req_to_token[req_pool_indices[i], pre_state_indices], + ) + swa_loc = token_to_kv_pool.translate_loc_from_full_to_swa(raw_loc) + state_loc = state_pool.translate_from_swa_loc_to_state_loc(swa_loc) + pre_kv_state = state_pool.get_state_by_state_loc(state_loc) + kv_and_score_buffer = KVAndScore.cat([pre_kv_state, kv_and_score], dim=0) + valid_kv_len = kv_and_score_buffer.kv.size(0) + + post_state_indices = self.compute_state_len_indices( + seq_len=prefix_lens[i] + extend_lens[i], ratio=self.ratio + ).to(device) + post_state_len = post_state_indices.size(0) + + # write to kv_and_score_states + assert post_state_len <= valid_kv_len + post_raw_loc = torch.where( + post_state_indices < 0, + -1, + req_to_token[req_pool_indices[i], post_state_indices], + ) + post_swa_loc = token_to_kv_pool.translate_loc_from_full_to_swa(post_raw_loc) + post_state_loc = state_pool.translate_from_swa_loc_to_state_loc( + post_swa_loc + ) + post_state_to_set = kv_and_score_buffer[valid_kv_len - post_state_len :] + state_pool.set_state_by_state_loc(post_state_loc, post_state_to_set) + + # Get the part that can be compressed (ratio-aligned) + compress_len = valid_kv_len // self.ratio * self.ratio + if compress_len == 0: + # Nothing to compress yet, just update pointers + pt += extend_lens[i] + continue + + # kv to compress: [compressed_len, ratio, head_dim * coff] + kv_and_score_to_compress = kv_and_score_buffer[:compress_len].view( + compress_len // self.ratio, self.ratio, -1 + ) + # NOTE: apply ape only when compressing + kv_and_score_to_compress.score.add_(self.ape.unsqueeze(0)) + + # Apply overlap transformation if enabled + if self.overlap: + new_kv = self.overlap_transform( + kv_and_score_to_compress.kv, fill_value=0 + ) + new_score = self.overlap_transform( + kv_and_score_to_compress.score, fill_value=float("-inf") + ) + kv_and_score_to_compress = KVAndScore.from_kv_score( + kv=new_kv, score=new_score + ) + del new_kv, new_score + # remove the first block before compression + kv_and_score_to_compress = kv_and_score_to_compress[1:] + + if kv_and_score_to_compress.kv.size(0) == 0: + pt += extend_lens[i] + continue + + kv_compressed = ( + kv_and_score_to_compress.kv + * kv_and_score_to_compress.score.softmax(dim=1) + ).sum(dim=1) + + # NOTE: ref code requires dtype as the same as hidden states (float32) + # the raw output of kv_compressed is float32 already + assert kv_compressed.dtype == torch.float32 + kv_compressed = self.norm(kv_compressed) + + beg_idx = prefix_lens[i] // self.ratio * self.ratio + end_idx = (prefix_lens[i] + extend_lens[i]) // self.ratio * self.ratio + freqs_cis = self.freqs_cis[beg_idx : end_idx : self.ratio] + assert freqs_cis.size(0) == kv_compressed.size( + 0 + ), f"{freqs_cis.shape=} {kv_compressed.shape=}" + apply_rotary_emb_triton( + kv_compressed[..., -self.rope_head_dim :], freqs_cis + ) + del beg_idx, end_idx + + if self.rotate: + kv_compressed = rotate_activation(kv_compressed) + + # get all the pos: ratio * n + (ratio - 1) > prefix_len - 1 + start = prefix_lens[i] + start = start + self.ratio - 1 - start % self.ratio + indices_in_seq = torch.arange( + start, + prefix_lens[i] + extend_lens[i], + self.ratio, + device=kv_and_scores.kv.device, + ) + assert indices_in_seq.size(0) == kv_compressed.size(0) + compressed_kv_output[indices_in_seq - prefix_lens[i] + pt] = kv_compressed + + pt += extend_lens[i] + + return compressed_kv_output + + def compress_decode_paged( + self, + kv_and_scores: KVAndScore, + forward_batch: ForwardBatch, + ): + """Paged and cudagraph compatible version of compress_decode""" + assert self.ape_converted + state_pool = self._get_state_pool(forward_batch) + token_to_kv_pool = forward_batch.token_to_kv_pool + assert isinstance(token_to_kv_pool, DeepSeekV4TokenToKVPool) + req_pool_indices = forward_batch.req_pool_indices + req_to_token = forward_batch.req_to_token_pool.req_to_token + seq_lens = forward_batch.seq_lens + + if forward_batch.forward_mode.is_target_verify(): + draft_tokens = forward_batch.attn_backend.speculative_num_draft_tokens + offsets = torch.arange(1, draft_tokens + 1, device=seq_lens.device) + seq_lens_2d = seq_lens[:, None] + offsets[None, :] + seq_lens = seq_lens_2d.view(-1) + req_pool_indices = req_pool_indices.repeat_interleave(draft_tokens) + + raw_locs = req_to_token[req_pool_indices, seq_lens - 1] + + # Update the new decode states + swa_locs = token_to_kv_pool.translate_loc_from_full_to_swa(raw_locs) + state_locs = state_pool.translate_from_swa_loc_to_state_loc(swa_locs) + state_pool.set_state_by_state_loc(state_locs, kv_and_scores) + + compress_bulk_len = self.ratio * self.coff + compress_indices = seq_lens[:, None] + torch.arange( + -compress_bulk_len, 0, device=seq_lens.device + ) + compress_indices.clamp_(min=-1) + compress_indices_raw = torch.where( + compress_indices < 0, + -1, + req_to_token[req_pool_indices[:, None], compress_indices], + ) + compress_indices_swa = token_to_kv_pool.translate_loc_from_full_to_swa( + compress_indices_raw + ) + compress_indices_state = state_pool.translate_from_swa_loc_to_state_loc( + compress_indices_swa + ) + kv_and_score_to_compress = state_pool.get_state_by_state_loc( + compress_indices_state.view(-1) + ).view(-1, self.ratio, self.coff * self.head_dim) + kv_and_score_to_compress.score.add_(self.ape.unsqueeze(0)) + + bs = seq_lens.size(0) + if self.overlap: + # shape: [bs, coff * ratio, coff * head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + bs, self.coff * self.ratio, self.coff * self.head_dim + ) + kv_and_score_to_compress = KVAndScore.from_kv_score( + kv=self.overlap_transform_decode(kv_and_score_to_compress.kv), + score=self.overlap_transform_decode(kv_and_score_to_compress.score), + ) + + self.print_tensor(kv_and_score_to_compress.kv, "kv_to_compress") + self.print_tensor(kv_and_score_to_compress.score, "score_to_compress") + + # kv_to_compress: [bs, ratio * coff, head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + bs, self.ratio * self.coff, self.head_dim + ) + + kv_compressed = ( + kv_and_score_to_compress.kv * kv_and_score_to_compress.score.softmax(dim=1) + ).sum(dim=1) + self.print_tensor(kv_compressed, "kv_before_norm") + kv_compressed = self.norm(kv_compressed) + self.print_tensor(kv_compressed, "kv_after_norm") + freqs_cis = self.freqs_cis[(seq_lens - 1) // self.ratio * self.ratio] + self.print_tensor(freqs_cis, "freqs_cis") + apply_rotary_emb_triton(kv_compressed[..., -self.rope_head_dim :], freqs_cis) + self.print_tensor(kv_compressed, "kv_after_rope") + if self.rotate: + kv_compressed = rotate_activation(kv_compressed) + + # `new_compressed_list` format is only used for testing + self.print_tensor(kv_compressed, "compressed_kv_output") + return kv_compressed + + def compress_extend( + self, + kv_and_scores: KVAndScore, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + assert self.ape_converted # Please keep this assertion + + # kv_and_score_states: [max_num_reqs, compress_ratio * coff, head_dim * coff] + kv_and_score_states = self._get_states(forward_batch) + _, _, head_dim_times_coff = kv_and_score_states.kv.shape + + # extract some info + prefix_lens = forward_batch.extend_prefix_lens_cpu + extend_lens = forward_batch.extend_seq_lens_cpu + req_pool_indices = forward_batch.req_pool_indices + assert extend_lens is not None and prefix_lens is not None + + # compress info + # TODO: reuse the buffer across layers and reduce the sizes + max_buffer_size = 2 * kv_and_score_states.shape[1] + kv_and_scores.shape[0] + temp_buffer_shape = [max_buffer_size, head_dim_times_coff] + temp_buffer = kv_and_scores.new_empty(temp_buffer_shape) + + # Deliberately fill w/ huge values, s.t. when misuse and access the unfilled values, + # we have higher probability to see something very weird + assert kv_and_scores.kv.shape[-1] == self.head_dim * self.coff + compressed_kv_output = torch.full( + (kv_and_scores.kv.size(0), self.head_dim), + fill_value=10000.0, + dtype=kv_and_scores.kv.dtype, + device=kv_and_scores.kv.device, + ) + + bs = forward_batch.batch_size + pt = 0 + for i in range(bs): + # Definitions of variables + # + # kv_and_score_state: (compress_ratio * coff, head_dim * coff) + # only it[:old_valid_state_len] has valid data + # + # kv_and_score_buffer: (old_valid_state_len + valid_kv_len, head_dim * coff) + # content is cat(kv_and_score_state[:old_valid_state_len], kv_and_score) + + kv_and_score = kv_and_scores[pt : pt + extend_lens[i]] + kv_and_score_state = kv_and_score_states[req_pool_indices[i]] + if prefix_lens[i] == 0: + # NOTE: padding with default values for overlap + kv_and_score_state.clear() + + # Create kv_and_score_buffer + pre_state_len = self.compute_state_len( + seq_len=prefix_lens[i], ratio=self.ratio + ) + valid_kv_len = pre_state_len + extend_lens[i] + kv_and_score_buffer = temp_buffer[:valid_kv_len] + kv_and_score_buffer[:pre_state_len] = kv_and_score_state[:pre_state_len] + kv_and_score_buffer[pre_state_len:valid_kv_len] = kv_and_score + + # Write to kv_and_score_states + post_state_len = self.compute_state_len( + seq_len=valid_kv_len, ratio=self.ratio + ) + kv_and_score_state[:post_state_len] = kv_and_score_buffer[ + valid_kv_len - post_state_len : valid_kv_len + ] + + # Get the part that can be compressed (ratio-aligned) + compress_len = valid_kv_len // self.ratio * self.ratio + if compress_len == 0: + # Nothing to compress yet, just update pointers + pt += extend_lens[i] + continue + + # kv to compress: [compressed_len, ratio, head_dim * coff] + kv_and_score_to_compress = kv_and_score_buffer[:compress_len].view( + compress_len // self.ratio, self.ratio, -1 + ) + # NOTE: apply ape only when compressing + kv_and_score_to_compress.score.add_(self.ape.unsqueeze(0)) + + # Apply overlap transformation if enabled + if self.overlap: + new_kv = self.overlap_transform( + kv_and_score_to_compress.kv, fill_value=0 + ) + new_score = self.overlap_transform( + kv_and_score_to_compress.score, fill_value=float("-inf") + ) + kv_and_score_to_compress = KVAndScore.from_kv_score( + kv=new_kv, score=new_score + ) + del new_kv, new_score + # remove the first block before compression + kv_and_score_to_compress = kv_and_score_to_compress[1:] + + if kv_and_score_to_compress.kv.size(0) == 0: + pt += extend_lens[i] + continue + + kv_compressed = ( + kv_and_score_to_compress.kv + * kv_and_score_to_compress.score.softmax(dim=1) + ).sum(dim=1) + + # NOTE: ref code requires dtype as the same as hidden states (float32) + # the raw output of kv_compressed is float32 already + assert kv_compressed.dtype == torch.float32 + kv_compressed = self.norm(kv_compressed) + + beg_idx = prefix_lens[i] // self.ratio * self.ratio + end_idx = (prefix_lens[i] + extend_lens[i]) // self.ratio * self.ratio + freqs_cis = self.freqs_cis[beg_idx : end_idx : self.ratio] + assert freqs_cis.size(0) == kv_compressed.size( + 0 + ), f"{freqs_cis.shape=} {kv_compressed.shape=}" + apply_rotary_emb_triton( + kv_compressed[..., -self.rope_head_dim :], freqs_cis + ) + del beg_idx, end_idx + + if self.rotate: + kv_compressed = rotate_activation(kv_compressed) + + # get all the pos: ratio * n + (ratio - 1) > prefix_len - 1 + start = prefix_lens[i] + start = start + self.ratio - 1 - start % self.ratio + indices_in_seq = torch.arange( + start, + prefix_lens[i] + extend_lens[i], + self.ratio, + device=kv_and_scores.kv.device, + ) + assert indices_in_seq.size(0) == kv_compressed.size(0) + compressed_kv_output[indices_in_seq - prefix_lens[i] + pt] = kv_compressed + + pt += extend_lens[i] + + return compressed_kv_output + + @maybe_torch_compile + def compress_decode( + self, + kv_and_scores: KVAndScore, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + assert self.ape_converted # Please keep this assertion + + seq_lens = forward_batch.seq_lens + kv_and_score_states_pool = self._get_states(forward_batch) + req_pool_indices = forward_batch.req_pool_indices + + # NOTE: first, write to the states + bs = kv_and_scores.kv.size(0) + write_pos = (seq_lens - 1) % self.ratio + self.overlap * self.ratio + kv_and_score_states_pool[req_pool_indices, write_pos] = kv_and_scores + + # NOTE: need to copy out before modifying overlap states + # kv_states: [bs, coff * ratio, coff * head_dim] + kv_and_score_to_compress = kv_and_score_states_pool[req_pool_indices] + + # Shift just compressed kv states left by ratio + if self.overlap: + should_shift = seq_lens % self.ratio == 0 + kv_and_score_states_pool[req_pool_indices, : self.ratio] = KVAndScore( + kv_score=torch.where( + should_shift[:, None, None], + kv_and_score_to_compress.kv_score[:, self.ratio :], + kv_and_score_to_compress.kv_score[:, : self.ratio], + ) + ) + + # shape: [bs * coff, ratio, coff * head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + -1, self.ratio, self.coff * self.head_dim + ) + kv_and_score_to_compress.score.add_(self.ape.unsqueeze(0)) + + if self.overlap: + # shape: [bs, coff * ratio, coff * head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + bs, self.coff * self.ratio, self.coff * self.head_dim + ) + kv_and_score_to_compress = KVAndScore.from_kv_score( + kv=self.overlap_transform_decode(kv_and_score_to_compress.kv), + score=self.overlap_transform_decode(kv_and_score_to_compress.score), + ) + + self.print_tensor(kv_and_score_to_compress.kv, "kv_to_compress") + self.print_tensor(kv_and_score_to_compress.score, "score_to_compress") + + # kv_to_compress: [bs, ratio * coff, head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + bs, self.ratio * self.coff, self.head_dim + ) + + kv_compressed = ( + kv_and_score_to_compress.kv * kv_and_score_to_compress.score.softmax(dim=1) + ).sum(dim=1) + self.print_tensor(kv_compressed, "kv_before_norm") + kv_compressed = self.norm(kv_compressed) + self.print_tensor(kv_compressed, "kv_after_norm") + freqs_cis = self.freqs_cis[(seq_lens - 1) // self.ratio * self.ratio] + self.print_tensor(freqs_cis, "freqs_cis") + apply_rotary_emb_triton(kv_compressed[..., -self.rope_head_dim :], freqs_cis) + self.print_tensor(kv_compressed, "kv_after_rope") + if self.rotate: + kv_compressed = rotate_activation(kv_compressed) + + # `new_compressed_list` format is only used for testing + self.print_tensor(kv_compressed, "compressed_kv_output") + return kv_compressed + + def compress_fused( + self, + kv_score: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + # TODO: this should be the final implementation after verifying correctness + backend = forward_batch.attn_backend + if TYPE_CHECKING: + assert isinstance(backend, DeepseekV4Backend) + is_paged = envs.SGLANG_OPT_DPSK_V4_RADIX.get() + if is_paged: + kv_score_buffer = self._get_state_pool(forward_batch) + kv_score_buffer = kv_score_buffer.kv_score_buffer.kv_score + else: + kv_score_buffer = self._get_states(forward_batch).kv_score + return backend.forward_compress( + kv_score_buffer=kv_score_buffer, + kv_score_input=kv_score, + ape=self.ape.view(-1, self.head_dim), + head_dim=self.head_dim, + norm=self.norm, + freqs_cis_cache=self.freqs_cis, + rotate=self.rotate, + compress_ratio=self.ratio, + forward_batch=forward_batch, + is_paged=is_paged, + ) + + def compress_dispatch( + self, + kv_score: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + if self.use_fused_compress: + return self.compress_fused(kv_score, forward_batch) + + if envs.SGLANG_OPT_USE_OLD_COMPRESSOR.get(): + kv = kv_score[:, : self.coff * self.head_dim] + score = kv_score[:, self.coff * self.head_dim :] + kv_and_scores = KVAndScoreOld(kv=kv, score=score) + self.compress_decode = self.compress_decode_old + self.compress_extend = self.compress_extend_old + else: + if envs.SGLANG_OPT_DPSK_V4_RADIX.get(): + self.compress_decode = self.compress_decode_paged + self.compress_extend = self.compress_extend_paged + kv_and_scores = KVAndScore(kv_score) + if TYPE_CHECKING: + assert isinstance(kv_and_scores, KVAndScore) + + if ( + forward_batch.forward_mode.is_decode() + or forward_batch.forward_mode.is_target_verify() + ): + result = self.compress_decode( + kv_and_scores=kv_and_scores, + forward_batch=forward_batch, + ) + elif forward_batch.forward_mode.is_extend(): + result = self.compress_extend( + kv_and_scores=kv_and_scores, + forward_batch=forward_batch, + ) + else: + msg = f"Forward mode {forward_batch.forward_mode} not supported in Compressor." + raise NotImplementedError(msg) + + return result + + def forward(self, x: torch.Tensor, forward_batch: ForwardBatch) -> torch.Tensor: + if forward_batch.forward_mode.is_idle(): + assert x.shape[0] == 0 + return x.new_empty(0, self.head_dim) + + self.forward_mode = forward_batch.forward_mode + + kv_score = linear_bf16_fp32(x, self.wkv_gate.weight) + return self.compress_dispatch(kv_score, forward_batch) + + def compress_extend_old( + self, kv_and_scores: KVAndScore, forward_batch: ForwardBatch + ) -> torch.Tensor: + assert self.ape_converted # Please keep this assertion + KVAndScore = KVAndScoreOld + + # kv_and_score_states: [max_num_reqs, compress_ratio * coff, head_dim * coff] + kv_and_score_states = self._get_states(forward_batch) + _, _, head_dim_times_coff = kv_and_score_states.kv.shape + + # extract some info + prefix_lens = forward_batch.extend_prefix_lens_cpu + extend_lens = forward_batch.extend_seq_lens_cpu + req_pool_indices = forward_batch.req_pool_indices + + # compress info + # TODO: reuse the buffer across layers and reduce the sizes + max_buffer_size = 2 * kv_and_score_states.shape[1] + kv_and_scores.shape[0] + temp_buffer_shape = [max_buffer_size, head_dim_times_coff] + temp_buffer = KVAndScore.empty_like(temp_buffer_shape, old=kv_and_scores) + + # Deliberately fill w/ huge values, s.t. when misuse and access the unfilled values, + # we have higher probability to see something very weird + assert kv_and_scores.kv.shape[-1] == self.head_dim * self.coff + compressed_kv_output = torch.full( + (kv_and_scores.kv.size(0), self.head_dim), + fill_value=10000.0, + dtype=kv_and_scores.kv.dtype, + device=kv_and_scores.kv.device, + ) + + bs = forward_batch.batch_size + pt = 0 + for i in range(bs): + # Definitions of variables + # + # kv_and_score_state: (compress_ratio * coff, head_dim * coff) + # only it[:old_valid_state_len] has valid data + # + # kv_and_score_buffer: (old_valid_state_len + valid_kv_len, head_dim * coff) + # content is cat(kv_and_score_state[:old_valid_state_len], kv_and_score) + + kv_and_score = kv_and_scores[pt : pt + extend_lens[i]] + kv_and_score_state = kv_and_score_states[req_pool_indices[i]] + if prefix_lens[i] == 0: + # NOTE: padding with default values for overlap + kv_and_score_state.clear() + + # Create kv_and_score_buffer + pre_state_len = self.compute_state_len( + seq_len=prefix_lens[i], ratio=self.ratio + ) + valid_kv_len = pre_state_len + extend_lens[i] + kv_and_score_buffer = temp_buffer[:valid_kv_len] + kv_and_score_buffer[:pre_state_len] = kv_and_score_state[:pre_state_len] + kv_and_score_buffer[pre_state_len:valid_kv_len] = kv_and_score + + # Write to kv_and_score_states + post_state_len = self.compute_state_len( + seq_len=valid_kv_len, ratio=self.ratio + ) + kv_and_score_state[:post_state_len] = kv_and_score_buffer[ + valid_kv_len - post_state_len : valid_kv_len + ] + + # Get the part that can be compressed (ratio-aligned) + compress_len = valid_kv_len // self.ratio * self.ratio + if compress_len == 0: + # Nothing to compress yet, just update pointers + pt += extend_lens[i] + continue + + # kv to compress: [compressed_len, ratio, head_dim * coff] + kv_and_score_to_compress = kv_and_score_buffer[:compress_len].view( + compress_len // self.ratio, self.ratio, -1 + ) + # NOTE: apply ape only when compressing + kv_and_score_to_compress.score = ( + kv_and_score_to_compress.score + self.ape.unsqueeze(0) + ) + + # Apply overlap transformation if enabled + if self.overlap: + kv_and_score_to_compress.kv = self.overlap_transform( + kv_and_score_to_compress.kv, 0 + ) + kv_and_score_to_compress.score = self.overlap_transform( + kv_and_score_to_compress.score, float("-inf") + ) + + # remove the first block before compression + kv_and_score_to_compress = kv_and_score_to_compress[1:] + + if kv_and_score_to_compress.kv.size(0) == 0: + pt += extend_lens[i] + continue + + kv_compressed = ( + kv_and_score_to_compress.kv + * kv_and_score_to_compress.score.softmax(dim=1) + ).sum(dim=1) + + # NOTE: ref code requires dtype as the same as hidden states (float32) + # the raw output of kv_compressed is float32 already + assert kv_compressed.dtype == torch.float32 + kv_compressed = self.norm(kv_compressed) + + beg_idx = prefix_lens[i] // self.ratio * self.ratio + end_idx = (prefix_lens[i] + extend_lens[i]) // self.ratio * self.ratio + freqs_cis = self.freqs_cis[beg_idx : end_idx : self.ratio] + assert freqs_cis.size(0) == kv_compressed.size( + 0 + ), f"{freqs_cis.shape=} {kv_compressed.shape=}" + apply_rotary_emb_triton( + kv_compressed[..., -self.rope_head_dim :], freqs_cis + ) + del beg_idx, end_idx + + if self.rotate: + kv_compressed = rotate_activation(kv_compressed) + + # get all the pos: ratio * n + (ratio - 1) > prefix_len - 1 + start = prefix_lens[i] + start = start + self.ratio - 1 - start % self.ratio + indices_in_seq = torch.arange( + start, + prefix_lens[i] + extend_lens[i], + self.ratio, + device=kv_and_scores.kv.device, + ) + assert indices_in_seq.size(0) == kv_compressed.size(0) + compressed_kv_output[indices_in_seq - prefix_lens[i] + pt] = kv_compressed + + pt += extend_lens[i] + + return compressed_kv_output + + def compress_decode_old( + self, + kv_and_scores: KVAndScore, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + assert self.ape_converted # Please keep this assertion + KVAndScore = KVAndScoreOld + + seq_lens = forward_batch.seq_lens + kv_and_score_states_pool = self._get_states(forward_batch) + req_pool_indices = forward_batch.req_pool_indices + + bs = kv_and_scores.kv.size(0) + write_pos = (seq_lens - 1) % self.ratio + self.overlap * self.ratio + kv_and_score_states_pool[req_pool_indices, write_pos] = kv_and_scores + + # NOTE: need to copy out before modifying overlap states + # kv_states: [bs, coff * ratio, coff * head_dim] + kv_and_score_to_compress = kv_and_score_states_pool[req_pool_indices] + + if self.overlap: + # Shift just compressed kv states left by ratio + should_shift = seq_lens % self.ratio == 0 + kv_and_score_states_pool[req_pool_indices, : self.ratio] = KVAndScore( + kv=torch.where( + should_shift[:, None, None], + kv_and_score_to_compress.kv[:, self.ratio :], + kv_and_score_to_compress.kv[:, : self.ratio], + ), + score=torch.where( + should_shift[:, None, None], + kv_and_score_to_compress.score[:, self.ratio :], + kv_and_score_to_compress.score[:, : self.ratio], + ), + ) + + # shape: [bs * coff, ratio, coff * head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + -1, self.ratio, self.coff * self.head_dim + ) + kv_and_score_to_compress.score = ( + kv_and_score_to_compress.score + self.ape.unsqueeze(0) + ) + + if self.overlap: + # shape: [bs, coff * ratio, coff * head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + bs, self.coff * self.ratio, self.coff * self.head_dim + ) + kv_and_score_to_compress.kv = self.overlap_transform_decode( + kv_and_score_to_compress.kv + ) + kv_and_score_to_compress.score = self.overlap_transform_decode( + kv_and_score_to_compress.score + ) + + self.print_tensor(kv_and_score_to_compress.kv, "kv_to_compress") + self.print_tensor(kv_and_score_to_compress.score, "score_to_compress") + + # kv_to_compress: [bs, ratio * coff, head_dim] + kv_and_score_to_compress = kv_and_score_to_compress.view( + bs, self.ratio * self.coff, self.head_dim + ) + + kv_compressed = ( + kv_and_score_to_compress.kv * kv_and_score_to_compress.score.softmax(dim=1) + ).sum(dim=1) + self.print_tensor(kv_compressed, "kv_before_norm") + kv_compressed = self.norm(kv_compressed) + self.print_tensor(kv_compressed, "kv_after_norm") + freqs_cis = self.freqs_cis[(seq_lens - 1) // self.ratio * self.ratio] + self.print_tensor(freqs_cis, "freqs_cis") + apply_rotary_emb_triton(kv_compressed[..., -self.rope_head_dim :], freqs_cis) + self.print_tensor(kv_compressed, "kv_after_rope") + if self.rotate: + kv_compressed = rotate_activation(kv_compressed) + + # `new_compressed_list` format is only used for testing + new_compressed_list = None + self.print_tensor(kv_compressed, "compressed_kv_output") + return kv_compressed + + +class C4Indexer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + rotary_emb: RotaryEmbedding, + freqs_cis: torch.Tensor, # TODO: remove it after using rotary embedding + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + ): + super().__init__() + self.layer_id = layer_id + self.dim = config.hidden_size + self.n_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.rope_head_dim = config.qk_rope_head_dim + self.index_topk = config.index_topk + self.q_lora_rank = config.q_lora_rank + self.softmax_scale = self.head_dim**-0.5 + # TODO: do we need to support TP indexer? + # currently, we duplicate indexer on all TP ranks + self.n_local_heads = self.n_heads + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + params_dtype=torch.bfloat16, + prefix=add_prefix("wq_b", prefix), + ) + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=None, + params_dtype=torch.bfloat16, + prefix=add_prefix("weights_proj", prefix), + ) + self.compressor = Compressor( + config, + self.layer_id, + True, # is_in_indexer + rotary_emb, + freqs_cis, + compress_ratio=4, + head_dim=self.head_dim, + rotate=True, + prefix=add_prefix("compressor", prefix), + ) + self.rotary_emb = rotary_emb + self.freqs_cis = freqs_cis + self.weight_scale: float = self.softmax_scale * self.n_heads**-0.5 + self.alt_streams = alt_streams + + def compute_q(self, q_lora: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + # [bs, n_heads, head_dim] + q, _ = self.wq_b(q_lora) + q = q.view(-1, self.n_local_heads, self.head_dim) + fused_rope( + q[..., -self.rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + ) + q = rotate_activation(q) + return q + + def compute_weights(self, x: torch.Tensor, skip_scale=False) -> torch.Tensor: + out, _ = self.weights_proj(x) + if not skip_scale: + out = out * self.weight_scale + return out + + def forward( + self, + x: torch.Tensor, + q_lora: torch.Tensor, + forward_batch: ForwardBatch, + x_for_compressor: Optional[torch.Tensor] = None, + enable_multi_stream: bool = False, + q_lora_ready: Optional[torch.cuda.Event] = None, + ) -> None: + if TYPE_CHECKING: + assert isinstance(forward_batch.attn_backend, DeepseekV4Backend) + return forward_batch.attn_backend.forward_c4_indexer( + x=x, + q_lora=q_lora, + forward_batch=forward_batch, + c4_indexer=self, + x_for_compressor=x_for_compressor if x_for_compressor is not None else x, + alt_streams=self.alt_streams, + enable_multi_stream=enable_multi_stream, + q_lora_ready=q_lora_ready, + ) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class MQALayer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + ) -> None: + super().__init__() + self.tp_rank = attn_tp_rank = get_attention_tp_rank() + self.tp_size = attn_tp_size = get_attention_tp_size() + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_size = get_attention_tp_size() + self.tp_rank = attn_tp_rank = 0 + self.tp_size = attn_tp_size = 1 + self.layer_id = layer_id + self.dim = config.hidden_size + self.qk_rope_head_dim = config.qk_rope_head_dim + if envs.SGLANG_DSV4_MODE.get() == "2604": + self.qk_nope_head_dim = config.head_dim - config.qk_rope_head_dim + else: + self.qk_nope_head_dim = config.qk_nope_head_dim + self.head_dim = self.qk_rope_head_dim + self.qk_nope_head_dim + self.n_heads = config.num_attention_heads + self.n_local_heads = self.n_heads // attn_tp_size + self.n_groups = config.o_groups + self.n_local_groups = self.n_groups // attn_tp_size + self.rope_head_dim = config.qk_rope_head_dim + self.softmax_scale = self.head_dim**-0.5 + self.hidden_size = config.hidden_size + self.q_lora_rank = config.q_lora_rank + self.o_lora_rank = config.o_lora_rank + self.eps = config.rms_norm_eps + compress_ratio = config.compress_ratios[layer_id] + assert compress_ratio in [0, 4, 128] + self.compress_ratio: Literal[0, 4, 128] = compress_ratio # type: ignore + + if envs.SGLANG_DSV4_MODE.get() == "2604": + assert self.head_dim == config.head_dim + else: + assert self.head_dim == config.v_head_dim + assert config.num_key_value_heads == 1 + + # need a indexer for compress ratio = 4 + rope_scaling = config.rope_scaling + if rope_scaling: + rope_scaling["rope_type"] = "deepseek_yarn" + + # Please keep this assertion and not remove it + # NOTE: + # 1. 2601 + # The `260119-updated` code changed compress_rope_theta + # 2. 2604 + # `official_code_0409/code/config.json` is 160000 + # while `official_code_0409/config.json` is 40000 + # maybe the latter is buggy? b/c dpsk's official generate.py uses `code/config.json` + expected_compress_rope_theta = os.environ.get( + "SGLANG_HACK_ASSERT_COMPRESS_ROPE_THETA" + ) + if expected_compress_rope_theta is None: + expected_compress_rope_theta = "160000" + expected_compress_rope_theta = int(expected_compress_rope_theta) + assert ( + config.compress_rope_theta == expected_compress_rope_theta + ), f"{config.compress_rope_theta=} {expected_compress_rope_theta=}" + rope_base = ( + config.compress_rope_theta if self.compress_ratio else config.rope_theta + ) + + self.rotary_emb = get_rope_wrapper( + head_size=self.rope_head_dim, + rotary_dim=self.rope_head_dim, + max_position=config.max_position_embeddings, + base=rope_base, + rope_scaling=rope_scaling, + is_neox_style=False, + device=get_global_server_args().device, + ) + + # naive impl: copy from reference code + from sglang.srt.layers.deepseek_v4_rope import precompute_freqs_cis + + if envs.SGLANG_DSV4_MODE.get() == "2604": + assert rope_scaling["factor"] == 16 + elif envs.SGLANG_DSV4_MODE.get() == "2601": + assert rope_scaling["factor"] == 4 + else: + raise NotImplementedError + + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert self.compress_ratio in {0, 4, 128} + if self.compress_ratio: + original_seq_len = rope_scaling["original_max_position_embeddings"] + assert original_seq_len == 65536 + else: + original_seq_len = 0 + else: + original_seq_len = rope_scaling["original_max_position_embeddings"] + + rope_scaling = config.rope_scaling + freqs_cis = precompute_freqs_cis( + dim=self.qk_rope_head_dim, + seqlen=config.max_position_embeddings, + original_seq_len=original_seq_len, + base=rope_base, + factor=rope_scaling["factor"], + beta_fast=rope_scaling["beta_fast"], + beta_slow=rope_scaling["beta_slow"], + ) + self.register_buffer("freqs_cis", freqs_cis, persistent=False) + self.freqs_cis: torch.Tensor + + if envs.SGLANG_OPT_USE_MULTI_STREAM_OVERLAP.get() and alt_streams is not None: + self.alt_streams = alt_streams[:3] # use first 3 streams for mqa layer + self.alt_streams_indexer = alt_streams[ + -2: + ] # use last 2 streams for indexer + else: + self.alt_streams = None + self.alt_streams_indexer = None + + from sglang.srt.utils import is_blackwell_supported + + self._multi_stream_bs_limit = 128 if is_blackwell_supported() else 64 + + self.compressor = None + self.indexer = None + if self.compress_ratio: + self.compressor = Compressor( + config, + layer_id=self.layer_id, + is_in_indexer=False, + rotary_emb=self.rotary_emb, + freqs_cis=freqs_cis, + compress_ratio=self.compress_ratio, + head_dim=self.head_dim, + rotate=False, + prefix=add_prefix("compressor", prefix), + ) + if self.compress_ratio == 4: + self.indexer = C4Indexer( + config, + rotary_emb=self.rotary_emb, + freqs_cis=freqs_cis, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("indexer", prefix), + alt_streams=self.alt_streams_indexer, + ) + + # Note: attention sink should be replicated + self.attn_sink = nn.Parameter(torch.empty(self.n_heads, dtype=torch.float32)) + self.wq_a = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wq_a", prefix), + ) + self.q_norm = RMSNorm(self.q_lora_rank, eps=self.eps) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wq_b", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + self.wkv = ReplicatedLinear( + self.hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("wkv", prefix), + ) + self.kv_norm = RMSNorm(self.head_dim, eps=self.eps) + self.wo_a = ColumnParallelLinear( + self.n_heads * self.head_dim // self.n_groups, + self.n_groups * self.o_lora_rank, + bias=False, + quant_config=quant_config if _FP8_WO_A_GEMM else None, + prefix=add_prefix("wo_a", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + **({} if _FP8_WO_A_GEMM else {"params_dtype": torch.bfloat16}), + ) + if _FP8_WO_A_GEMM: + # fp8_einsum handles scale transform internally — skip UE8M0 conversion + assert hasattr( + self.wo_a, "weight_scale_inv" + ), "FP8 quant_config must create weight_scale_inv" + self.wo_a.weight_scale_inv.format_ue8m0 = True + self.wo_b = RowParallelLinear( + self.n_groups * self.o_lora_rank, + self.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=attn_tp_size > 1, + prefix=add_prefix("wo_b", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + + self.attn_mqa = RadixAttention( + self.n_local_heads, + self.head_dim, + self.softmax_scale, + num_kv_heads=1, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("attn_mqa", prefix), + ) + + self.overlap_store_cache = envs.SGLANG_OPT_USE_OVERLAP_STORE_CACHE.get() + + def _compute_q_a( + self, + x: torch.Tensor, + ) -> torch.Tensor: + # [bs, q_lora_rank] + q, _ = self.wq_a(x) + # [bs, q_lora_rank] + q = self.q_norm(q) + q_lora = q # only used for indexer + return q_lora + + def _compute_q_b( + self, + q: torch.Tensor, + positions: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # [bs, n_local_heads, head_dim] + q, _ = self.wq_b(q) + q = q.view(-1, self.n_local_heads, self.head_dim) + q = rms_normalize_triton(q, self.eps) + + if positions is not None: + fused_rope( + q[..., -self.qk_rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + ) + else: + apply_rotary_emb_triton(q[..., -self.qk_rope_head_dim :], self.freqs_cis) + return q + + def _compute_kv( + self, + x: torch.Tensor, + positions: Optional[torch.Tensor] = None, + freqs_cis: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # [bs, head_dim] + kv, _ = self.wkv(x) + # [bs, head_dim] + kv = self.kv_norm(kv) + if positions is not None: + fused_rope( + kv[..., -self.qk_rope_head_dim :].unsqueeze(1), + None, + self.freqs_cis, + positions=positions, + ) + else: + apply_rotary_emb_triton(kv[..., -self.qk_rope_head_dim :], self.freqs_cis) + return kv + + def _forward_prepare_multi_stream( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + attn_backend: DeepseekV4Backend, + freqs_cis: Optional[torch.Tensor] = None, + q_out: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.alt_streams is not None + assert len(self.alt_streams) >= 3 + + current_stream = torch.cuda.current_stream() + stream_kv = self.alt_streams[0] + stream_compressor = self.alt_streams[1] + stream_indexer = self.alt_streams[2] + + stream_kv.wait_stream(current_stream) + stream_compressor.wait_stream(current_stream) + stream_indexer.wait_stream(current_stream) + + # main stream: compute q + q_lora = self._compute_q_a(x) + q_lora_ready = current_stream.record_event() + q = self._compute_q_b(q_lora, positions, freqs_cis) + if q_out is not None: + q_out.copy_(q) + + # alt stream 2: compute indexer + if self.indexer is not None: + with torch.cuda.stream(stream_indexer): + self.indexer( + x=x, + q_lora=q_lora, + forward_batch=forward_batch, + enable_multi_stream=True, + q_lora_ready=q_lora_ready, + ) + + # alt stream 0: compute kv + with torch.cuda.stream(stream_kv): + kv = self._compute_kv(x, positions, freqs_cis) + if self.overlap_store_cache: + attn_backend.store_cache( + layer_id=self.layer_id, + swa_k=kv, + forward_batch=forward_batch, + ) + + # alt stream 1: compute compressor + if self.compressor is not None: + with torch.cuda.stream(stream_compressor): + attn_backend.forward_core_compressor( + x, forward_batch, self.layer_id, self.compressor + ) + + current_stream.wait_stream(stream_kv) + current_stream.wait_stream(stream_compressor) + current_stream.wait_stream(stream_indexer) + + return q, kv + + def _forward_prepare( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + attn_backend: DeepseekV4Backend, + freqs_cis: Optional[torch.Tensor] = None, + q_out: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # [bs, q_lora_rank] + q, _ = self.wq_a(x) + # [bs, q_lora_rank] + q = self.q_norm(q) + q_lora = q # only used for indexer + # [bs, n_local_heads, head_dim] + q, _ = self.wq_b(q) + q = q.view(-1, self.n_local_heads, self.head_dim) + # [bs, n_local_heads, head_dim] + q = rms_normalize_triton(q, self.eps) + + # [bs, head_dim] + kv, _ = self.wkv(x) + # [bs, head_dim] + kv = self.kv_norm(kv) + + fused_rope( + q[..., -self.qk_rope_head_dim :], + kv[..., -self.qk_rope_head_dim :].unsqueeze(1), + self.freqs_cis, + positions=positions, + ) + + _use_cp = self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch) + if _use_cp: + kv = cp_all_gather_rerange_output( + kv.contiguous(), + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + x_for_compressor = ( + cp_all_gather_rerange_output( + x.contiguous(), + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + if self.compressor is not None + else x + ) + else: + x_for_compressor = x + + if self.overlap_store_cache: + attn_backend.store_cache( + layer_id=self.layer_id, + swa_k=kv, + forward_batch=forward_batch, + ) + + if self.indexer is not None: + self.indexer( + x=x, + q_lora=q_lora, + forward_batch=forward_batch, + x_for_compressor=x_for_compressor if _use_cp else None, + ) + if self.compressor is not None: + attn_backend.forward_core_compressor( + x_for_compressor, forward_batch, self.layer_id, self.compressor + ) + + if q_out is not None: + q_out.copy_(q) + return q, kv + + def forward( + self, + x: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + debug_return_kv: bool = False, + ) -> torch.Tensor: + if not get_attn_tp_context().input_scattered and x.shape[0] == 0: + assert ( + not self.wo_b.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return x + + attn_backend = forward_batch.attn_backend + if TYPE_CHECKING: + assert isinstance(attn_backend, DeepseekV4Backend) + + freqs_cis = None + + enable_multi_stream = ( + envs.SGLANG_OPT_USE_MULTI_STREAM_OVERLAP.get() + and self.alt_streams is not None + and get_is_capture_mode() + and x.shape[0] <= self._multi_stream_bs_limit + and not (self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch)) + ) + + tp_slice, q_padded, q_out = slice(None), None, None + if self.tp_size > 1: + # pad the q to [batch_size, n_heads] + q_padded = x.new_empty(x.shape[0], self.n_heads, self.head_dim) + rank = self.tp_rank + tp_slice = slice(rank * self.n_local_heads, (rank + 1) * self.n_local_heads) + q_out = q_padded[:, tp_slice, :] + + if enable_multi_stream: + q, kv = self._forward_prepare_multi_stream( + x, positions, forward_batch, attn_backend, freqs_cis, q_out + ) + else: + q, kv = self._forward_prepare( + x, positions, forward_batch, attn_backend, freqs_cis, q_out + ) + + # for TP attention, use the padded q, since q_out is set to the correct slice + o = attn_backend.forward( + q=q_padded if q_padded is not None else q, + k=kv, + v=kv, + layer=self.attn_mqa, + forward_batch=forward_batch, + compress_ratio=self.compress_ratio, + attn_sink=self.attn_sink, + save_kv_cache=not self.overlap_store_cache, + ) + # NOTE: no-op for pure DP-attention + o = o[:, tp_slice, :] + fused_rope( + o[..., -self.qk_rope_head_dim :], + None, + self.freqs_cis, + positions=positions, + inverse=True, + ) + + o = o.view(o.shape[0], self.n_local_groups, -1) + + if _FP8_WO_A_GEMM: + import deep_gemm + + T, G, D = o.shape + R = self.o_lora_rank + o_fp8, o_s = sglang_per_token_group_quant_fp8( + o.reshape(T * G, D).contiguous(), + group_size=128, + ) + output = torch.empty(T, G, R, device=o.device, dtype=torch.bfloat16) + deep_gemm.fp8_einsum( + "bhr,hdr->bhd", + (o_fp8.view(T, G, D), o_s.view(T, G, -1)), + (self.wo_a.weight.view(G, R, D), self.wo_a.weight_scale_inv.data), + output, + recipe=(1, 1, 128), + ) + o = output + else: + wo_a = self.wo_a.weight.view(self.n_local_groups, self.o_lora_rank, -1) + o = torch.einsum("tgd,grd->tgr", o, wo_a) + + o, _ = self.wo_b(o.flatten(1)) + + return o + + +class DeepseekV4DecoderLayer(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + moe_quant_config_override: Optional[QuantizationConfig] = None, + is_nextn: bool = False, + prefix: str = "", + alt_streams: Optional[List[torch.cuda.Stream]] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_id = layer_id + self.is_nextn = is_nextn + self.self_attn = MQALayer( + config=config, + layer_id=layer_id, + quant_config=quant_config, + prefix=add_prefix("self_attn", prefix), + alt_streams=alt_streams, + ) + self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn) + is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False) + is_next_layer_sparse = self._is_layer_sparse(layer_id + 1, is_nextn=False) + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=1 if is_nextn else config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + is_next_layer_sparse=is_next_layer_sparse, + ) + # TODO: check whether the implementation matches + # TODO: make necessary changes if possible + self.mlp = deepseek_v2.DeepseekV2MoE( + config=config, + quant_config=moe_quant_config_override or quant_config, + prefix=add_prefix("mlp", prefix), + layer_id=self.layer_id, + alt_stream=alt_streams[0] if alt_streams is not None else None, + is_nextn=is_nextn, + is_deepseek_v4=True, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + # self.layer_communicator = LayerCommunicator( + # layer_scatter_modes=self.layer_scatter_modes, + # input_layernorm=self.input_layernorm, + # post_attention_layernorm=self.post_attention_layernorm, + # allow_reduce_scatter=True, + # is_last_layer=( + # is_nextn or (self.layer_id == self.config.num_hidden_layers - 1) + # ), + # ) + + self.hc_mult = hc_mult = config.hc_mult + self.hc_sinkhorn_iters = config.hc_sinkhorn_iters + self.hc_eps = config.hc_eps + mix_hc = (2 + hc_mult) * hc_mult + hc_dim = hc_mult * config.hidden_size + self.hc_attn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim, dtype=torch.float32)) + self.hc_ffn_fn = nn.Parameter(torch.empty(mix_hc, hc_dim, dtype=torch.float32)) + self.hc_attn_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float32)) + self.hc_ffn_base = nn.Parameter(torch.empty(mix_hc, dtype=torch.float32)) + self.hc_attn_scale = nn.Parameter(torch.empty(3, dtype=torch.float32)) + self.hc_ffn_scale = nn.Parameter(torch.empty(3, dtype=torch.float32)) + self.rms_norm_eps = config.rms_norm_eps + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + + def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool: + if envs.SGLANG_DSV4_MODE.get() == "2604": + first_k_dense_replace = 0 + moe_layer_freq = 1 + else: + first_k_dense_replace = self.config.first_k_dense_replace + moe_layer_freq = self.config.moe_layer_freq + return is_nextn or ( + self.config.n_routed_experts is not None + and layer_id >= first_k_dense_replace + and layer_id % moe_layer_freq == 0 + ) + + def hc_pre( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + @maybe_torch_compile + def hc_pre_torch_impl(x, hc_fn): + x_flat = x.flatten(1).float() + rsqrt = torch.rsqrt( + x_flat.square().mean(-1, keepdim=True) + self.rms_norm_eps + ) + mixes = (F.linear(x_flat, hc_fn) * rsqrt).unsqueeze(1) + return x_flat, mixes + + # x: [n,hc,d] -> y: [n,d], where n=b*s + shape, dtype = x.size(), x.dtype + + # Handle empty batch + if x.shape[0] == 0: + y = torch.empty((0, shape[-1]), dtype=dtype, device=x.device) + post = torch.empty((0, self.hc_mult), dtype=dtype, device=x.device) + comb = torch.empty( + (0, self.hc_mult, self.hc_mult), dtype=dtype, device=x.device + ) + return y, post, comb + + if envs.SGLANG_OPT_USE_TILELANG_MHC_PRE.get(): + from sglang.srt.layers.mhc import mhc_pre + + post, comb, y = mhc_pre( + residual=x, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=self.rms_norm_eps, + hc_pre_eps=self.hc_eps, + hc_sinkhorn_eps=self.hc_eps, + hc_post_mult_value=2.0, + sinkhorn_repeat=self.hc_sinkhorn_iters, + ) + # returned post should be [n, hc_mult] + return y, post.squeeze(-1), comb + + if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get(): + # DeepGEMM implementation + import deep_gemm + + x_flat = x.flatten(1).bfloat16() + + m, k = x_flat.shape + mix_hc = hc_fn.size(0) + d_out = torch.empty((m, mix_hc), dtype=torch.float, device=x.device) + s_out = torch.empty((m,), dtype=torch.float, device=x.device) + # TODO: maybe remove the contiguity requirement? + deep_gemm.tf32_hc_prenorm_gemm( + x_flat, hc_fn.float().contiguous(), d_out, s_out, num_splits=None + ) + rsqrt = torch.rsqrt(s_out / k + self.rms_norm_eps) + mixes = (d_out * rsqrt.unsqueeze(1)).unsqueeze(1) + else: + # Naive Torch implementation + x_flat, mixes = hc_pre_torch_impl(x, hc_fn) + + from sglang.srt.layers.mhc import hc_split_sinkhorn + + pre, post, comb = hc_split_sinkhorn( + mixes, + hc_scale, + hc_base, + self.hc_mult, + self.hc_sinkhorn_iters, + self.hc_eps, + ) + y = (pre.squeeze(1).unsqueeze(-1) * x_flat.view(shape)).sum(dim=1) + return y.to(dtype), post.squeeze(1), comb.squeeze(1) + + def hc_post( + self, + x: torch.Tensor, + residual: torch.Tensor, + post: torch.Tensor, + comb: torch.Tensor, + ): + + # x: [n,d], residual: [n,hc,d] -> y: [n,hc,d] + # post: [n,hc], comb: [n,hc,hc] + + # Handle empty batch + if x.shape[0] == 0: + return torch.empty( + (0, self.hc_mult, x.shape[-1]), dtype=x.dtype, device=x.device + ) + + if envs.SGLANG_OPT_USE_TILELANG_MHC_POST.get(): + from sglang.srt.layers.mhc import mhc_post + + result = mhc_post(x, residual, post, comb) + return result + + assert residual.shape == (x.shape[0], self.hc_mult, x.shape[-1]) + assert post.shape == (x.shape[0], self.hc_mult) + assert comb.shape == (x.shape[0], self.hc_mult, self.hc_mult) + + @maybe_torch_compile + def hc_post_torch_impl(x, residual, post, comb): + return ( + post.unsqueeze(-1) * x.unsqueeze(1) + + (comb.unsqueeze(-1) * residual.unsqueeze(2)).sum(dim=1) + ).type_as(x) + + result = hc_post_torch_impl(x, residual, post, comb) + return result + + def forward( + self, + positions: torch.tensor, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + forward_batch: ForwardBatch, + input_ids_global: torch.Tensor, + ) -> torch.Tensor: + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert deepseek_v4_moe_code_path_checker.observed == 0 + + + residual = hidden_states + hidden_states, post, comb = self.hc_pre( + hidden_states, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base + ) # -> [n, d] + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + x=hidden_states, + positions=positions, + forward_batch=forward_batch, + ) + + hidden_states = self.hc_post(hidden_states, residual, post, comb) + residual = hidden_states # [n, hc, d] + hidden_states, post, comb = self.hc_pre( + hidden_states, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base + ) # -> [n, d] + hidden_states = self.post_attention_layernorm(hidden_states) + + # Communication logic (equivalent to LayerCommunicator): + # + # ======================== i. TP MoE ======================== + # DP attn + TP moe (moe_a2a_backend=none): + # * mlp_mode = FULL (each-rank-has-whole-world-tokens) + # * prepare_mlp -> _gather_hidden_states_and_residual -> dp_gather_partial + # * postprocess_layer -> _scatter_hidden_states -> dp_scatter + # Need Gather before MoE and Scatter after MoE. + # + # ======================== ii. DeepEP MoE ======================== + # DP attn + DeepEP moe (moe_a2a_backend=deepep/flashinfer/etc): + # * mlp_mode = SCATTERED (each-rank-only-has-this-rank-tokens) + # * prepare_mlp -> _simple (just layernorm, no gather) + # * postprocess_layer -> _trivial (no scatter) + # Because attn_tp_size==1 when tp==dp==ep, SCATTERED and TP_ATTN_FULL + # have the same group_size. Token dispatch/combine is handled by + # DeepEP inside MoE forward. No Gather/Scatter around MoE. + _use_cp = self.nsa_enable_prefill_cp and nsa_use_prefill_cp(forward_batch) + _use_tp_moe_gather = ( + not _use_cp + and get_attention_dp_size() > 1 + and get_moe_a2a_backend().is_none() + ) + # ----------------------------------- CP: fix input_ids to LOCAL ---------------- + if _use_cp: + # CP requires DeepEP — TP MoE's all-reduce assumes identical tokens + # across ranks, which CP violates. (Analogous to NSACPLayerCommunicator's + # assert mlp_mode==SCATTERED when dp_size>1.) + assert get_moe_a2a_backend().is_deepep(), ( + "CP requires DeepEP (moe_a2a_backend == deepep). " + "Only DeepEP is tested with CP's per-rank token split." + ) + # DeepEP handles cross-rank MoE dispatch/combine internally. + # No gather/scatter needed — tokens stay LOCAL (SCATTERED). + # This matches DSV3.2's mlp_mode=SCATTERED behavior with DeepEP + CP. + # + # Hash gating (n_hash_layers=3) needs input_ids[i] to correspond to + # hidden_states[i]. hidden_states is LOCAL [N/cp_size] (round-robin). + # input_ids is ORIGINAL [N] on every rank (never CP-split). + # Slice to LOCAL to match hidden_states. + cp_rank = get_attention_tp_rank() + cp_size = get_attention_tp_size() + input_ids = input_ids[cp_rank::cp_size].contiguous() + # TODO: improve the name - it is indeed local in CP, but is only used by e.g. Hash gating + input_ids_global = input_ids + # ----------------------------------- DP: gather for TP MoE -------------------- + elif _use_tp_moe_gather: + hidden_states, local_hidden_states = get_global_dp_buffer(), hidden_states + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + # ----------------------------------- MoE ------------------------------------ + hidden_states = self.mlp( + hidden_states, + forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + # ----------------------------------- Scatter (DP only, not CP) ---------------- + if _use_tp_moe_gather: + hidden_states, global_hidden_states = get_local_dp_buffer(), hidden_states + dp_scatter(hidden_states, global_hidden_states, forward_batch) + + hidden_states = self.hc_post( + hidden_states, residual, post, comb + ) # [n, d] -> [n, hc, d] + + # if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B" and not _is_hip: + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + assert deepseek_v4_moe_code_path_checker.observed == 1 + deepseek_v4_moe_code_path_checker.observed = 0 + + return hidden_states + + +class DeepseekV4Model(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: DeepSeekV4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + self.pp_group = get_pp_group() + self.first_k_dense_replace = config.first_k_dense_replace + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + ) + self.rms_norm_eps = config.rms_norm_eps + self.alt_streams = ( + [torch.cuda.Stream() for _ in range(5)] if (_is_cuda or _is_hip) else None + ) + self.layers, self.start_layer, self.end_layer = make_layers( + config.num_hidden_layers, + lambda idx, prefix: DeepseekV4DecoderLayer( + config=config, + layer_id=idx, + quant_config=quant_config, + prefix=prefix, + alt_streams=self.alt_streams, + ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, + prefix=add_prefix("layers", prefix), + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gemm_output_zero_allocator_size = 0 + self.layers_to_capture = [] + if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): + self.enable_a2a_moe = True + else: + self.enable_a2a_moe = False + + self.hc_eps = config.hc_eps + self.hc_mult = hc_mult = config.hc_mult + self.norm_eps = config.rms_norm_eps + hc_dim = hc_mult * config.hidden_size + self.hc_head_fn = nn.Parameter( + torch.empty(hc_mult, hc_dim, dtype=torch.float32) + ) + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) + + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_size = get_attention_tp_size() + + def hc_head( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + # TODO + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor], + pp_proxy_tensors: Optional[PPProxyTensors], + ) -> torch.Tensor: + total_num_layers = self.end_layer - self.start_layer + device = input_embeds.device if input_embeds is not None else input_ids.device + zero_allocator = BumpAllocator( + buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), + dtype=torch.float32, + device=device, + ) + has_gemm_output_zero_allocator = hasattr( + self, "gemm_output_zero_allocator_size" + ) + gemm_output_zero_allocator = ( + BumpAllocator( + buffer_size=self.gemm_output_zero_allocator_size, + dtype=torch.float32, + device=device, + ) + if has_gemm_output_zero_allocator + and self.gemm_output_zero_allocator_size > 0 + else None + ) + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + + if get_attention_dp_size() > 1 and get_moe_a2a_backend().is_none(): + input_ids_global = torch.empty( + (_DpGatheredBufferWrapper._global_dp_buffer_len, 1), + dtype=input_ids.dtype, + device=input_ids.device, + ) + dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch) + input_ids_global = input_ids_global.squeeze(-1) + else: + input_ids_global = input_ids + + if nsa_use_prefill_cp(forward_batch): + hidden_states = cp_split_and_rebuild_data(forward_batch, hidden_states) + positions = cp_split_and_rebuild_position(forward_batch, positions) + + for i in range(self.start_layer, self.end_layer): + # TODO: ctx? + layer = self.layers[i] + hidden_states = layer( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + # zero_allocator, + # gemm_output_zero_allocator, + ) + + if nsa_use_prefill_cp(forward_batch): + hidden_states = cp_all_gather_rerange_output( + hidden_states, + self.cp_size, + forward_batch, + torch.cuda.current_stream(), + ) + + pre_hc_head = ( + hidden_states.flatten(1) + if envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + else None + ) + + hidden_states = self.hc_head( + hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base + ) + hidden_states = self.norm(hidden_states) + + if pre_hc_head is not None: + return hidden_states, pre_hc_head + return hidden_states + + +class DeepseekV4ForCausalLM(nn.Module): + def __init__( + self, + config: DeepSeekV4Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + self.determine_num_fused_shared_experts() + self.model = DeepseekV4Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.pp_group = get_pp_group() + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + self.logits_processor = LogitsProcessor(config) + self.capture_aux_hidden_states = False + # TODO: is this true that compress is kind of NSA + get_attn_tp_context().init_context(config.q_lora_rank, is_nsa=True) + + self._routed_experts_weights_of_layer = LazyValue( + lambda: { + layer_id: layer.mlp.get_moe_weights() + for layer_id, layer in enumerate(self.model.layers) + if isinstance(layer.mlp, deepseek_v2.DeepseekV2MoE) + } + ) + + self.nsa_enable_prefill_cp = is_nsa_enable_prefill_cp() + if self.nsa_enable_prefill_cp: + self.cp_rank = get_attention_tp_rank() + self.cp_size = get_attention_tp_size() + + @property + def routed_experts_weights_of_layer(self): + return self._routed_experts_weights_of_layer.value + + def determine_num_fused_shared_experts(self): + self.num_fused_shared_experts = 0 + if get_global_server_args().disable_shared_experts_fusion: + return + + # Only Deepseek V3/R1 can use shared experts fusion optimization now. + disable_reason = None + if self.config.n_routed_experts != 256 or self.config.n_shared_experts != 1: + disable_reason = "Config not support fused shared expert(s)." + elif (not _is_cuda or torch.cuda.get_device_capability("cuda") < (8, 0)) and ( + not _is_hip or torch.cuda.get_device_capability("cuda") < (9, 4) + ): + disable_reason = ( + "Only Deepseek V3/R1 on NV-platform with capability >= 80 " + "or AMD-platform with capability >= gfx942(MI30x) can use shared experts fusion optimization." + ) + elif get_moe_expert_parallel_world_size() > 1 and ( + not _is_hip or torch.cuda.get_device_capability("cuda") < (9, 4) + ): + disable_reason = "Only Deepseek V3/R1 on AMD-platform with capability >= gfx942(MI30x) can use shared experts fusion optimization under expert parallelism." + elif disable_reason is None and get_moe_a2a_backend().is_deepep(): + disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under deepep expert parallelism." + elif self.quant_config and self.quant_config.get_name() == "w4afp8": + disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts." + elif ( + envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get() + ): + disable_reason = "2604 routed experts use FP4 while shared experts remain FP8; fusion would incorrectly apply FP4 to shared experts." + + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B": + disable_reason = "2604B checkpoint requires different clamping for shared and routed experts" + + if disable_reason is not None: + get_global_server_args().disable_shared_experts_fusion = True + self.num_fused_shared_experts = 0 + log_info_on_rank0( + logger, + f"{disable_reason} Shared experts fusion optimization is disabled.", + ) + return + + self.num_fused_shared_experts = self.config.n_shared_experts + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + + if self.nsa_enable_prefill_cp: + if can_cp_split(len(input_ids), self.cp_size, True, forward_batch): + forward_batch.nsa_cp_metadata = prepare_input_dp_with_cp_dsa( + len(input_ids), + self.cp_rank, + self.cp_size, + forward_batch.seq_lens_cpu.tolist(), + ) + + with get_attn_tp_context().maybe_input_scattered(forward_batch): + hidden_states = self.model.forward( + input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors + ) + aux_hidden_states = None + pre_hc_head = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if ( + envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + ): + hidden_states, pre_hc_head = hidden_states + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + # TODO: indeed ours is "hidden_states_before_hc_head" instead of "before norm" + # abuse the existing field temporarily to minimize code diff + # should rename and generalize later, e.g. "hidden_states_for_spec" + hidden_states_before_norm=pre_hc_head, + ) + + def _setup_fp8_wo_a_scales(self, is_nextn: bool) -> None: + from deep_gemm import transform_sf_into_required_layout + + layers = self.model.layers + for layer in layers: + attn = layer.self_attn + G = attn.n_local_groups + R = attn.o_lora_rank + D = attn.wo_a.weight.shape[1] + + # Pre-transform weight scale to DeepGEMM required layout (TMA-aligned / UE8M0 packed) + # fp8_einsum('bhr,hdr->bhd') maps B=[h,d,r]=[G,R,D], so N=R, K=D for the B-side scale + raw_scale = attn.wo_a.weight_scale_inv.data.view(G, R // 128, D // 128) + attn.wo_a.weight_scale_inv.data = transform_sf_into_required_layout( + raw_scale, + mn=R, + k=D, + recipe=(1, 128, 128), + num_groups=G, + is_sfa=False, + ) + + def post_load_weights(self, is_nextn=False, weight_names=None): + if _FP8_WO_A_GEMM: + self._setup_fp8_wo_a_scales(is_nextn) + + # ================ apply_ape_hotfix, should not be needed for final ckpt ================ + if is_nextn: + return + for layer in self.model.layers: + self_attn = layer.self_attn + if self_attn.compress_ratio != 0 and not self_attn.compressor.ape_converted: + self_attn.compressor.apply_ape_hotfix() + if ( + self_attn.compress_ratio == 4 + and not self_attn.indexer.compressor.ape_converted + ): + self_attn.indexer.compressor.apply_ape_hotfix() + + # This is used externally, please try to keep the API mostly unchanged + @staticmethod + def remap_weight_name_to_dpsk_hf_format( + name: str, is_nextn: bool = False, num_hidden_layers: Optional[int] = None + ) -> str: + if name == "embed.weight": + return "model.embed_tokens.weight" + if name == "head.weight": + return "lm_head.weight" + if name == "norm.weight": + return "model.norm.weight" + if name.startswith("hc_head_"): + return "model." + name + + if is_nextn and name.startswith("mtp."): + parts = name.split(".", 2) + if len(parts) >= 3: + rest = parts[2] + nextn_spec_prefixes = [ + "e_proj", + "h_proj", + "emb", + "enorm", + "hnorm", + "norm", + "head", + "hc_head", + ] + is_nextn_spec = any(rest.startswith(p) for p in nextn_spec_prefixes) + if is_nextn_spec: + if rest.startswith("emb.tok_emb"): + rest = rest.replace("emb.tok_emb", "embed_tokens") + elif rest == "norm.weight": + rest = "shared_head.norm.weight" + elif rest.startswith("head."): + rest = "shared_head.head.weight" + elif rest == "e_proj.scale": + rest = "e_proj.weight_scale_inv" + elif rest == "h_proj.scale": + rest = "h_proj.weight_scale_inv" + name = f"model.layers.{num_hidden_layers}." + rest + + if name.startswith("layers."): + name = "model." + name + name = name.replace(".attn.", ".self_attn.") + name = name.replace(".ffn.", ".mlp.") + name = name.replace(".attn_norm.", ".input_layernorm.") + name = name.replace(".ffn_norm.", ".post_attention_layernorm.") + + if not ATTN_BIT_WISE_EQUAL_MODE: + if "self_attn" in name and ( + "compressor" not in name or not COMPRESSOR_BIT_WISE_EQUAL_MODE + ): + name = name.replace(".scale", ".weight_scale_inv") + + if not MOE_BIT_WISE_EQUAL_MODE: + name = name.replace(".gate.tid2eid", ".topk.tid2eid") + name = name.replace(".gate.bias", ".gate.e_score_correction_bias") + name = name.replace(".w1.", ".gate_proj.") + name = name.replace(".w2.", ".down_proj.") + name = name.replace(".w3.", ".up_proj.") + if "mlp" in name: + name = name.replace(".scale", ".weight_scale_inv") + + return name + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + assert envs.SGLANG_DSV4_MODE.get() in ["2601", "2604"] + if envs.SGLANG_DSV4_MODE.get() == "2604": + assert envs.SGLANG_DSV4_2604_SUBMODE.get() in ["2604A", "2604B"] + else: + assert envs.SGLANG_DSV4_2604_SUBMODE.get() == "" + + if MOE_BIT_WISE_EQUAL_MODE: + assert ( + self.num_fused_shared_experts == 0 + ), "use --disable-shared-experts-fusion for MoE bit-wise equal mode" + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + # compatible with old design + nextn_layer_id = ( + 0 + if self.config.num_hidden_layers == 1 + else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") + + # Ignore this, b/c it is for nvfp4 ckpt + # weights = self._maybe_quant_weights_to_fp8_ue8m0( + # weights, NVFP4_CKPT_FP8_ATTN_QUANT_MODULES, is_nextn + # ) + + if ( + envs.SGLANG_DSV4_MODE.get() == "2604" + and not envs.SGLANG_OPT_FP8_WO_A_GEMM.get() + ): + if envs.SGLANG_DSV4_FP4_EXPERTS.get(): + weights = _dequant_fp8_wo_a(weights) + else: + # Converted FP8 checkpoint: wo_a is already bf16; drop stale wo_a.scale if present + weights = ((n, t) for n, t in weights if not n.endswith(".wo_a.scale")) + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, + ) + # Params for special naming rules in mixed-precision models, for example: + # model.layers.xx.mlp.experts.xx.w1.input_scale. For details, + # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main. + + if self.quant_config and self.quant_config.get_name() == "w4afp8": + expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts + ) + + # fuse compressor wkv and wgate weights into wkv_gate + cache_compressor_weight = {} + COMPRESSOR_PART = ".compressor.w" # match wkv and wgate, skip ape + + # use default weight loader if module has no custom weight_loader + def auto_weight_loader(module): + return getattr(module, "weight_loader", default_weight_loader) + + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names_out_of_layer = [ + "shared_head.norm", + "shared_head.head", + "embed_tokens", + ".e_proj", # Note that we need a . here to avoid confusion with gate_proj + "h_proj", + "enorm", + "hnorm", + "hc_head_base", + "hc_head_fn", + "hc_head_scale", + ] + + if self.num_fused_shared_experts > 0: + assert self.num_fused_shared_experts == 1 + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + weight_names = [] + for name, loaded_weight in weights: + try: + use_async_loading = should_async_load(loaded_weight) + + # remap reference's temp ckpt weight -> deepseek hf format + name = self.remap_weight_name_to_dpsk_hf_format( + name, + is_nextn=is_nextn, + num_hidden_layers=self.config.num_hidden_layers, + ) + + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue + if ( + self.num_fused_shared_experts > 0 + and "mlp.shared_experts" in name + ): + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + + weight_names.append(name) + + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) + >= self.config.num_hidden_layers + ): + continue + + if name.startswith("mtp"): + continue + else: + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + # Skip target model weights + if not name.startswith(nextn_layer_prefix): + continue + + in_decoder = True + # For nextn specific weights (out of layer) + # The nextn layer prefix of these weights has been removed + for weight_name in nextn_spec_weight_names_out_of_layer: + if weight_name in name: + in_decoder = False + name = name.replace(nextn_layer_prefix, "model") + break + + # For decoder layer weights + if in_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict and name.startswith("mtp"): # TODO + break + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight, shard_id), + ) + loaded_params.add(name) + break + else: + for mapping in expert_params_mapping: + if MOE_BIT_WISE_EQUAL_MODE: + continue + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + if _is_npu: + name = name.replace("weight_packed", "weight") + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=( + param, + loaded_weight, + name, + ), + func_kwargs={ + "shard_id": shard_id, + "expert_id": expert_id, + }, + ) + loaded_params.add(name) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading embed_tokens if not first rank in pipeline parallelism + if ( + ".embed_tokens." in name + and not self.pp_group.is_first_rank + ): + continue + # Skip loading norm if not last rank in pipeline parallelism + if ".norm." in name and not self.pp_group.is_last_rank: + continue + elif COMPRESSOR_PART in name: + is_kv = name.endswith(".wkv.weight") + is_wgate = name.endswith(".wgate.weight") + assert is_kv != is_wgate # exactly one is true + key = name.rsplit(".", 2)[0] + assert key.endswith(".compressor") + if key not in cache_compressor_weight: + cache_compressor_weight[key] = ( + is_kv, + loaded_weight, + ) + else: + assert key in cache_compressor_weight + cached_is_kv, cached_weight = ( + cache_compressor_weight[key] + ) + assert cached_is_kv != is_kv + kv = loaded_weight if is_kv else cached_weight + wgate = loaded_weight if is_wgate else cached_weight + fused_weight = torch.cat([kv, wgate], dim=0) + param_name = key + ".wkv_gate.weight" + param = params_dict[param_name] + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, fused_weight), + ) + loaded_params.add(param_name) + cache_compressor_weight.pop(key) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace( + f"{scale[0]}_proj", "attn_mqa" + ) + break + if name not in params_dict: + # modelopt ckpt contains not needed weights for MTP module: + # model.decoder.self_attn.attn_mqa.v_scale and + # model.decoder.self_attn.attn_mqa.k_scale + if not name.startswith("mtp"): # TODO: mtp + logger.warning( + f"{name} not found in params_dict." + ) + continue + param = params_dict[name] + + # if "attn_sink" in name: + # attn_tp_rank = get_attention_tp_rank() + # start = attn_tp_rank * param.numel() + # param.data.copy_( + # loaded_weight[start : start + param.numel()] + # ) + # loaded_params.add(name) + # continue + + weight_loader = auto_weight_loader(param) + maybe_executor_submit( + executor=executor, + futures=futures, + use_async=use_async_loading, + func=weight_loader, + func_args=(param, loaded_weight), + ) + loaded_params.add(name) + except Exception as e: + e.add_note(f"{name=} {loaded_weight.shape=}") + raise + + # Wait for all tasks to complete and raise any exceptions. + for future in concurrent.futures.as_completed(futures): + future.result() + + assert len(cache_compressor_weight) == 0 + unloaded_params = params_dict.keys() - loaded_params + + skipped_checking_patterns = ["attn_mqa.k_scale", "attn_mqa.v_scale"] + if is_nextn: + skipped_checking_patterns.extend(["lm_head", "embed_tokens"]) + unloaded_params = { + p + for p in unloaded_params + # hack to skip checking these in default ckpt. should have more rigorous check. + if all( + skipped_checking_pattern not in p + for skipped_checking_pattern in skipped_checking_patterns + ) + } + if os.environ.get("SGLANG_SKIP_CHECKPOINT_LOAD_CHECK", "0") == "0": + if unloaded_params: + raise RuntimeError( + f"Some weights are not initialized from checkpoints: {unloaded_params}" + ) + + self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.n_routed_experts, + num_groups=None, + ) + + +EntryClass = [DeepseekV4ForCausalLM] + + +def _dequant_fp8(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Dequant fp8 block-quantized wo_a weight: bf16 = fp8_weight * e8m0_scale. + + Specifically for wo_a in 2604 checkpoint: + weight: [8192, 4096] fp8_e4m3fn (64*128 x 32*128) + scale: [64, 32] fp8_e8m0fnu (per 128x128 block) + """ + from einops import rearrange + + assert ( + weight.dtype == torch.float8_e4m3fn + ), f"expected fp8_e4m3fn, got {weight.dtype}" + assert ( + scale.dtype == torch.float8_e8m0fnu + ), f"expected fp8_e8m0fnu, got {scale.dtype}" + assert weight.shape == (8192, 4096), f"unexpected weight shape {weight.shape}" + assert scale.shape == (64, 32), f"unexpected scale shape {scale.shape}" + + weight_f32 = rearrange( + weight.float(), "(sn bn) (sk bk) -> sn bn sk bk", bn=128, bk=128 + ) + result = rearrange( + weight_f32 * scale.float()[:, None, :, None], "sn bn sk bk -> (sn bn) (sk bk)" + ) + + assert result.shape == (8192, 4096) + return result.to(torch.bfloat16) + + +def _dequant_fp8_wo_a( + weights: Iterable[Tuple[str, torch.Tensor]], +) -> Iterable[Tuple[str, torch.Tensor]]: + """Dequant fp8 wo_a weights inline: pair (wo_a.scale, wo_a.weight) -> bf16 wo_a.weight. + + 2601 checkpoint: + layers.0.attn.wo_a.weight torch.bfloat16 [8192, 4096] 64.00MB min=-0.375 max=0.3125 + + 2604 checkpoint: + layers.0.attn.wo_a.scale torch.float8_e8m0fnu [64, 32] 0.00MB + layers.0.attn.wo_a.weight torch.float8_e4m3fn [8192, 4096] 32.00MB + """ + weights_dict = dict(weights) + + for name in list(weights_dict.keys()): + if name not in weights_dict: + continue + if not name.endswith(".wo_a.weight"): + continue + scale_name = name.replace(".wo_a.weight", ".wo_a.scale") + assert scale_name in weights_dict + weight = weights_dict.pop(name) + scale = weights_dict.pop(scale_name) + yield name, _dequant_fp8(weight, scale) + + yield from weights_dict.items() diff --git a/python/sglang/srt/models/deepseek_v4_nextn.py b/python/sglang/srt/models/deepseek_v4_nextn.py new file mode 100644 index 000000000000..3323716e36e5 --- /dev/null +++ b/python/sglang/srt/models/deepseek_v4_nextn.py @@ -0,0 +1,255 @@ +"""Inference-only DeepSeek V4 NextN Speculative Decoding.""" + +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.environ import envs +from sglang.srt.layers.dp_attention import ( + _DpGatheredBufferWrapper, + dp_gather_partial, + get_attention_dp_size, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.utils import get_moe_a2a_backend +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.deepseek_v4 import DeepseekV4DecoderLayer, DeepseekV4ForCausalLM +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class DeepseekV4ModelNextN(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_id = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rms_norm_eps = config.rms_norm_eps + + self.layers_to_capture = [] + if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): + self.enable_a2a_moe = True + else: + self.enable_a2a_moe = False + + self.hc_eps = config.hc_eps + self.hc_mult = hc_mult = config.hc_mult + hc_dim = hc_mult * config.hidden_size + self.hc_head_fn = nn.Parameter( + torch.empty(hc_mult, hc_dim, dtype=torch.float32) + ) + self.hc_head_base = nn.Parameter(torch.empty(hc_mult, dtype=torch.float32)) + self.hc_head_scale = nn.Parameter(torch.empty(1, dtype=torch.float32)) + + self.e_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("e_proj", prefix), + ) + self.h_proj = ReplicatedLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("h_proj", prefix), + ) + + layer_name = "decoder" + + # Multi stream is disabled on MTP layer + self.decoder = DeepseekV4DecoderLayer( + config, + layer_id=0, + quant_config=quant_config, + is_nextn=True, + prefix=add_prefix(layer_name, prefix), + alt_streams=None, + ) + + self.shared_head = nn.Module() + self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Coped from DeepSeekV4Model + def hc_head( + self, + x: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + ): + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + self.rms_norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + self.hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if hidden_states.shape[0] > 0: + if ( + envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + ): + n_tokens = hidden_states.shape[0] + d = self.config.hidden_size + # spec_info.hidden_states: [n, hc*d] → reshape to [n*hc, d] for 2D kernels + hc_flat = forward_batch.spec_info.hidden_states.view( + n_tokens * self.hc_mult, d + ) + # hnorm + h_proj on each hc copy independently: [n*hc, d] → [n*hc, d] + h_proj_out, _ = self.h_proj(self.hnorm(hc_flat)) + # reshape back: [n*hc, d] → [n, hc, d] + h_proj_hidden_states = h_proj_out.view(n_tokens, self.hc_mult, d) + + # embed: [n, d] → enorm → e_proj → [n, d] + e_proj_hidden_states, _ = self.e_proj(self.enorm(hidden_states)) + # broadcast [n, 1, d] + [n, hc, d] → [n, hc, d] + hidden_states = e_proj_hidden_states[:, None, :] + h_proj_hidden_states + else: + e_proj_hidden_states, _ = self.e_proj(self.enorm(hidden_states)) + h_proj_hidden_states, _ = self.h_proj( + self.hnorm(forward_batch.spec_info.hidden_states) + ) + hidden_states = e_proj_hidden_states + h_proj_hidden_states + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + else: + hidden_states = hidden_states.unsqueeze(1).repeat(1, self.hc_mult, 1) + + if get_attention_dp_size() > 1 and get_moe_a2a_backend().is_none(): + input_ids_global = torch.empty( + (_DpGatheredBufferWrapper._global_dp_buffer_len, 1), + dtype=input_ids.dtype, + device=input_ids.device, + ) + dp_gather_partial(input_ids_global, input_ids[:, None], forward_batch) + input_ids_global = input_ids_global.squeeze(-1) + else: # Pure TP attention + input_ids_global = input_ids + + hidden_states = self.decoder( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + input_ids=input_ids, + input_ids_global=input_ids_global, + ) + + # decoder output: [n, hc, d] → flatten to [n, hc*d] for spec pipeline + pre_hc_head = ( + hidden_states.flatten(1) + if envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + else None + ) + + hidden_states = self.hc_head( + hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base + ) + hidden_states = self.shared_head.norm(hidden_states) + + if pre_hc_head is not None: + return hidden_states, pre_hc_head + return hidden_states + + +class DeepseekV4ForCausalLMNextN(DeepseekV4ForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.pp_group = get_pp_group() + self.quant_config = quant_config + # if not set, model load will be broken in DeepseekV3ForCausalLM load_weights() + self.determine_num_fused_shared_experts() + + self.model = DeepseekV4ModelNextN( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), + use_attn_tp_group=get_global_server_args().enable_dp_lm_head, + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + result = self.model(input_ids, positions, forward_batch) + pre_hc_head = None + if ( + envs.SGLANG_FIX_MTP_HC_HIDDEN.get() + and envs.SGLANG_DSV4_MODE.get() == "2604" + ): + hidden_states, pre_hc_head = result + else: + hidden_states = result + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + hidden_states_before_norm=pre_hc_head, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights, is_nextn=True) + + +EntryClass = [DeepseekV4ForCausalLMNextN] diff --git a/python/sglang/srt/models/registry.py b/python/sglang/srt/models/registry.py index 066be3dc44b3..a28c15fa782b 100644 --- a/python/sglang/srt/models/registry.py +++ b/python/sglang/srt/models/registry.py @@ -104,7 +104,9 @@ def import_model_classes(package_name: str, strict: bool = False): except Exception as e: if strict: raise - logger.warning(f"Ignore import error when loading {name}: {e}") + logger.warning( + f"In import_model_classes: Ignore import error when loading {name}: {e}" + ) continue if hasattr(module, "EntryClass"): entry = module.EntryClass diff --git a/python/sglang/srt/parser/reasoning_parser.py b/python/sglang/srt/parser/reasoning_parser.py index 8949ba5d75b4..6b5a469f982b 100644 --- a/python/sglang/srt/parser/reasoning_parser.py +++ b/python/sglang/srt/parser/reasoning_parser.py @@ -304,6 +304,7 @@ class ReasoningParser: DetectorMap: Dict[str, Type[BaseReasoningFormatDetector]] = { "deepseek-r1": DeepSeekR1Detector, "deepseek-v3": Qwen3Detector, + "deepseek-v4": Qwen3Detector, "glm45": Qwen3Detector, "gpt-oss": GptOssDetector, "kimi": KimiDetector, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1a049ced4fcc..f9fc18a03e98 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -122,6 +122,7 @@ "torch_native", "flex_attention", "nsa", + "compressed", # NVIDIA specific "cutlass_mla", "fa3", @@ -519,6 +520,7 @@ class ServerArgs: hicache_storage_backend_extra_config: Optional[str] = None # Hierarchical sparse attention + enable_hisparse: bool = False hierarchical_sparse_attention_extra_config: Optional[str] = None # LMCache @@ -609,6 +611,7 @@ class ServerArgs: keep_mm_feature_on_device: bool = False enable_return_hidden_states: bool = False enable_return_routed_experts: bool = False + enable_return_indexer_topk: bool = False scheduler_recv_interval: int = 1 numa_node: Optional[List[int]] = None enable_deterministic_inference: bool = False @@ -1172,6 +1175,48 @@ def _handle_model_specific_adjustments(self): ]: self.dtype = "bfloat16" + if model_arch in [ + "DeepseekV4ForCausalLM", + ]: + self.attention_backend = "compressed" + self.page_size = 256 + logger.info( + f"Use compressed attention backend for {model_arch}, setting page_size to 256." + ) + + if self.max_running_requests is None: + self.max_running_requests = 256 + logger.warning( + f"Setting max_running_requests to {self.max_running_requests} for {model_arch}." + ) + + if self.kv_cache_dtype == "auto": + self.kv_cache_dtype = "fp8_e4m3" + logger.warning( + f"Setting KV cache dtype to {self.kv_cache_dtype} for {model_arch}." + ) + assert self.kv_cache_dtype in [ + "fp8_e4m3" + ], f"{self.kv_cache_dtype} is not supported for {model_arch}" + + if self.speculative_algorithm is not None: + assert ( + self.speculative_algorithm == "EAGLE" + ), f"Only EAGLE speculative algorithm is supported for {model_arch}" + assert ( + self.speculative_eagle_topk == 1 + ), f"Only EAGLE speculative algorithm with topk == 1 is supported for {model_arch}" + + if not envs.SGLANG_ENABLE_SPEC_V2.get(): + envs.SGLANG_ENABLE_SPEC_V2.set(True) + logger.warning("Spec v2 is enabled for EAGLE speculative decoding.") + + if self.swa_full_tokens_ratio == ServerArgs.swa_full_tokens_ratio: + self.swa_full_tokens_ratio = 0.1 + logger.info( + f"Setting swa_full_tokens_ratio to {self.swa_full_tokens_ratio} for {model_arch}." + ) + if model_arch in [ "DeepseekV3ForCausalLM", "MistralLarge3ForCausalLM", @@ -1205,8 +1250,8 @@ def _handle_model_specific_adjustments(self): self.dp_size == 1 ), "For round-robin split mode, dp attention is not supported." assert ( - self.tp_size == 8 - ), "Current multi-machine CP support suffers from precision issues. So context parallel only support Single machine(tp_size == 8)" + self.tp_size <= 8 + ), "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." logger.warning( f"Enable Context Parallel opt for deeeseekv3.2-DSA, Setting dp_size == {self.dp_size} and moe_dense_tp_size == {self.moe_dense_tp_size}, ep_size == {self.ep_size}, tp_size == {self.tp_size}, kv_cache_dtype == {self.kv_cache_dtype}, moe_a2a_backend {self.moe_a2a_backend} " @@ -1220,9 +1265,14 @@ def _handle_model_specific_adjustments(self): ) if is_hip(): - self.page_size = 1 + # self.page_size = 1 + # logger.warning( + # "Setting page size to 1 for DeepSeek DSA on ROCm." + # ) + self.page_size = 64 logger.warning( - "Setting page size to 1 for DeepSeek DSA on ROCm." + "Setting page size to 64 for DeepSeek DSA on torch implementation.\n" + "Need to be changed based on ROCm implementation.\n" ) else: # For CUDA GPU @@ -1301,6 +1351,33 @@ def _handle_model_specific_adjustments(self): "Use triton fused moe by default for bf16 nextn layer in deepseek fp4 checkpoint." ) + elif model_arch in [ + "DeepseekV4ForCausalLM", + ]: + # Mirrors the DeepseekV2ForCausalLM CP config above (line ~1240), + # adapted for V4: same round-robin-split guards (dp_size=1, + # tp_size<=8, tilelang disabled), but without V2-specific settings + # (kv_cache_dtype, moe_a2a_backend, etc.). + if self.enable_nsa_prefill_context_parallel: + if self.nsa_prefill_cp_mode == "round-robin-split": + self.moe_dense_tp_size = 1 + assert ( + self.dp_size == 1 + ), "For round-robin split mode, dp attention is not supported." + assert ( + self.tp_size <= 8 + ), "Context parallel only supports single machine (tp_size <= 8). Cross-machine CP has precision issues." + logger.warning( + f"Enable Context Parallel for DeepSeekV4, " + f"dp_size={self.dp_size}, moe_dense_tp_size={self.moe_dense_tp_size}, " + f"ep_size={self.ep_size}, tp_size={self.tp_size}" + ) + else: + raise ValueError( + f"DeepSeekV4 only supports round-robin-split CP mode, " + f"got {self.nsa_prefill_cp_mode}" + ) + elif model_arch in ["GptOssForCausalLM"]: # Set attention backend for GPT-OSS if self.is_attention_backend_not_set(): @@ -2231,6 +2308,7 @@ def _handle_speculative_decoding(self): if model_arch in [ "DeepseekV32ForCausalLM", "DeepseekV3ForCausalLM", + "DeepseekV4ForCausalLM", "Glm4MoeForCausalLM", "Glm4MoeLiteForCausalLM", "BailingMoeForCausalLM", @@ -4054,6 +4132,11 @@ def add_cli_args(parser: argparse.ArgumentParser): help="A dictionary in JSON string format, or a string starting with a leading '@' and a config file in JSON/YAML/TOML format, containing extra configuration for the storage backend.", ) + parser.add_argument( + "--enable-hisparse", + action="store_true", + help="Enable hierarchical sparse attention", + ) # Hierarchical sparse attention parser.add_argument( "--hierarchical-sparse-attention-extra-config", @@ -4469,6 +4552,11 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable returning routed experts of each layer with responses.", ) + parser.add_argument( + "--enable-return-indexer-topk", + action="store_true", + help="Enable returning indexer topk indices of layers with indexer with responses.", + ) parser.add_argument( "--scheduler-recv-interval", type=int, diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py index 9c630da72fb1..e0a47d61c7cb 100644 --- a/python/sglang/srt/speculative/draft_utils.py +++ b/python/sglang/srt/speculative/draft_utils.py @@ -55,6 +55,7 @@ def create_decode_backend(self): "trtllm_mla": self._create_trtllm_mla_decode_backend, "nsa": self._create_nsa_decode_backend, "ascend": self._create_ascend_decode_backend, + "compressed": self._create_compressed_decode_backend, } return self._create_backend( @@ -79,6 +80,7 @@ def create_draft_extend_backend(self): "trtllm_mla": self._create_trtllm_mla_prefill_backend, "nsa": self._create_nsa_prefill_backend, "ascend": self._create_ascend_prefill_backend, + "compressed": self._create_compressed_prefill_backend, } backend_name = ( "decode_attention_backend" @@ -189,6 +191,15 @@ def _create_ascend_decode_backend(self): self.draft_model_runner, self.topk, self.speculative_num_steps ) + def _create_compressed_decode_backend(self): + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DeepseekV4MultiStepBackend, + ) + + return DeepseekV4MultiStepBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + def _create_flashinfer_prefill_backend(self): if not get_global_server_args().use_mla_backend: from sglang.srt.layers.attention.flashinfer_backend import ( @@ -247,3 +258,10 @@ def _create_flashmla_prefill_backend(self): "flashmla prefill backend is not yet supported for draft extend." ) return None + + def _create_compressed_prefill_backend(self): + from sglang.srt.layers.attention.deepseek_v4_backend_radix import ( + DeepseekV4BackendRadix, + ) + + return DeepseekV4BackendRadix(self.draft_model_runner, skip_prefill=False) diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 5fe45086ca4a..f7a6c2fbc871 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -102,7 +102,7 @@ def __init__(self, eagle_worker: EAGLEWorker): self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) self.hidden_states = torch.zeros( - (self.max_bs, self.model_runner.model_config.hidden_size), + (self.max_bs, self.model_runner.model_config.spec_hidden_size), dtype=self.model_runner.dtype, ) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index e1afdd84b547..38b3b4991620 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -121,7 +121,10 @@ def __init__(self, eagle_worker: EAGLEWorker): ) else: self.hidden_states = torch.zeros( - (self.max_num_token, self.model_runner.model_config.hidden_size), + ( + self.max_num_token, + self.model_runner.model_config.spec_hidden_size, + ), dtype=self.model_runner.dtype, ) self.seq_len_fill_value = ( diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index e22eeaee46cd..cf9ffb8855b7 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -235,7 +235,7 @@ def verify( return EagleVerifyOutput( draft_input=EagleDraftInput.create_idle_input( device=batch.device, - hidden_size=batch.model_config.hidden_size, + hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, @@ -597,7 +597,7 @@ def verify( else: draft_input = EagleDraftInput.create_idle_input( device=batch.device, - hidden_size=batch.model_config.hidden_size, + hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index b542e9615f2a..715d5ce2c3c8 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -267,7 +267,7 @@ def sample( (which contains spec decoding information). """ if batch.forward_mode.is_idle(): - predict = torch.empty(0, dtype=torch.long, device=batch.input_ids.device) + predict = torch.empty(0, dtype=torch.int32, device=batch.input_ids.device) accept_length = torch.empty( 0, dtype=torch.int32, device=batch.input_ids.device ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 0086e2aa700e..03ce349596db 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -515,7 +515,7 @@ def _draft_preprocess_decode(self, batch: ScheduleBatch): def _draft_preprocess_idle(self, batch: ScheduleBatch): batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, - hidden_size=self.model_config.hidden_size, + hidden_size=self.model_config.spec_hidden_size, dtype=self.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, @@ -901,12 +901,12 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): if not input_is_idle and batch.spec_info.verified_id.numel() == 0: batch = batch.copy() batch.prepare_for_idle() - hidden_size = ( - self.model_config.hidden_size * 3 - if self.speculative_algorithm.is_eagle3() + hidden_size = self.model_config.spec_hidden_size + if ( + self.speculative_algorithm.is_eagle3() and self.eagle_use_aux_hidden_state - else self.model_config.hidden_size - ) + ): + hidden_size = self.model_config.hidden_size * 3 batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, hidden_size=hidden_size, @@ -954,6 +954,7 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch): forward_batch.spec_info.hidden_states = logits_output.hidden_states else: forward_batch.can_run_dp_cuda_graph = False + # print(f"forward_draft_extend_after_decode.spec_info: {forward_batch.spec_info.accept_length}") if not forward_batch.forward_mode.is_idle(): self.draft_model_runner.attn_backend.init_forward_metadata( forward_batch diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 1c90a3041f62..c4d2e1f20200 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -12,6 +12,7 @@ from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import ( EAGLEDraftNpuGraphRunner, ) +from sglang.srt.layers.attention.deepseek_v4_backend_radix import DeepseekV4BackendRadix from sglang.srt.layers.attention.triton_backend import TritonMultiStepDraftBackend from sglang.srt.layers.attention.trtllm_mla_backend import ( TRTLLMMLAMultiStepDraftBackend, @@ -279,6 +280,11 @@ def init_cuda_graphs(self): _is_cuda and isinstance(self.draft_attn_backend, TRTLLMMLAMultiStepDraftBackend) ) + or ( + _is_cuda + and isinstance(self.draft_extend_attn_backend, DeepseekV4BackendRadix) + and envs.SGLANG_OPT_V4_DRAFT_EXTEND_CUDA_GRAPH.get() + ) ): tic = time.perf_counter() before_mem = get_available_gpu_memory(self.device, self.gpu_id) @@ -659,7 +665,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): if model_worker_batch.spec_info is None: model_worker_batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, - hidden_size=self.target_worker.model_config.hidden_size, + hidden_size=self.target_worker.model_config.spec_hidden_size, dtype=self.target_worker.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index 44bb2f0de128..d5cb73e6354f 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -602,7 +602,7 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): if model_worker_batch.spec_info is None: model_worker_batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, - hidden_size=self.target_worker.model_config.hidden_size, + hidden_size=self.target_worker.model_config.spec_hidden_size, dtype=self.target_worker.model_config.dtype, topk=self.topk * self.speculative_num_steps, capture_hidden_mode=CaptureHiddenMode.LAST, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 8e39ee4cad66..d6cb5c65e90a 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -4086,3 +4086,11 @@ def bind_to_closest_numa_node_cuda(): if is_numa_available() and nvgpu_available(): node_id = get_current_device_numa_node_cuda() numa_bind_to_node(node_id) + + +def maybe_torch_compile(func): + from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode + + if get_is_capture_mode(): + return torch.compile(func) + return func diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index f8743b416eaf..91aa7e3c7227 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -20,7 +20,7 @@ import tempfile import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Type, Union import torch from huggingface_hub import snapshot_download @@ -69,7 +69,6 @@ from sglang.srt.connector import create_remote_connector from sglang.srt.multimodal.customized_mm_processor_utils import _CUSTOMIZED_MM_PROCESSOR from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset, mistral_utils -from sglang.srt.utils.patch_tokenizer import patch_tokenizer _CONFIG_REGISTRY: List[Type[PretrainedConfig]] = [ AfmoeConfig, @@ -163,8 +162,10 @@ def get_hf_text_config(config: PretrainedConfig): # Temporary hack for DeepSeek-V3.2 model -def _load_deepseek_v32_model( +def _load_deepseek_temp_model( model_path: str, + model_type: Literal["deepseek_v32", "deepseek_v4"], + architecture: Literal["DeepseekV3ForCausalLM", "DeepseekV4ForCausalLM"], trust_remote_code: bool = False, revision: Optional[str] = None, **kwargs, @@ -179,13 +180,13 @@ def _load_deepseek_v32_model( with open(config_file, "r") as f: config_json = json.load(f) - config_json["architectures"] = ["DeepseekV3ForCausalLM"] + config_json["architectures"] = [architecture] config_json["model_type"] = "deepseek_v3" tmp_path = os.path.join(tempfile.gettempdir(), "_tmp_config_folder") os.makedirs(tmp_path, exist_ok=True) - unique_path = os.path.join(tmp_path, f"deepseek_v32_{os.getpid()}") + unique_path = os.path.join(tmp_path, f"{model_type}_{os.getpid()}") with open(unique_path, "w") as f: json.dump(config_json, f) @@ -277,11 +278,26 @@ def get_config( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) except ValueError as e: - if not "deepseek_v32" in str(e): + if "deepseek_v4" in str(e): + config = _load_deepseek_temp_model( + model, + model_type="deepseek_v4", + architecture="DeepseekV4ForCausalLM", + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + elif "deepseek_v32" in str(e): + config = _load_deepseek_temp_model( + model, + model_type="deepseek_v32", + architecture="DeepseekV3ForCausalLM", + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: raise e - config = _load_deepseek_v32_model( - model, trust_remote_code=trust_remote_code, revision=revision, **kwargs - ) if ( config.architectures is not None @@ -504,7 +520,7 @@ def get_tokenizer( ) attach_additional_stop_token_ids(tokenizer) - tokenizer = patch_tokenizer(tokenizer) + return tokenizer diff --git a/python/sglang/srt/utils/weight_checker.py b/python/sglang/srt/utils/weight_checker.py index ca5c2e08e24c..37d91425eee0 100644 --- a/python/sglang/srt/utils/weight_checker.py +++ b/python/sglang/srt/utils/weight_checker.py @@ -104,12 +104,16 @@ def _random_like(t: torch.Tensor): shape = t.shape dtype = t.dtype - if dtype.is_floating_point: + # FP8 types (float8_e4m3fn, etc.) have is_floating_point=False but are logically floats + if dtype.is_floating_point or "float" in str(dtype): return torch.rand(shape, device=device, dtype=torch.float32).to(dtype) if dtype == torch.bool: return torch.rand(shape, device=device) > 0.5 + if dtype.is_complex: + return torch.randn(shape, device=device, dtype=dtype) + info = torch.iinfo(dtype) return torch.randint( low=int(info.min), high=int(info.max), size=shape, device=device, dtype=dtype @@ -121,7 +125,18 @@ def _postprocess_tensors( ) -> Iterable[Tuple[str, bool, torch.Tensor]]: from sglang.srt.debug_utils.dumper import get_tensor_info - skip_compare_names = [] + # skip because megatron don't use k_scale/v_scale + skip_compare_names = [ + name + for name in raw + if any(pattern in name for pattern in ["attn_mqa.k_scale", "attn_mqa.v_scale"]) + ] + # rope parameters should not be updated from megatron + skip_compare_names += [ + name + for name in raw + if any(pattern in name for pattern in ["freqs_cis", "cos_sin_cache"]) + ] # dequant fp8 quant_names = [ @@ -130,19 +145,27 @@ def _postprocess_tensors( # Match: `something.weight`, `something.experts.w2_weight` if name.endswith("weight") and name.replace("weight", "weight_scale_inv") in raw ] + quant_scale_names = [ + name.replace("weight", "weight_scale_inv") for name in quant_names + ] skip_compare_names += quant_names + skip_compare_names += quant_scale_names for name in quant_names: w_q = raw[name] w_s = raw[name.replace("weight", "weight_scale_inv")] try: - # TODO this is only needed for Blackwell - w_s_inverse_transformed = inverse_transform_scale_ue8m0( - w_s, mn=w_q.shape[-2] - ) + # ue8m0 format has int32 dtype + # triton format has float32 dtype + if w_s.dtype == torch.int32: + # TODO this is only needed for Blackwell + w_s_for_dequant = inverse_transform_scale_ue8m0(w_s, mn=w_q.shape[-2]) + else: + w_s_for_dequant = w_s + w_dequant = block_quant_dequant( w_q, - w_s_inverse_transformed, + w_s_for_dequant, # TODO do not hardcode block_size=[128, 128], dtype=torch.bfloat16, diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index e637aaee90cb..0d1e0e7d530c 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,11 +1,34 @@ #!/bin/bash +# Usage: +# ./killall_sglang.sh - Kill SGLang processes only (NVIDIA mode) +# ./killall_sglang.sh rocm - Kill SGLang processes only (ROCm mode) +# ./killall_sglang.sh all - Kill all GPU processes (NVIDIA mode) +# ./killall_sglang.sh gpus 0,1,2,3 - Kill all processes on specific GPUs + if [ "$1" = "rocm" ]; then echo "Running in ROCm mode" # Clean SGLang processes pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_diffusion::' | xargs -r kill -9 +elif [ "$1" = "gpus" ] && [ -n "$2" ]; then + # Kill all processes on specific GPUs only + echo "Killing all processes on GPUs: $2" + + # Show current GPU status + nvidia-smi + + # Build device file list from GPU IDs (e.g., "0,1,2,3" -> "/dev/nvidia0 /dev/nvidia1 ...") + devices=$(echo "$2" | tr ',' '\n' | sed 's/^[[:space:]]*//;s/[[:space:]]*$//' | sed 's|^|/dev/nvidia|' | tr '\n' ' ') + echo "Targeting devices: $devices" + + # Kill all processes using specified GPU devices + lsof $devices 2>/dev/null | awk 'NR>1 {print $2}' | sort -u | xargs -r kill -9 2>/dev/null + + # Show GPU status after clean up + nvidia-smi + else # Show current GPU status nvidia-smi @@ -13,8 +36,8 @@ else # Clean SGLang processes pgrep -f 'sglang::|sglang\.launch_server|sglang\.bench|sglang\.data_parallel|sglang\.srt|sgl_diffusion::' | xargs -r kill -9 - # Clean all GPU processes if any argument is provided - if [ $# -gt 0 ]; then + # Clean all GPU processes if "all" argument is provided + if [ "$1" = "all" ]; then # Check if sudo is available if command -v sudo >/dev/null 2>&1; then sudo apt-get update diff --git a/sgl-kernel/cmake/flashmla.cmake b/sgl-kernel/cmake/flashmla.cmake index c17266af243f..8b7c1d5f0ca2 100644 --- a/sgl-kernel/cmake/flashmla.cmake +++ b/sgl-kernel/cmake/flashmla.cmake @@ -30,6 +30,11 @@ if(${CUDA_VERSION} VERSION_GREATER 12.8) "-gencode=arch=compute_100a,code=sm_100a" ) endif() +if(${CUDA_VERSION} VERSION_GREATER_EQUAL 13.0) +list(APPEND FLASHMLA_CUDA_FLAGS + "-gencode=arch=compute_103a,code=sm_103a" +) +endif() set(FlashMLA_SOURCES diff --git a/test/registered/function_call/test_function_call_parser.py b/test/registered/function_call/test_function_call_parser.py index 7263a492cffe..6348d3e6615a 100644 --- a/test/registered/function_call/test_function_call_parser.py +++ b/test/registered/function_call/test_function_call_parser.py @@ -5,6 +5,7 @@ from sglang.srt.function_call.base_format_detector import BaseFormatDetector from sglang.srt.function_call.core_types import StreamingParseResult from sglang.srt.function_call.deepseekv3_detector import DeepSeekV3Detector +from sglang.srt.function_call.deepseekv4_detector import DeepSeekV4Detector from sglang.srt.function_call.deepseekv32_detector import DeepSeekV32Detector from sglang.srt.function_call.glm4_moe_detector import Glm4MoeDetector from sglang.srt.function_call.glm47_moe_detector import Glm47MoeDetector @@ -1613,6 +1614,283 @@ def test_streaming_no_parameters_with_whitespace(self): self.assertEqual(params, {}) +class TestDeepSeekV4Detector(unittest.TestCase): + """DeepSeek V4 DSML tool-call tests. + + Mirrors TestDeepSeekV32Detector but targets the V4 outer block name + ``<|DSML|tool_calls>`` instead of ``<|DSML|function_calls>``. The V4 + reference encoder only emits XML-parameter form, so the V32 JSON-body + tests have no V4 analogue and are intentionally omitted. + """ + + def setUp(self): + self.tools = [ + Tool( + type="function", + function=Function( + name="search", + description="Searches for information related to query and displays topn results.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query string", + }, + "topn": { + "type": "integer", + "description": "Number of top results to display", + "default": 10, + }, + "source": { + "type": "string", + "description": "Source to search within", + "enum": ["web", "news"], + "default": "web", + }, + }, + "required": ["query"], + }, + ), + ), + Tool( + type="function", + function=Function( + name="get_favorite_tourist_spot", + description="Return the favorite tourist spot for a given city.", + parameters={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + ), + ] + self.detector = DeepSeekV4Detector() + from transformers import AutoTokenizer + + # V3.2 tokenizer works for the chunk-split streaming test: it already + # has the DSML special tokens and decodes the test strings losslessly. + self.tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3.2") + self.interval = 1 + + def test_detect_and_parse_xml_format(self): + """Test parsing standard XML format (DSML)""" + text = """I'll help you with information about San Francisco and get its favorite tourist spot for you.\n\n + <|DSML|tool_calls>\n + <|DSML|invoke name="get_favorite_tourist_spot">\n + <|DSML|parameter name="city" string="true">San Francisco\n + \n + <|DSML|invoke name="search"> + <|DSML|parameter name="query" string="true">WebNav benchmark + <|DSML|parameter name="topn" string="false">10 + <|DSML|parameter name="source" string="true">web + + + """ + result = self.detector.detect_and_parse(text, self.tools) + + self.assertIn("I'll help you with information", result.normal_text) + self.assertEqual(len(result.calls), 2) + + call1 = result.calls[0] + self.assertEqual(call1.name, "get_favorite_tourist_spot") + params1 = json.loads(call1.parameters) + self.assertEqual(params1["city"], "San Francisco") + + call2 = result.calls[1] + self.assertEqual(call2.name, "search") + params2 = json.loads(call2.parameters) + self.assertEqual(params2["query"], "WebNav benchmark") + self.assertEqual(params2["topn"], 10) + self.assertEqual(params2["source"], "web") + + def test_streaming_xml_format(self): + """Test streaming parsing of XML format""" + text = """<|DSML|tool_calls> + <|DSML|invoke name="get_favorite_tourist_spot"> + <|DSML|parameter name="city" string="true">San Francisco + <|DSML|parameter name="another_city" string="true">London + <|DSML|parameter name="topn" string="false">10 + <|DSML|parameter name="obj" string="false">{"name": "John", "age": 30} + + """ + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, self.tools) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual(len(tool_calls_by_index), 1) + self.assertEqual(tool_calls_by_index[0]["name"], "get_favorite_tourist_spot") + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params["city"], "San Francisco") + self.assertEqual(params["another_city"], "London") + self.assertEqual(params["topn"], 10) + self.assertEqual(params["obj"]["name"], "John") + self.assertEqual(params["obj"]["age"], 30) + + def test_detect_and_parse_no_parameters(self): + """Test parsing function calls with no parameters (non-streaming)""" + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """Let me get the current date for you. + +<|DSML|tool_calls> +<|DSML|invoke name="get_date"> + +""" + + result = self.detector.detect_and_parse(text, tools_with_no_param) + + self.assertIn("Let me get the current date", result.normal_text) + self.assertEqual(len(result.calls), 1) + + call = result.calls[0] + self.assertEqual(call.name, "get_date") + params = json.loads(call.parameters) + self.assertEqual(params, {}) + + def test_streaming_no_parameters(self): + """Test streaming parsing of function calls with no parameters.""" + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """<|DSML|tool_calls> +<|DSML|invoke name="get_date"> + +""" + + self.detector = DeepSeekV4Detector() + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual( + len(tool_calls_by_index), 1, "Should have exactly one tool call" + ) + self.assertEqual(tool_calls_by_index[0]["name"], "get_date") + + params_str = tool_calls_by_index[0]["parameters"].strip() + params = json.loads(params_str) + self.assertEqual(params, {}) + + def test_streaming_no_parameters_with_whitespace(self): + """Test streaming parsing when invoke content has only whitespace (newlines).""" + tools_with_no_param = self.tools + [ + Tool( + type="function", + function=Function( + name="get_date", + description="Get the current date.", + parameters={"type": "object", "properties": {}}, + ), + ), + ] + + text = """<|DSML|tool_calls> +<|DSML|invoke name="get_date"> + + +""" + + self.detector = DeepSeekV4Detector() + + input_ids = self.tokenizer.encode(text, add_special_tokens=False) + chunk_ids = [ + input_ids[i : i + self.interval] + for i in range(0, len(input_ids), self.interval) + ] + chunks = [self.tokenizer.decode(chunk_id) for chunk_id in chunk_ids] + + tool_calls_by_index = {} + + for chunk in chunks: + result = self.detector.parse_streaming_increment(chunk, tools_with_no_param) + for call in result.calls: + if call.tool_index is not None: + if call.tool_index not in tool_calls_by_index: + tool_calls_by_index[call.tool_index] = { + "name": "", + "parameters": "", + } + + if call.name: + tool_calls_by_index[call.tool_index]["name"] = call.name + if call.parameters: + tool_calls_by_index[call.tool_index][ + "parameters" + ] += call.parameters + + self.assertEqual( + len(tool_calls_by_index), 1, "Should have exactly one tool call" + ) + self.assertEqual(tool_calls_by_index[0]["name"], "get_date") + params = json.loads(tool_calls_by_index[0]["parameters"]) + self.assertEqual(params, {}) + + class TestQwen3CoderDetector(unittest.TestCase): """Test suite for Qwen3CoderDetector.""" diff --git a/test/registered/openai_server/basic/test_serving_chat.py b/test/registered/openai_server/basic/test_serving_chat.py index d81f2efb051f..090207fc68ef 100644 --- a/test/registered/openai_server/basic/test_serving_chat.py +++ b/test/registered/openai_server/basic/test_serving_chat.py @@ -37,7 +37,7 @@ def __init__(self): tool_call_parser="hermes", reasoning_parser=None, ) - # Mock hf_config for _use_dpsk_v32_encoding check + # Mock hf_config for _resolve_chat_encoding_spec check mock_hf_config = Mock() mock_hf_config.architectures = ["LlamaForCausalLM"] self.model_config.hf_config = mock_hf_config @@ -614,18 +614,29 @@ def test_dpsk_v32_encoding_path(self): tokenizer_manager.tokenizer.chat_template = None serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) - self.assertTrue(serving_chat.use_dpsk_v32_encoding) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv32") # Case 2: Chat template exists -> should NOT use dpsk encoding tokenizer_manager.tokenizer.chat_template = "some template" serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) - self.assertFalse(serving_chat.use_dpsk_v32_encoding) + self.assertIsNone(serving_chat.chat_encoding_spec) # Case 3: Not DeepSeek V3.2 architecture -> should NOT use dpsk encoding tokenizer_manager.tokenizer.chat_template = None mock_hf_config.architectures = ["LlamaForCausalLM"] serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) - self.assertFalse(serving_chat.use_dpsk_v32_encoding) + self.assertIsNone(serving_chat.chat_encoding_spec) + + # Case 4: DeepseekV4 arch -> always dsv4, even with chat_template + # (release ships a stale V3 jinja we deliberately override). + mock_hf_config.architectures = ["DeepseekV4ForCausalLM"] + tokenizer_manager.tokenizer.chat_template = "stale v3 jinja" + serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv4") + + tokenizer_manager.tokenizer.chat_template = None + serving_chat = OpenAIServingChat(tokenizer_manager, TemplateManager()) + self.assertEqual(serving_chat.chat_encoding_spec, "dsv4") if __name__ == "__main__":