Skip to content
This repository was archived by the owner on Apr 20, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions configs/rebuild-deepep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/bin/bash
set -eux

echo "=== Rebuilding DeepEP with kNumMaxTopK=16 for Qwen3.5 (topk=10) ==="

DEEPEP_SRC="/sgl-workspace/DeepEP"

if [ ! -d "$DEEPEP_SRC" ]; then
echo "ERROR: DeepEP source not found at $DEEPEP_SRC (mount via extra_mount)"
exit 1
fi

cd "$DEEPEP_SRC"

# Find NVSHMEM
NVSHMEM_DIR=$(find /usr/local -name "nvshmem" -type d 2>/dev/null | head -1)
if [ -z "${NVSHMEM_DIR:-}" ]; then
echo "ERROR: NVSHMEM installation not found under /usr/local" >&2
exit 1
fi
echo "NVSHMEM_DIR=$NVSHMEM_DIR"

# Fix missing nvshmem symlinks (container has .so.3 but not .so)
NVSHMEM_LIB="$NVSHMEM_DIR/lib"
if [ ! -f "$NVSHMEM_LIB/libnvshmem_host.so" ] && [ -f "$NVSHMEM_LIB/libnvshmem_host.so.3" ]; then
Comment thread
YAMY1234 marked this conversation as resolved.
echo "Creating missing nvshmem symlinks..."
ln -sf libnvshmem_host.so.3 "$NVSHMEM_LIB/libnvshmem_host.so"
fi

# Apply kNumMaxTopK=16 patch (Qwen3.5 uses topk=10, default kNumMaxTopK=8 is insufficient)
# Note: source has both kNumMaxTopK (uppercase) and kNumMaxTopk (lowercase) as separate variables
sed -i 's/kNumMaxTopK[[:space:]]*=[[:space:]]*[0-9][0-9]*/kNumMaxTopK = 16/g' csrc/kernels/internode_ll.cu
sed -i 's/kNumMaxTopk[[:space:]]*=[[:space:]]*[0-9][0-9]*/kNumMaxTopk = 16/g' csrc/kernels/internode_ll.cu

# Verify the patch was applied
grep -q "kNumMaxTop. = 16" csrc/kernels/internode_ll.cu && echo "Patch verified: kNumMaxTopK/k=16" || {
Comment thread
YAMY1234 marked this conversation as resolved.
echo "ERROR: kNumMaxTopK patch failed to apply!"; exit 1;
}

# Build with full output so we can debug failures
# set -e will auto-exit on failure
TORCH_CUDA_ARCH_LIST="10.0" \
NVSHMEM_DIR="$NVSHMEM_DIR" \
pip install -e . --no-build-isolation 2>&1

echo "=== DeepEP rebuild complete ==="
python3 -c "import deep_ep; print('deep_ep imported successfully')"
126 changes: 126 additions & 0 deletions recipes/qwen3.5/1p1d-dep4-dep4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Qwen3.5-397B-A17B-FP8 Disaggregated 1P1D: DEP4 Prefill + DEP4 Decode
# Both sides use Data Expert Parallel (DP4 + TP4 + EP4) with dp-attention
# Homogeneous TP layout to avoid KV/Mamba state slice transfer overhead

name: "qwen3.5-1p1d-dep4-dep4"

model:
path: "qwen3.5-fp8"
container: "dev" # docker://lmsysorg/sglang:dev
precision: "fp8"

resources:
gpu_type: "gb200"
gpus_per_node: 4
prefill_nodes: 1
decode_nodes: 1
prefill_workers: 1
decode_workers: 1

backend:

prefill_environment:
TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
PYTHONUNBUFFERED: "1"
NCCL_MNNVL_ENABLE: "1"
NCCL_CUMEM_ENABLE: "1"
MC_FORCE_MNNVL: "1"
SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"

decode_environment:
TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
PYTHONUNBUFFERED: "1"
NCCL_MNNVL_ENABLE: "1"
NCCL_CUMEM_ENABLE: "1"
MC_FORCE_MNNVL: "1"
SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
SGLANG_DECODE_BOOTSTRAP_TIMEOUT: "1000"
SGLANG_HACK_SEQ_BOOTSTRAP_ROOM: "1"
SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"

sglang_config:
prefill:
served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
model-path: "/model/"

attention-backend: "trtllm_mha"
quantization: "fp8"
kv-cache-dtype: "fp8_e4m3"
moe-runner-backend: "flashinfer_trtllm"

# DEP4: DP4 + TP4 + EP4 with dp-attention (same layout as decode)
tensor-parallel-size: 4
data-parallel-size: 4
expert-parallel-size: 4
enable-dp-attention: true
enable-dp-lm-head: true
moe-dense-tp-size: 1

mamba-scheduler-strategy: "no_buffer"
mamba-track-interval: 2048
mamba-ssm-dtype: "bfloat16"

disaggregation-mode: "prefill"
disable-radix-cache: true
disaggregation-decode-tp: 4
disaggregation-decode-dp: 4

mem-fraction-static: 0.80
chunked-prefill-size: 16384
context-length: 2020
load-balance-method: "round_robin"
watchdog-timeout: 1000000
disable-cuda-graph: true

decode:
served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
model-path: "/model/"

attention-backend: "trtllm_mha"
quantization: "fp8"
kv-cache-dtype: "fp8_e4m3"
moe-runner-backend: "flashinfer_trtllm"

# DEP4: DP4 + TP4 + EP4 with dp-attention
tensor-parallel-size: 4
data-parallel-size: 4
expert-parallel-size: 4
enable-dp-attention: true
enable-dp-lm-head: true
moe-dense-tp-size: 1

mamba-scheduler-strategy: "no_buffer"
mamba-track-interval: 2048
mamba-ssm-dtype: "bfloat16"

disaggregation-mode: "decode"
disable-radix-cache: true

mem-fraction-static: 0.80
chunked-prefill-size: 16384
context-length: 2020
cuda-graph-max-bs: 1024
watchdog-timeout: 1000000

decode-log-interval: 1
stream-interval: 50

benchmark:
type: "sa-bench"
isl: 1000
osl: 1000
concurrencies: "1x2x4x8x16x32x64x128x256x512x1024"
req_rate: "inf"
115 changes: 115 additions & 0 deletions recipes/qwen3.5/1p1d-tep4-tep4.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Qwen3.5-397B-A17B-FP8 Disaggregated 1P1D: TEP4 Prefill + TEP4 Decode
# Both sides use Tensor Expert Parallel (TP4 + EP4), no dp-attention

name: "qwen3.5-1p1d-tep4-tep4"

model:
path: "qwen3.5-fp8"
container: "dev" # docker://lmsysorg/sglang:dev
precision: "fp8"
Comment thread
YAMY1234 marked this conversation as resolved.

resources:
gpu_type: "gb200"
gpus_per_node: 4
prefill_nodes: 1
decode_nodes: 1
prefill_workers: 1
decode_workers: 1

backend:

prefill_environment:
TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
PYTHONUNBUFFERED: "1"
NCCL_MNNVL_ENABLE: "1"
NCCL_CUMEM_ENABLE: "1"
MC_FORCE_MNNVL: "1"
SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"
Comment thread
YAMY1234 marked this conversation as resolved.

decode_environment:
TORCH_DISTRIBUTED_DEFAULT_TIMEOUT: "1800"
PYTHONUNBUFFERED: "1"
NCCL_MNNVL_ENABLE: "1"
NCCL_CUMEM_ENABLE: "1"
MC_FORCE_MNNVL: "1"
SGLANG_DG_CACHE_DIR: "/configs/deepgemm-cache"
FLASHINFER_WORKSPACE_BASE: "/configs/flashinfer-cache"
SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE: "100000"
SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT: "100000"
SGLANG_DISAGGREGATION_WAITING_TIMEOUT: "100000"
SGLANG_DECODE_BOOTSTRAP_TIMEOUT: "1000"
SGLANG_HACK_SEQ_BOOTSTRAP_ROOM: "1"
SGLANG_MOONCAKE_CUSTOM_MEM_POOL: "True"
SGLANG_USE_MESSAGE_QUEUE_BROADCASTER: "0"
SGLANG_DISABLE_TP_MEMORY_INBALANCE_CHECK: "1"

sglang_config:
prefill:
served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
model-path: "/model/"


attention-backend: "trtllm_mha"
quantization: "fp8"
kv-cache-dtype: "fp8_e4m3"

# TEP4: TP4 + EP4, standard TP attention (no dp-attention)
tensor-parallel-size: 4
expert-parallel-size: 4
moe-dense-tp-size: 1

mamba-scheduler-strategy: "no_buffer"
mamba-track-interval: 2048
mamba-ssm-dtype: "bfloat16"

disaggregation-mode: "prefill"
disable-radix-cache: true
disaggregation-decode-tp: 4
disaggregation-decode-dp: 1

mem-fraction-static: 0.75
chunked-prefill-size: 16384
context-length: 2020
Comment thread
YAMY1234 marked this conversation as resolved.
load-balance-method: "round_robin"
watchdog-timeout: 1000000
disable-cuda-graph: true

decode:
served-model-name: "Qwen/Qwen3.5-397B-A17B-FP8"
model-path: "/model/"


attention-backend: "trtllm_mha"
quantization: "fp8"
kv-cache-dtype: "fp8_e4m3"

# TEP4: TP4 + EP4, standard TP attention (no dp-attention)
tensor-parallel-size: 4
expert-parallel-size: 4
moe-dense-tp-size: 1

mamba-scheduler-strategy: "no_buffer"
mamba-track-interval: 2048
mamba-ssm-dtype: "bfloat16"

disaggregation-mode: "decode"
disable-radix-cache: true

mem-fraction-static: 0.70
chunked-prefill-size: 16384
context-length: 2020
watchdog-timeout: 1000000

benchmark:
type: "sa-bench"
isl: 1000
osl: 1000
concurrencies: "8x32x128x256x512x1024"
req_rate: "inf"
Loading
Loading