Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ export const DeepSeekV4Deployment = () => {
// _LAUNCH_HEAD always prepends these:
// Per-hardware env (whitelist #1: NVSHMEM removed for B200).
const HW_ENV = {
h200: ["SGLANG_DSV4_FP4_EXPERTS=0"], // allinone _ENV_H200
h200: [], // allinone _ENV_H200
b200: [], // _ENV_B200 minus NVSHMEM
gb300: [], // _ENV_GB300
// GB200 multinode needs NCCL MNNVL for cross-node NVLink communication.
Expand Down Expand Up @@ -625,7 +625,7 @@ export const DeepSeekV4Deployment = () => {
const isBlackwell = hardware === "b200" || hardware === "gb200" || isGB300;

const HW_ENV = {
h200: ["SGLANG_DSV4_FP4_EXPERTS=0"],
h200: [],
b200: [],
gb300: [],
gb200: [],
Expand Down
96 changes: 46 additions & 50 deletions python/run_dsv4.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,6 @@
#/dockerx/data/models/DeepSeek-V4-Flash

#### FP8 model path ####
export SGLANG_REASONING_EFFORT=max

export SGLANG_OPT_USE_FUSED_COMPRESS=false #use PyTorch implemented compressor
export SGLANG_OPT_USE_OLD_COMPRESSOR=true #use old compressor
export SGLANG_OPT_USE_TILELANG_SWA_PREPARE=false #use old prepare
export SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK=false #use old topk
export SGLANG_OPT_USE_FUSED_HASH_TOPK=false #AMD: hash_topk JIT needs CUDA toolchain

export SGLANG_HACK_FLASHMLA_BACKEND=torch
export SGLANG_HACK_FLASHMLA_BACKEND=tilelang
export SGLANG_OPT_DEEPGEMM_HC_PRENORM=false #use old prenorm

export SGLANG_OPT_USE_TILELANG_MHC_PRE=false #use torch hc_pre
export SGLANG_OPT_USE_TILELANG_MHC_POST=false #use torch hc_post

export SGLANG_ENABLE_THINKING=1
export SGLANG_USE_AITER=1
export SGLANG_USE_ROCM700A=1
export SGLANG_TOPK_TRANSFORM_512_TORCH=1
export SGLANG_FP8_PAGED_MQA_LOGITS_TORCH=1

export SGLANG_DSV4_FP4_EXPERTS=false

export SGLANG_OPT_DPSK_V4_RADIX=0
export SGLANG_OPT_USE_OVERLAP_STORE_CACHE=false #non-radix backend has no store_cache method
export SGLANG_OPT_USE_FUSED_STORE_CACHE=false #fused_store_cache JIT needs CUDA toolchain

export SGLANG_FORCE_TRITON_MOE_FP8=1 # this is required to apply swiglu_limit clamp in fused_moe_triton

python3 -m sglang.launch_server \
--model-path /dockerx/data2/models/DeepSeek-V4-Flash-FP8 \
--trust-remote-code \
--tp 8 \
--dp 8 \
--enable-dp-attention \
--disable-radix-cache \
--attention-backend compressed \
--max-running-request 256 \
--page-size 256 \
--chunked-prefill-size 8192 \
--port 8000 \
--disable-shared-experts-fusion \
--tool-call-parser deepseekv4 \
--reasoning-parser deepseek-v4

#### FP4 model path ####
#export SGLANG_REASONING_EFFORT=max
#
#export SGLANG_OPT_USE_FUSED_COMPRESS=false #use PyTorch implemented compressor
Expand All @@ -56,6 +10,7 @@ python3 -m sglang.launch_server \
#export SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK=false #use old topk
#export SGLANG_OPT_USE_FUSED_HASH_TOPK=false #AMD: hash_topk JIT needs CUDA toolchain
#
#export SGLANG_HACK_FLASHMLA_BACKEND=torch
#export SGLANG_HACK_FLASHMLA_BACKEND=tilelang
#export SGLANG_OPT_DEEPGEMM_HC_PRENORM=false #use old prenorm
#
Expand All @@ -68,19 +23,18 @@ python3 -m sglang.launch_server \
#export SGLANG_TOPK_TRANSFORM_512_TORCH=1
#export SGLANG_FP8_PAGED_MQA_LOGITS_TORCH=1
#
#export SGLANG_DSV4_FP4_EXPERTS=False
#
#export SGLANG_OPT_DPSK_V4_RADIX=0
#export SGLANG_OPT_USE_OVERLAP_STORE_CACHE=false #non-radix backend has no store_cache method
#export SGLANG_OPT_USE_FUSED_STORE_CACHE=false #fused_store_cache JIT needs CUDA toolchain
#
#export SGLANG_FORCE_TRITON_MOE_FP8=1 # this is required to apply swiglu_limit clamp in fused_moe_triton
#
#python3 -m sglang.launch_server \
# --model-path /dockerx/data2/models/DeepSeek-V4-Flash-FP8/ \
# --model-path /dockerx/data2/models/DeepSeek-V4-Pro-FP8 \
# --trust-remote-code \
# --tp 8 \
# --ep 8 \
# --dp 8 \
# --enable-dp-attention \
# --disable-radix-cache \
# --attention-backend compressed \
# --max-running-request 256 \
Expand All @@ -91,3 +45,45 @@ python3 -m sglang.launch_server \
# --disable-cuda-graph \
# --tool-call-parser deepseekv4 \
# --reasoning-parser deepseek-v4

#### FP4 model path ####
export SGLANG_REASONING_EFFORT=max

export SGLANG_OPT_USE_FUSED_COMPRESS=false #use PyTorch implemented compressor
export SGLANG_OPT_USE_OLD_COMPRESSOR=true #use old compressor
export SGLANG_OPT_USE_TILELANG_SWA_PREPARE=false #use old prepare
export SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK=false #use old topk
export SGLANG_OPT_USE_FUSED_HASH_TOPK=false #AMD: hash_topk JIT needs CUDA toolchain

export SGLANG_HACK_FLASHMLA_BACKEND=tilelang
export SGLANG_OPT_DEEPGEMM_HC_PRENORM=false #use old prenorm

export SGLANG_OPT_USE_TILELANG_MHC_PRE=false #use torch hc_pre
export SGLANG_OPT_USE_TILELANG_MHC_POST=false #use torch hc_post

export SGLANG_ENABLE_THINKING=1
export SGLANG_USE_AITER=1
export SGLANG_USE_ROCM700A=1
export SGLANG_TOPK_TRANSFORM_512_TORCH=1
export SGLANG_FP8_PAGED_MQA_LOGITS_TORCH=1

export SGLANG_OPT_DPSK_V4_RADIX=0
export SGLANG_OPT_USE_OVERLAP_STORE_CACHE=false #non-radix backend has no store_cache method
export SGLANG_OPT_USE_FUSED_STORE_CACHE=false #fused_store_cache JIT needs CUDA toolchain

export SGLANG_FORCE_TRITON_MOE_FP8=0 # this is required to apply swiglu_limit clamp in fused_moe_triton

python3 -m sglang.launch_server \
--model-path /dockerx/data/deepseek-ai/DeepSeek-V4-Pro \
--trust-remote-code \
--tp 8 \
--disable-radix-cache \
--attention-backend compressed \
--max-running-request 256 \
--page-size 256 \
--chunked-prefill-size 8192 \
--port 8000 \
--disable-shared-experts-fusion \
--disable-cuda-graph \
--tool-call-parser deepseekv4 \
--reasoning-parser deepseek-v4
14 changes: 14 additions & 0 deletions python/sglang/srt/configs/deepseek_v4.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
from typing import Optional

_fp4_experts: Optional[bool] = None


def set_fp4_experts(value: bool) -> None:
global _fp4_experts
_fp4_experts = value


def get_fp4_experts() -> bool:
return bool(_fp4_experts)


class DeepSeekV4Config:
"""Configuration holder for DeepSeek V4 model parameters.

Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ class Envs:
# DeepSeek V4
SGLANG_DSV4_MODE = EnvStr("2604")
SGLANG_DSV4_2604_SUBMODE = EnvStr("2604B")
SGLANG_DSV4_FP4_EXPERTS = EnvBool(True)
SGLANG_OPT_DEEPGEMM_HC_PRENORM = EnvBool(True)
SGLANG_OPT_USE_TILELANG_MHC_PRE = EnvBool(True)
SGLANG_OPT_USE_TILELANG_MHC_POST = EnvBool(True)
Expand Down
12 changes: 3 additions & 9 deletions python/sglang/srt/layers/deep_gemm_wrapper/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

from sglang.srt.configs.deepseek_v4 import get_fp4_experts
from sglang.srt.environ import envs
from sglang.srt.layers.deep_gemm_wrapper import compile_utils
from sglang.srt.layers.deep_gemm_wrapper.configurer import ( # noqa: F401
Expand Down Expand Up @@ -55,10 +56,7 @@ def grouped_gemm_nt_f8f8bf16_masked(
):

fp4_kwargs = (
dict(recipe_a=(1, 128), recipe_b=(1, 32))
if envs.SGLANG_DSV4_MODE.get() == "2604"
and envs.SGLANG_DSV4_FP4_EXPERTS.get()
else {}
dict(recipe_a=(1, 128), recipe_b=(1, 32)) if get_fp4_experts() else {}
)

return deep_gemm.fp8_m_grouped_gemm_nt_masked(
Expand Down Expand Up @@ -108,11 +106,7 @@ def grouped_gemm_nt_f8f8bf16_contig(
if envs.SGLANG_HACK_SKIP_FP4_FP8_GEMM.get():
out.zero_()
return
fp4_kwargs = (
dict(recipe_a=(1, 128), recipe_b=(1, 32))
if envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get()
else {}
)
fp4_kwargs = dict(recipe_a=(1, 128), recipe_b=(1, 32)) if get_fp4_experts() else {}

with compile_utils.deep_gemm_execution_hook(m, n, k, num_groups, kernel_type):
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
Expand Down
5 changes: 2 additions & 3 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.nn import Module
from torch.nn.parameter import Parameter

from sglang.srt.configs.deepseek_v4 import get_fp4_experts
from sglang.srt.distributed import get_tensor_model_parallel_world_size, get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
Expand Down Expand Up @@ -820,9 +821,7 @@ def __init__(self, quant_config: Fp8Config):
self.use_mxfp8 or self.quant_config.weight_block_size is not None
)
self.with_bias = False
self.is_fp4_expert = (
envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get()
)
self.is_fp4_expert = get_fp4_experts()
if get_moe_runner_backend().is_cutlass():
assert (
cutlass_fp8_supported()
Expand Down
14 changes: 8 additions & 6 deletions python/sglang/srt/models/deepseek_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import sglang.srt.models.deepseek_v2 as deepseek_v2
from sglang.jit_kernel.deepseek_v4 import fused_rope, linear_bf16_fp32
from sglang.srt.configs.deepseek_v4 import DeepSeekV4Config
from sglang.srt.configs.deepseek_v4 import (
DeepSeekV4Config,
set_fp4_experts,
)
from sglang.srt.debug_utils.deepseek_v4_debug_utils import (
deepseek_v4_moe_code_path_checker,
)
Expand Down Expand Up @@ -2224,6 +2227,7 @@ def __init__(
) -> None:
super().__init__()
self.config = config
set_fp4_experts(getattr(config, "expert_dtype", None) == "fp4")
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts()
Expand Down Expand Up @@ -2287,10 +2291,8 @@ def determine_num_fused_shared_experts(self):
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under deepep expert parallelism."
elif self.quant_config and self.quant_config.get_name() == "w4afp8":
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
elif (
envs.SGLANG_DSV4_MODE.get() == "2604" and envs.SGLANG_DSV4_FP4_EXPERTS.get()
):
disable_reason = "2604 routed experts use FP4 while shared experts remain FP8; fusion would incorrectly apply FP4 to shared experts."
elif getattr(self.config, "expert_dtype", None) == "fp4":
disable_reason = "Routed experts use FP4 while shared experts remain FP8; fusion would incorrectly apply FP4 to shared experts."

if envs.SGLANG_DSV4_2604_SUBMODE.get() == "2604B":
disable_reason = "2604B checkpoint requires different clamping for shared and routed experts"
Expand Down Expand Up @@ -2492,7 +2494,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
envs.SGLANG_DSV4_MODE.get() == "2604"
and not envs.SGLANG_OPT_FP8_WO_A_GEMM.get()
):
if envs.SGLANG_DSV4_FP4_EXPERTS.get():
if getattr(self.config, "expert_dtype", None) == "fp4":
weights = _dequant_fp8_wo_a(weights)
else:
# Converted FP8 checkpoint: wo_a is already bf16; drop stale wo_a.scale if present
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/utils/hf_transformers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def _load_deepseek_v4_model(
"hc_mult",
"hc_sinkhorn_iters",
"hc_eps",
"expert_dtype",
]:
if key in raw_config and not hasattr(config, key):
setattr(config, key, raw_config[key])
Expand Down
Loading