diff --git a/docker/rocm.Dockerfile b/docker/rocm.Dockerfile index 7ee4206e193b..82c46d17d81e 100644 --- a/docker/rocm.Dockerfile +++ b/docker/rocm.Dockerfile @@ -561,6 +561,10 @@ ENV SGLANG_TOPK_TRANSFORM_512_TORCH=0 ENV SGLANG_OPT_USE_FUSED_COMPRESS=true ENV SGLANG_OPT_USE_TILELANG_INDEXER=true ENV SGLANG_HACK_FLASHMLA_BACKEND=tilelang +ENV SGLANG_OPT_USE_AITER_MHC_PRE=true +ENV SGLANG_OPT_USE_AITER_MHC_POST=true +ENV SGLANG_OPT_USE_TILELANG_MHC_PRE=false +ENV SGLANG_OPT_USE_TILELANG_MHC_POST=false ENV NCCL_MIN_NCHANNELS=112 ENV ROCM_QUICK_REDUCE_QUANTIZATION=INT8 diff --git a/python/run_dsv4.sh b/python/run_dsv4.sh index 2a4fe70b6565..440182692eaa 100755 --- a/python/run_dsv4.sh +++ b/python/run_dsv4.sh @@ -1,80 +1,37 @@ -#export CUDA_VISIBLE_DEVICES=0,1,2,3 -#/dockerx/data/models/DeepSeek-V4-Flash +export SGLANG_OPT_USE_OLD_COMPRESSOR=true +export SGLANG_OPT_USE_TILELANG_SWA_PREPARE=false +export SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK=false +export SGLANG_OPT_USE_FUSED_HASH_TOPK=false -#### FP8 model path #### -#export SGLANG_REASONING_EFFORT=max -# -#export SGLANG_OPT_USE_FUSED_COMPRESS=false #use PyTorch implemented compressor -#export SGLANG_OPT_USE_OLD_COMPRESSOR=true #use old compressor -#export SGLANG_OPT_USE_TILELANG_SWA_PREPARE=false #use old prepare -#export SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK=false #use old topk -#export SGLANG_OPT_USE_FUSED_HASH_TOPK=false #AMD: hash_topk JIT needs CUDA toolchain -# -#export SGLANG_HACK_FLASHMLA_BACKEND=torch -#export SGLANG_HACK_FLASHMLA_BACKEND=tilelang -#export SGLANG_OPT_DEEPGEMM_HC_PRENORM=false #use old prenorm -# -#export SGLANG_OPT_USE_TILELANG_MHC_PRE=false #use torch hc_pre -#export SGLANG_OPT_USE_TILELANG_MHC_POST=false #use torch hc_post -# -#export SGLANG_ENABLE_THINKING=1 -#export SGLANG_USE_AITER=1 -#export SGLANG_USE_ROCM700A=1 -#export SGLANG_TOPK_TRANSFORM_512_TORCH=1 -#export SGLANG_FP8_PAGED_MQA_LOGITS_TORCH=1 -# -#export SGLANG_OPT_DPSK_V4_RADIX=0 -#export SGLANG_OPT_USE_OVERLAP_STORE_CACHE=false #non-radix backend has no store_cache method -#export SGLANG_OPT_USE_FUSED_STORE_CACHE=false #fused_store_cache JIT needs CUDA toolchain -# -#export SGLANG_FORCE_TRITON_MOE_FP8=1 # this is required to apply swiglu_limit clamp in fused_moe_triton -# -#python3 -m sglang.launch_server \ -# --model-path /dockerx/data2/models/DeepSeek-V4-Pro-FP8 \ -# --trust-remote-code \ -# --tp 8 \ -# --dp 8 \ -# --enable-dp-attention \ -# --disable-radix-cache \ -# --attention-backend compressed \ -# --max-running-request 256 \ -# --page-size 256 \ -# --chunked-prefill-size 8192 \ -# --port 8000 \ -# --disable-shared-experts-fusion \ -# --disable-cuda-graph \ -# --tool-call-parser deepseekv4 \ -# --reasoning-parser deepseek-v4 +export SGLANG_OPT_DEEPGEMM_HC_PRENORM=false -#### FP4 model path #### -export SGLANG_REASONING_EFFORT=max - -export SGLANG_OPT_USE_FUSED_COMPRESS=false #use PyTorch implemented compressor -export SGLANG_OPT_USE_OLD_COMPRESSOR=true #use old compressor -export SGLANG_OPT_USE_TILELANG_SWA_PREPARE=false #use old prepare -export SGLANG_OPT_USE_JIT_KERNEL_FUSED_TOPK=false #use old topk -export SGLANG_OPT_USE_FUSED_HASH_TOPK=false #AMD: hash_topk JIT needs CUDA toolchain - -export SGLANG_HACK_FLASHMLA_BACKEND=tilelang -export SGLANG_OPT_DEEPGEMM_HC_PRENORM=false #use old prenorm - -export SGLANG_OPT_USE_TILELANG_MHC_PRE=false #use torch hc_pre -export SGLANG_OPT_USE_TILELANG_MHC_POST=false #use torch hc_post +export SGLANG_OPT_USE_TILELANG_MHC_PRE=false +export SGLANG_OPT_USE_TILELANG_MHC_POST=false export SGLANG_ENABLE_THINKING=1 export SGLANG_USE_AITER=1 export SGLANG_USE_ROCM700A=1 -export SGLANG_TOPK_TRANSFORM_512_TORCH=1 export SGLANG_FP8_PAGED_MQA_LOGITS_TORCH=1 export SGLANG_OPT_DPSK_V4_RADIX=0 -export SGLANG_OPT_USE_OVERLAP_STORE_CACHE=false #non-radix backend has no store_cache method -export SGLANG_OPT_USE_FUSED_STORE_CACHE=false #fused_store_cache JIT needs CUDA toolchain +export SGLANG_OPT_USE_OVERLAP_STORE_CACHE=false +export SGLANG_OPT_USE_FUSED_STORE_CACHE=false + +# changed +export SGLANG_OPT_USE_FUSED_COMPRESS=true +export SGLANG_TOPK_TRANSFORM_512_TORCH=0 +export SGLANG_OPT_USE_TILELANG_INDEXER=true +export SGLANG_HACK_FLASHMLA_BACKEND=tilelang +export SGLANG_REASONING_EFFORT=max +export SGLANG_FORCE_TRITON_MOE_FP8=0 +export SGLANG_OPT_USE_AITER_MHC_PRE=true +export SGLANG_OPT_USE_AITER_MHC_POST=true -export SGLANG_FORCE_TRITON_MOE_FP8=0 # this is required to apply swiglu_limit clamp in fused_moe_triton +MODEL=/dockerx/data/deepseek-ai/DeepSeek-V4-Pro +MODEL=/dockerx/data/sgl-project/DeepSeek-V4-Flash-FP8/ python3 -m sglang.launch_server \ - --model-path /dockerx/data/deepseek-ai/DeepSeek-V4-Pro \ + --model-path ${MODEL} \ --trust-remote-code \ --tp 8 \ --disable-radix-cache \ @@ -84,6 +41,5 @@ python3 -m sglang.launch_server \ --chunked-prefill-size 8192 \ --port 8000 \ --disable-shared-experts-fusion \ - --disable-cuda-graph \ --tool-call-parser deepseekv4 \ --reasoning-parser deepseek-v4 diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index a74fbac81fd6..c677965be6e0 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -589,6 +589,8 @@ class Envs: SGLANG_TOPK_TRANSFORM_512_TORCH = EnvBool(False) SGLANG_OPT_BF16_FP32_GEMM_ALGO = EnvBool(False) SGLANG_FORCE_TRITON_MOE_FP8 = EnvBool(False) + SGLANG_OPT_USE_AITER_MHC_PRE= EnvBool(True) + SGLANG_OPT_USE_AITER_MHC_POST= EnvBool(True) # fmt: on # EPD diff --git a/python/sglang/srt/layers/mhc.py b/python/sglang/srt/layers/mhc.py index 08526a982d63..4819d4e6b24d 100644 --- a/python/sglang/srt/layers/mhc.py +++ b/python/sglang/srt/layers/mhc.py @@ -543,11 +543,21 @@ def mhc_pre( if num_tokens <= 2048: assert n_splits == 1 + if hc_hidden_size == 16384: + hidden_block = 256 + elif hc_hidden_size == 28672: + hidden_block = 128 + else: + raise NotImplementedError( + f"mhc_pre splitk kernel only supports hc_hidden_size in {{16384, 28672}}, " + f"got {hc_hidden_size}" + ) kernel_0, kernel_1 = mhc_pre_gemm_sqrsum_splitk_kernel( hc_mult3, hc_hidden_size, split_k=n_splits_pre, token_block=32, + hidden_block=hidden_block, ) partial_out = gemm_out_mul.new_empty(n_splits_pre, num_tokens, 32) partial_sqrsum = gemm_out_sqrsum.new_empty(n_splits_pre, num_tokens) diff --git a/python/sglang/srt/models/deepseek_v4.py b/python/sglang/srt/models/deepseek_v4.py index 157596e08f16..fb03e3d66fbe 100644 --- a/python/sglang/srt/models/deepseek_v4.py +++ b/python/sglang/srt/models/deepseek_v4.py @@ -1889,6 +1889,23 @@ def hc_pre_torch_impl(x, hc_fn): # returned post should be [n, hc_mult] return y, post.squeeze(-1), comb + if _is_hip and envs.SGLANG_OPT_USE_AITER_MHC_PRE.get(): + from aiter.ops.mhc import mhc_pre + + post, comb, y = mhc_pre( + residual=x, + fn=hc_fn, + hc_scale=hc_scale, + hc_base=hc_base, + rms_eps=self.rms_norm_eps, + hc_pre_eps=self.hc_eps, + hc_sinkhorn_eps=self.hc_eps, + hc_post_mult_value=2.0, + sinkhorn_repeat=self.hc_sinkhorn_iters, + ) + # returned post should be [n, hc_mult] + return y, post.squeeze(-1), comb + if envs.SGLANG_OPT_DEEPGEMM_HC_PRENORM.get(): # DeepGEMM implementation import deep_gemm @@ -1945,6 +1962,14 @@ def hc_post( result = mhc_post(x, residual, post, comb) return result + elif _is_hip and envs.SGLANG_OPT_USE_AITER_MHC_POST.get(): + from aiter.ops.mhc import mhc_post + + result = torch.empty_like(residual) + mhc_post(result, x, residual, post, comb) + + return result + assert residual.shape == (x.shape[0], self.hc_mult, x.shape[-1]) assert post.shape == (x.shape[0], self.hc_mult) assert comb.shape == (x.shape[0], self.hc_mult, self.hc_mult)