diff --git a/configs/patches/vllm_numa_bind_hash_fix.py b/configs/patches/vllm_numa_bind_hash_fix.py new file mode 100644 index 00000000..0759238c --- /dev/null +++ b/configs/patches/vllm_numa_bind_hash_fix.py @@ -0,0 +1,84 @@ +""" +Patch vLLM's ParallelConfig.compute_hash to exclude NUMA-bind fields +(numa_bind / numa_bind_nodes / numa_bind_cpus) from the DP consistency hash. + +Symptom (seen on GB300, 1 worker, DP=4, numa-bind=True): + RuntimeError: Configuration mismatch detected for engine 3. + All DP workers must have identical configurations for parameters that + affect collective communication ... + +Root cause: when numa-bind is enabled, each DP rank auto-detects and stores +its own per-rank NUMA node in ParallelConfig.numa_bind_nodes. These per-rank +values enter compute_hash(), so ranks on different NUMA nodes produce +different hashes and fail the DP startup check. NUMA binding affects only +host-side memory locality, not collective-communication semantics, so it is +safe to exclude from the DP hash. + +Reference: vllm/config/parallel.py, ParallelConfig.compute_hash(), +ignored_factors set. +""" + +import sys +from pathlib import Path + +TARGET = Path( + "/usr/local/lib/python3.12/dist-packages/vllm/config/parallel.py" +) + +# Idempotency: if any of our additions is already present, skip. +MARKER = '"numa_bind",' + +# Anchor: the last entry of the existing ignored_factors set in the +# upstream compute_hash method. We insert the three numa fields just +# before the closing brace. +OLD = ' "_api_process_rank",\n }' + +NEW = ( + ' "_api_process_rank",\n' + ' # srt-slurm-sa hotfix: numa-bind fields are per-rank runtime\n' + ' # topology, not collective-communication semantics.\n' + ' "numa_bind",\n' + ' "numa_bind_nodes",\n' + ' "numa_bind_cpus",\n' + ' }' +) + + +def main(): + if not TARGET.exists(): + print(f"[vllm-numa-bind-hash-fix] Target not found: {TARGET}", file=sys.stderr) + sys.exit(1) + + content = TARGET.read_text() + + if MARKER in content: + print("[vllm-numa-bind-hash-fix] Already patched, skipping.", file=sys.stderr) + return + + count = content.count(OLD) + if count == 0: + print( + "[vllm-numa-bind-hash-fix] Could not find ignored_factors anchor. " + "vLLM version may have drifted; inspect ParallelConfig.compute_hash().", + file=sys.stderr, + ) + sys.exit(1) + if count > 1: + print( + f"[vllm-numa-bind-hash-fix] Anchor is ambiguous ({count} occurrences); " + "refusing to patch.", + file=sys.stderr, + ) + sys.exit(1) + + content = content.replace(OLD, NEW) + TARGET.write_text(content) + print( + "[vllm-numa-bind-hash-fix] Added numa_bind/numa_bind_nodes/numa_bind_cpus " + "to ParallelConfig.compute_hash ignored_factors.", + file=sys.stderr, + ) + + +if __name__ == "__main__": + main() diff --git a/configs/vllm-container-deps.sh b/configs/vllm-container-deps.sh index 43807255..15e7733c 100644 --- a/configs/vllm-container-deps.sh +++ b/configs/vllm-container-deps.sh @@ -2,4 +2,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -pip install msgpack \ No newline at end of file +pip install msgpack + +if [ -f /configs/patches/vllm_numa_bind_hash_fix.py ]; then + python3 /configs/patches/vllm_numa_bind_hash_fix.py +fi diff --git a/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-1p1d-dep8-dep8-16-c256-c512-c1024-offload.yaml b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-1p1d-dep8-dep8-16-c256-c512-c1024-offload.yaml new file mode 100644 index 00000000..d1da4f28 --- /dev/null +++ b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-1p1d-dep8-dep8-16-c256-c512-c1024-offload.yaml @@ -0,0 +1,114 @@ +name: "svf-vllm-disagg-gb200-2p1d-dep8-dep16" +model: + path: "deepseekv4-fp4" + container: "vllm/vllm-openai@sha256:2af012a17c2cee0bc1428c03a8a5e42b552f25dc6f73495ab5a29ccf4123c257" + precision: "fp4" + +dynamo: + hash: 6a159fedd8e4a1563aa647c31f622aedbf254b5b + install: true + +setup_script: vllm-container-deps.sh +resources: + gpu_type: "gb200" + gpus_per_node: 4 + prefill_nodes: 2 + decode_nodes: 2 + prefill_workers: 1 + decode_workers: 1 + gpus_per_prefill: 8 + gpus_per_decode: 8 +frontend: + type: dynamo + enable_multiple_frontends: false +backend: + type: vllm + connector: null + prefill_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: "1024" + VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: "2048" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + decode_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + vllm_config: + prefill: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 8 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + enforce-eager: true + max-model-len: 16384 + max-num-seqs: 16 + max-num-batched-tokens: 32768 + trust-remote-code: true + no-enable-prefix-caching: true + no-enable-flashinfer-autotune: true + no-async-scheduling: true + block-size: 256 + gpu-memory-utilization: 0.8 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + numa-bind: true + offload-group-size: 3 + offload-num-in-group: 1 + offload-prefetch-step: 2 + # offload-params: "w13_weight w2_weight w13_weight_scale w2_weight_scale wq_b wo_a wo_b shared_experts" + tokenizer-mode: deepseek_v4 + decode: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 8 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + max-model-len: 16384 + max-num-seqs: 128 + max-cudagraph-capture-size: 128 + max-num-batched-tokens: 128 + trust-remote-code: true + no-enable-prefix-caching: true + block-size: 256 + compilation-config: '{"cudagraph_mode":"FULL_DECODE_ONLY","mode":0}' + gpu-memory-utilization: 0.9 + stream-interval: 50 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + tokenizer-mode: deepseek_v4 +benchmark: + type: "sa-bench" + isl: 8192 + osl: 1024 + concurrencies: "4x8x16x32x64x256x512x1024" + req_rate: "inf" + tokenizer_mode: "deepseek_v4" + use_chat_template: true diff --git a/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-3p1d-dep8-dep16-40-c4096-offload.yaml b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-3p1d-dep8-dep16-40-c4096-offload.yaml new file mode 100644 index 00000000..e6e454f2 --- /dev/null +++ b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-3p1d-dep8-dep16-40-c4096-offload.yaml @@ -0,0 +1,114 @@ +name: "svf-vllm-disagg-gb200-2p1d-dep8-dep16" +model: + path: "deepseekv4-fp4" + container: "vllm/vllm-openai@sha256:2af012a17c2cee0bc1428c03a8a5e42b552f25dc6f73495ab5a29ccf4123c257" + precision: "fp4" + +dynamo: + hash: 6a159fedd8e4a1563aa647c31f622aedbf254b5b + install: true + +setup_script: vllm-container-deps.sh +resources: + gpu_type: "gb200" + gpus_per_node: 4 + prefill_nodes: 6 + decode_nodes: 4 + prefill_workers: 3 + decode_workers: 1 + gpus_per_prefill: 8 + gpus_per_decode: 16 +frontend: + type: dynamo + enable_multiple_frontends: false +backend: + type: vllm + connector: null + prefill_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: "1024" + VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: "2048" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + decode_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + vllm_config: + prefill: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 8 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + enforce-eager: true + max-model-len: 16384 + max-num-seqs: 16 + max-num-batched-tokens: 32768 + trust-remote-code: true + no-enable-prefix-caching: true + no-enable-flashinfer-autotune: true + no-async-scheduling: true + block-size: 256 + gpu-memory-utilization: 0.8 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + numa-bind: true + offload-group-size: 3 + offload-num-in-group: 1 + offload-prefetch-step: 2 + # offload-params: "w13_weight w2_weight w13_weight_scale w2_weight_scale wq_b wo_a wo_b shared_experts" + tokenizer-mode: deepseek_v4 + decode: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 16 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + max-model-len: 16384 + max-num-seqs: 256 + max-cudagraph-capture-size: 256 + max-num-batched-tokens: 256 + trust-remote-code: true + no-enable-prefix-caching: true + block-size: 256 + compilation-config: '{"cudagraph_mode":"FULL_DECODE_ONLY","mode":0}' + gpu-memory-utilization: 0.9 + stream-interval: 50 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + tokenizer-mode: deepseek_v4 +benchmark: + type: "sa-bench" + isl: 8192 + osl: 1024 + concurrencies: "4x8x16x32x64x256x512x1024" + req_rate: "inf" + tokenizer_mode: "deepseek_v4" + use_chat_template: true diff --git a/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-3p1d-dep8-dep8-32-c2048-offload.yaml b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-3p1d-dep8-dep8-32-c2048-offload.yaml new file mode 100644 index 00000000..10fdbf1a --- /dev/null +++ b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-3p1d-dep8-dep8-32-c2048-offload.yaml @@ -0,0 +1,114 @@ +name: "svf-vllm-disagg-gb200-2p1d-dep8-dep16" +model: + path: "deepseekv4-fp4" + container: "vllm/vllm-openai@sha256:2af012a17c2cee0bc1428c03a8a5e42b552f25dc6f73495ab5a29ccf4123c257" + precision: "fp4" + +dynamo: + hash: 6a159fedd8e4a1563aa647c31f622aedbf254b5b + install: true + +setup_script: vllm-container-deps.sh +resources: + gpu_type: "gb200" + gpus_per_node: 4 + prefill_nodes: 6 + decode_nodes: 2 + prefill_workers: 3 + decode_workers: 1 + gpus_per_prefill: 8 + gpus_per_decode: 8 +frontend: + type: dynamo + enable_multiple_frontends: false +backend: + type: vllm + connector: null + prefill_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: "1024" + VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: "2048" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + decode_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + vllm_config: + prefill: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 8 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + enforce-eager: true + max-model-len: 16384 + max-num-seqs: 16 + max-num-batched-tokens: 32768 + trust-remote-code: true + no-enable-prefix-caching: true + no-enable-flashinfer-autotune: true + no-async-scheduling: true + block-size: 256 + gpu-memory-utilization: 0.8 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + numa-bind: true + offload-group-size: 3 + offload-num-in-group: 1 + offload-prefetch-step: 2 + # offload-params: "w13_weight w2_weight w13_weight_scale w2_weight_scale wq_b wo_a wo_b shared_experts" + tokenizer-mode: deepseek_v4 + decode: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 8 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + max-model-len: 16384 + max-num-seqs: 256 + max-cudagraph-capture-size: 256 + max-num-batched-tokens: 256 + trust-remote-code: true + no-enable-prefix-caching: true + block-size: 256 + compilation-config: '{"cudagraph_mode":"FULL_DECODE_ONLY","mode":0}' + gpu-memory-utilization: 0.9 + stream-interval: 50 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + tokenizer-mode: deepseek_v4 +benchmark: + type: "sa-bench" + isl: 8192 + osl: 1024 + concurrencies: "4x8x16x32x64x256x512x1024" + req_rate: "inf" + tokenizer_mode: "deepseek_v4" + use_chat_template: true diff --git a/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-6p1d-dep8-dep16-64-c8192-offload.yaml b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-6p1d-dep8-dep16-64-c8192-offload.yaml new file mode 100644 index 00000000..90af2b3d --- /dev/null +++ b/recipes/vllm/deepseek-v4-pro/8k1k/disagg-gb200-6p1d-dep8-dep16-64-c8192-offload.yaml @@ -0,0 +1,114 @@ +name: "svf-vllm-disagg-gb200-2p1d-dep8-dep16" +model: + path: "deepseekv4-fp4" + container: "vllm/vllm-openai@sha256:2af012a17c2cee0bc1428c03a8a5e42b552f25dc6f73495ab5a29ccf4123c257" + precision: "fp4" + +dynamo: + hash: 6a159fedd8e4a1563aa647c31f622aedbf254b5b + install: true + +setup_script: vllm-container-deps.sh +resources: + gpu_type: "gb200" + gpus_per_node: 4 + prefill_nodes: 12 + decode_nodes: 4 + prefill_workers: 6 + decode_workers: 1 + gpus_per_prefill: 8 + gpus_per_decode: 16 +frontend: + type: dynamo + enable_multiple_frontends: false +backend: + type: vllm + connector: null + prefill_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_SPARSE_INDEXER_MAX_LOGITS_MB: "1024" + VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: "2048" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + decode_environment: + TILELANG_CLEANUP_TEMP_FILES: "1" + VLLM_USE_NCCL_SYMM_MEM: "1" + NCCL_CUMEM_ENABLE: "1" + NCCL_MNNVL_ENABLE: "1" + NCCL_NVLS_ENABLE: "1" + VLLM_SERVER_DEV_MODE: "1" + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: "1" + VLLM_MOE_ROUTING_SIMULATION_STRATEGY: "uniform_random" + UCX_MEMTYPE_CACHE: "n" + UCX_MEMTYPE_REG_WHOLE: "n" + UCX_TLS: "cuda_copy,cuda_ipc,tcp" + UCX_CUDA_IPC_ENABLE_MNNVL: "y" + NCCL_P2P_LEVEL: NVL + vllm_config: + prefill: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 8 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + enforce-eager: true + max-model-len: 16384 + max-num-seqs: 16 + max-num-batched-tokens: 32768 + trust-remote-code: true + no-enable-prefix-caching: true + no-enable-flashinfer-autotune: true + no-async-scheduling: true + block-size: 256 + gpu-memory-utilization: 0.8 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + numa-bind: true + offload-group-size: 3 + offload-num-in-group: 1 + offload-prefetch-step: 2 + # offload-params: "w13_weight w2_weight w13_weight_scale w2_weight_scale wq_b wo_a wo_b shared_experts" + tokenizer-mode: deepseek_v4 + decode: + kv-transfer-config: '{"kv_connector": "NixlConnector", "kv_role": "kv_both"}' + served-model-name: "deepseek-ai/DeepSeek-V4-Pro" + kv-cache-dtype: "fp8" + tensor-parallel-size: 1 + pipeline-parallel-size: 1 + data-parallel-size: 16 + data-parallel-rpc-port: 13345 + enable-expert-parallel: true + max-model-len: 16384 + max-num-seqs: 512 + max-cudagraph-capture-size: 512 + max-num-batched-tokens: 512 + trust-remote-code: true + no-enable-prefix-caching: true + block-size: 256 + compilation-config: '{"cudagraph_mode":"FULL_DECODE_ONLY","mode":0}' + gpu-memory-utilization: 0.9 + stream-interval: 50 + no-disable-hybrid-kv-cache-manager: true + enable-sleep-mode: true + tokenizer-mode: deepseek_v4 +benchmark: + type: "sa-bench" + isl: 8192 + osl: 1024 + concurrencies: "4x8x16x32x64x256x512x1024" + req_rate: "inf" + tokenizer_mode: "deepseek_v4" + use_chat_template: true diff --git a/src/srtctl/benchmarks/sa_bench.py b/src/srtctl/benchmarks/sa_bench.py index 5f220393..e690cb19 100644 --- a/src/srtctl/benchmarks/sa_bench.py +++ b/src/srtctl/benchmarks/sa_bench.py @@ -101,5 +101,6 @@ def build_command( str(b.num_warmup_mult) if b.num_warmup_mult is not None else "2", b.custom_tokenizer or "", str(b.use_chat_template).lower(), + b.tokenizer_mode or "auto", ] return cmd diff --git a/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py b/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py index 87f3f9ef..0014f221 100644 --- a/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py +++ b/src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py @@ -629,10 +629,30 @@ def get_tokenizer( "to use mistral tokenizer mode." ) from e return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path)) + if tokenizer_mode == "deepseek_v4": + try: + from vllm.tokenizers.deepseek_v4 import DeepseekV4Tokenizer + except ImportError as e: + raise ImportError( + "DeepseekV4Tokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use deepseek_v4 tokenizer mode." + ) from e + return DeepseekV4Tokenizer.from_pretrained(str(pretrained_model_name_or_path)) if custom_tokenizer: if custom_tokenizer == "glm_moe_dsa": return _load_glm_moe_dsa_tokenizer(pretrained_model_name_or_path) + if custom_tokenizer == "deepseek_v4": + try: + from vllm.tokenizers.deepseek_v4 import DeepseekV4Tokenizer + except ImportError as e: + raise ImportError( + "DeepseekV4Tokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use deepseek_v4 tokenizer." + ) from e + return DeepseekV4Tokenizer.from_pretrained(str(pretrained_model_name_or_path)) from importlib import import_module try: module_path, class_name = custom_tokenizer.rsplit('.', 1) diff --git a/src/srtctl/benchmarks/scripts/sa-bench/bench.sh b/src/srtctl/benchmarks/scripts/sa-bench/bench.sh index acddf754..999705e0 100644 --- a/src/srtctl/benchmarks/scripts/sa-bench/bench.sh +++ b/src/srtctl/benchmarks/scripts/sa-bench/bench.sh @@ -64,6 +64,10 @@ NUM_PROMPTS_MULT=${13:-10} NUM_WARMUP_MULT=${14:-2} CUSTOM_TOKENIZER=${15:-} USE_CHAT_TEMPLATE=${16:-true} +TOKENIZER_MODE=${17:-auto} + +# Build optional tokenizer mode args +TOKENIZER_MODE_ARGS=(--tokenizer-mode "$TOKENIZER_MODE") # Build optional custom tokenizer args CUSTOM_TOKENIZER_ARGS=() @@ -136,6 +140,8 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do --percentile-metrics ttft,tpot,itl,e2el \ --max-concurrency "$concurrency" \ --trust-remote-code \ + "${TOKENIZER_MODE_ARGS[@]}" \ + "${CHAT_TEMPLATE_ARGS[@]}" \ "${CUSTOM_TOKENIZER_ARGS[@]}" num_prompts=$((concurrency * 10)) @@ -166,6 +172,7 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do --percentile-metrics ttft,tpot,itl,e2el \ --max-concurrency "$concurrency" \ --trust-remote-code \ + "${TOKENIZER_MODE_ARGS[@]}" \ "${CHAT_TEMPLATE_ARGS[@]}" \ "${CUSTOM_TOKENIZER_ARGS[@]}" \ --save-result --result-dir "$result_dir" --result-filename "$result_filename" @@ -179,4 +186,3 @@ done stop_all_profiling echo "SA-Bench complete. Results in $result_dir" - diff --git a/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py b/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py index a5ea6490..952a8b23 100644 --- a/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py +++ b/src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py @@ -1272,11 +1272,12 @@ def main(args: argparse.Namespace): "--tokenizer-mode", type=str, default="auto", - choices=["auto", "slow", "mistral", "custom"], + choices=["auto", "slow", "mistral", "custom", "deepseek_v4"], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' "always use the slow tokenizer. \n* " - '"mistral" will always use the `mistral_common` tokenizer. \n*' + '"mistral" will always use the `mistral_common` tokenizer. \n* ' + '"deepseek_v4" will use vLLM\'s DeepSeek V4 tokenizer. \n* ' '"custom" will use --tokenizer to select the preregistered tokenizer.', ) diff --git a/src/srtctl/core/schema.py b/src/srtctl/core/schema.py index c535be39..3910819e 100644 --- a/src/srtctl/core/schema.py +++ b/src/srtctl/core/schema.py @@ -543,6 +543,7 @@ class BenchmarkConfig: num_warmup_mult: int | None = None # Multiplier for warmup prompts = concurrency * mult (default: 2) # Trace replay benchmark fields (uses aiperf with mooncake_trace dataset type) trace_file: str | None = None # Path to trace JSONL file (container path, e.g., /traces/dataset.jsonl) + tokenizer_mode: str | None = None # Tokenizer mode passed to SA-Bench (e.g., "auto", "deepseek_v4") custom_tokenizer: str | None = None # Custom tokenizer class (e.g., "module.path.ClassName") use_chat_template: bool = True # Pass --use-chat-template to benchmark (default: true) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index c15759b2..5a2b2d47 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -77,6 +77,36 @@ def test_validate_config_valid(self): errors = runner.validate_config(config) assert errors == [] + def test_build_command_includes_tokenizer_mode(self): + """Passes tokenizer mode through to the SA-Bench script.""" + from unittest.mock import MagicMock + + from srtctl.benchmarks.sa_bench import SABenchRunner + from srtctl.core.schema import BenchmarkConfig, ModelConfig, ResourceConfig, SrtConfig + + runner = SABenchRunner() + runtime = MagicMock() + runtime.frontend_port = 8000 + runtime.is_hf_model = False + + config = SrtConfig( + name="test", + model=ModelConfig(path="/model", container="/image", precision="fp4"), + resources=ResourceConfig(gpu_type="h100"), + benchmark=BenchmarkConfig( + type="sa-bench", + isl=1024, + osl=1024, + concurrencies=[4, 8], + tokenizer_mode="deepseek_v4", + use_chat_template=True, + ), + ) + + cmd = runner.build_command(config, runtime) + + assert cmd[-3:] == ["", "true", "deepseek_v4"] + class TestSGLangBenchRunner: """Test SGLang-Bench runner."""