diff --git a/CMakeLists.txt b/CMakeLists.txt index 65df275cd314..b2db8419a623 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -94,10 +94,10 @@ find_package(Torch REQUIRED) # This check must happen after find_package(Torch) because that's when CMAKE_CUDA_COMPILER_VERSION gets defined if(DEFINED CMAKE_CUDA_COMPILER_VERSION AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) - set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0") + set(CUDA_SUPPORTED_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0;10.0;11.0;12.0;12.1") elseif(DEFINED CMAKE_CUDA_COMPILER_VERSION AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8) - set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") + set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0;12.1") else() set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") endif() @@ -530,9 +530,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f;12.1f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) set(SRCS @@ -648,9 +648,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(FP4_ARCHS "12.0f" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(FP4_ARCHS "12.0f;12.1f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(FP4_ARCHS "12.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(FP4_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS @@ -699,10 +699,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() # CUTLASS MLA Archs and flags + # Note: CUTLASS MLA only supports SM100/SM103, NOT SM12x (GB10/SM121) + # SM12x devices use TRITON_MLA instead if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(MLA_ARCHS "10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MLA_ARCHS "10.0f;10.3f;11.0f" "${CUDA_ARCHS}") else() - cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MLA_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") endif() if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS) set(SRCS @@ -818,7 +820,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # moe_data.cu is used by all CUTLASS MoE kernels. if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f;12.1f" "${CUDA_ARCHS}") else() cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") endif() diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 000000000000..5d2be18f2d05 --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,474 @@ +// vLLM CI/CD Pipeline for DGX Spark 2 (SM121/GB10) +// ================================================ +// Builds and tests vLLM on spark2 (ARM64 + 8x A100) +// Jenkins master on node5 SSHs to spark2 for execution +// +// Target: NVIDIA GB10 Blackwell (SM121) + ARM64 + CUDA 13.1 +// Model: Qwen3-Next-80B-A3B-FP8 + +pipeline { + agent any + + options { + timestamps() + ansiColor('xterm') + buildDiscarder(logRotator(numToKeepStr: '30', artifactNumToKeepStr: '10')) + timeout(time: 180, unit: 'MINUTES') + disableConcurrentBuilds() + } + + environment { + // Spark2 connection (ARM64 + GPU build server) + SPARK2_HOST = '192.168.4.208' + SPARK2_USER = 'seli' + + // vLLM server endpoint + VLLM_API_URL = 'http://192.168.4.208:8000' + + // Remote paths on spark2 + VLLM_DIR = '/data/vllm' + VLLM_ENV = '/data/vllm-env' + CONTAINER_DIR = '/data/vllm-container' + + // Local paths for results + RESULTS_DIR = 'test-results' + ALLURE_RESULTS = 'test-results/allure-results' + METRICS_DIR = 'test-results/metrics' + + // Build settings + TORCH_CUDA_ARCH_LIST = '12.1' + MAX_JOBS = '20' + } + + parameters { + booleanParam( + name: 'REBUILD_VLLM', + defaultValue: false, + description: 'Rebuild vLLM from source (preserves torch)' + ) + booleanParam( + name: 'REBUILD_FLASHINFER', + defaultValue: false, + description: 'Rebuild FlashInfer from source' + ) + booleanParam( + name: 'REBUILD_DOCKER', + defaultValue: false, + description: 'Rebuild Docker container image' + ) + booleanParam( + name: 'RUN_BENCHMARKS', + defaultValue: true, + description: 'Run performance benchmarks' + ) + booleanParam( + name: 'SYNC_CODE', + defaultValue: true, + description: 'Pull latest code on spark2 before build' + ) + string( + name: 'BENCHMARK_PROMPTS', + defaultValue: '20', + description: 'Number of prompts for benchmark' + ) + } + + stages { + stage('Prepare') { + steps { + echo '๐Ÿ“ Preparing test environment...' + sh """ + mkdir -p ${RESULTS_DIR} + mkdir -p ${ALLURE_RESULTS} + mkdir -p ${METRICS_DIR} + rm -f ${RESULTS_DIR}/*.xml ${RESULTS_DIR}/*.html 2>/dev/null || true + rm -rf ${ALLURE_RESULTS}/* 2>/dev/null || true + rm -f ${METRICS_DIR}/*.csv 2>/dev/null || true + """ + } + } + + stage('Spark2 Health Check') { + steps { + echo '๐Ÿ” Checking spark2 connectivity and GPU status...' + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + echo '=== System Info ===' + uname -a + echo '' + echo '=== GPU Status ===' + nvidia-smi --query-gpu=name,memory.total,memory.used,utilization.gpu,temperature.gpu --format=csv + echo '' + echo '=== Docker Status ===' + docker ps --format 'table {{.Names}}\\t{{.Status}}' | head -10 + " + ''' + } + } + } + + stage('vLLM Server Health') { + steps { + echo '๐Ÿ” Checking vLLM server status and collecting metrics...' + script { + def health = sh( + script: "curl -sf ${VLLM_API_URL}/health && echo 'OK' || echo 'DOWN'", + returnStdout: true + ).trim() + + def models = sh( + script: "curl -sf ${VLLM_API_URL}/v1/models | jq -r '.data[0].id' 2>/dev/null || echo 'UNKNOWN'", + returnStdout: true + ).trim() + + // Collect GPU metrics for plotting + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + nvidia-smi --query-gpu=memory.used,utilization.gpu,temperature.gpu --format=csv,noheader,nounits + " > ${METRICS_DIR}/gpu_initial.csv || true + ''' + } + + if (health.contains('OK')) { + echo "โœ… vLLM Server: Running" + echo "๐Ÿ“ฆ Loaded Model: ${models}" + } else { + echo "โš ๏ธ vLLM Server: Not responding" + } + } + } + } + + stage('Sync Code') { + when { + expression { params.SYNC_CODE == true } + } + steps { + echo '๐Ÿ“ฅ Syncing latest code to spark2...' + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + cd ${VLLM_DIR} && \ + git fetch origin && \ + git status && \ + git pull origin \\$(git branch --show-current) && \ + git log -1 --pretty=format:'%h - %s (%an, %ar)' + " + ''' + } + } + } + + stage('Verify Torch') { + steps { + echo '๐Ÿ”ง Verifying CUDA torch installation...' + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + source ${VLLM_ENV}/bin/activate && \ + python3 -c \\" +import torch +print(f'PyTorch Version: {torch.__version__}') +print(f'CUDA Available: {torch.cuda.is_available()}') +if torch.cuda.is_available(): + print(f'CUDA Version: {torch.version.cuda}') + print(f'GPU: {torch.cuda.get_device_name(0)}') + print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB') +\\" + " + ''' + } + } + } + + stage('Build FlashInfer') { + when { + expression { params.REBUILD_FLASHINFER == true } + } + steps { + echo '๐Ÿ”จ Rebuilding FlashInfer on spark2...' + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + cd ${CONTAINER_DIR} && \ + ./safe-rebuild-vllm.sh --flashinfer-only --clear-cache -y + " + ''' + } + } + } + + stage('Build vLLM') { + when { + expression { params.REBUILD_VLLM == true } + } + steps { + echo '๐Ÿ”จ Rebuilding vLLM on spark2 (preserving torch)...' + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + cd ${CONTAINER_DIR} && \ + ./safe-rebuild-vllm.sh --vllm-only -y + " + ''' + } + } + } + + stage('Build Docker') { + when { + expression { params.REBUILD_DOCKER == true } + } + steps { + echo '๐Ÿณ Rebuilding Docker container on spark2...' + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + cd ${CONTAINER_DIR} && \ + ./safe-rebuild-vllm.sh --sync --docker -y 2>&1 | tee build.log + " + ''' + } + } + } + + stage('API Health Tests') { + steps { + echo '๐ŸŒ Running API tests against live vLLM server...' + script { + // Test health endpoint + def healthResult = sh( + script: "curl -sf ${VLLM_API_URL}/health && echo 'PASS' || echo 'FAIL'", + returnStdout: true + ).trim() + + // Test models endpoint + def modelsResult = sh( + script: "curl -sf ${VLLM_API_URL}/v1/models | jq -e '.data | length > 0' && echo 'PASS' || echo 'FAIL'", + returnStdout: true + ).trim() + + // Quick inference test with timing + def inferStart = System.currentTimeMillis() + def inferResult = sh( + script: ''' + curl -sf -X POST "${VLLM_API_URL}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d '{"model":"/models/Qwen3-Next-80B-A3B-FP8","messages":[{"role":"user","content":"Say hello in one word"}],"max_tokens":10}' \ + | jq -e '.choices[0].message.content' && echo 'PASS' || echo 'FAIL' + ''', + returnStdout: true + ).trim() + def inferTime = System.currentTimeMillis() - inferStart + + echo "Health Check: ${healthResult.contains('PASS') ? 'โœ…' : 'โŒ'}" + echo "Models Check: ${modelsResult.contains('PASS') ? 'โœ…' : 'โŒ'}" + echo "Inference Check: ${inferResult.contains('PASS') ? 'โœ…' : 'โŒ'} (${inferTime}ms)" + + // Write API test metrics CSV for plotting + writeFile file: "${METRICS_DIR}/api_tests.csv", text: """test,result,latency_ms +health,${healthResult.contains('PASS') ? 1 : 0},0 +models,${modelsResult.contains('PASS') ? 1 : 0},0 +inference,${inferResult.contains('PASS') ? 1 : 0},${inferTime} +""" + + // Write JUnit result + def failures = [healthResult, modelsResult, inferResult].count { it.contains('FAIL') } + writeFile file: "${RESULTS_DIR}/api_results.xml", text: """ + + ${healthResult.contains('FAIL') ? '' : ''} + ${modelsResult.contains('FAIL') ? '' : ''} + ${inferResult.contains('FAIL') ? '' : ''} +""" + } + } + post { + always { + junit( + testResults: "${RESULTS_DIR}/api_results.xml", + allowEmptyResults: true, + skipPublishingChecks: true + ) + } + } + } + + stage('Benchmarks') { + when { + anyOf { + expression { params.RUN_BENCHMARKS == true } + triggeredBy 'TimerTrigger' + } + } + steps { + echo "โšก Running performance benchmarks (${params.BENCHMARK_PROMPTS} prompts)..." + script { + // Run benchmark and capture detailed metrics + sshagent(['ssh-credentials']) { + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + # Warmup request + curl -sf -X POST '${VLLM_API_URL}/v1/chat/completions' \ + -H 'Content-Type: application/json' \ + -d '{\"model\":\"/models/Qwen3-Next-80B-A3B-FP8\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello\"}],\"max_tokens\":5}' > /dev/null + + echo 'Warmup complete, starting benchmark...' + " + ''' + + // Run multiple inference requests and collect timing + sh """ + echo 'prompt_id,tokens,latency_ms,tokens_per_sec' > ${METRICS_DIR}/benchmark_detailed.csv + + for i in \$(seq 1 ${params.BENCHMARK_PROMPTS}); do + PROMPT="Write a haiku about number \$i" + START=\$(date +%s%3N) + + RESPONSE=\$(curl -sf -X POST "${VLLM_API_URL}/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -d "{\\\"model\\\":\\\"/models/Qwen3-Next-80B-A3B-FP8\\\",\\\"messages\\\":[{\\\"role\\\":\\\"user\\\",\\\"content\\\":\\\"\$PROMPT\\\"}],\\\"max_tokens\\\":50}" 2>/dev/null) + + END=\$(date +%s%3N) + LATENCY=\$((END - START)) + + TOKENS=\$(echo "\$RESPONSE" | jq -r '.usage.completion_tokens // 0' 2>/dev/null || echo "0") + if [ "\$LATENCY" -gt 0 ] && [ "\$TOKENS" -gt 0 ]; then + TPS=\$(echo "scale=2; \$TOKENS * 1000 / \$LATENCY" | bc) + else + TPS=0 + fi + + echo "\$i,\$TOKENS,\$LATENCY,\$TPS" >> ${METRICS_DIR}/benchmark_detailed.csv + echo "Request \$i: \${TOKENS} tokens in \${LATENCY}ms (\${TPS} tok/s)" + done + """ + + // Collect final GPU metrics + sh ''' + ssh -o StrictHostKeyChecking=no ${SPARK2_USER}@${SPARK2_HOST} " + nvidia-smi --query-gpu=memory.used,utilization.gpu,temperature.gpu --format=csv,noheader,nounits + " > ${METRICS_DIR}/gpu_final.csv || true + ''' + } + + // Calculate summary statistics + sh ''' + if [ -f ${METRICS_DIR}/benchmark_detailed.csv ]; then + # Calculate averages using awk + awk -F',' 'NR>1 { + sum_tokens+=$2; sum_latency+=$3; sum_tps+=$4; count++ + } END { + if(count>0) { + printf "avg_tokens,avg_latency_ms,avg_tokens_per_sec\\n" + printf "%.1f,%.1f,%.2f\\n", sum_tokens/count, sum_latency/count, sum_tps/count + } + }' ${METRICS_DIR}/benchmark_detailed.csv > ${METRICS_DIR}/benchmark_summary.csv + + echo "=== Benchmark Summary ===" + cat ${METRICS_DIR}/benchmark_summary.csv + fi + ''' + } + } + post { + always { + archiveArtifacts( + artifacts: "${METRICS_DIR}/*.csv", + allowEmptyArchive: true + ) + } + } + } + + stage('Plot Metrics') { + steps { + echo '๐Ÿ“Š Generating performance plots...' + script { + // Plot benchmark latency over requests + plot( + csvFileName: 'benchmark_latency.csv', + csvSeries: [[ + file: "${METRICS_DIR}/benchmark_detailed.csv", + inclusionFlag: 'OFF', + displayTableFlag: false, + exclusionValues: 'prompt_id', + url: '' + ]], + group: 'vLLM Performance', + title: 'Inference Latency (ms)', + style: 'line', + yaxis: 'Latency (ms)', + numBuilds: '30', + useDescr: false + ) + + // Plot tokens per second + plot( + csvFileName: 'benchmark_tps.csv', + csvSeries: [[ + file: "${METRICS_DIR}/benchmark_detailed.csv", + inclusionFlag: 'OFF', + displayTableFlag: false, + exclusionValues: 'prompt_id,tokens,latency_ms', + url: '' + ]], + group: 'vLLM Performance', + title: 'Tokens per Second', + style: 'line', + yaxis: 'Tokens/sec', + numBuilds: '30', + useDescr: false + ) + } + } + } + + stage('Allure Report') { + steps { + echo '๐Ÿ“ˆ Generating Allure report...' + script { + try { + allure([ + includeProperties: false, + jdk: '', + properties: [], + reportBuildPolicy: 'ALWAYS', + results: [[path: "${ALLURE_RESULTS}"]] + ]) + } catch (e) { + echo "Allure report generation skipped: ${e.message}" + } + } + } + } + } + + post { + always { + echo '๐Ÿงน Archiving results...' + archiveArtifacts( + artifacts: "${RESULTS_DIR}/**/*", + allowEmptyArchive: true + ) + } + success { + echo ''' +โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•— +โ•‘ โœ… vLLM Pipeline completed successfully! โ•‘ +โ•‘ โ•‘ +โ•‘ Server: http://192.168.4.208:8000 โ•‘ +โ•‘ Model: Qwen3-Next-80B-A3B-FP8 โ•‘ +โ•‘ GPU: NVIDIA GB10 (SM121) - 115GB โ•‘ +โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ• +''' + } + failure { + echo 'โŒ vLLM Pipeline failed! Check logs for details.' + } + unstable { + echo 'โš ๏ธ vLLM Pipeline unstable (some tests failed)' + } + } +} diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 00fb959211fa..a57a39639d31 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -218,14 +218,47 @@ def serialize_compile_artifacts( state.pop("shape_env") state.pop("vllm_backend", None) state.pop("_fake_mode", None) - for node in state["graph_module"].graph.nodes: - node.meta.pop("source_fn_stack", None) - node.meta.pop("nn_module_stack", None) - for name, submod in state["graph_module"].named_children(): - if hasattr(submod, "graph"): - for node in submod.graph.nodes: - node.meta.pop("source_fn_stack", None) - node.meta.pop("nn_module_stack", None) + + def _strip_unpicklable_node_meta(graph_module: torch.fx.GraphModule) -> None: + """Remove metadata containing raw torch.fx.Node references. + + GraphPickler raises "Unexpected raw Node during pickling" when + node.meta contains raw Node objects. The default key filter + already strips source_fn_stack / nn_module_stack / + fwd_source_fn_stack, but other keys (e.g. from_node on newer + PyTorch nightlies, or custom keys injected by passes) can + also carry Node references. Walk every value and drop any + key whose value tree contains a raw Node. + """ + def _has_node_ref(obj: Any, depth: int = 0) -> bool: + if depth > 8: + return False + if isinstance(obj, torch.fx.Node): + return True + if isinstance(obj, dict): + return any(_has_node_ref(v, depth + 1) for v in obj.values()) + if isinstance(obj, (list, tuple)): + return any(_has_node_ref(v, depth + 1) for v in obj) + return False + + for node in graph_module.graph.nodes: + keys_to_remove = [ + k for k, v in node.meta.items() + if _has_node_ref(v) + ] + for k in keys_to_remove: + del node.meta[k] + for _name, submod in graph_module.named_children(): + if hasattr(submod, "graph"): + for node in submod.graph.nodes: + keys_to_remove = [ + k for k, v in node.meta.items() + if _has_node_ref(v) + ] + for k in keys_to_remove: + del node.meta[k] + + _strip_unpicklable_node_meta(state["graph_module"]) graph_reducer_override = GraphPickler.reducer_override @@ -240,6 +273,11 @@ def _graph_reducer_override( return obj._torch_unpickler, (obj._torch_handler_name,) if isinstance(obj, FakeTensorMode): return type(None), () + # Handle raw torch.fx.Node references that survived metadata + # stripping (e.g. nested inside complex structures). Serialize + # as None to avoid GraphPickler's "Unexpected raw Node" assertion. + if isinstance(obj, torch.fx.Node): + return type(None), () return graph_reducer_override(self, obj) if state.get("sym_tensor_indices"): diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4839fc80c1a1..b7796d93d5c8 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1569,11 +1569,17 @@ def _postprocess_messages(messages: list[ConversationMessage]) -> None: continue for item in tool_calls: - # if arguments is None or empty string, set to {} - if content := item["function"].get("arguments"): - if not isinstance(content, (dict, list)): + content = item["function"].get("arguments") + if isinstance(content, str): + try: + # This handles valid JSON. It will raise a JSONDecodeError + # for empty, whitespace-only, or malformed strings. item["function"]["arguments"] = json.loads(content) - else: + except json.JSONDecodeError: + # Default to an empty dict for any string that isn't valid JSON. + item["function"]["arguments"] = {} + elif not isinstance(content, (dict, list)): + # This handles None and other unexpected types. item["function"]["arguments"] = {} diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index b1dc1a860501..e995290e65b8 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1282,11 +1282,19 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool: @functools.cache def use_flashinfer_prefill() -> bool: - # For blackwell default to flashinfer prefill if it's available since - # it is faster than FA2. + """Check if FlashInfer prefill should be used for MLA attention. + + Note: This uses FlashInfer's general prefill path with is_blackwell_class(), + which supports SM10x, SM11x, and SM12x variants via Blackwell-family kernels. + This is distinct from FlashInfer MLA-specific backends which only support + SM100/SM103. The prefill kernels use gen_fmha_cutlass_sm100a_module. + See FlashInfer README: "beta support for 103, 110, 120, and 121" + """ from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() + # FlashInfer MLA prefill only supports SM100 (capability.major == 10) + # SM121/GB10 will use TRITON_MLA instead if not ( not vllm_config.attention_config.disable_flashinfer_prefill and has_flashinfer() @@ -1300,20 +1308,36 @@ def use_flashinfer_prefill() -> bool: @functools.cache def use_cudnn_prefill() -> bool: + """Check if cuDNN prefill should be used for MLA attention. + + The cuDNN SDPA cubins (named cudnn_sm100_*) are architecture-family + binaries. FlashInfer's cubin loader downloads these from NVIDIA artifactory. + Uses is_blackwell_class() to support all Blackwell variants (SM10x, SM12x). + See: https://github.com/flashinfer-ai/flashinfer (supports SM121 beta) + """ from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() return ( has_flashinfer() and vllm_config.attention_config.use_cudnn_prefill - and current_platform.is_device_capability_family(100) + and current_platform.is_blackwell_class() and has_nvidia_artifactory() ) @functools.cache def use_trtllm_ragged_deepseek_prefill() -> bool: - """Check if TRT-LLM ragged DeepSeek prefill should be used.""" + """Check if TRT-LLM ragged DeepSeek prefill should be used. + + Note: This uses FlashInfer's trtllm_ragged_attention_deepseek kernel which + is only supported on SM100/SM103 (B200/GB200), NOT on SM120/SM121 (GB10). + FlashInfer's benchmark matrix confirms this: + - SM10.0/10.3 (B200/GB200): trtllm-native supported + - SM12.0/12.1 (GB10): trtllm-native NOT supported (uses fa2/cudnn fallback) + + We restrict this to is_device_capability_family(100) to exclude SM12x. + """ from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() diff --git a/vllm/model_executor/layers/batch_invariant.py b/vllm/model_executor/layers/batch_invariant.py index 9f8b1955eb09..c603582843b3 100644 --- a/vllm/model_executor/layers/batch_invariant.py +++ b/vllm/model_executor/layers/batch_invariant.py @@ -935,7 +935,7 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB = torch.library.Library("aten", "IMPL") if ( - current_platform.is_device_capability_family(100) + current_platform.is_blackwell_class() or current_platform.is_device_capability(80) or current_platform.is_device_capability(89) ): diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 539712587a71..50fac227df7f 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -319,12 +319,11 @@ def supports_expert_map(self) -> bool: def supports_packed_ue8m0_act_scales(self) -> bool: """ - DeepGemm supports packed ue8m0 activation scales format in devices == sm100 + DeepGemm supports packed ue8m0 activation scales format on Blackwell-class + devices (SM10x, SM11x, SM12x). The E8M0 format is architecture-family + compatible across all Blackwell variants. """ - return ( - is_deep_gemm_e8m0_used() - and current_platform.is_device_capability_family(100) - ) + return is_deep_gemm_e8m0_used() and current_platform.is_blackwell_class() def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: # Let PrepareAndFinalize::finalize() decide the impl. diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..f47e4c67b456 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json new file mode 100644 index 000000000000..f47e4c67b456 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB10,dtype=fp8_w8a8,block_shape=[128,128].json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.5.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 4ee2aab25068..106c9d3a3a49 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -28,6 +28,7 @@ from vllm.utils.flashinfer import ( flashinfer_cutlass_fused_moe, has_flashinfer_cutlass_fused_moe, + has_flashinfer_nvfp4, ) logger = init_logger(__name__) @@ -152,14 +153,14 @@ def _supports_quant_scheme( ] and p.has_device_capability(90) ) - # fp8 block-scale, wmxfp4a16 on 9.0 + # fp8 block-scale, wmxfp4a16 on 9.0, and fp8 block-scale on 12.0+ (SM121/GB10) or ( scheme in [ (kMxfp4Static, None), (kFp8Static128BlockSym, kFp8Dynamic128Sym), ] - and p.is_device_capability(90) + and (p.is_device_capability(90) or p.is_device_capability_family(120)) ) # nvfp4, wmxfp4amxfp8 on 10.0+ or ( @@ -169,6 +170,7 @@ def _supports_quant_scheme( (kNvfp4Static, kNvfp4Dynamic), ] and p.has_device_capability(100) + and (scheme != (kNvfp4Static, kNvfp4Dynamic) or has_flashinfer_nvfp4()) ) ) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 1cff68162183..d8b26ddffa29 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -121,11 +121,13 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: logger.info_once("Using FlashInfer MXFP4 BF16 backend for SM90") return Mxfp4Backend.SM90_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability_family(100) + current_platform.is_blackwell_class() and has_flashinfer() and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS ): - logger.info_once("Using FlashInfer MXFP4 MXFP8 CUTLASS backend for SM100") + logger.info_once( + "Using FlashInfer MXFP4 MXFP8 CUTLASS backend for Blackwell-class GPU" + ) return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS elif ( current_platform.is_device_capability_family(100) @@ -136,16 +138,16 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100", scope="local" ) return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM - elif current_platform.is_device_capability_family(100) and has_flashinfer(): + elif current_platform.is_blackwell_class() and has_flashinfer(): logger.info_once( - "Using FlashInfer MXFP4 BF16 backend for SM100, " - "For faster performance on SM100, consider setting " + "Using FlashInfer MXFP4 BF16 backend for Blackwell-class GPU. " + "For faster performance, consider setting " "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, though this may impact " "accuracy." ) return Mxfp4Backend.SM100_FI_MXFP4_BF16 elif ( - current_platform.is_device_capability_family(100) + current_platform.is_blackwell_class() or current_platform.is_device_capability(90) ) and not has_flashinfer(): logger.warning_once( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index a8be1d61ac24..4839c203ccdd 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -108,9 +108,11 @@ def get_flashinfer_moe_backend() -> FlashinferMoeBackend: flashinfer_moe_backend == "latency" and not current_platform.is_device_capability_family(100) ): + # TRTLLM MOE backend only supports SM100/SM103 (B100/B200), + # NOT SM120/SM121 (GB10 DGX Spark). Fall back to CUTLASS. logger.info_once( "Flashinfer TRTLLM MOE backend is only supported on " - "SM100 and later, using CUTLASS backend instead", + "SM100/SM103, using CUTLASS backend instead", scope="local", ) return FlashinferMoeBackend.CUTLASS diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 23d7cf55474a..b29086438390 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -71,7 +71,7 @@ def _swizzle_mxfp4(quant_tensor, scale, num_warps): "split_k": 1, } opt_flags.update_opt_flags_constraints(constraints) - elif current_platform.is_device_capability_family(100): + elif current_platform.is_blackwell_class(): constraints = { "is_persistent": True, "epilogue_subtile": 1, diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index e00a17a153fb..4ab8b2761b4f 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -861,7 +861,7 @@ def fastsafetensors_weights_iterator( # Use nogds=True for TP > 1 to avoid cuFileDriverOpen() which # initializes the GDS DMA subsystem for all visible GPUs, creating # unwanted CUDA contexts on every device. - nogds = pg.size() > 1 + nogds = True # GB10 does not support GDS for f_list in tqdm( weight_files_sub_lists, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index b76168281380..7bff36cc4c5f 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -149,7 +149,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: else: kernel_block_alignment_size = 16 if ( - current_platform.is_device_capability_family(100) + current_platform.is_blackwell_class() and model_config.get_head_size() == 256 and ( attention_config.backend is None diff --git a/vllm/model_executor/warmup/kernel_warmup.py b/vllm/model_executor/warmup/kernel_warmup.py index 70abd8a6c503..9c30295e16b6 100644 --- a/vllm/model_executor/warmup/kernel_warmup.py +++ b/vllm/model_executor/warmup/kernel_warmup.py @@ -39,7 +39,7 @@ def kernel_warmup(worker: "Worker"): enable_flashinfer_autotune = ( worker.vllm_config.kernel_config.enable_flashinfer_autotune ) - # FlashInfer autotune for Hopper (SM 9.0) and Blackwell (SM 10.0) GPUs + # FlashInfer autotune for Hopper (SM 9.0) and Blackwell-class (SM 10.x/12.x) GPUs if enable_flashinfer_autotune is False: logger.info("Skipping FlashInfer autotune because it is disabled.") elif has_flashinfer() and current_platform.has_device_capability(90): diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2025c41ab8d9..d961b13cc19d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -44,6 +44,18 @@ torch.backends.cuda.enable_cudnn_sdp(False) +def _is_blackwell_class(device_capability: DeviceCapability) -> bool: + """Check if device is Blackwell-class (SM10x, SM11x, SM12x). + + Blackwell architecture includes: + - SM100/SM101: B100, B200 (major=10) + - SM120/SM121: GB10 DGX Spark (major=12) + + Note: SM11x may be used by future Blackwell variants. + """ + return device_capability.major in (10, 11, 12) + + @cache def _get_backend_priorities( use_mla: bool, @@ -51,8 +63,10 @@ def _get_backend_priorities( num_heads: int | None = None, ) -> list[AttentionBackendEnum]: """Get backend priorities with lazy import to avoid circular dependency.""" + is_blackwell = _is_blackwell_class(device_capability) + if use_mla: - if device_capability.major == 10: + if is_blackwell: # Prefer FlashInfer at low head counts (FlashMLA uses padding) if num_heads is not None and num_heads <= 16: sparse_backends = [ @@ -81,7 +95,7 @@ def _get_backend_priorities( AttentionBackendEnum.FLASHMLA_SPARSE, ] else: - if device_capability.major == 10: + if is_blackwell: return [ AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASH_ATTN, @@ -164,6 +178,21 @@ def is_fully_connected(cls, device_ids: list[int]) -> bool: def log_warnings(cls): pass + @classmethod + def is_blackwell_class(cls, device_id: int = 0) -> bool: + """Check if device is Blackwell-class (SM10x, SM11x, SM12x). + + Blackwell architecture includes: + - SM100/SM101: B100, B200 (major=10) + - SM120/SM121: GB10 DGX Spark (major=12) + + Note: SM11x may be used by future Blackwell variants. + """ + capability = cls.get_device_capability(device_id) + if capability is None: + return False + return capability.major in (10, 11, 12) + @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config = vllm_config.parallel_config @@ -325,7 +354,16 @@ def get_attn_backend_cls( @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: - if cls.has_device_capability(80): + if _is_blackwell_class(cls.get_device_capability()): + # SM12x (GB10): Flash Attention ViT kernels lack SM121 PTX, + # prefer FlashInfer which is compiled with SM121 support. + return [ + AttentionBackendEnum.FLASHINFER, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.TORCH_SDPA, + AttentionBackendEnum.FLASH_ATTN, + ] + elif cls.has_device_capability(80): return [ AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TRITON_ATTN, @@ -339,7 +377,6 @@ def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER, ] - @classmethod def get_vit_attn_backend( cls, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 774d9e0713da..32e4f0fb6b2c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -345,6 +345,23 @@ def is_device_capability_family( return False return (current_capability.to_int() // 10) == (capability // 10) + @classmethod + def is_blackwell_class(cls, device_id: int = 0) -> bool: + """Check if device is Blackwell-class GPU (SM10x, SM11x, SM12x). + + Blackwell architecture family includes: + - SM100/SM101: B100, B200 data center GPUs (major=10) + - SM120/SM121: GB10 DGX Spark, Thor edge devices (major=12) + + Note: SM11x reserved for future Blackwell variants. + + Returns False for non-CUDA platforms. + """ + capability = cls.get_device_capability(device_id=device_id) + if capability is None: + return False + return capability.major in (10, 11, 12) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index f05bc555bfdc..766236524065 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -2,6 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os +import shutil +import subprocess import types from importlib.util import find_spec @@ -9,6 +11,99 @@ logger = init_logger(__name__) + +def _configure_triton_ptxas_for_new_gpus(): + """ + Configure TRITON_PTXAS_PATH for GPUs that may not be supported by + Triton's bundled ptxas (e.g., Jetson Thor sm_110a, DGX Spark sm_121a). + + Triton bundles a ptxas binary (currently CUDA 12.8) that may not support + the newest GPU architectures. When running on such GPUs, Triton kernel + compilation fails with errors like: + ptxas fatal: Value 'sm_121a' is not defined for option 'gpu-name' + + This function uses Triton's native GPU detection to check the architecture + and configures Triton to use the system's CUDA toolkit ptxas instead, + which typically has broader architecture support (e.g., CUDA 13.0+). + """ + # Don't override if already set by user + if os.environ.get("TRITON_PTXAS_PATH"): + return + + # Try to find system ptxas + cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda") + system_ptxas_paths = [ + os.path.join(cuda_home, "bin", "ptxas"), + "/usr/local/cuda/bin/ptxas", + shutil.which("ptxas"), # Check PATH + ] + + system_ptxas = None + for path in system_ptxas_paths: + if path and os.path.isfile(path) and os.access(path, os.X_OK): + system_ptxas = path + break + + if not system_ptxas: + # No system ptxas found, can't help + return + + # Use Triton's native GPU detection to get the architecture. + # This is how Triton itself determines the target GPU. + try: + from triton.backends import backends + + nvidia_backend = backends.get("nvidia") + if nvidia_backend is None or nvidia_backend.driver is None: + return + + if not nvidia_backend.driver.is_active(): + return + + # Get the current GPU target using Triton's driver + driver_instance = nvidia_backend.driver() + target = driver_instance.get_current_target() + arch = target.arch # e.g., 121 for sm_121a (CC 12.1) + + # GPUs with arch >= 110 (compute capability >= 11.0) may need system ptxas + # - arch 110: Jetson Thor (sm_110a, CC 11.0) + # - arch 120: Blackwell B100/B200 (sm_120, CC 12.0) + # - arch 121: DGX Spark GB10 (sm_121a, CC 12.1) + if arch >= 110: + # Check if system ptxas is functional + try: + result = subprocess.run( + [system_ptxas, "--version"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # System ptxas is available, use it + os.environ["TRITON_PTXAS_PATH"] = system_ptxas + major, minor = divmod(arch, 10) + logger.info( + "Detected GPU with compute capability %d.%d (arch=%d). " + "Configuring TRITON_PTXAS_PATH=%s to ensure " + "Triton kernel compilation compatibility.", + major, + minor, + arch, + system_ptxas, + ) + except (subprocess.TimeoutExpired, FileNotFoundError, OSError) as e: + logger.debug("Cannot use system ptxas: %s", e) + + except Exception as e: + # Don't fail if detection doesn't work - user can still set + # TRITON_PTXAS_PATH manually + logger.debug("Failed to auto-configure TRITON_PTXAS_PATH: %s", e) + + +# Configure ptxas before importing Triton to ensure kernels can compile +# on new GPU architectures (Thor, GB10, etc.) +_configure_triton_ptxas_for_new_gpus() + HAS_TRITON = ( find_spec("triton") is not None or find_spec("pytorch-triton-xpu") is not None # Not compatible diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index ee104a6cc75c..e3259377bd48 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -53,7 +53,7 @@ def init_oracle_cache(cls) -> None: cls._oracle_cache = ( # type: ignore cls.UE8M0 - if current_platform.is_device_capability_family(100) + if current_platform.is_blackwell_class() else cls.FLOAT32_CEIL_UE8M0 ) @@ -72,7 +72,7 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability_family(100) + or current_platform.is_blackwell_class() ) return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index c3ac839c21d1..a34af18a7424 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -207,16 +207,39 @@ def has_flashinfer_trtllm_fused_moe() -> bool: @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: - """Return `True` if FlashInfer CUTLASS fused MoE is available.""" + """Return `True` if FlashInfer CUTLASS fused MoE engine is available. + + Only checks for the core CUTLASS MoE entry point. FP4-specific + utilities (fp4_quantize, nvfp4_block_scale_interleave) are checked + separately via has_flashinfer_nvfp4() and gated by + _supports_quant_scheme(). This allows FP8 CUTLASS MoE to work on + architectures like SM121 (GB10) that have cutlass_fused_moe but + may lack FP4 utilities. + """ if not has_flashinfer_moe(): return False - # Check if all required functions are available required_functions = [ ("flashinfer.fused_moe", "cutlass_fused_moe"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + +@functools.cache +def has_flashinfer_nvfp4() -> bool: + """Return `True` if FlashInfer NVFP4 quantization utilities are available. + + Checks for fp4_quantize and nvfp4_block_scale_interleave which are + required for NVFP4 quantization paths but not for FP8 CUTLASS MoE. + """ + required_functions = [ ("flashinfer", "fp4_quantize"), ("flashinfer", "nvfp4_block_scale_interleave"), - ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), ] for module_name, attr_name in required_functions: @@ -277,14 +300,19 @@ def has_nvidia_artifactory() -> bool: @functools.cache def supports_trtllm_attention() -> bool: """ - TRTLLM attention is supported if the platform is SM100, + TRTLLM attention is supported if the platform is SM100/SM103, NVIDIA artifactory is accessible, and batch-invariant mode is not enabled. + + Note: TRTLLM attention kernels are NOT supported on SM12x (GB10). + FlashInfer's benchmark matrix confirms trtllm-native is only available + for SM10.0/10.3 (B200/GB200), not SM12.0/12.1 (GB10). SM12x devices should + fall back to other attention backends (FA2, cuDNN, etc.). """ # Batch-invariant mode disables TRTLLM attention if vllm_is_batch_invariant(): return False - # Requires SM100 and NVIDIA artifactory to be accessible to download cubins + # Requires SM100/SM103 only (NOT SM12x) and NVIDIA artifactory for cubins return ( current_platform.is_device_capability_family(100) and has_nvidia_artifactory() ) @@ -768,6 +796,7 @@ def should_use_flashinfer_for_blockscale_fp8_gemm( "has_flashinfer_comm", "has_flashinfer_all2all", "has_flashinfer_cutlass_fused_moe", + "has_flashinfer_nvfp4", "has_flashinfer_cutedsl_grouped_gemm_nt_masked", "has_flashinfer_fp8_blockscale_gemm", "has_nvidia_artifactory", diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 20502cbf0feb..5bb130992b92 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -93,7 +93,8 @@ def get_flash_attn_version( fa_version = vllm_config.attention_config.flash_attn_version # 3. fallback for unsupported combinations - if device_capability.major >= 10 and fa_version == 3: + # Blackwell-class: SM10x, SM11x, SM12x (GB10) - FA3 not supported + if device_capability.major in (10, 11, 12) and fa_version == 3: logger.warning_once( "Cannot use FA version 3 on Blackwell platform, " "defaulting to FA version 4 if supported, otherwise FA2." diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 844e8597e5b1..12d607fbf98a 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -385,7 +385,8 @@ def supports_sink(cls) -> bool: @classmethod def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: capability = current_platform.get_device_capability() - if capability is not None and capability.major == 10: + # Blackwell-class: SM10x, SM11x, SM12x (GB10) + if capability is not None and capability.major in (10, 11, 12): return "HND" return None @@ -630,7 +631,7 @@ def __init__( self.paged_kv_indices = self._make_buffer(max_num_pages) self.paged_kv_last_page_len = self._make_buffer(max_num_reqs) - if self.head_dim == 256 and current_platform.is_device_capability_family(100): + if self.head_dim == 256 and current_platform.is_blackwell_class(): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # head size 256 and block size 16 is not supported on blackwell. assert kv_cache_spec.block_size != 16, ( diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index 0751b5f0f34c..904a9e2ea7ab 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -62,6 +62,8 @@ def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + # Cutlass MLA only supports SM100/SM103 (B200/GB200), NOT SM12x (GB10). + # SM12x devices should use other attention backends. return capability.major == 10 diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 102d5706b997..bb290b13629c 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -61,6 +61,8 @@ def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: @classmethod def supports_compute_capability(cls, capability: DeviceCapability) -> bool: + # FlashInfer MLA only supports SM100/SM103 (B200/GB200), NOT SM12x (GB10). + # SM12x devices should use other attention backends. return capability.major == 10 @classmethod diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 7cc50ec84584..2f625b0eb342 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -274,6 +274,7 @@ def __init__( # - SM100 (Blackwell head128): num_sm_parts = num_sms / s_q / 2 # For max buffer size, use s_q = 1 (the case that produces largest output) # Use padded head count since that's what will be passed to the kernel + # Note: SM121/GB10 does not support FlashMLA Sparse (uses TRITON_MLA instead) h_q = self.fp8_decode_padded_heads if current_platform.is_device_capability_family(100): # SM100 head64 or head64x2 uses full SM count @@ -562,7 +563,8 @@ def __init__( self.softmax_scale = scale assert indexer is not None self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer - # Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell + # Prefill BF16 kernel requires 64 on Hopper, 128 on Blackwell (SM100 only) + # Note: SM121/GB10 does not support FlashMLA Sparse (uses TRITON_MLA instead) self.prefill_padding = ( 128 if current_platform.is_device_capability_family(100) else 64 ) diff --git a/vllm/v1/attention/ops/flashmla.py b/vllm/v1/attention/ops/flashmla.py index aa667570a823..23a4ccad84ae 100644 --- a/vllm/v1/attention/ops/flashmla.py +++ b/vllm/v1/attention/ops/flashmla.py @@ -73,7 +73,8 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]: ): return ( False, - "FlashMLA Sparse is only supported on Hopper and Blackwell devices.", + "FlashMLA Sparse is only supported on SM90 (Hopper) " + "and SM100 (Blackwell B200/GB200).", ) return True, None