diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml new file mode 100644 index 000000000000..cca58097e8aa --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Llama-3.2-1B-Instruct-FP8 -b "auto" -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Llama-3.2-1B-Instruct-FP8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.335 + - name: "exact_match,flexible-extract" + value: 0.323 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml new file mode 100644 index 000000000000..54579a63a9b8 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2.5-1.5B-Instruct -b auto -l 1319 -f 5 -t 1 +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.54 + - name: "exact_match,flexible-extract" + value: 0.59 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 000000000000..a2f235f48581 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.47 + - name: "exact_match,flexible-extract" + value: 0.64 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 37eeac85c933..27a1a9a82bd3 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,3 +3,4 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml +Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 254d01edf844..36e0543879b3 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,10 +1,6 @@ -Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +Qwen2.5-1.5B-Instruct.yaml Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml -Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-compressed-tensors.yaml -Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml -Qwen2-1.5B-Instruct-FP8W8.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index d29903bf497f..97dcc42312f6 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -3,6 +3,9 @@ # This script runs test inside the corresponding ROCm docker container. set -o pipefail +# Export Python path +export PYTHONPATH=".." + # Print ROCm version echo "--- Confirming Clean Initial State" while true; do @@ -74,6 +77,15 @@ HF_MOUNT="/root/.cache/huggingface" commands=$@ echo "Commands:$commands" + +if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"} +fi + +if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} +fi + #ignore certain kernels tests if [[ $commands == *" kernels/core"* ]]; then commands="${commands} \ @@ -161,6 +173,8 @@ fi PARALLEL_JOB_COUNT=8 +MYPYTHONPATH=".." + # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then # assign job count as the number of shards used @@ -181,6 +195,7 @@ if [[ $commands == *"--shard-id="* ]]; then -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}_${GPU}" \ "${image_name}" \ /bin/bash -c "${commands_gpu}" \ @@ -211,6 +226,7 @@ else -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}" \ "${image_name}" \ /bin/bash -c "${commands}" diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 07b898787eba..939daddad92b 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -1,6 +1,6 @@ #!/bin/bash -set -xue +set -xu # Build the docker image. docker build -f docker/Dockerfile.tpu -t vllm-tpu . @@ -24,33 +24,80 @@ docker run --privileged --net host --shm-size=16G -it \ && export VLLM_XLA_CHECK_RECOMPILATION=1 \ && echo HARDWARE \ && tpu-info \ - && echo TEST_0 \ - && pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \ - && echo TEST_1 \ - && pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \ - && echo TEST_2 \ - && pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ - && echo TEST_3 \ - && pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ - && echo TEST_4 \ - && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ - && echo TEST_5 \ - && python3 /workspace/vllm/examples/offline_inference/tpu.py \ - && echo TEST_6 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ - && echo TEST_7 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \ - && echo TEST_8 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ - && echo TEST_9 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \ - && echo TEST_10 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ - && echo TEST_11 \ - && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py \ - && echo TEST_12 \ - && pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" \ - + && { \ + echo TEST_0: Running test_perf.py; \ + pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ + echo TEST_0_EXIT_CODE: \$?; \ + } & \ + && { \ + echo TEST_1: Running test_compilation.py; \ + pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ + echo TEST_1_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_2: Running test_basic.py; \ + pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ + echo TEST_2_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ + pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ + echo TEST_3_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_4: Running test_quantization_accuracy.py; \ + pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ + echo TEST_4_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_5: Running examples/offline_inference/tpu.py; \ + python3 /workspace/vllm/examples/offline_inference/tpu.py; \ + echo TEST_5_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_6: Running test_tpu_model_runner.py; \ + pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ + echo TEST_6_EXIT_CODE: \$?; \ + } & \ + && { \ + echo TEST_7: Running test_sampler.py; \ + pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ + echo TEST_7_EXIT_CODE: \$?; \ + } & \ + && { \ + echo TEST_8: Running test_topk_topp_sampler.py; \ + pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ + echo TEST_8_EXIT_CODE: \$?; \ + } & \ + && { \ + echo TEST_9: Running test_multimodal.py; \ + pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ + echo TEST_9_EXIT_CODE: \$?; \ + } & \ + && { \ + echo TEST_10: Running test_pallas.py; \ + pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ + echo TEST_10_EXIT_CODE: \$?; \ + } & \ + && { \ + echo TEST_11: Running test_struct_output_generate.py; \ + pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ + echo TEST_11_EXIT_CODE: \$?; \ + } & \ + && { \ + echo TEST_12: Running test_moe_pallas.py; \ + pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ + echo TEST_12_EXIT_CODE: \$?; \ + } & \ + # Disable the TPU LoRA tests until the feature is activated + # && { \ + # echo TEST_13: Running test_moe_pallas.py; \ + # pytest -s -v /workspace/vllm/tests/tpu/lora/; \ + # echo TEST_13_EXIT_CODE: \$?; \ + # } & \ + wait \ + && echo 'All tests have attempted to run. Check logs for individual test statuses and exit codes.' \ +" # TODO: This test fails because it uses RANDOM_SEED sampling # && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 01d04759f536..f7e4af4f2af4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -32,6 +32,7 @@ steps: ##### fast check tests ##### - label: Documentation Build # 2min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/test_docs/docs" fast_check: true no_gpu: True @@ -42,6 +43,7 @@ steps: - grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html - label: Async Engine, Inputs, Utils, Worker Test # 24min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/mq_llm_engine @@ -62,6 +64,7 @@ steps: - pytest -v -s worker # Worker - label: Python-only Installation Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - tests/standalone_tests/python_only_compile.sh - setup.py @@ -69,7 +72,7 @@ steps: - bash standalone_tests/python_only_compile.sh - label: Basic Correctness Test # 30min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true torch_nightly: true source_file_dependencies: @@ -86,6 +89,7 @@ steps: - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Chunked Prefill Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/basic_correctness/test_chunked_prefill @@ -94,7 +98,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true source_file_dependencies: - vllm/core @@ -104,10 +108,10 @@ steps: - pytest -v -s core - label: Entrypoints Test # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true torch_nightly: true - #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/entrypoints/llm @@ -126,6 +130,7 @@ steps: - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Distributed Tests (4 GPUs) # 10min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -158,7 +163,7 @@ steps: - popd - label: Metrics, Tracing Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 source_file_dependencies: - vllm/ @@ -172,7 +177,7 @@ steps: ##### 1 GPU test ##### - label: Regression Test # 5min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/test_regression @@ -182,7 +187,7 @@ steps: working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/engine @@ -196,7 +201,7 @@ steps: - pytest -v -s tokenization - label: V1 Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/v1 @@ -221,8 +226,8 @@ steps: - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" - #mirror_hardwares: [amd] source_file_dependencies: - vllm/entrypoints - examples/ @@ -246,7 +251,7 @@ steps: - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/prefix_caching @@ -254,6 +259,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test # 36min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers - vllm/sampling_metadata.py @@ -264,7 +270,7 @@ steps: - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - label: LogitsProcessor Test # 5min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/model_executor/layers - vllm/model_executor/guided_decoding @@ -275,6 +281,7 @@ steps: - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 40min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/spec_decode - tests/spec_decode @@ -285,7 +292,7 @@ steps: - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/lora - tests/lora @@ -293,6 +300,7 @@ steps: parallelism: 4 - label: PyTorch Compilation Unit Tests + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -303,6 +311,7 @@ steps: - pytest -v -s compile/test_sequence_parallelism.py - label: PyTorch Fullgraph Smoke Test # 9min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -314,6 +323,7 @@ steps: - pytest -v -s compile/piecewise/test_toy_llama.py - label: PyTorch Fullgraph Test # 18min + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -322,7 +332,7 @@ steps: - pytest -v -s compile/test_full_graph.py - label: Kernels Core Operation Test - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/ - tests/kernels/core @@ -330,7 +340,7 @@ steps: - pytest -v -s kernels/core - label: Kernels Attention Test %N - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/attention/ - vllm/attention @@ -341,7 +351,7 @@ steps: parallelism: 2 - label: Kernels Quantization Test %N - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/quantization/ - vllm/model_executor/layers/quantization @@ -351,7 +361,7 @@ steps: parallelism: 2 - label: Kernels MoE Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/moe/ - tests/kernels/moe @@ -360,7 +370,7 @@ steps: - pytest -v -s kernels/moe - label: Kernels Mamba Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ - tests/kernels/mamba @@ -368,7 +378,7 @@ steps: - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min - # mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader @@ -379,14 +389,15 @@ steps: - pytest -v -s tensorizer_loader - label: Benchmarks # 9min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/.buildkite" - mirror_hardwares: [amd] source_file_dependencies: - benchmarks/ commands: - bash scripts/run-benchmarks.sh - label: Benchmarks CLI Test # 10min + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/benchmarks/ @@ -394,6 +405,7 @@ steps: - pytest -v -s benchmarks/ - label: Quantization Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization @@ -402,6 +414,7 @@ steps: - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ @@ -411,6 +424,7 @@ steps: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/entrypoints/openai/ @@ -419,6 +433,7 @@ steps: - pytest -s entrypoints/openai/correctness/ - label: Encoder Decoder tests # 5min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/encoder_decoder @@ -426,8 +441,8 @@ steps: - pytest -v -s encoder_decoder - label: OpenAI-Compatible Tool Use # 20 min + mirror_hardwares: [amdexperimental] fast_check: false - #mirror_hardwares: [ amd ] source_file_dependencies: - vllm/ - tests/tool_use @@ -439,6 +454,7 @@ steps: ##### models test ##### - label: Basic Models Test # 24min + mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ @@ -454,7 +470,7 @@ steps: - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' - label: Language Models Test (Standard) - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/models/language @@ -464,6 +480,7 @@ steps: - pytest -v -s models/language -m core_model - label: Language Models Test (Extended) + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -474,7 +491,7 @@ steps: - pytest -v -s models/language -m 'not core_model' - label: Multi-Modal Models Test (Standard) - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/models/multimodal @@ -485,6 +502,7 @@ steps: - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work - label: Multi-Modal Models Test (Extended) 1 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -494,6 +512,7 @@ steps: - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' - label: Multi-Modal Models Test (Extended) 2 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -503,6 +522,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' - label: Multi-Modal Models Test (Extended) 3 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -512,7 +532,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' - label: Quantized Models Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers/quantization - tests/models/quantization @@ -521,7 +541,7 @@ steps: # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] optional: true commands: - echo 'Testing custom models...' @@ -533,7 +553,7 @@ steps: ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -544,6 +564,7 @@ steps: - pytest -v -s distributed/test_shm_broadcast.py - label: 2 Node Tests (4 GPUs in total) # 16min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 num_nodes: 2 @@ -562,7 +583,7 @@ steps: - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - label: Distributed Tests (2 GPUs) # 40min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -599,6 +620,7 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - label: Plugin Tests (2 GPUs) # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -618,6 +640,7 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - label: Multi-step Tests (4 GPUs) # 36min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -638,6 +661,7 @@ steps: - pytest -v -s multi_step/test_correctness_llm.py - label: Pipeline Parallelism Test # 45min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -651,6 +675,7 @@ steps: - pytest -v -s distributed/test_pipeline_parallel.py - label: LoRA TP Test (Distributed) + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 4 source_file_dependencies: - vllm/lora @@ -666,6 +691,7 @@ steps: - label: Weight Loading Multiple GPU Test # 33min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index 637d2dd11454..00b0f024c0da 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -75,7 +75,7 @@ body: ``` ``` - The error message you got, with the full traceback. + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. ``` validations: required: true diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b3bfe0af7f5..270c480003e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -288,6 +288,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/attention/mla/cutlass_mla_entry.cu") @@ -418,6 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -494,7 +496,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu") + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") @@ -532,7 +536,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible # to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") diff --git a/README.md b/README.md index dda3ae6009f5..df294c600770 100644 --- a/README.md +++ b/README.md @@ -16,18 +16,20 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). +- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). +- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). + +
+Previous News + - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). - [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0). - [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted. -- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). - [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! - -
-Previous News - - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! - [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py new file mode 100644 index 000000000000..0d091b47c3e1 --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -0,0 +1,408 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe +kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit +activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8) +and 16-bit activations. +""" +import nvtx +import torch +import torch.utils.benchmark as benchmark + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts, + fused_topk) +from vllm.scalar_type import scalar_types +from vllm.utils import FlexibleArgumentParser + +WEIGHT_SHAPES_MOE = { + "nvidia/DeepSeek-R1-FP4": [ + [256, 8, 2048, 7168], + ], +} + +DEFAULT_MODELS = [ + "nvidia/DeepSeek-R1-FP4", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp( + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + + +def bench_run(results: list[benchmark.Measurement], model: str, + num_experts: int, topk: int, per_act_token: bool, + per_out_ch: bool, mkn: tuple[int, int, int]): + label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " + "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, + mkn)) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + _, a_fp8_scale = ops.scaled_fp8_quant(a) + + w1_fp8q = torch.empty((num_experts, 2 * n, k), + device=device, + dtype=torch.float8_e4m3fn) + w2_fp8q = torch.empty((num_experts, k, n), + device=device, + dtype=torch.float8_e4m3fn) + w1_fp8scale = torch.empty((num_experts, 1, 1), + device=device, + dtype=torch.float32) + w2_fp8scale = torch.empty((num_experts, 1, 1), + device=device, + dtype=torch.float32) + + for expert in range(num_experts): + w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert]) + + w1_fp8q_notransp = w1_fp8q.clone() + w2_fp8q_notransp = w2_fp8q.clone() + w1_fp8q = w1_fp8q.transpose(1, 2) + w2_fp8q = w2_fp8q.transpose(1, 2) + + score = torch.randn((m, num_experts), device=device, dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + quant_blocksize = 16 + w1_blockscale = torch.empty((num_experts, 2 * n, k // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn) + w2_blockscale = torch.empty((num_experts, k, n // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn) + + # n_b_scales = 2 * n if per_out_ch else 1 + # k_b_scales = k if per_out_ch else 1 + w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), + device=device, + dtype=torch.uint8) + w2_fp4 = torch.empty((num_experts, k, n // 2), + device=device, + dtype=torch.uint8) + + w1_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32) + w2_gs = torch.empty((num_experts, ), device=device, dtype=torch.float32) + a1_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32) + a2_gs = torch.ones((num_experts, ), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_e = w1[expert] + w2_e = w2[expert] + w1_amax = torch.abs(w1_e).max().to(torch.float32) + w2_amax = torch.abs(w2_e).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1_e, w1_gs[expert]) + + w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2_e, w2_gs[expert]) + + def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, num_repeats: int): + for _ in range(num_repeats): + fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale) + + def run_cutlass_moe_fp4(a: torch.Tensor, w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, w1_gs: torch.Tensor, + w2_gs: torch.Tensor, a1_gs: torch.Tensor, + a2_gs: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, m: int, n: int, k: int, + e: int, device: torch.device, num_repeats: int): + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp4", color="green"): + cutlass_moe_fp4(a=a, + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device) + + def run_cutlass_from_graph( + a: torch.Tensor, a1_gscale: torch.Tensor, w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, m: int, n: int, + k: int, e: int, device: torch.device): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return cutlass_moe_fp4(a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_alphas, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device) + + def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + return fused_experts(a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph(a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph(a, w1_fp8q_notransp, w2_fp8q_notransp, + topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, + a_fp8_scale) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + "w1_fp8q_notransp": w1_fp8q_notransp, + "w2_fp8q_notransp": w2_fp8q_notransp, + "w1_fp8scale": w1_fp8scale, + "w2_fp8scale": w2_fp8scale, + "a_fp8_scale": a_fp8_scale, + # Cutlass params + "a": a, + "a1_gscale": a1_gs, + "w1_fp4": w1_fp4, + "w1_blockscale": w1_blockscale, + "w1_alphas": w1_gs, + "a2_gscale": a2_gs, + "w2_fp4": w2_fp4, + "w2_blockscale": w2_blockscale, + "w2_alphas": w2_gs, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "m": m, + "n": n, + "k": k, + "e": num_experts, + "device": device, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe_fp4": run_cutlass_moe_fp4, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, + topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + + run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_gs, + w2_gs, a1_gs, a2_gs, topk_weights, topk_ids, m, n, k, + num_experts, device, num_warmup) + + results.append( + benchmark.Timer( + stmt= + "run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4", + ).blocked_autorange(min_run_time=min_run_time)) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time)) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run(results, model, num_experts, topk, + per_act_token, per_out_ch, mkn) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark NVFP4 CUTLASS MOE across specified " + "models/shapes/batches") + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", + nargs="+", + type=int, + default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 1884a80a4077..4e328b4d49e5 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -6,15 +6,16 @@ from contextlib import nullcontext from datetime import datetime from itertools import product +from types import SimpleNamespace from typing import Any, TypedDict import ray import torch from ray.experimental.tqdm_ray import tqdm -from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser @@ -534,8 +535,12 @@ def get_weight_block_size_safety(config, default_value=None): def main(args: argparse.Namespace): print(args) - config = AutoConfig.from_pretrained( - args.model, trust_remote_code=args.trust_remote_code) + config = get_config(model=args.model, + trust_remote_code=args.trust_remote_code) + if args.model_prefix: + config = getattr(config, args.model_prefix) + config = SimpleNamespace(**config) + if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k @@ -546,15 +551,14 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif (config.architectures[0] == "DeepseekV3ForCausalLM" - or config.architectures[0] == "DeepseekV2ForCausalLM"): + elif (config.architectures[0] + in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM")): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in [ - "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" - ]: + elif config.architectures[0] in ("Qwen2MoeForCausalLM", + "Qwen3MoeForCausalLM"): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size @@ -569,7 +573,8 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + dtype = torch.float16 if current_platform.is_rocm() else getattr( + torch, config.torch_dtype) use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" block_quant_shape = get_weight_block_size_safety(config) @@ -659,6 +664,7 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--model-prefix", type=str, required=False) args = parser.parse_args() main(args) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 00670bd398b5..fb763db9fc35 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) +elseif(POWER10_FOUND) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.7.2 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + set(DNNL_CPU_RUNTIME "OMP") + + FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) endif() @@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) +elseif(POWER10_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) endif() # @@ -214,4 +245,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") \ No newline at end of file +message(STATUS "Enabling C extension.") diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index a8e1be37eb41..089b9840ea2e 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -4,6 +4,7 @@ #include #include +#include #include namespace vec_op { @@ -62,6 +63,10 @@ typedef struct f32x4x4_t { __vector float val[4]; } f32x4x4_t; +typedef struct i32x4x4_t { + __vector int32_t val[4]; +} i32x4x4_t; + struct FP32Vec8; struct FP32Vec16; @@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec { vec_xst(reg.val[0], 0, (signed short*)ptr); vec_xst(reg.val[1], 16, (signed short*)ptr); } + + void save(void* ptr, const int elem_num) const { + const int clamped_elem = std::max(0, std::min(elem_num, 16)); + + // Calculate elements to store in each 128-bit part (8 elements each) + const int elements_val0 = std::min(clamped_elem, 8); + const int elements_val1 = std::max(clamped_elem - 8, 0); + + // Convert elements to bytes (2 bytes per element) + const size_t bytes_val0 = elements_val0 * sizeof(signed short); + const size_t bytes_val1 = elements_val1 * sizeof(signed short); + + signed short* dest = static_cast(ptr); + // Store the first part using vec_xst_len + if (bytes_val0 > 0) { + vec_xst_len(reg.val[0], dest, bytes_val0); + } + // Store the second part if needed + if (bytes_val1 > 0) { + vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); + } + } }; const static __vector signed short zero = vec_splats((signed short)0); @@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec { } }; +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + i32x4x4_t reg; + int32_t values[VEC_ELEM_NUM]; + }; + + i32x4x4_t reg; + + explicit INT32Vec16(const void* data_ptr) { + reg.val[0] = vec_xl(0, reinterpret_cast(data_ptr)); + reg.val[1] = + vec_xl(16, reinterpret_cast(data_ptr)); + reg.val[2] = + vec_xl(32, reinterpret_cast(data_ptr)); + reg.val[3] = + vec_xl(48, reinterpret_cast(data_ptr)); + } + + void save(int32_t* ptr) const { + vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr)); + } + + void save(int32_t* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(int32_t)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(int32_t)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(int32_t)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(int32_t)); + + vec_xst_len(reg.val[0], reinterpret_cast(ptr), bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const INT32Vec16& v) { + reg.val[0] = vec_ctf(v.reg.val[0], 0); + reg.val[1] = vec_ctf(v.reg.val[1], 0); + reg.val[2] = vec_ctf(v.reg.val[2], 0); + reg.val[3] = vec_ctf(v.reg.val[3], 0); + } + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1]), @@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec { vec_div(reg.val[3], b.reg.val[3])})); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(f32x4x4_t( + {vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), + vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), + vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), + vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))})); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]), + vec_max(reg.val[1], b.reg.val[1]), + vec_max(reg.val[2], b.reg.val[2]), + vec_max(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 max(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + // Create a vector of element indices for each chunk + __vector unsigned int indices = {0, 1, 2, 3}; + __vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + // Compute masks for each chunk + __vector unsigned int chunk_offset0 = {0, 0, 0, + 0}; // Chunk 0: Elements 0-3 + __vector unsigned int chunk_offset1 = {4, 4, 4, + 4}; // Chunk 1: Elements 4-7 + __vector unsigned int chunk_offset2 = {8, 8, 8, + 8}; // Chunk 2: Elements 8-11 + __vector unsigned int chunk_offset3 = {12, 12, 12, + 12}; // Chunk 3: Elements 12-15 + + // Compute masks for each chunk + __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + // Apply masks to compute the result for each chunk + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_max(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_max(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_max(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_max(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]), + vec_min(reg.val[1], b.reg.val[1]), + vec_min(reg.val[2], b.reg.val[2]), + vec_min(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 min(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + vector unsigned int indices = {0, 1, 2, 3}; + vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + vector unsigned int chunk_offset0 = {0, 0, 0, 0}; + vector unsigned int chunk_offset1 = {4, 4, 4, 4}; + vector unsigned int chunk_offset2 = {8, 8, 8, 8}; + vector unsigned int chunk_offset3 = {12, 12, 12, 12}; + + vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_min(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_min(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_min(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_min(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 abs() const { + return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]), + vec_abs(reg.val[2]), vec_abs(reg.val[3])})); + } + + float reduce_max() { + __vector float max01 = vec_max(reg.val[0], reg.val[1]); + __vector float max23 = vec_max(reg.val[2], reg.val[3]); + __vector float max_all = vec_max(max01, max23); + __vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8)); + temp = vec_max(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + + float reduce_min() { + __vector float min01 = vec_min(reg.val[0], reg.val[1]); + __vector float min23 = vec_min(reg.val[2], reg.val[3]); + __vector float min_all = vec_min(min01, min23); + __vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8)); + temp = vec_min(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + float reduce_sum() const { AliasReg ar; ar.reg = reg; @@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec { vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[3], 48, ptr); } + + void save(float* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(float)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(float)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(float)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(float)); + + vec_xst_len(reg.val[0], ptr, bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + +struct INT8Vec16 : public Vec { + constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 + + union AliasReg { + __vector signed char reg; + int8_t values[VEC_NUM_ELEM]; + }; + + __vector signed char reg; + + explicit INT8Vec16(const FP32Vec16& vec) { + __vector signed int ret[4]; + ret[0] = vec_cts(vec.reg.val[0], 0); + ret[1] = vec_cts(vec.reg.val[1], 0); + ret[2] = vec_cts(vec.reg.val[2], 0); + ret[3] = vec_cts(vec.reg.val[3], 0); + + __vector signed short packed1 = vec_packs(ret[0], ret[1]); + __vector signed short packed2 = vec_packs(ret[2], ret[3]); + + reg = vec_packs(packed1, packed2); + } + + void save(void* ptr) const { + *reinterpret_cast<__vector signed char*>(ptr) = reg; + } + void save(signed char* ptr, const int elem_num) { + vec_xst_len(reg, ptr, static_cast(elem_num)); + } }; template diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 6751e7e55fc5..f61dbcc948e8 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output, } } +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} +#elif defined(__powerpc64__) +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } +} +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} template void dynamic_quant_epilogue(const float* input, scalar_t* output, const float* a_scale, const float* b_scale, @@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, + "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") } template @@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, + "dynamic_quant_epilogue requires AVX512/powerpc64 support.") } #endif } // namespace @@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant( } }); } + +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_ppc64le only supports INT8 inputs."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + // We dont need this + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); + } + }); +} + +#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 84b2a8555ccf..447e826bc1c0 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const std::optional& azp, const std::optional& bias); +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias); +#endif + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor azp_adj," " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#elif defined(__powerpc64__) + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index dbe0e30f5cbf..0877da52435e 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel { #endif } }; + +template +struct enable_sm100_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/ops.h b/csrc/ops.h index 1dfd2e067e85..21c5a9e29740 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -208,6 +208,12 @@ void cutlass_moe_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets); + void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -235,6 +241,12 @@ std::vector cutlass_sparse_compress(torch::Tensor const& a); void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 000000000000..84492553c02f --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,27 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK( + a.size(0) % 4 == 0, + "Input tensor must have a number of rows that is a multiple of 4. ", + "but got: ", a.size(0), " rows."); + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh new file mode 100644 index 000000000000..ef324364c6d5 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -0,0 +1,205 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template +struct cutlass_3x_gemm_fp8_blockwise { + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using LayoutC = LayoutD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + // MMA and Cluster Tile Shapes + // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster + // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; + static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); + static constexpr int ScaleGranularityM = + size<0>(MmaTileShape{}) / ScaleMsPerTile; + static constexpr int ScaleGranularityN = + size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); + static constexpr int ScaleGranularityK = + size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); + + // Shape of the threadblocks in a cluster + using ClusterShape_MNK = ClusterShape; + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + // clang-format off + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; + // clang-format on + + using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = cute::make_shape(m, n, k, 1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto m = a.size(0); + auto k = a.size(1); + auto n = b.size(1); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { + return std::ceil(static_cast(m) / tile1SM) * + std::ceil(static_cast(n) / tile1SM) >= + sms; + }; + bool use_2sm = should_use_2sm(m, n); + if (use_2sm) { + cutlass_gemm_caller_blockwise, Shape<_256, _1, _1>, + Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Shape<_128, _1, _1>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 000000000000..b589a479081e --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,57 @@ +#include +#include "cuda_utils.h" + +template +void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias, + Fp8Func fp8_func, Int8Func int8_func, + BlockwiseFunc blockwise_func) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int M = a.size(0), N = b.size(1), K = a.size(1); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == torch::kFloat8_e4m3fn) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(a.dtype() == torch::kInt8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + blockwise_func(c, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 85272804774d..c1242fdb39da 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu index 459eb1bb76eb..0cbd5305e3c2 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm100 (Blackwell). @@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - TORCH_CHECK( - (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), - "Currently, block scaled fp8 gemm is not implemented for Blackwell"); - - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm100_fp8, + nullptr, // int8 not supported on SM100 + vllm::cutlass_scaled_mm_blockwise_sm100_fp8); } #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu index bcb91040d5e2..211302171f07 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper). @@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - - if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == torch::kFloat8_e4m3fn) { - vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); - } else { - TORCH_CHECK(a.dtype() == torch::kInt8); - vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); - } - } else { - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); - - vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); - } + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm90_fp8, + vllm::cutlass_scaled_mm_sm90_int8, + vllm::cutlass_scaled_mm_blockwise_sm90_fp8); } void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 54b63894e4cb..9843cd857d48 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -37,12 +37,6 @@ void cutlass_moe_mm_sm90( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void get_cutlass_moe_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); - #endif #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 @@ -53,6 +47,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif +#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ + defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); +#endif + void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -110,6 +113,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { #if defined CUDA_VERSION if (cuda_device_capability >= 90 && cuda_device_capability < 100) { return CUDA_VERSION >= 12000; + } else if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; } #endif @@ -222,7 +227,8 @@ void get_cutlass_moe_mm_data( // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k); diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu new file mode 100644 index 000000000000..45ec3d29ce04 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -0,0 +1,402 @@ +#include +#include + +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include + +using namespace cute; + +template +__global__ void __get_group_gemm_starts( + ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, + ElementSF** a_scales_offsets, ElementSF** b_scales_offsets, + ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, + ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int, + ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets, + const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes, + const int K, const int N) { + int64_t expert_id = threadIdx.x; + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + // Originally int32_t but upcasting to int64_t to avoid overflow + // during offset calculations + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t sf_offset = static_cast(sf_offsets[expert_id]); + // size for block in block scale. + int64_t group_size = 16; + int64_t m = static_cast(problem_sizes_as_shapes[expert_id * 3]); + int64_t n = static_cast(problem_sizes_as_shapes[expert_id * 3 + 1]); + int64_t k = static_cast(problem_sizes_as_shapes[expert_id * 3 + 2]); + assert((m >= 0 && n == N && k == K && k % 2 == 0) && + "unexpected problem sizes"); + + int64_t half_k = static_cast(k / 2); + int64_t group_k = static_cast(k / group_size); + // Shape of A as uint8/byte = [M, K // 2] + // Shape of B as uint8/byte = [E, N, K // 2] + a_offsets[expert_id] = a_base_as_int + expert_offset * half_k; + + b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k; + // Shape of C = [M, N] + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + // Shape of a_scale = [sum(sf_sizes), K // group_size] + a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k; + + assert((reinterpret_cast(a_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + + // Shape of B scale = [E, N, K // group_size] + b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k; + assert((reinterpret_cast(b_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + // Shape of alpha = [E] + alpha_offsets[expert_id] = alphas_base_as_int + expert_id; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); +} + +#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \ + TENSOR_C_TYPE, C_TYPE, LayoutSFA, \ + LayoutSFB, ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + __get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(a_starts.data_ptr()), \ + static_cast(b_starts.data_ptr()), \ + static_cast(out_starts.data_ptr()), \ + static_cast(a_scales_starts.data_ptr()), \ + static_cast(b_scales_starts.data_ptr()), \ + static_cast(alpha_starts.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast(alphas.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(sf_offsets.data_ptr()), \ + static_cast(problem_sizes.data_ptr()), K, N); \ + } + +template +void run_get_group_gemm_starts( + const torch::Tensor& a_starts, const torch::Tensor& b_starts, + const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts, + const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts, + const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, + /*these are used for their base addresses*/ + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor const& out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& alphas, + torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets, + torch::Tensor const& problem_sizes, int M, int N, int K) { + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + TORCH_CHECK(out_tensors.size(1) == N, + "Output tensor shape doesn't match expected shape"); + TORCH_CHECK(K / 2 == b_tensors.size(2), + "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" + " dimension must match"); + if (false) { + } + //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, + // ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16, + cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t, + cutlass::float_ue4m3_t, torch::kFloat16, + half, LayoutSFA, LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm100; + using EpilogueOperatorClass = + cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag + using MainloopOperatorClass = + cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + + using ClusterShape = Shape<_1, _1, _1>; + struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _128, _128>; + using KernelSchedule = cutlass::gemm:: + KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + }; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape, + ClusterShape, Shape<_128, _64>, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentD, + typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA, + ElementB, LayoutB*, AlignmentB, ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = Gemm1SM; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + torch::Tensor c_strides1 = + torch::full({num_experts}, output.stride(0), options_int); + torch::Tensor a_strides1 = + torch::full({num_experts}, a.stride(0) * 2, options_int); + torch::Tensor b_strides1 = + torch::full({num_experts}, b.stride(1) * 2, options_int); + + run_get_group_gemm_starts( + a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, + layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, + expert_offsets, sf_offsets, problem_sizes, M, N, K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm100GroupParams< + typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides1.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides1.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides1.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides1.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = + reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +#define CHECK_TYPE(x, st, m) \ + TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) \ + TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + // Input validation + CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); + CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); + CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); + CHECK_INPUT(alphas, at::ScalarType::Float, "alphas"); + + TORCH_CHECK(a_blockscale.dim() == 2, + "expected a_blockscale to be of shape [num_experts, rounded_m," + " k // group_size], observed rank: ", + a_blockscale.dim()) + TORCH_CHECK(b_blockscales.dim() == 3, + "expected b_blockscale to be of shape: " + " [num_experts, n, k // group_size], observed rank: ", + b_blockscales.dim()) + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have the shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32."); + + int M = static_cast(a.size(0)); + int N = static_cast(b.size(1)); + int E = static_cast(b.size(0)); + int K = static_cast(2 * b.size(2)); + + if (output.scalar_type() == torch::kBFloat16) { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } else { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel, vLLM must " + "be compiled with ENABLE_NVFP4 for SM100+ and CUDA " + "12.8 or above."); +#endif +} diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu new file mode 100644 index 000000000000..076c4a085337 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -0,0 +1,404 @@ +#include + +#include +#include + +#include +#include + +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + // Find index within the experts. + int rowIdx_in_expert = 0; + int expert_idx = 0; + for (int i = 0; i < n_experts; i++) { + if (rowIdx >= input_offset_by_experts[i] && + rowIdx < input_offset_by_experts[i + 1]) { + rowIdx_in_expert = rowIdx - input_offset_by_experts[i]; + expert_idx = i; + break; + } + } + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void quant_impl(void* output, void* output_scale, void* input, + void* input_global_scale, void* input_offset_by_experts, + void* output_scale_offset_by_experts, int m_topk, int k, + int n_experts, cudaStream_t stream) { + // TODO: this multiProcessorCount should be cached. + int device; + cudaGetDevice(&device); + int multiProcessorCount; + cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, + device); + + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(k / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM)); + + cvt_fp16_to_fp4<<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), n_experts); +} + +/*Quantization entry for fp4 experts quantization*/ +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); + +constexpr auto HALF = at::ScalarType::Half; +constexpr auto BF16 = at::ScalarType::BFloat16; +constexpr auto FLOAT = at::ScalarType::Float; +constexpr auto INT = at::ScalarType::Int; +constexpr auto UINT8 = at::ScalarType::Byte; + +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + CHECK_INPUT(output, "output must be a CUDA tensor"); + CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); + CHECK_INPUT(input, "input must be a CUDA tensor"); + CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); + CHECK_INPUT(input_offset_by_experts, + "input_offset_by_experts must be a CUDA tensor"); + CHECK_INPUT(output_scale_offset_by_experts, + "output_scale_offset_by_experts must be a CUDA tensor"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + auto m_topk = input.size(0); + auto k = input.size(1); + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + + auto in_dtype = input.dtype(); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + if (in_dtype == at::ScalarType::Half) { + quant_impl(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, + n_experts, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, + k, n_experts, stream); + } else { + TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); + } +} \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index b1426c43b456..badbb7e310df 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, torch::Tensor const& input_sf); #endif +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if defined ENABLE_NVFP4 && ENABLE_NVFP4 return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); +} + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_experts_quant_sm100a( + output, output_scale, input, input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled nvfp4 experts quantization kernel"); } diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index ea3bb4299046..03bd5964a7fc 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -9,7 +9,7 @@ at::Tensor as_g_workspace; torch::Tensor allspark_w8a16_gemm( torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, c10::optional const& b_qzeros, + torch::Tensor const& b_scales, std::optional const& b_qzeros, int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -918,7 +918,7 @@ void allspark_qgemm_w8a16_perc_ampere( torch::Tensor allspark_w8a16_gemm( torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, c10::optional const& b_qzeros, + torch::Tensor const& b_scales, std::optional const& b_qzeros, int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { // Verify device and strides diff --git a/csrc/quantization/gptq_allspark/allspark_repack.cu b/csrc/quantization/gptq_allspark/allspark_repack.cu index ea8eccf040df..7a5b2f95cc2e 100644 --- a/csrc/quantization/gptq_allspark/allspark_repack.cu +++ b/csrc/quantization/gptq_allspark/allspark_repack.cu @@ -100,9 +100,9 @@ void rearrange_kn_weight_as_n32k16_order_ldg16( void rearrange_kn_weight_as_n32k16_order( torch::Tensor const& b_qweight, torch::Tensor const& b_scales, - c10::optional const& b_zeros, bool has_zp, + std::optional const& b_zeros, bool has_zp, torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder, - c10::optional const& b_zeros_reorder, const int64_t K, + std::optional const& b_zeros_reorder, const int64_t K, const int64_t N, const int64_t N_32align) { // Verify device and strides TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 292352649163..8cc5a0f4f218 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; const auto head_idx = blockIdx.x; const auto seq_idx = blockIdx.y; @@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + const float out_scale = + (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; - + acc *= out_scale; const int64_t query_start_off = static_cast( query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + @@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE } // clang-format on @@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( PARTITION_SIZE, NPAR_LOOPS> \ <<>>( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions); + context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ + fp8_out_scale_ptr); template & query_start_loc, int max_context_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, const std::optional& fp8_out_scale) { int num_seqs = block_tables.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher( int* context_lens_ptr = context_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: fp8_out_scale is optional. + const auto fp8_out_scale_ptr = + fp8_out_scale + ? static_cast(fp8_out_scale.value().data_ptr()) + : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); @@ -1736,33 +1744,54 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \ - ALIBI_ENABLED) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale); - -#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - PSIZE) \ - if (alibi_slopes) { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \ - } else { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \ +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE, ALIBI_ENABLED) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); + +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + false); \ } -#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ - switch (block_size) { \ - case 16: \ - CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \ - break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#if defined(__HIPCC__) && defined(__gfx90a__) + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + } else { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ + 256); \ + } +#else + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + uint8_t, 256); \ + } else { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ + 256); \ + } +#endif + +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ @@ -1795,7 +1824,8 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, + const std::optional& fp8_out_scale) { // clang-format on const int head_size = query.size(2); if (kv_cache_dtype == "auto") { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b90cfdc617af..e538197dbcb0 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); -void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, - double scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - const std::optional& query_start_loc, - int64_t block_size, int64_t max_context_len, - const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale); +void paged_attention( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int64_t block_size, + int64_t max_context_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const std::optional& fp8_out_scale); diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 72d2820f2aab..b3717892db78 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -126,8 +126,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid / num_warps; - const int qthreadid = threadid % num_warps; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; float acc[NUM_A_ROWS_PER_BLOCK]; @@ -142,15 +142,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, // rowA_elem4[i] holds 8 * half numbers seen as a single float4. rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; } - colB_elem4x = bf4[threadid * 4 + 0]; - colB_elem4y = bf4[threadid * 4 + 1]; - colB_elem4z = bf4[threadid * 4 + 2]; - colB_elem4w = bf4[threadid * 4 + 3]; - scalar2_t Af2; - [[maybe_unused]] scalar2_t Bf2; float2 S; auto Ah2ptr = reinterpret_cast(&rowA_elem4); @@ -193,12 +191,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, if (qwarpid < NUM_A_ROWS_PER_BLOCK) { acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; - for (int mask = num_warps / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); } - float oval2 = __shfl_xor(acc[qwarpid], num_warps); + float oval2 = __shfl_xor(acc[qwarpid], 16); - if (lane % (num_warps * 2) == 0) { + if (lane % 32 == 0) { oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; } @@ -222,9 +221,10 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle // operations. const int NUM_THREADS = - K * 2 / 16 % WARP_SIZE == 0 - ? K * 2 / 16 - : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + max(rows_per_block * 16, + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE)); int NUM_BLOCKS = M / rows_per_block; @@ -275,13 +275,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -318,6 +327,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -343,7 +353,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -374,24 +388,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -419,32 +417,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t n = 0; n < N; n++) { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]) - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -453,37 +436,84 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } - m += CuCount * _WvPrGrp * YTILE; } } @@ -505,13 +535,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -573,6 +612,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -598,7 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -628,24 +672,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -676,32 +704,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -710,34 +723,82 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -774,14 +835,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -857,6 +926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) kFit = min(kFit, K); float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -888,7 +958,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -937,24 +1011,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -989,32 +1047,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -1031,34 +1074,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 4ac6fd1e9940..34575477bcc9 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale," + " Tensor? fp8_out_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 7ca40a5e7827..1dbd11f5f2a5 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -363,6 +363,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // cutlass nvfp4 block scaled group GEMM + ops.def( + "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," + " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," + " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", + {stride_tag}); + ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( @@ -492,6 +500,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! output_scale, Tensor input_scale) -> ()"); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Compute NVFP4 experts quantization. + ops.def( + "scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); + // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices // of the given capability ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 12009b8aa046..222b9c158e5e 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="7e1ed08" +ARG AITER_BRANCH="5a77249" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base diff --git a/docs/source/assets/deployment/chatbox-chat.png b/docs/source/assets/deployment/chatbox-chat.png new file mode 100644 index 000000000000..b1718cb50471 Binary files /dev/null and b/docs/source/assets/deployment/chatbox-chat.png differ diff --git a/docs/source/assets/deployment/chatbox-settings.png b/docs/source/assets/deployment/chatbox-settings.png new file mode 100644 index 000000000000..a8e3d7b2894c Binary files /dev/null and b/docs/source/assets/deployment/chatbox-settings.png differ diff --git a/docs/source/assets/deployment/dify-chat.png b/docs/source/assets/deployment/dify-chat.png new file mode 100644 index 000000000000..dfea23309c1c Binary files /dev/null and b/docs/source/assets/deployment/dify-chat.png differ diff --git a/docs/source/assets/deployment/dify-create-chatbot.png b/docs/source/assets/deployment/dify-create-chatbot.png new file mode 100644 index 000000000000..07bbde5ba285 Binary files /dev/null and b/docs/source/assets/deployment/dify-create-chatbot.png differ diff --git a/docs/source/assets/deployment/dify-settings.png b/docs/source/assets/deployment/dify-settings.png new file mode 100644 index 000000000000..7900cc774741 Binary files /dev/null and b/docs/source/assets/deployment/dify-settings.png differ diff --git a/docs/source/community/meetups.md b/docs/source/community/meetups.md index 085918bed2b0..aa1a71c86c0a 100644 --- a/docs/source/community/meetups.md +++ b/docs/source/community/meetups.md @@ -4,6 +4,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [The first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg), March 16th 2025. [[Slides]](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). diff --git a/docs/source/deployment/frameworks/chatbox.md b/docs/source/deployment/frameworks/chatbox.md new file mode 100644 index 000000000000..e62f4647150f --- /dev/null +++ b/docs/source/deployment/frameworks/chatbox.md @@ -0,0 +1,36 @@ +(deployment-chatbox)= + +# Chatbox + +[Chatbox](https://github.com/chatboxai/chatbox) is a desktop client for LLMs, available on Windows, Mac, Linux. + +It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. + +## Prerequisites + +- Setup vLLM environment + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve qwen/Qwen1.5-0.5B-Chat +``` + +- Download and install [Chatbox desktop](https://chatboxai.app/en#download). + +- On the bottom left of settings, Add Custom Provider + - API Mode: `OpenAI API Compatible` + - Name: vllm + - API Host: `http://{vllm server host}:{vllm server port}/v1` + - API Path: `/chat/completions` + - Model: `qwen/Qwen1.5-0.5B-Chat` + +:::{image} /assets/deployment/chatbox-settings.png +::: + +- Go to `Just chat`, and start to chat: + +:::{image} /assets/deployment/chatbox-chat.png +::: diff --git a/docs/source/deployment/frameworks/dify.md b/docs/source/deployment/frameworks/dify.md new file mode 100644 index 000000000000..5cdf6a387637 --- /dev/null +++ b/docs/source/deployment/frameworks/dify.md @@ -0,0 +1,56 @@ +(deployment-dify)= + +# Dify + +[Dify](https://github.com/langgenius/dify) is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. + +It supports vLLM as a model provider to efficiently serve large language models. + +This guide walks you through deploying Dify using a vLLM backend. + +## Prerequisites + +- Setup vLLM environment +- Install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve Qwen/Qwen1.5-7B-Chat +``` + +- Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): + +```console +git clone https://github.com/langgenius/dify.git +cd dify +cd docker +cp .env.example .env +docker compose up -d +``` + +- Open the browser to access `http://localhost/install`, config the basic login information and login. + +- In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. + +- Fill in the model provider details as follows: + - **Model Type**: `LLM` + - **Model Name**: `Qwen/Qwen1.5-7B-Chat` + - **API Endpoint URL**: `http://{vllm_server_host}:{vllm_server_port}/v1` + - **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat` + - **Completion Mode**: `Completion` + +:::{image} /assets/deployment/dify-settings.png +::: + +- To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: + +:::{image} /assets/deployment/dify-create-chatbot.png +::: + +- Click the chatbot you just created to open the chat interface and start interacting with the model: + +:::{image} /assets/deployment/dify-chat.png +::: diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md index d1c058eafa4c..6708f2c4135f 100644 --- a/docs/source/deployment/frameworks/index.md +++ b/docs/source/deployment/frameworks/index.md @@ -6,6 +6,8 @@ anything-llm bentoml cerebrium +chatbox +dify dstack helm lws diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index ec661d8ec641..0f7475777797 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -86,7 +86,7 @@ To improve privacy in shared environments, vLLM supports isolating prefix cache {"role": "user", "content": "Here is a document with details about the world series: ..."}, {"role": "user", "content": "Who won the world series in 2020?"} ], - "cache_salt": "Z3V2bmV3aGxza3ZubGFoZ3Zud3V3ZWZ2bmd0b3V2bnZmc2xpZ3RoZ2x2aQ==" + "cache_salt": "your-cache-salt" } ``` diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 7920131643c2..4d8ce0fd9227 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -137,3 +137,9 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You `vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. + +### Full Cudagraph capture + +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"` + +Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. diff --git a/docs/source/features/quantization/fp8.md b/docs/source/features/quantization/fp8.md index 21969bbc2b9f..cb304d54726c 100644 --- a/docs/source/features/quantization/fp8.md +++ b/docs/source/features/quantization/fp8.md @@ -117,7 +117,7 @@ Here's an example of the resulting scores: ## Troubleshooting and Support -If you encounter any issues or have feature requests, please open an issue on the `vllm-project/llm-compressor` GitHub repository. +If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository. ## Online Dynamic Quantization diff --git a/docs/source/features/quantization/int4.md b/docs/source/features/quantization/int4.md index be48788a4ef6..7a0ab4ad229e 100644 --- a/docs/source/features/quantization/int4.md +++ b/docs/source/features/quantization/int4.md @@ -169,4 +169,4 @@ recipe = GPTQModifier( ## Troubleshooting and Support -If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository. The full INT4 quantization example in `llm-compressor` is available [here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py). +If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository. The full INT4 quantization example in `llm-compressor` is available [here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py). diff --git a/docs/source/features/quantization/int8.md b/docs/source/features/quantization/int8.md index d6ddca18e268..1e4b01d35575 100644 --- a/docs/source/features/quantization/int8.md +++ b/docs/source/features/quantization/int8.md @@ -138,4 +138,4 @@ Quantized models can be sensitive to the presence of the `bos` token. Make sure ## Troubleshooting and Support -If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository. +If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository. diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index a079eb8b77e7..4759d0c26c35 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -17,7 +17,9 @@ vLLM currently supports the following reasoning models: | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | | [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | -- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. +:::{note} +IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. +::: ## Quickstart @@ -83,7 +85,7 @@ Streaming chat completions are also supported for reasoning models. The `reasoni } ``` -OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client support extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example: +OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client supports extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example: ```python from openai import OpenAI @@ -221,7 +223,7 @@ print(f"Function called: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") ``` -For more examples, please refer to . +For more examples, please refer to . ## Limitations @@ -229,7 +231,7 @@ For more examples, please refer to . ```python # import the required packages @@ -286,7 +288,7 @@ class ExampleParser(ReasoningParser): """ ``` -Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`. +Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in . ```python @dataclass @@ -312,7 +314,7 @@ class DeepSeekReasoner(Reasoner): ... ``` -The structured output engine like `xgrammar` will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. +The structured output engine like [xgrammar](https://github.com/mlc-ai/xgrammar) will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. Finally, you can enable reasoning for the model by using the `--reasoning-parser` flags. diff --git a/docs/source/getting_started/installation/python_env_setup.inc.md b/docs/source/getting_started/installation/python_env_setup.inc.md index a03d35030fe8..00b61ea5c826 100644 --- a/docs/source/getting_started/installation/python_env_setup.inc.md +++ b/docs/source/getting_started/installation/python_env_setup.inc.md @@ -14,6 +14,6 @@ Or you can create a new Python environment using [uv](https://docs.astral.sh/uv/ ```console # (Recommended) Create a new uv environment. Use `--seed` to install `pip` and `setuptools` in the environment. -uv venv vllm --python 3.12 --seed -source vllm/bin/activate +uv venv --python 3.12 --seed +source .venv/bin/activate ``` diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index bcaa4f9b96cd..bb2997f008ed 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -213,10 +213,13 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions :::{important} A chat template is **required** to use Chat Completions API. +For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`. -Although most models come with a chat template, for others you have to define one yourself. -The chat template can be inferred based on the documentation on the model's HuggingFace repo. -For example, DeepSeek-VL2 requires a chat template that can be found here: +If no default chat template is available, we will first look for a built-in fallback in . +If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument. + +For certain models, we provide alternative chat templates inside . +For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. ::: ### Image Inputs diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index 6857c6e9e31d..8e6f78ed7de2 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -7,9 +7,8 @@ def create_parser(): parser = FlexibleArgumentParser() # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") # Add sampling params sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument("--max-tokens", type=int) diff --git a/examples/offline_inference/basic/generate.py b/examples/offline_inference/basic/generate.py index 54b52b22a45a..72f4a8208386 100644 --- a/examples/offline_inference/basic/generate.py +++ b/examples/offline_inference/basic/generate.py @@ -7,9 +7,8 @@ def create_parser(): parser = FlexibleArgumentParser() # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") # Add sampling params sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument("--max-tokens", type=int) diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 91e2f68ecffb..020521611f33 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -118,8 +118,8 @@ def main(): acceptance_counts[step] += count print("-" * 50) - print(f"mean acceptance length: \ - {sum(acceptance_counts) / acceptance_counts[0]:.2f}") + print(f"mean acceptance length (including bonus tokens): \ + {1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}") print("-" * 50) # print acceptance at each token position diff --git a/examples/offline_inference/reproduciblity.py b/examples/offline_inference/reproducibility.py similarity index 100% rename from examples/offline_inference/reproduciblity.py rename to examples/offline_inference/reproducibility.py diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index 9c57af1c158c..660369e55d40 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -138,7 +138,7 @@ def main(): api_key="-", ) - model = "Qwen/Qwen2.5-3B-Instruct" + model = client.models.list().data[0].id print("Guided Choice Completion:") print(guided_choice_completion(client, model)) diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py index b807bc540526..42aa12c451c0 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py @@ -59,7 +59,7 @@ def main(): }] response = client.chat.completions.create( - model="meta-llama/Llama-3.1-8B-Instruct", + model=client.models.list().data[0].id, messages=messages, response_format={ "type": diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py index 5da9236c5306..a04f0cdf12f7 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -4,7 +4,7 @@ like DeepSeekR1. The thinking process will not be guided by the JSON schema provided by the user. Only the final output will be structured. -To run this example, you need to start the vLLM server with the reasoning +To run this example, you need to start the vLLM server with the reasoning parser: ```bash diff --git a/examples/online_serving/ray_serve_deepseek.py b/examples/online_serving/ray_serve_deepseek.py index f9ef3e2da1a1..e2dce107e78a 100644 --- a/examples/online_serving/ray_serve_deepseek.py +++ b/examples/online_serving/ray_serve_deepseek.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """ Example to deploy DeepSeek R1 or V3 with Ray Serve LLM. -See Ray Serve LLM documentation at: +See more details at: +https://docs.ray.io/en/latest/serve/tutorials/serve-deepseek.html +And see Ray Serve LLM documentation at: https://docs.ray.io/en/latest/serve/llm/serving-llms.html Run `python3 ray_serve_deepseek.py` to deploy the model. diff --git a/examples/template_florence2.jinja b/examples/template_florence2.jinja deleted file mode 100644 index 3fa2cccc2406..000000000000 --- a/examples/template_florence2.jinja +++ /dev/null @@ -1,3 +0,0 @@ -{%- for message in messages -%} - {{- message['content'] -}} -{%- endfor -%} diff --git a/examples/template_paligemma.jinja b/examples/template_paligemma.jinja deleted file mode 100644 index 3fa2cccc2406..000000000000 --- a/examples/template_paligemma.jinja +++ /dev/null @@ -1,3 +0,0 @@ -{%- for message in messages -%} - {{- message['content'] -}} -{%- endfor -%} diff --git a/examples/template_qwen_vl.jinja b/examples/template_qwen_vl.jinja deleted file mode 100644 index 3fa2cccc2406..000000000000 --- a/examples/template_qwen_vl.jinja +++ /dev/null @@ -1,3 +0,0 @@ -{%- for message in messages -%} - {{- message['content'] -}} -{%- endfor -%} diff --git a/examples/tool_chat_template_mistral3.jinja b/examples/tool_chat_template_mistral3.jinja index 2b2f94d7ef52..7c4249ec44c5 100644 --- a/examples/tool_chat_template_mistral3.jinja +++ b/examples/tool_chat_template_mistral3.jinja @@ -29,7 +29,14 @@ {%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} -{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} +{%- set filtered_messages = [] %} +{%- for message in loop_messages %} + {%- if message["role"] not in ["tool", "tool_results"] and not message.get("tool_calls") %} + {%- set filtered_messages = filtered_messages + [message] %} + {%- endif %} +{%- endfor %} + +{%- for message in filtered_messages %} {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} {%- endif %} @@ -116,4 +123,4 @@ {%- else %} {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} {%- endif %} -{%- endfor %} \ No newline at end of file +{%- endfor %} diff --git a/requirements/common.txt b/requirements/common.txt index 7ea27753eab7..f537b3aab541 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -19,7 +19,7 @@ pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 -llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" +llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64" @@ -29,7 +29,7 @@ partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 -importlib_metadata +importlib_metadata; python_version < '3.10' mistral_common[opencv] >= 1.5.4 opencv-python-headless >= 4.11.0 # required for video IO pyyaml @@ -43,7 +43,7 @@ watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu -opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-exporter-otlp>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0 # vllm.tracing +opentelemetry-sdk>=1.26.0 # vllm.tracing +opentelemetry-api>=1.26.0 # vllm.tracing +opentelemetry-exporter-otlp>=1.26.0 # vllm.tracing +opentelemetry-semantic-conventions-ai>=0.4.1 # vllm.tracing diff --git a/requirements/docs.txt b/requirements/docs.txt index ccc5ef0aa978..9c267edaceaf 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -6,6 +6,7 @@ sphinx-design==0.6.1 sphinx-togglebutton==0.3.2 myst-parser==3.0.1 # `myst-parser==4.0.1` breaks inline code in titles msgspec +snowballstemmer<3 # https://github.com/snowballstem/snowball/issues/229 commonmark # Required by sphinx-argparse when using :markdownhelp: # Custom autodoc2 is necessary for faster docstring processing diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 52fbf787f1df..abd4212c6e35 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -1,3 +1,5 @@ +# Common dependencies +-r common.txt # entrypoints test # librosa==0.10.2.post1 # required by audio tests in entrypoints/openai diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 1458f0893a93..9f3b0e8ae079 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -5,11 +5,13 @@ """ import os import weakref +from unittest.mock import Mock import pytest from vllm import LLM from vllm.platforms import current_platform +from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import VllmRunner from ..models.utils import check_outputs_equal @@ -152,9 +154,44 @@ def test_models_distributed( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +def test_failed_model_execution(vllm_runner, monkeypatch) -> None: + + from vllm.envs import VLLM_USE_V1 + + if not VLLM_USE_V1: + pytest.skip("Skipping V0 test, dump input not supported") + + # Needed to mock an error in the same process + monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + + with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: + if isinstance(vllm_model.model.llm_engine, LLMEngineV1): + v1_test_failed_model_execution(vllm_model) + + +def v1_test_failed_model_execution(vllm_model): + + engine = vllm_model.model.llm_engine + mocked_execute_model = Mock( + side_effect=RuntimeError("Mocked Critical Error")) + engine.engine_core.engine_core.model_executor.execute_model =\ + mocked_execute_model + + with pytest.raises(RuntimeError) as exc_info: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + vllm_model.generate_greedy(prompts, 200, use_tqdm=False) + assert isinstance(exc_info.value, RuntimeError) + assert "Mocked Critical Error" in str(exc_info.value) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py new file mode 100644 index 000000000000..a71a40cda73e --- /dev/null +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import os + +import pytest + +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig + +MODEL = "Qwen/Qwen2-1.5B-Instruct" + + +@contextlib.contextmanager +def temporary_environ(env_vars): + """ + Temporarily set environment variables and restore them afterward. + We have to do this vs monkeypatch because monkeypatch doesn't work + with "module" scoped fixtures. + """ + original_env = {k: os.environ.get(k) for k in env_vars} + try: + os.environ.update(env_vars) + yield + finally: + for k, v in original_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +@pytest.fixture(scope="module") +def full_cudagraph_llm(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "3" + }): + return LLM(model=MODEL, + gpu_memory_utilization=0.2, + compilation_config=CompilationConfig(full_cuda_graph=True)) + + +@pytest.fixture(scope="module") +def piecewise_llm(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "3" + }): + return LLM(model=MODEL, + gpu_memory_utilization=0.5, + compilation_config=CompilationConfig()) + + +def generate_text(llm: LLM, batch_size: int, max_tokens: int): + prompts = ["Hi my name is"] * batch_size + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + top_p=0.95) + + return llm.generate(prompts, sampling_params) + + +@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), + (16, 10), (25, 10), + (32, 10), (45, 10), + (64, 10), (8, 5), + (8, 20), (8, 200)]) +def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, + piecewise_llm): + """ + Load full cudagraph model and piecewise model once, and at the same time to + reuse them across various test cases. + + Test various batch sizes and max_tokens to ensure that the full cudagraph + compilation works for padded cases too. + """ + piecewise_responses = generate_text(piecewise_llm, + batch_size=batch_size, + max_tokens=max_tokens) + full_cudagraph_responses = generate_text(full_cudagraph_llm, + batch_size=batch_size, + max_tokens=max_tokens) + + # Check that all responses are the same + for i in range(len(piecewise_responses)): + assert piecewise_responses[i].outputs[ + 0].text == full_cudagraph_responses[i].outputs[0].text + + +def test_full_cudagraph_with_invalid_backend(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": + "2" #FA2 not supported with full_cuda_graph + }), pytest.raises(RuntimeError): + LLM(model=MODEL, + compilation_config=CompilationConfig(full_cuda_graph=True)) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index c09406385987..397517b8665b 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -9,7 +9,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel, PassConfig from vllm.platforms import current_platform from ..utils import create_new_process_for_each_test @@ -95,9 +95,6 @@ def test_full_graph( run_model(optimization_level, model, model_kwargs) -PassConfig = CompilationConfig.PassConfig - - # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( "compilation_config, model_info", diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 1e1364ce7bf6..5d38ff91490e 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -11,7 +11,7 @@ kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, PassConfig, VllmConfig from .backend import TestBackend @@ -53,9 +53,8 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, torch.set_default_device("cuda") vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config= \ - CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_noop=True)) + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) noop_pass = NoOpEliminationPass(vllm_config) fusion_pass = FusionPass.instance(vllm_config) act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6a696fe0226b..4d56b34bdecf 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -9,7 +9,8 @@ FusionPass, QuantKey) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) @@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) vllm_config.compilation_config.pass_config = \ - CompilationConfig.PassConfig(enable_fusion=True, - enable_noop=True) + PassConfig(enable_fusion=True, enable_noop=True) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 79f5486dadcd..6152f171705b 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -10,7 +10,7 @@ find_specified_fn_maybe, is_func) from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - VllmConfig) + PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) @@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=CompilationConfig.PassConfig( - enable_sequence_parallelism=True, ), ) + vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( + enable_sequence_parallelism=True)) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 313848372e04..f87f175acd06 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -6,7 +6,7 @@ from vllm._custom_ops import scaled_fp8_quant from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from .backend import TestBackend @@ -36,8 +36,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): # Reshape pass is needed for the fusion pass to work config = VllmConfig() config.compilation_config = CompilationConfig( - pass_config=CompilationConfig.PassConfig(enable_fusion=True, - enable_reshape=True)) + pass_config=PassConfig(enable_fusion=True, enable_reshape=True)) fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(fusion_pass) diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 19497ad9c140..bbf3ed5843b3 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -206,7 +206,7 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallism': sp_enabled, + 'enable_sequence_parallelism': sp_enabled, 'enable_noop': True, 'enable_fusion': True, }, @@ -223,7 +223,7 @@ def _compare_sp( "--distributed-executor-backend", distributed_backend, "--compilation_config", - str(compilation_config), + json.dumps(compilation_config), ] tp_env = { diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 65471cb3af38..ce8873d58d4d 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -8,21 +8,18 @@ import pytest -from vllm.config import config +from vllm.config import CompilationConfig, config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, get_type, is_not_builtin, is_type, literal_to_kwargs, nullable_kvs, - optional_type) + optional_type, parse_type) from vllm.utils import FlexibleArgumentParser @pytest.mark.parametrize(("type", "value", "expected"), [ (int, "42", 42), - (int, "None", None), (float, "3.14", 3.14), - (float, "None", None), (str, "Hello World!", "Hello World!"), - (str, "None", None), (json.loads, '{"foo":1,"bar":2}', { "foo": 1, "bar": 2 @@ -31,15 +28,20 @@ "foo": 1, "bar": 2 }), - (json.loads, "None", None), ]) -def test_optional_type(type, value, expected): - optional_type_func = optional_type(type) +def test_parse_type(type, value, expected): + parse_type_func = parse_type(type) context = nullcontext() if value == "foo=1,bar=2": context = pytest.warns(DeprecationWarning) with context: - assert optional_type_func(value) == expected + assert parse_type_func(value) == expected + + +def test_optional_type(): + optional_type_func = optional_type(int) + assert optional_type_func("None") is None + assert optional_type_func("42") == 42 @pytest.mark.parametrize(("type_hint", "type", "expected"), [ @@ -89,7 +91,40 @@ def test_literal_to_kwargs(type_hints, expected): @config @dataclass -class DummyConfigClass: +class NestedConfig: + field: int = 1 + """field""" + + +@config +@dataclass +class FromCliConfig1: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 1 + return inst + + +@config +@dataclass +class FromCliConfig2: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 2 + return inst + + +@config +@dataclass +class DummyConfig: regular_bool: bool = True """Regular bool with default True""" optional_bool: Optional[bool] = None @@ -108,18 +143,24 @@ class DummyConfigClass: """Literal of literals with default 1""" json_tip: dict = field(default_factory=dict) """Dict which will be JSON in CLI""" + nested_config: NestedConfig = field(default_factory=NestedConfig) + """Nested config""" + from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1) + """Config with from_cli method""" + from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2) + """Different config with from_cli method""" @pytest.mark.parametrize(("type_hint", "expected"), [ (int, False), - (DummyConfigClass, True), + (DummyConfig, True), ]) def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected def test_get_kwargs(): - kwargs = get_kwargs(DummyConfigClass) + kwargs = get_kwargs(DummyConfig) print(kwargs) # bools should not have their type set @@ -142,6 +183,11 @@ def test_get_kwargs(): # dict should have json tip in help json_tip = "\n\nShould be a valid JSON string." assert kwargs["json_tip"]["help"].endswith(json_tip) + # nested config should should construct the nested config + assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) + # from_cli configs should be constructed with the correct method + assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3 + assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4 @pytest.mark.parametrize(("arg", "expected"), [ @@ -177,7 +223,7 @@ def test_compilation_config(): # default value args = parser.parse_args([]) - assert args.compilation_config is None + assert args.compilation_config == CompilationConfig() # set to O3 args = parser.parse_args(["-O3"]) @@ -194,7 +240,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config", - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) @@ -202,7 +248,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config=" - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index d51b7c26344f..6470249dddbc 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("backend", ["mp", "ray"]) @create_new_process_for_each_test() -def test_collective_rpc(tp_size, backend): +def test_collective_rpc(tp_size, backend, monkeypatch): if tp_size == 1 and backend == "ray": pytest.skip("Skip duplicate test case") if tp_size == 1: @@ -21,6 +21,7 @@ def test_collective_rpc(tp_size, backend): def echo_rank(self): return self.rank + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, load_format="dummy", diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 78e40eeecde1..48ede50e98f7 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -2,11 +2,13 @@ import pytest +from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import (apply_hf_chat_template, load_chat_template) from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer +from ...models.registry import HF_EXAMPLE_MODELS from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" @@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike(): MODEL_TEMPLATE_GENERATON_OUTPUT) def test_get_gen_prompt(model, template, add_generation_prompt, continue_final_message, expected_output): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + # Initialize the tokenizer - tokenizer = get_tokenizer(tokenizer_name=model) + tokenizer = get_tokenizer( + tokenizer_name=model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) template_content = load_chat_template(chat_template=template) # Create a mock request object using keyword arguments @@ -106,8 +122,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Call the function and get the result result = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=True, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, tools=None, diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 1de30f0ac057..bcb25ed99062 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -4,8 +4,6 @@ from typing import Optional import pytest -from packaging.version import Version -from transformers import __version__ as TRANSFORMERS_VERSION from vllm.assets.image import ImageAsset from vllm.config import ModelConfig @@ -19,6 +17,7 @@ from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from ..models.registry import HF_EXAMPLE_MODELS from ..utils import VLLM_PATH EXAMPLES_DIR = VLLM_PATH / "examples" @@ -772,6 +771,7 @@ def get_conversation(is_hf: bool): enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer @@ -793,8 +793,8 @@ def get_conversation(is_hf: bool): ) vllm_result = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=model_config.trust_remote_code, conversation=conversation, chat_template=None, tools=None, @@ -813,6 +813,16 @@ def get_conversation(is_hf: bool): @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer @@ -834,10 +845,10 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=None, tools=tools, - trust_remote_code=True, ) assert isinstance(chat_template, str) @@ -857,24 +868,32 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ) # yapf: enable def test_resolve_content_format_hf_defined(model, expected_format): - if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version( - "4.49.0"): - pytest.skip("Qwen2.5-VL requires transformers>=4.49.0") + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) tokenizer_group = TokenizerGroup( model, enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=None, tools=None, - trust_remote_code=True, ) assert isinstance(chat_template, str) @@ -884,11 +903,70 @@ def test_resolve_content_format_hf_defined(model, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( + model_config, + None, # Test detecting the tokenizer's chat_template + None, + "auto", + tokenizer, + ) + + assert resolved_format == expected_format + + +# yapf: disable +@pytest.mark.parametrize( + ("model", "expected_format"), + [("Salesforce/blip2-opt-2.7b", "string"), + ("facebook/chameleon-7b", "string"), + ("deepseek-ai/deepseek-vl2-tiny", "string"), + ("microsoft/Florence-2-base", "string"), + ("adept/fuyu-8b", "string"), + ("google/paligemma-3b-mix-224", "string"), + ("Qwen/Qwen-VL", "string"), + ("Qwen/Qwen-VL-Chat", "string")], +) +# yapf: enable +def test_resolve_content_format_fallbacks(model, expected_format): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + + tokenizer_group = TokenizerGroup( + model_config.tokenizer, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + trust_remote_code=model_config.trust_remote_code, + ) + tokenizer = tokenizer_group.tokenizer + + # Test detecting the tokenizer's chat_template + chat_template = resolve_hf_chat_template( + model_config, + tokenizer, + chat_template=None, + tools=None, + ) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + model_config, None, # Test detecting the tokenizer's chat_template None, "auto", tokenizer, - trust_remote_code=True, ) assert resolved_format == expected_format @@ -899,22 +977,14 @@ def test_resolve_content_format_hf_defined(model, expected_format): ("template_path", "expected_format"), [("template_alpaca.jinja", "string"), ("template_baichuan.jinja", "string"), - ("template_blip2.jinja", "string"), - ("template_chameleon.jinja", "string"), ("template_chatglm.jinja", "string"), ("template_chatglm2.jinja", "string"), ("template_chatml.jinja", "string"), - ("template_deepseek_vl2.jinja", "string"), ("template_dse_qwen2_vl.jinja", "openai"), ("template_falcon_180b.jinja", "string"), ("template_falcon.jinja", "string"), - ("template_florence2.jinja", "string"), - ("template_fuyu.jinja", "string"), ("template_inkbot.jinja", "string"), - ("template_paligemma.jinja", "string"), ("template_teleflm.jinja", "string"), - ("template_qwen_vl.jinja", "string"), - ("template_qwen_vl_chat.jinja", "string"), ("template_vlm2vec.jinja", "openai"), ("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_hermes.jinja", "string"), @@ -926,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format): ) # yapf: enable def test_resolve_content_format_examples(template_path, expected_format): + model_config = ModelConfig( + PHI3V_MODEL_ID, # Dummy + tokenizer=PHI3V_MODEL_ID, # Dummy + trust_remote_code=True, + ) + tokenizer_group = TokenizerGroup( - PHI3V_MODEL_ID, + PHI3V_MODEL_ID, # Dummy enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer.chat_template = None @@ -944,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( + model_config, chat_template, None, "auto", dummy_tokenizer, - trust_remote_code=True, ) assert resolved_format == expected_format diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index b0414244c215..436cb430817e 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -102,7 +102,10 @@ def test_env( block_size, False, use_mla=use_mla) - assert backend.get_name() == name + if use_v1 and name != "TRITON_MLA": + assert backend.get_name() == f"{name}_VLLM_V1" + else: + assert backend.get_name() == name else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 4cf7bcb01d4d..6ffe27abf709 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ROCM_USE_AITER", "1") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index abf3e3667a75..d68310060386 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -286,6 +286,7 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) +@pytest.mark.flaky(reruns=2) @pytest.mark.parametrize("m", [1, 123, 666]) @pytest.mark.parametrize("n", [128, 1024]) @pytest.mark.parametrize("k", [256, 2048]) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py new file mode 100644 index 000000000000..ae63b379f39d --- /dev/null +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from tests.kernels.utils import torch_moe +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1536), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + quant_blocksize = 16 + round_up = lambda x, y: (x + y - 1) // y * y + sf_w1_2n = round_up(2 * n, 128) + sf_w1_k = round_up(k // quant_blocksize, 4) + w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), + device="cuda", + dtype=torch.float8_e4m3fn) + + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + sf_w2_k = round_up(k, 128) + sf_w2_n = round_up(n // quant_blocksize, 4) + w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), + device="cuda", + dtype=torch.float8_e4m3fn) + + w1_q = torch.empty((e, 2 * n, k // 2), + device="cuda", + dtype=torch.uint8) + w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) + w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_amax = torch.abs(w1).max().to(torch.float32) + w2_amax = torch.abs(w2).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1[expert], w1_gs[expert]) + + w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2[expert], w2_gs[expert]) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + + cutlass_output = cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_q, + w1_blockscale=w1_blockscale, + w1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + w2_fp4=w2_q, + w2_blockscale=w2_blockscale, + w2_alphas=(1 / w2_gs), + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=e, + device=a.device, + ) + + # Reference check: + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=quant_blocksize) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + + torch.testing.assert_close(torch_output, + cutlass_output, + atol=1e-1, + rtol=1e-1) + + +if __name__ == "__main__": + test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py new file mode 100644 index 000000000000..58eaeee1c0b8 --- /dev/null +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.scalar_type import scalar_types + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_nvfp4_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8084d9bf2c2d..633addd421f4 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int, out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1) opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return + if m % 4 != 0 and current_platform.has_device_capability(100): + return cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index b08026c5867d..1f49900b2d90 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import pytest import torch +from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types if not current_platform.has_device_capability(100): pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", @@ -19,95 +20,24 @@ SEEDS = [42] CUDA_DEVICES = ['cuda:0'] -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -kE2M1ToFloatArray = [ - 0., - 0.5, - 1., - 1.5, - 2., - 3., - 4., - 6., -] - - -def e2m1_to_fp32(int4_value): - signBit = (int4_value & 0x8) - int4_absValue = int4_value & 0x7 - float_result = kE2M1ToFloatArray[int4_absValue] - if (signBit): - float_result = -float_result - return float_result - - -def break_fp4_bytes(a, dtype): - assert (a.dtype == torch.uint8) - m, n = a.shape - a = a.flatten() - # Get upper 4 bits - highHalfByte = (a & 0xF0) >> 4 - # Get lower 4 bits - lowHalfByte = a & 0x0F - fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) - fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) - # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] - out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) - return out - - -def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): - sf_m, sf_k = a_sf_swizzled.shape - m_tiles = (m + 128 - 1) // 128 - f = block_size * 4 - k_tiles = (k + f - 1) // f - tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) - return out[0:m, 0:k] - - -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): - """Dequantize the fp4 tensor back to high precision.""" - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape - k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale - - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) - return out - def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, m, n, dtype, block_size, device): _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert (m_k == n_k) - a_in_dtype = dequantize_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) return torch.matmul(a_in_dtype, b_in_dtype.t()) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 622079c39445..76d33169081a 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,7 +8,7 @@ DTYPES = [torch.bfloat16, torch.float16] M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] -K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 +K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0 N = [1, 2, 3, 4] SEEDS = [0] diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index dc433f9dad26..b940f7190bb2 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -47,7 +47,7 @@ def dist_init(): temp_file = tempfile.mkstemp()[1] backend = "nccl" - if current_platform.is_cpu(): + if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" init_distributed_environment(world_size=1, diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 67f3866beff5..0d4e0bf681f2 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict +from typing import NamedTuple, Optional from unittest.mock import patch import pytest @@ -9,52 +10,96 @@ from vllm.lora.utils import (get_adapter_absolute_path, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models.utils import WeightsMapper + + +class LoRANameParserTestConfig(NamedTuple): + name: str + module_name: str + is_lora_a: bool + is_bias: bool + weights_mapper: Optional[WeightsMapper] = None def test_parse_fine_tuned_lora_name_valid(): - fixture = { - ("base_model.model.lm_head.lora_A.weight", "lm_head", True, False), - ("base_model.model.lm_head.lora_B.weight", "lm_head", False, False), - ( + fixture = [ + LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", + "lm_head", True, False), + LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", + "lm_head", False, False), + LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, False, ), - ( + LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj", True, False, ), - ( + LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj", False, False, ), - } - for name, module_name, is_lora_a, is_bias in fixture: + # Test with WeightsMapper + LoRANameParserTestConfig( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.model.layers.9.mlp.down_proj", + True, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.model.layers.9.mlp.down_proj", + False, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.model.layers.9.mlp.down_proj", + True, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.model.layers.9.mlp.down_proj", + False, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + ] + for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: assert (module_name, is_lora_a, - is_bias) == parse_fine_tuned_lora_name(name) + is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) def test_parse_fine_tuned_lora_name_invalid(): diff --git a/tests/models/quantization/test_mxfp4.py b/tests/models/quantization/test_mxfp4.py new file mode 100644 index 000000000000..9a060829525e --- /dev/null +++ b/tests/models/quantization/test_mxfp4.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# flake8: noqa +"""Tests Quark mxfp4 models against ground truth generation +""" +import pytest + +from vllm import LLM, SamplingParams + +MODELS = ["amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"] + +EXPECTED_STRS_MAP = { + "amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8": [ + '\n### Key Features\n\n* **High-throughput Inference**: vLL', + '\nArtificial intelligence (AI) has evolved significantly since its inception in the 1', + 'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been', + 'A neural network is a machine learning model inspired by the structure of the human brain. It consists of', + '\nTitle: The Dreaming Robot\n\nAs the sun set on the bustling metropol', + '\nThe COVID-19 pandemic has had a profound impact on global economic structures and business', + 'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th', + " everybody knows this proverbial saying, but did you know that it's not entirely accurate?", + ] +} + + +@pytest.mark.skip(reason="Model to be released in the future") +@pytest.mark.quant_model +@pytest.mark.parametrize("model_name", MODELS) +def test_models(example_prompts, model_name) -> None: + sampling_params = SamplingParams(max_tokens=20, temperature=0) + llm = LLM( + model=model_name, + kv_cache_dtype="fp8", + quantization="quark", + ) + outputs = llm.generate(example_prompts, sampling_params) + for i, output in enumerate(outputs): + output_str = output.outputs[0].text + expected_str = EXPECTED_STRS_MAP[model_name][i] + assert expected_str == output_str, ( + f"Expected: {expected_str!r}\nvLLM: {output_str!r}") diff --git a/tests/models/registry.py b/tests/models/registry.py index cd5e1dab0a4a..a1f2edac02b9 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -182,7 +182,9 @@ def check_available_online( "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 - "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct"), + "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", + extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 + "hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501 "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", is_available_online=False), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), @@ -378,7 +380,7 @@ def check_available_online( # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 tokenizer="Isotr0py/Florence-2-tokenizer", - trust_remote_code=True), # noqa: E501 + trust_remote_code=True,), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 } diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index ce1429fda943..478184c34b91 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -26,6 +26,11 @@ "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png", ] +TEST_VIDEO_URLS = [ + "https://www.bogotobogo.com/python/OpenCV_Python/images/mean_shift_tracking/slow_traffic_small.mp4", + "https://filesamples.com/samples/video/avi/sample_640x360.avi", +] + @pytest.fixture(scope="module") def url_images() -> dict[str, Image.Image]: @@ -134,6 +139,18 @@ async def test_fetch_image_local_files(image_url: str): f"file://{temp_dir}/../{os.path.basename(image_url)}") +@pytest.mark.asyncio +@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) +@pytest.mark.parametrize("num_frames", [-1, 32, 1800]) +async def test_fetch_video_http(video_url: str, num_frames: int): + connector = MediaConnector() + + video_sync = connector.fetch_video(video_url, num_frames=num_frames) + video_async = await connector.fetch_video_async(video_url, + num_frames=num_frames) + assert np.array_equal(video_sync, video_async) + + # Used for the next two tests related to `merge_and_sort_multimodal_metadata`. class TestCase(NamedTuple): mm_positions: "MultiModalPlaceholderDict" diff --git a/tests/quantization/test_quark.py b/tests/quantization/test_quark.py index ce918a324887..ae09ac58e675 100644 --- a/tests/quantization/test_quark.py +++ b/tests/quantization/test_quark.py @@ -5,6 +5,7 @@ """ import pytest +import torch from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501 QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8) @@ -63,3 +64,28 @@ def check_model(model): output = llm.generate_greedy("Hello my name is", max_tokens=20) assert output + + +def test_quark_fp8_parity(vllm_runner): + quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method" + fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method" + + llm_kwargs = { + "tensor_parallel_size": 1, + "enforce_eager": True, + "gpu_memory_utilization": 0.1 + } + with (vllm_runner(quark_model_id, **llm_kwargs) as + quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle): + quark_model = (quark_handle.model.llm_engine.model_executor. + driver_worker.model_runner.model) + quark_state_dict = quark_model.state_dict() + + fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker. + model_runner.model) + fp8_state_dict = fp8_model.state_dict() + + assert fp8_state_dict.keys() == quark_state_dict.keys() + + for key in fp8_state_dict: + assert torch.equal(fp8_state_dict[key], quark_state_dict[key]) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 6924aba11576..90340f8cff03 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -478,7 +478,7 @@ def test_sampler_mixed(seed: int, device: str): sampling_params = SamplingParams( temperature=random.random() + 0.1, top_p=min(random.random() + 0.1, 1), - top_k=random.randint(0, 10) or -1, + top_k=random.randint(0, 10), n=n, presence_penalty=random.randint(0, 1), ) diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index a88ae8cda73d..7efef163d2b9 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,12 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 - -import functools -import gc -from typing import Callable, TypeVar - import pytest -import torch -from typing_extensions import ParamSpec from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @@ -25,32 +18,6 @@ def cleanup(): cleanup_dist_env_and_memory(shutdown_ray=True) -_P = ParamSpec("_P") -_R = TypeVar("_R") - - -def retry_until_skip(n: int): - - def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]: - - @functools.wraps(func) - def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R: - for i in range(n): - try: - return func(*args, **kwargs) - except AssertionError: - gc.collect() - torch.cuda.empty_cache() - if i == n - 1: - pytest.skip(f"Skipping test after {n} attempts.") - - raise AssertionError("Code should not be reached") - - return wrapper_retry - - return decorator_retry - - @pytest.fixture(autouse=True) def tensorizer_config(): config = TensorizerConfig(tensorizer_uri="vllm") diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index 5b9661bf6b05..7136dd44de03 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -28,7 +28,6 @@ from vllm.utils import PlaceholderModule, import_from_path from ..utils import VLLM_PATH, RemoteOpenAIServer -from .conftest import retry_until_skip try: from tensorizer import EncryptionParams @@ -325,7 +324,7 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs( assert outputs == deserialized_outputs -@retry_until_skip(3) +@pytest.mark.flaky(reruns=3) def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path): gc.collect() torch.cuda.empty_cache() diff --git a/tests/tpu/lora/__init__.py b/tests/tpu/lora/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py new file mode 100644 index 000000000000..21d7fce691c9 --- /dev/null +++ b/tests/tpu/lora/test_lora.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +# This file contains tests to ensure that LoRA works correctly on the TPU +# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct +# for this. The adapters are: +# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges +# from 1 to 4. + +# These adapters are trained using a standard huggingface peft training script, +# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run +# 100 training iterations with a training batch size of 100. + + +@pytest.fixture(scope="function", autouse=True) +def use_v1_only(monkeypatch: pytest.MonkeyPatch): + """ + Since Multi-LoRA is only supported on the v1 TPU backend, set VLLM_USE_V1=1 + for all tests in this file + """ + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + yield + + +def setup_vllm(num_loras: int) -> vllm.LLM: + return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct", + num_scheduler_steps=1, + max_model_len=256, + max_seq_len_to_capture=256, + max_num_seqs=8, + enable_lora=True, + max_loras=num_loras, + max_lora_rank=8) + + +def test_single_lora(): + """ + This test ensures we can run a single LoRA adapter on the TPU backend. + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which + will force Qwen2.5-3B-Instruct to claim 1+1=1. + """ + + llm = setup_vllm(1) + + prompt = "What is 1+1? \n" + + lora_request = LoRARequest( + "lora_adapter_1", 1, + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter") + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams(max_tokens=256, + temperature=0), + lora_request=lora_request)[0].outputs[0].text + + answer = output.strip()[0] + + assert answer.isdigit() + assert int(answer) == 1 + + +def test_lora_hotswapping(): + """ + This test ensures we can run multiple LoRA adapters on the TPU backend, even + if we only have space to store 1. + + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which + will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. + """ + + lora_name_template = \ + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_requests = [ + LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) + for i in range(1, 5) + ] + + llm = setup_vllm(1) + + prompt = "What is 1+1? \n" + + for i, req in enumerate(lora_requests): + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams( + max_tokens=256, temperature=0), + lora_request=req)[0].outputs[0].text + answer = output.strip()[0] + + assert answer.isdigit() + assert int(answer) == i + 1 + + +def test_multi_lora(): + """ + This test ensures we can run multiple LoRA adapters on the TPU backend, when + we have enough space to store all of them. + + We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which + will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. + """ + lora_name_template = \ + "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" + lora_requests = [ + LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) + for i in range(1, 5) + ] + + llm = setup_vllm(4) + + prompt = "What is 1+1? \n" + + for i, req in enumerate(lora_requests): + output = llm.generate(prompt, + sampling_params=vllm.SamplingParams( + max_tokens=256, temperature=0), + lora_request=req)[0].outputs[0].text + + answer = output.strip()[0] + + assert answer.isdigit() + assert int(output.strip()[0]) == i + 1 diff --git a/tests/tpu/lora/test_pallas_kernels.py b/tests/tpu/lora/test_pallas_kernels.py new file mode 100644 index 000000000000..8bd47de50c34 --- /dev/null +++ b/tests/tpu/lora/test_pallas_kernels.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import + +N_TOKENS = [16, 1024, 4096] +HIDDEN_SIZES = [1024, 2048, 4096] + +DTYPES = [torch.bfloat16] +NUM_LORA = [1, 4, 16] +RANKS = [32, 256, 512] + + +def generate_test_data(T, D, L, N, seed, dtype=torch.float32): + """ + Inputs: (All integers) + T: Total number of tokens + D: Input dim + L: LoRA Dim + N: N LoRAs + + Outputs: + inputs: torch.Tensor - shape (T, D) + loras: torch.Tensor - shape (N, 1, L, D) + idxs: torch.Tensor - shape (T, ) - all values must be in [0, N) + + ref_output: torch.Tensor - shape (T, L) - inputs @ loras[idxs].T + """ + torch.manual_seed(seed) + + inputs = torch.randn((T, D), device="xla", dtype=dtype) + loras = torch.randn((N, 1, L, D), device="xla", dtype=dtype) + idxs = torch.randint(0, N, (T, ), dtype=torch.int32, device="xla") + + ref_output = ref_bgmv(inputs, loras, idxs) + return inputs, loras, idxs, ref_output + + +def ref_bgmv(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.Tensor): + selected_loras = loras[idxs] + if len(selected_loras.shape) == 4: + selected_loras = selected_loras.squeeze(axis=1) + + batch_size, output_size, input_size = selected_loras.shape + return (selected_loras @ inputs.reshape( + (batch_size, input_size, 1))).reshape((batch_size, output_size)) + + +# Parameterize tests with various shapes and dtypes +@pytest.mark.parametrize("T", N_TOKENS) +@pytest.mark.parametrize("D", HIDDEN_SIZES) +@pytest.mark.parametrize("L", RANKS) +@pytest.mark.parametrize("N", NUM_LORA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seed", [0]) +def test_bgmv_correctness(T, D, L, N, dtype, op_type, seed): + if op_type == "expand": + D, L = L, D + + inputs, loras, idxs, ref_output = generate_test_data( + T, D, L, N, seed, dtype) + + # Run bgmv + output = torch.ops.xla.bgmv(inputs, loras, idxs) + + # Make sure we have no NaNs + assert not torch.any(torch.isnan(output)) + + # Compare with reference output + assert torch.allclose(output, ref_output, rtol=1e-2, atol=1e-2) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index df487ec2ccaa..1cdc80dd3546 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -539,7 +539,7 @@ def test_allocate_with_lookahead(): max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=2, # Total required: 3+2=5 tokens ) assert len(blocks.blocks) == 2 # ceil(5/4)=2 blocks @@ -550,7 +550,7 @@ def test_allocate_with_lookahead(): # required_blocks = ceil((3 + 2) /4) = 2 blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=2, ) assert len(blocks.blocks) == 2 @@ -561,7 +561,7 @@ def test_allocate_with_lookahead(): max_model_len=100) blocks = kv_cache_manager.allocate_slots( request, - num_tokens=3, + num_new_tokens=3, num_lookahead_tokens=4, ) assert len(blocks.blocks) == 2 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 01295e848ee9..a03810625466 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -194,7 +194,7 @@ def test_prefill_plp(): all_token_ids = common_token_ids + unique_token_ids req0 = make_request("0", all_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0) - assert len(manager.req_to_block_hashes[req0.request_id]) == 3 + assert len(manager.req_to_block_hashes[req0.request_id]) == 0 assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -256,7 +256,7 @@ def test_prefill_plp(): common_token_ids + unique_token_ids, prompt_logprobs=5) computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2) - assert len(manager.req_to_block_hashes[req2.request_id]) == 3 + assert len(manager.req_to_block_hashes[req2.request_id]) == 0 assert not computed_blocks.blocks assert num_computed_tokens == 0 blocks = manager.allocate_slots(req2, 55, computed_blocks) @@ -299,7 +299,8 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.allocate_slots(req0, 4) assert new_blocks is not None and len(new_blocks.blocks) == 0 - assert manager.req_to_blocks[req0.request_id][-1].block_hash is None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-1].block_hash is None # Append slots with allocating a new block. req0.num_computed_tokens = 59 @@ -309,8 +310,10 @@ def test_decode(): req0.append_output_token_ids(7) new_blocks = manager.allocate_slots(req0, 19) assert new_blocks is not None and len(new_blocks.blocks) == 1 - assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None - assert manager.req_to_blocks[req0.request_id][-1].block_hash is None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-2].block_hash is not None + assert manager.single_type_manager.req_to_blocks[ + req0.request_id][-1].block_hash is None def test_evict(): @@ -689,7 +692,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert not computed_blocks.blocks assert num_computed_tokens == 0 manager.allocate_slots(req0, 48, computed_blocks) - block_part0 = manager.req_to_blocks[req0.request_id] + block_part0 = manager.single_type_manager.req_to_blocks[req0.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... | req1 = make_request("1", common_token_ids * 2) @@ -697,7 +700,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks(): assert computed_blocks.blocks == block_part0 assert num_computed_tokens == 3 * 16 manager.allocate_slots(req1, 48, computed_blocks) - block_part1 = manager.req_to_blocks[req1.request_id] + block_part1 = manager.single_type_manager.req_to_blocks[req1.request_id] # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) | # | Req1-5(F)| ... | manager.free(req1) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index bfe9df10d4d1..0ca2ced89148 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -812,10 +812,11 @@ def _assert_right_kv_cache_manager( # Make sure the request stats are right. EXPECTED_TOTAL_BLOCKS = num_tokens // block_size for req_id in req_ids: - blocks = scheduler.kv_cache_manager.req_to_blocks[req_id] + blocks = (scheduler.kv_cache_manager.single_type_manager. + req_to_blocks[req_id]) hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id] - assert (scheduler.kv_cache_manager.num_cached_block[req_id] == - EXPECTED_TOTAL_BLOCKS) + assert (scheduler.kv_cache_manager.single_type_manager. + num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS) assert len(blocks) == EXPECTED_TOTAL_BLOCKS assert len(hashes) == EXPECTED_TOTAL_BLOCKS @@ -1195,9 +1196,11 @@ def assert_scheduler_empty(scheduler: Scheduler): assert len(scheduler.encoder_cache_manager.cached) == 0 # KVCache Manager. - assert len(scheduler.kv_cache_manager.req_to_blocks) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.req_to_blocks) == 0 assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0 - assert len(scheduler.kv_cache_manager.num_cached_block) == 0 + assert len( + scheduler.kv_cache_manager.single_type_manager.num_cached_block) == 0 num_free_blocks = ( scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks) assert num_free_blocks == ( diff --git a/tests/v1/core/test_specialized_manager.py b/tests/v1/core/test_specialized_manager.py index 595c8608fc64..540720cb9b2f 100644 --- a/tests/v1/core/test_specialized_manager.py +++ b/tests/v1/core/test_specialized_manager.py @@ -8,6 +8,14 @@ from vllm.v1.kv_cache_interface import SlidingWindowSpec +def get_sliding_window_manager(sliding_window_spec, block_pool): + return SlidingWindowManager(sliding_window_spec, + block_pool, + use_eagle=False, + num_kv_cache_groups=1, + caching_hash_fn=lambda x: x) + + def test_sliding_window_possible_cached_prefix(): sliding_window_spec = SlidingWindowSpec( block_size=2, @@ -19,9 +27,7 @@ def test_sliding_window_possible_cached_prefix(): ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, - block_pool, - use_eagle=False) + manager = get_sliding_window_manager(sliding_window_spec, block_pool) def run_one_case(block_is_cached, expect_length): block_hash_list = [ @@ -81,9 +87,7 @@ def test_sliding_window_remove_skipped_blocks(): block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) - manager = SlidingWindowManager(sliding_window_spec, - block_pool, - use_eagle=False) + manager = get_sliding_window_manager(sliding_window_spec, block_pool) null_block_id = block_pool.null_block.block_id @@ -104,39 +108,35 @@ def assert_block_id(block_table, ids): 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010 ] block_table = id_to_block_table(original_block_ids) - removed = manager.remove_skipped_blocks(block_table, 0) - assert_block_id(removed, []) + manager.req_to_blocks["test"] = block_table + + manager.remove_skipped_blocks("test", 0) assert_block_id(block_table, original_block_ids) # 4 tokens are computed. Only token 0 is out of the sliding window. As # block 1000 also contains token 1 that is in the sliding window, block 1000 # cannot be removed. - removed = manager.remove_skipped_blocks(block_table, 4) - assert_block_id(removed, []) + manager.remove_skipped_blocks("test", 4) assert_block_id(block_table, original_block_ids) # 5 tokens are computed. Token 0 & 1 are out of the sliding window. # Block 1000 can be removed. - removed = manager.remove_skipped_blocks(block_table, 5) - assert_block_id(removed, [original_block_ids[0]]) + manager.remove_skipped_blocks("test", 5) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 6 tokens are computed. Token 0-2 are out of the sliding window. # Cannot remove new block as the block 1001 is still used by token 3. - removed = manager.remove_skipped_blocks(block_table, 6) - assert_block_id(removed, []) + manager.remove_skipped_blocks("test", 6) assert_block_id(block_table, [null_block_id] + original_block_ids[1:]) # 7 tokens are computed. Token 0-3 are out of the sliding window. # Block 1001 can be removed and block 1000 is already removed. - removed = manager.remove_skipped_blocks(block_table, 7) - assert_block_id(removed, [original_block_ids[1]]) + manager.remove_skipped_blocks("test", 7) assert_block_id(block_table, [null_block_id] * 2 + original_block_ids[2:]) # 11 tokens are computed. Token 0-7 are out of the sliding window. # Block 1002 & 1003 can be removed now. Block 1003 represents a longer # sequence, and is expected to be evicted earlier than 1002, so the order # of removed blocks should be [1003, 1002]. - removed = manager.remove_skipped_blocks(block_table, 11) - assert_block_id(removed, [original_block_ids[3], original_block_ids[2]]) + manager.remove_skipped_blocks("test", 11) assert_block_id(block_table, [null_block_id] * 4 + original_block_ids[4:]) diff --git a/tests/v1/entrypoints/llm/test_struct_output_generate.py b/tests/v1/entrypoints/llm/test_struct_output_generate.py index c3ea024f58cb..81601c87ad8b 100644 --- a/tests/v1/entrypoints/llm/test_struct_output_generate.py +++ b/tests/v1/entrypoints/llm/test_struct_output_generate.py @@ -105,8 +105,9 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=sample_json_schema)) outputs = llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" + (f"Give an example JSON for an employee profile that fits this " + f"schema. Make the response as short as possible. Schema: " + f"{sample_json_schema}") ] * 2, sampling_params=sampling_params, use_tqdm=True) @@ -136,7 +137,8 @@ def test_structured_output( outputs = llm.generate( prompts=("Generate a JSON object with curly braces for a person with " - "name and age fields for John Smith who is 31 years old."), + "name and age fields for John Smith who is 31 years old. " + "Make the response as short as possible."), sampling_params=sampling_params, use_tqdm=True) @@ -165,19 +167,20 @@ def test_structured_output( with pytest.raises(ValueError, match="The provided JSON schema contains features " "not supported by xgrammar."): - llm.generate(prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {unsupported_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True) + llm.generate( + prompts=[(f"Give an example JSON for an employee profile that " + f"fits this schema: {unsupported_json_schema}. " + f"Make the response as short as possible.")] * 2, + sampling_params=sampling_params, + use_tqdm=True) else: - outputs = llm.generate( - prompts=("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}"), - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts=( + "Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None for output in outputs: assert output is not None @@ -199,8 +202,10 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf)) outputs = llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -231,8 +236,10 @@ def test_structured_output( max_tokens=1000, guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark)) outputs = llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short as " + "possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -269,8 +276,10 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(grammar="not a grammar")) with pytest.raises(ValueError, match="Failed to convert the grammar "): llm.generate( - prompts=("Generate a sql statement that selects col_1 from " - "table_1 where it is equal to 1"), + prompts=( + "Generate a sql statement that selects col_1 from " + "table_1 where it is equal to 1. Make the response as short " + "as possible."), sampling_params=sampling_params, use_tqdm=True, ) @@ -284,7 +293,8 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(regex=sample_regex)) outputs = llm.generate( prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" + (f"Give an example IPv4 address with this regex: {sample_regex}. " + f"Make the response as short as possible.") ] * 2, sampling_params=sampling_params, use_tqdm=True, @@ -309,7 +319,8 @@ def test_structured_output( top_p=0.95, guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) outputs = llm.generate( - prompts="The best language for type-safe systems programming is ", + prompts=("The best language for type-safe systems programming is " + "(Make the response as short as possible.) "), sampling_params=sampling_params, use_tqdm=True) assert outputs is not None @@ -331,11 +342,12 @@ def test_structured_output( temperature=1.0, max_tokens=1000, guided_decoding=GuidedDecodingParams(json=json_schema)) - outputs = llm.generate( - prompts="Generate a JSON with the brand, model and car_type of" - "the most iconic car from the 90's", - sampling_params=sampling_params, - use_tqdm=True) + outputs = llm.generate(prompts=( + "Generate a JSON with the brand, model and car_type of the most " + "iconic car from the 90's. Make the response as short as " + "possible."), + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None @@ -373,7 +385,8 @@ def test_structured_output( guided_decoding=GuidedDecodingParams(json=json_schema)) outputs = llm.generate( - prompts="Generate a description of a frog using 50 characters.", + prompts=("Generate a description of a frog using 50 characters. " + "Make the response as short as possible."), sampling_params=sampling_params, use_tqdm=True) @@ -452,7 +465,8 @@ def test_structured_output( You are a helpful assistant. -Given the previous instructions, what is the weather in New York City? +Given the previous instructions, what is the weather in New York City? \ +Make the response as short as possible. """ # Change this once other backends support structural_tag @@ -509,9 +523,10 @@ def test_structured_output_auto_mode( max_tokens=1000, guided_decoding=GuidedDecodingParams(json=unsupported_json_schema)) - prompts = ("Give an example JSON object for a grade " - "that fits this schema: " - f"{unsupported_json_schema}") + prompts = ( + "Give an example JSON object for a grade " + "that fits this schema: " + f"{unsupported_json_schema}. Make the response as short as possible.") # This would fail with the default of "xgrammar", but in "auto" # we will handle fallback automatically. outputs = llm.generate(prompts=prompts, @@ -566,7 +581,8 @@ def test_guidance_no_additional_properties(monkeypatch: pytest.MonkeyPatch): prompt = ( "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a " "helpful assistant.<|im_end|>\n<|im_start|>user\nPlease generate a " - "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" + "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20. " + "Make the response as short as possible." "<|im_end|>\n<|im_start|>assistant\n") def generate_with_backend(backend): diff --git a/tests/v1/test_serial_utils.py b/tests/v1/test_serial_utils.py index d1271b210ad8..ee490071f6a2 100644 --- a/tests/v1/test_serial_utils.py +++ b/tests/v1/test_serial_utils.py @@ -9,8 +9,8 @@ import torch from vllm.multimodal.inputs import (MultiModalBatchedField, - MultiModalFieldElem, MultiModalKwargs, - MultiModalKwargsItem, + MultiModalFieldElem, MultiModalFlatField, + MultiModalKwargs, MultiModalKwargsItem, MultiModalSharedField, NestedTensors) from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder @@ -36,59 +36,62 @@ class MyType: empty_tensor: torch.Tensor -def test_encode_decode(): +def test_encode_decode(monkeypatch: pytest.MonkeyPatch): """Test encode/decode loop with zero-copy tensors.""" - obj = MyType( - tensor1=torch.randint(low=0, - high=100, - size=(1024, ), - dtype=torch.int32), - a_string="hello", - list_of_tensors=[ - torch.rand((1, 10), dtype=torch.float32), - torch.rand((3, 5, 4000), dtype=torch.float64), - torch.tensor(1984), # test scalar too - # Make sure to test bf16 which numpy doesn't support. - torch.rand((3, 5, 1000), dtype=torch.bfloat16), - torch.tensor([float("-inf"), float("inf")] * 1024, - dtype=torch.bfloat16), - ], - numpy_array=np.arange(512), - unrecognized=UnrecognizedType(33), - small_f_contig_tensor=torch.rand(5, 4).t(), - large_f_contig_tensor=torch.rand(1024, 4).t(), - small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], - large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], - empty_tensor=torch.empty(0), - ) + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") - encoder = MsgpackEncoder(size_threshold=256) - decoder = MsgpackDecoder(MyType) + obj = MyType( + tensor1=torch.randint(low=0, + high=100, + size=(1024, ), + dtype=torch.int32), + a_string="hello", + list_of_tensors=[ + torch.rand((1, 10), dtype=torch.float32), + torch.rand((3, 5, 4000), dtype=torch.float64), + torch.tensor(1984), # test scalar too + # Make sure to test bf16 which numpy doesn't support. + torch.rand((3, 5, 1000), dtype=torch.bfloat16), + torch.tensor([float("-inf"), float("inf")] * 1024, + dtype=torch.bfloat16), + ], + numpy_array=np.arange(512), + unrecognized=UnrecognizedType(33), + small_f_contig_tensor=torch.rand(5, 4).t(), + large_f_contig_tensor=torch.rand(1024, 4).t(), + small_non_contig_tensor=torch.rand(2, 4)[:, 1:3], + large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20], + empty_tensor=torch.empty(0), + ) - encoded = encoder.encode(obj) + encoder = MsgpackEncoder(size_threshold=256) + decoder = MsgpackDecoder(MyType) - # There should be the main buffer + 4 large tensor buffers - # + 1 large numpy array. "large" is <= 512 bytes. - # The two small tensors are encoded inline. - assert len(encoded) == 8 + encoded = encoder.encode(obj) + + # There should be the main buffer + 4 large tensor buffers + # + 1 large numpy array. "large" is <= 512 bytes. + # The two small tensors are encoded inline. + assert len(encoded) == 8 - decoded: MyType = decoder.decode(encoded) + decoded: MyType = decoder.decode(encoded) - assert_equal(decoded, obj) + assert_equal(decoded, obj) - # Test encode_into case + # Test encode_into case - preallocated = bytearray() + preallocated = bytearray() - encoded2 = encoder.encode_into(obj, preallocated) + encoded2 = encoder.encode_into(obj, preallocated) - assert len(encoded2) == 8 - assert encoded2[0] is preallocated + assert len(encoded2) == 8 + assert encoded2[0] is preallocated - decoded2: MyType = decoder.decode(encoded2) + decoded2: MyType = decoder.decode(encoded2) - assert_equal(decoded2, obj) + assert_equal(decoded2, obj) class MyRequest(msgspec.Struct): @@ -122,7 +125,7 @@ def test_multimodal_kwargs(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 44559, +-20 for minor changes - assert total_len >= 44539 and total_len <= 44579 + assert 44539 <= total_len <= 44579 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] assert all(nested_equal(d[k], decoded[k]) for k in d) @@ -135,14 +138,15 @@ def test_multimodal_items_by_modality(): "video", "v0", [torch.zeros(1000, dtype=torch.int8) for _ in range(4)], - MultiModalBatchedField(), + MultiModalFlatField( + [[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0), ) e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)) - e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000, - dtype=torch.int32), - MultiModalBatchedField()) + e4 = MultiModalFieldElem( + "image", "i1", torch.zeros(1000, dtype=torch.int32), + MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2)) audio = MultiModalKwargsItem.from_elems([e1]) video = MultiModalKwargsItem.from_elems([e2]) image = MultiModalKwargsItem.from_elems([e3, e4]) @@ -161,7 +165,7 @@ def test_multimodal_items_by_modality(): total_len = sum(memoryview(x).cast("B").nbytes for x in encoded) # expected total encoding length, should be 14255, +-20 for minor changes - assert total_len >= 14235 and total_len <= 14275 + assert 14250 <= total_len <= 14300 decoded: MultiModalKwargs = decoder.decode(encoded).mm[0] # check all modalities were recovered and do some basic sanity checks @@ -178,8 +182,7 @@ def test_multimodal_items_by_modality(): def nested_equal(a: NestedTensors, b: NestedTensors): if isinstance(a, torch.Tensor): return torch.equal(a, b) - else: - return all(nested_equal(x, y) for x, y in zip(a, b)) + return all(nested_equal(x, y) for x, y in zip(a, b)) def assert_equal(obj1: MyType, obj2: MyType): @@ -199,11 +202,10 @@ def assert_equal(obj1: MyType, obj2: MyType): assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_dict_serialization(allow_pickle: bool): +def test_dict_serialization(): """Test encoding and decoding of a generic Python object using pickle.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder() # Create a sample Python object obj = {"key": "value", "number": 42} @@ -218,11 +220,10 @@ def test_dict_serialization(allow_pickle: bool): assert obj == decoded, "Decoded object does not match the original object." -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_tensor_serialization(allow_pickle: bool): +def test_tensor_serialization(): """Test encoding and decoding of a torch.Tensor.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(torch.Tensor) # Create a sample tensor tensor = torch.rand(10, 10) @@ -238,11 +239,10 @@ def test_tensor_serialization(allow_pickle: bool): tensor, decoded), "Decoded tensor does not match the original tensor." -@pytest.mark.parametrize("allow_pickle", [True, False]) -def test_numpy_array_serialization(allow_pickle: bool): +def test_numpy_array_serialization(): """Test encoding and decoding of a numpy array.""" - encoder = MsgpackEncoder(allow_pickle=allow_pickle) - decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle) + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(np.ndarray) # Create a sample numpy array array = np.random.rand(10, 10) @@ -268,26 +268,31 @@ def __eq__(self, other): return isinstance(other, CustomClass) and self.value == other.value -def test_custom_class_serialization_allowed_with_pickle(): +def test_custom_class_serialization_allowed_with_pickle( + monkeypatch: pytest.MonkeyPatch): """Test that serializing a custom class succeeds when allow_pickle=True.""" - encoder = MsgpackEncoder(allow_pickle=True) - decoder = MsgpackDecoder(CustomClass, allow_pickle=True) - obj = CustomClass("test_value") + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + encoder = MsgpackEncoder() + decoder = MsgpackDecoder(CustomClass) - # Encode the custom class - encoded = encoder.encode(obj) + obj = CustomClass("test_value") - # Decode the custom class - decoded = decoder.decode(encoded) + # Encode the custom class + encoded = encoder.encode(obj) - # Verify the decoded object matches the original - assert obj == decoded, "Decoded object does not match the original object." + # Decode the custom class + decoded = decoder.decode(encoded) + + # Verify the decoded object matches the original + assert obj == decoded, ( + "Decoded object does not match the original object.") def test_custom_class_serialization_disallowed_without_pickle(): """Test that serializing a custom class fails when allow_pickle=False.""" - encoder = MsgpackEncoder(allow_pickle=False) + encoder = MsgpackEncoder() obj = CustomClass("test_value") diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py index 57c195982ca8..2bbeb3ddac91 100644 --- a/tests/v1/tpu/test_sampler.py +++ b/tests/v1/tpu/test_sampler.py @@ -26,7 +26,7 @@ def test_sampler_different(model_name: str): enforce_eager=False, max_num_seqs=1, max_model_len=512, - max_num_batched_tokens=512) + max_num_batched_tokens=256) prompts = [ "Write a short story about a robot that dreams for the first time." ] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c693c14f4f2d..80f549745219 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -117,13 +117,14 @@ def paged_attention_rocm( kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, query_start_loc, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, - v_scale) + v_scale, fp8_out_scale) def mla_decode_kvcache_cpu( @@ -744,10 +745,11 @@ def get_cutlass_moe_mm_data( - output_permutation: Permutation that must be used to shuffle the output after executing the MMs. """ - torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, - problem_sizes1, problem_sizes2, - input_permutation, output_permutation, - num_experts, n, k) + return torch.ops._C.get_cutlass_moe_mm_data(topk_ids, expert_offsets, + problem_sizes1, problem_sizes2, + input_permutation, + output_permutation, + num_experts, n, k) def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, @@ -766,9 +768,41 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, MMs used in the fused MoE operation. - a/b/c_strides: The data strides passed to grouped matrix multiplication. """ - torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales, - b_scales, expert_offsets, problem_sizes, - a_strides, b_strides, c_strides) + return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, + a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, + c_strides) + + +def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, + a_scales: torch.Tensor, b_scales: torch.Tensor, + alphas: torch.Tensor, problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, sf_offsets: torch.Tensor, + out_dtype: torch.dtype, device: torch.device): + """ + An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs + the gemms for each combination based on the specified problem sizes. + + This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. + - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized + input and expert weights. + - a_/b_scales: The blockscales in FP8-E4M3 precision + - expert_offsets/sf_offsets: Indices that mark at which token index + each expert begins its computation. The number of tokens + computed with expert E is expert_offsets[E + 1] - + expert_offsets[E] And the sf_size per expert is + sf_offset[E+1] - sf_offset[E] + - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped + MMs used in the fused MoE operation. + """ + m_topk = a_tensors.shape[0] + n = b_tensors.shape[1] + c_shape = (m_topk, n) + c = torch.empty(c_shape, device=device, dtype=out_dtype) + torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales, + b_scales, alphas, problem_sizes, + expert_offsets, sf_offsets) + return c.to(out_dtype) # aqlm @@ -959,6 +993,57 @@ def scaled_fp4_quant( return output, output_scale +def scaled_fp4_experts_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + topk: int, + expert_map: Optional[torch.Tensor] = None, + MAX_TOKENS_PER_EXPERT: int = 163840, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale, for + packed MoE Inputs. + Args: + input: The input tensor to be quantized to FP4 + expert_map: The expert map tensor + input_global_scale: A scalar scaling factor for the entire tensor. + expert_offsets: The expert offsets tensor + blockscale_offsets: The blockscale offsets tensor + Outputs: + output: The quantized tensor in FP4 + output_scales: The blockscale tensor in FP8-E4M3 + """ + assert not current_platform.is_rocm() + assert input_tensor.ndim == 2, ( + f'input.ndim needs to be == 2, but got {input_tensor.ndim}.') + + input_tensor = input_tensor[ + expert_map] if expert_map is not None else input_tensor + m_numtopk, k = input_tensor.shape + assert (m_numtopk <= MAX_TOKENS_PER_EXPERT * topk), ( + f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT * topk for" + f" scaled_fp4_experts_quant kernel, observed m_numtopk = {m_numtopk}") + scales_k = k // 16 + padded_k = (scales_k + (4 - 1)) // 4 + + # output is uint8 and packed fp4 values + output = torch.empty(m_numtopk, + k // 2, + device=input_tensor.device, + dtype=torch.uint8) + output_scales = torch.empty(MAX_TOKENS_PER_EXPERT * topk, + padded_k, + dtype=torch.int32, + device=input_tensor.device) + torch.ops._C.scaled_fp4_experts_quant(output, output_scales, input_tensor, + input_global_scale, expert_offsets, + blockscale_offsets) + output_scales = output_scales.view(torch.float8_e4m3fn) + return output, output_scales + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 55a63a81677f..d701c59a234f 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -57,16 +57,16 @@ def get_kv_cache_shape( def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dsts: torch.Tensor, ) -> None: - HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + HPUPagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dsts) @staticmethod def copy_blocks( kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + src_to_dsts: torch.Tensor, ) -> None: - HPUPagedAttention.copy_blocks(kv_caches, src_to_dists) + HPUPagedAttention.copy_blocks(kv_caches, src_to_dsts) @dataclass @@ -77,6 +77,7 @@ class HPUAttentionMetadata(HPUPagedAttentionMetadata, AttentionMetadata): is_prompt: bool attn_bias: Optional[torch.Tensor] seq_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] class HPUAttentionImpl(AttentionImpl, torch.nn.Module): @@ -198,8 +199,7 @@ def forward( key_cache = None value_cache = None if attn_metadata.is_prompt and self.attn_type \ - is not AttentionType.ENCODER_ONLY \ - and attn_metadata.block_list is None: + is not AttentionType.ENCODER_ONLY: key = key.unflatten(0, (block_indices.size(0), -1)) value = value.unflatten(0, (block_indices.size(0), -1)) if kv_cache is not None and isinstance(kv_cache, tuple): @@ -229,6 +229,9 @@ def forward( attn_bias = attn_bias.tile((1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) + block_list = attn_metadata.block_list if attn_metadata \ + and attn_metadata.block_list is not None else None + out = ops.prompt_attention( impl=self.prefill_impl, query=query.view(query_shape), @@ -237,23 +240,25 @@ def forward( is_causal=True, attn_bias=attn_bias, valid_seq_lengths=attn_metadata.seq_lens_tensor, - **self.common_attention_args()) + **self.common_attention_args(block_list, key_cache, + value_cache)) output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. output = HPUPagedAttention.forward_decode( query=query, - key_cache=key_cache, - value_cache=value_cache, - block_list=attn_metadata.block_list, block_mapping=attn_metadata.block_mapping, block_bias=attn_metadata.attn_bias, block_groups=attn_metadata.block_groups, - **self.common_attention_args()) + **self.common_attention_args(attn_metadata.block_list, + key_cache, value_cache)) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) - def common_attention_args(self): + def common_attention_args(self, + block_list=None, + key_cache=None, + value_cache=None): fsdpa_op = self.fused_scaled_dot_product_attention.apply \ if self.fused_scaled_dot_product_attention is not None else None return { @@ -266,6 +271,9 @@ def common_attention_args(self): 'keys_fetch_func': self.k_cache.fetch_from_cache, 'values_fetch_func': self.v_cache.fetch_from_cache, 'softmax_op': self.softmax, + 'block_list': block_list, + 'key_cache': key_cache, + 'value_cache': value_cache, } diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 0100c082aa21..363aa08ef003 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -211,8 +211,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON @@ -377,7 +375,6 @@ def graph_capture_get_metadata_for_batch( seq_start_loc=None, context_lens_tensor=None, block_tables=self._graph_block_tables[:batch_size], - input_positions=self._positions[:batch_size], head_dim=self.runner.model_config.get_head_size()) if is_encoder_decoder_model: @@ -393,7 +390,6 @@ def get_graph_input_buffers(self, "slot_mapping": attn_metadata.slot_mapping, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "block_tables": attn_metadata.decode_metadata.block_tables, - "input_positions": attn_metadata.decode_metadata.input_positions, } if is_encoder_decoder_model: raise NotImplementedError( @@ -405,16 +401,10 @@ def prepare_graph_input_buffers(self, input_buffers, attn_metadata, is_encoder_decoder_model: bool = False): - input_positions = attn_metadata.input_positions - num_positions = input_positions.shape[0] input_buffers["seq_lens_tensor"].copy_( attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) input_buffers["block_tables"].copy_( attn_metadata.decode_metadata.block_tables, non_blocking=True) - # CUDA graph buffer is padded so only perform a partial copy based on - # num_positions - input_buffers["input_positions"][:num_positions].copy_( - input_positions, non_blocking=True) if is_encoder_decoder_model: raise NotImplementedError( "TritonMLAState does not support encoder/decoder yet") @@ -456,11 +446,6 @@ class MLACommonMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool - # New for MLA (compared to FlashAttention) - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -563,8 +548,6 @@ def prefill_metadata(self): self.context_lens_tensor[:self.num_prefills]) block_tables = (None if self.block_tables is None else self.block_tables[:self.num_prefills]) - input_positions = (None if self.input_positions is None else - self.input_positions[:self.num_prefill_tokens]) self._cached_prefill_metadata = self.__class__( # Required by ModelRunner @@ -578,7 +561,6 @@ def prefill_metadata(self): multi_modal_placeholder_index_maps=None, enable_kv_scales_calculation=False, # MLACommonMetadata - input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -615,8 +597,6 @@ def decode_metadata(self): self.seq_lens_tensor[self.num_prefills:]) block_tables = (None if self.block_tables is None else self.block_tables[self.num_prefills:]) - input_positions = (None if self.input_positions is None else - self.input_positions[self.num_prefill_tokens:]) self._cached_decode_metadata = self.__class__( # Required by ModelRunner @@ -646,7 +626,6 @@ def decode_metadata(self): if self.seq_start_loc is not None else None, context_lens_tensor=None, block_tables=block_tables, - input_positions=input_positions, head_dim=self.head_dim, is_profile_run=self.is_profile_run) return self._cached_decode_metadata @@ -765,7 +744,6 @@ def prepare(self): self.context_lens: List[int] = [] self.block_tables: List[List[int]] = [] self.curr_seq_lens: List[int] = [] - self.input_positions: List[int] = [] self.multimodal_placeholder_maps: Dict[ str, MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) @@ -786,13 +764,11 @@ def _add_seq_group( block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( + curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 @@ -912,8 +888,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device, self.runner.pin_memory) seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, self.runner.pin_memory) - input_positions = async_tensor_h2d(self.input_positions, torch.long, - device, self.runner.pin_memory) slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, device, self.runner.pin_memory) query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32, @@ -987,7 +961,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], multi_modal_placeholder_index_maps=None, # Not Attention Related enable_kv_scales_calculation=False, # MLACommonMetadata - input_positions=input_positions, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, @@ -1033,7 +1006,6 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - rotary_emb: RotaryEmbedding, kv_b_proj: ColumnParallelLinear, ) -> None: self.num_heads = num_heads @@ -1048,10 +1020,6 @@ def __init__( self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - - self.rotary_emb = rotary_emb - self.use_yarn_rope = isinstance(rotary_emb, - DeepseekScalingRotaryEmbedding) self.kv_b_proj = kv_b_proj self.triton_fa_func = triton_attention @@ -1367,41 +1335,15 @@ def forward( has_decode = attn_metadata.decode_metadata is not None has_prefill = attn_metadata.prefill_metadata is not None - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert hasattr(attn_metadata, "input_positions") - num_prefill_tokens: int = attn_metadata.num_prefill_tokens q = q.view(-1, self.num_heads, self.qk_head_dim) decode_q = q[num_prefill_tokens:] - decode_k_pe = k_pe[num_prefill_tokens:] - decode_input_positions = \ - attn_metadata.input_positions[num_prefill_tokens:] prefill_q = q[:num_prefill_tokens] prefill_k_pe = k_pe[:num_prefill_tokens] - prefill_input_positions = \ - attn_metadata.input_positions[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens] - if has_decode: - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - decode_input_positions, decode_q_pe, decode_k_pe) - - if has_prefill: - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - prefill_input_positions, prefill_q_pe, prefill_k_pe) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -1424,6 +1366,15 @@ def forward( attn_metadata) if has_decode: + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + output[num_prefill_tokens:] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) diff --git a/vllm/attention/backends/rocm_aiter_mla.py b/vllm/attention/backends/rocm_aiter_mla.py index 2984bc1dad64..4936c8201399 100644 --- a/vllm/attention/backends/rocm_aiter_mla.py +++ b/vllm/attention/backends/rocm_aiter_mla.py @@ -148,13 +148,11 @@ def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool, block_tables = inter_data.block_tables for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, - curr_sliding_window_block, input_positions) in zip( + curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], inter_data.orig_seq_lens, inter_data.seq_lens, inter_data.query_lens, inter_data.context_lens, - inter_data.curr_sliding_window_blocks, - inter_data.input_positions): - self.input_positions.extend(input_positions) + inter_data.curr_sliding_window_blocks): self.context_lens.append(context_len) if is_prompt: self.num_prefills += 1 diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index 1dedd2ffc5fa..a97c36338d3c 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -5,7 +5,7 @@ ############################################################################### from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from vllm_hpu_extension import cache_ops, ops @@ -63,43 +63,25 @@ def write_to_paged_cache(key: torch.Tensor, value: torch.Tensor, def forward_decode(**kwargs) -> torch.Tensor: return ops.flat_pa(**kwargs) - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, - seq_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_query_len: int, - alibi_slopes: Optional[torch.Tensor], - sliding_window: Optional[int], - ) -> torch.Tensor: - raise NotImplementedError( - "forward_prefix is not implemented for HPUPagedAttention") - @staticmethod def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_kv_cache: Tuple[torch.Tensor, torch.Tensor], + dst_kv_cache: Tuple[torch.Tensor, torch.Tensor], + src_to_dsts: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] - cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dsts) src_value_cache = src_kv_cache[1] dst_value_cache = dst_kv_cache[1] - cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dsts) @staticmethod def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dsts: torch.Tensor, ) -> None: key_caches = [kv_cache[0] for kv_cache in kv_caches] value_caches = [kv_cache[1] for kv_cache in kv_caches] - cache_ops.copy_blocks(key_caches, value_caches, src_to_dists) + cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) diff --git a/vllm/attention/ops/rocm_aiter_mla.py b/vllm/attention/ops/rocm_aiter_mla.py index 1c90f8c19b09..3348d18804aa 100644 --- a/vllm/attention/ops/rocm_aiter_mla.py +++ b/vllm/attention/ops/rocm_aiter_mla.py @@ -4,6 +4,9 @@ import torch +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op + def get_aiter_mla_metadata(max_batch_size: int, block_size: int, max_block_per_batch: int, @@ -30,6 +33,28 @@ def aiter_mla_decode_fwd( kv_last_page_lens: Optional[torch.Tensor] = None, logit_cap: float = 0.0, ): + + torch.ops.vllm.rocm_aiter_mla_decode_fwd(q, + kv_buffer.view( + -1, 1, 1, q.shape[-1]), + o, + kv_indptr, + kv_indices, + kv_last_page_lens, + sm_scale=sm_scale, + logit_cap=logit_cap) + + +def mla_decode_fwd_impl( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: from aiter.mla import mla_decode_fwd mla_decode_fwd(q, @@ -40,3 +65,24 @@ def aiter_mla_decode_fwd( kv_last_page_lens, sm_scale=sm_scale, logit_cap=logit_cap) + + +def mla_decode_fwd_fake( + q: torch.Tensor, + kv_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: Optional[torch.Tensor] = None, + kv_indices: Optional[torch.Tensor] = None, + kv_last_page_lens: Optional[torch.Tensor] = None, + sm_scale: float = 1.0, + logit_cap: float = 0.0, +) -> None: + pass + + +if current_platform.is_rocm(): + direct_register_custom_op(op_name="rocm_aiter_mla_decode_fwd", + op_func=mla_decode_fwd_impl, + mutates_args=["o"], + fake_impl=mla_decode_fwd_fake, + tags=[torch.Tag.needs_fixed_stride_order]) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a1ff5fb1196b..c2e8c726c943 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -17,7 +17,8 @@ from vllm.logger import init_logger from vllm.utils import weak_ref_tensors -from .compiler_interface import EagerAdaptor, InductorAdaptor +from .compiler_interface import (CompilerInterface, EagerAdaptor, + InductorAdaptor, InductorStandaloneAdaptor) from .counter import compilation_counter from .inductor_pass import InductorPass from .monitor import end_monitoring_torch_compile @@ -26,6 +27,19 @@ logger = init_logger(__name__) +def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface: + if compilation_config.use_inductor: + if envs.VLLM_TEST_STANDALONE_COMPILE: + logger.info("Using InductorStandaloneAdaptor") + return InductorStandaloneAdaptor() + else: + logger.info("Using InductorAdaptor") + return InductorAdaptor() + else: + logger.info("Using EagerAdaptor") + return EagerAdaptor() + + class CompilerManager: """ A manager to manage the compilation process, including @@ -41,11 +55,11 @@ class CompilerManager: support int as key. """ - def __init__(self, use_inductor: bool): + def __init__(self, compilation_config: CompilationConfig): self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() - cls = InductorAdaptor if use_inductor else EagerAdaptor - self.compiler = cls() self.is_cache_updated = False + self.compilation_config = compilation_config + self.compiler = make_compiler(compilation_config) def compute_hash(self, vllm_config: VllmConfig) -> str: return self.compiler.compute_hash(vllm_config) @@ -123,8 +137,15 @@ def compile(self, # no compiler cached the graph, or the cache is disabled, # we need to compile it + if isinstance(self.compiler, InductorAdaptor): + # Let compile_fx generate a key for us + maybe_key = None + else: + maybe_key = \ + f"artifact_shape_{runtime_shape}_subgraph_{graph_index}" compiled_graph, handle = self.compiler.compile( - graph, example_inputs, additional_inductor_config, runtime_shape) + graph, example_inputs, additional_inductor_config, runtime_shape, + maybe_key) assert compiled_graph is not None, "Failed to compile the graph" @@ -336,7 +357,7 @@ def __init__( self.compilation_config = vllm_config.compilation_config self.compiler_manager: CompilerManager = CompilerManager( - self.compilation_config.use_inductor) + self.compilation_config) # `torch.compile` is JIT compiled, so we don't need to # do anything here diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index b7e7a79bef0b..423581784f7a 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -50,7 +50,8 @@ def compile( graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: """ Compile the graph with the given example inputs and compiler config, @@ -71,6 +72,10 @@ def compile( If the compiler doesn't support caching, it should return None for the handle. If the compiler fails to compile the graph, it should return None for the compiled function as well. + + `key` is required for StandaloneInductorAdapter, it specifies where to + save the compiled artifact. The compiled artifact gets saved to + `cache_dir/key`. """ return None, None @@ -127,23 +132,108 @@ def produce_guards_expression(self, *args, **kwargs): return "" +def get_inductor_factors() -> List[Any]: + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + return factors + + +class InductorStandaloneAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler. + Requires PyTorch 2.8+. + This is not on by default yet, but we plan to turn it on by default for + PyTorch 2.8. + + Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off. + """ + name = "inductor_standalone" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors = get_inductor_factors() + hash_str = hashlib.md5(str(factors).encode(), + usedforsecurity=False).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.cache_dir = cache_dir + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None, + key: Optional[str] = None, + ) -> Tuple[Optional[Callable], Optional[Any]]: + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) + set_inductor_config(current_config, runtime_shape) + + if isinstance(runtime_shape, int): + dynamic_shapes = "from_example_inputs" + else: + dynamic_shapes = "from_tracing_context" + + from torch._inductor import standalone_compile + with pass_context(runtime_shape): + compiled_graph = standalone_compile( + graph, + example_inputs, + dynamic_shapes=dynamic_shapes, + options={"config_patches": current_config}) + + # Save the compiled artifact to disk in the specified path + assert key is not None + path = os.path.join(self.cache_dir, key) + compiled_graph.save(path=path, format="unpacked") + return compiled_graph, (key, path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + path = handle[1] + inductor_compiled_graph = torch._inductor.CompiledArtifact.load( + path=path, format="unpacked") + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + def compiled_graph_wrapper(*args): + graph_output = inductor_compiled_graph(*args) + # unpack the tuple if needed + # TODO(rzou): the implication is that we're not + # reading the python bytecode correctly in vLLM? + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph_wrapper + + class InductorAdaptor(CompilerInterface): """ - The adaptor for the Inductor compiler, version 2.5 and 2.6. + The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7. """ name = "inductor" def compute_hash(self, vllm_config: VllmConfig) -> str: - factors: List[Any] = [] - # summarize system state - from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() - factors.append(system_factors) - - # summarize pytorch state - from torch._inductor.codecache import torch_key - torch_factors = torch_key() - factors.append(torch_factors) + factors = get_inductor_factors() hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()[:10] return hash_str @@ -168,23 +258,19 @@ def compile( graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: - current_config = {} from torch._inductor.compile_fx import compile_fx + current_config = {} + if compiler_config is not None: + current_config.update(compiler_config) # disable remote cache current_config["fx_graph_cache"] = True current_config["fx_graph_remote_cache"] = False - if compiler_config is not None: - current_config.update(compiler_config) - - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters - # can be beneficial - current_config["max_autotune"] = True - current_config["coordinate_descent_tuning"] = True + set_inductor_config(current_config, runtime_shape) # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 @@ -422,6 +508,14 @@ def metrics_context(self) -> contextlib.AbstractContextManager: return contextlib.nullcontext() +def set_inductor_config(config, runtime_shape): + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + config["max_autotune"] = True + config["coordinate_descent_tuning"] = True + + class EagerAdaptor(CompilerInterface): name = "eager" @@ -430,7 +524,8 @@ def compile( graph: fx.GraphModule, example_inputs: List[Any], compiler_config: Dict[str, Any], - runtime_shape: Optional[int] = None + runtime_shape: Optional[int] = None, + key: Optional[str] = None, ) -> Tuple[Optional[Callable], Optional[Any]]: # we don't need to compile the graph, just return the graph itself. # It does not support caching, return None for the handle. diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index e8bffb406f14..c95e0bce5f2e 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -4,7 +4,7 @@ import torch -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import PassConfig, VllmConfig # yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( @@ -56,10 +56,7 @@ def end_and_log(self): class PrinterInductorPass(VllmInductorPass): - def __init__(self, - name: str, - config: CompilationConfig.PassConfig, - always=False): + def __init__(self, name: str, config: PassConfig, always=False): super().__init__(config) self.name = name self.always = always diff --git a/vllm/config.py b/vllm/config.py index 11e4e500aa09..ef0163eaff85 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -7,13 +7,12 @@ import inspect import json import re -import sys import textwrap import warnings from collections import Counter from contextlib import contextmanager -from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, - replace) +from dataclasses import (MISSING, Field, asdict, dataclass, field, fields, + is_dataclass, replace) from functools import cached_property from importlib.util import find_spec from pathlib import Path @@ -21,7 +20,6 @@ Protocol, TypeVar, Union, cast, get_args, get_origin) import torch -from pydantic import BaseModel, Field, PrivateAttr from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig from typing_extensions import deprecated @@ -34,7 +32,7 @@ QuantizationMethods, get_quantization_config) from vllm.model_executor.models import ModelRegistry -from vllm.platforms import CpuArchEnum, current_platform +from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, @@ -58,7 +56,7 @@ ConfigType = type[DataclassInstance] else: - QuantizationConfig = None + QuantizationConfig = Any ConfigType = type logger = init_logger(__name__) @@ -170,6 +168,12 @@ def config(cls: ConfigT) -> ConfigT: """ A decorator that ensures all fields in a dataclass have default values and that each field has a docstring. + + If a `ConfigT` is used as a CLI argument itself, the default value provided + by `get_kwargs` will be the result parsing a JSON string as the kwargs + (i.e. `ConfigT(**json.loads(cli_arg))`). However, if a particular `ConfigT` + requires custom construction from CLI (i.e. `CompilationConfig`), it can + have a `from_cli` method, which will be called instead. """ if not is_dataclass(cls): raise TypeError("The decorated class must be a dataclass.") @@ -203,7 +207,7 @@ def get_field(cls: ConfigType, name: str) -> Field: cls_fields = {f.name: f for f in fields(cls)} if name not in cls_fields: raise ValueError(f"Field '{name}' not found in {cls.__name__}.") - named_field: Field = cls_fields.get(name) + named_field: Field = cls_fields[name] if (default_factory := named_field.default_factory) is not MISSING: return field(default_factory=default_factory) if (default := named_field.default) is not MISSING: @@ -212,6 +216,10 @@ def get_field(cls: ConfigType, name: str) -> Field: f"{cls.__name__}.{name} must have a default value or default factory.") +def is_init_field(cls: ConfigType, name: str) -> bool: + return next(f for f in fields(cls) if f.name == name).init + + TokenizerMode = Literal["auto", "slow", "mistral", "custom"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] @@ -2008,13 +2016,13 @@ def compute_hash(self) -> str: def __post_init__(self) -> None: if self.max_model_len is None: self.max_model_len = 8192 - logger.warning( + logger.warning_once( "max_model_len was is not set. Defaulting to arbitrary value " "of %d.", self.max_model_len) if self.max_num_seqs is None: self.max_num_seqs = 128 - logger.warning( + logger.warning_once( "max_num_seqs was is not set. Defaulting to arbitrary value " "of %d.", self.max_num_seqs) @@ -2050,6 +2058,13 @@ def __post_init__(self) -> None: _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) + # When using default settings, + # Ensure max_num_batched_tokens does not exceed model limit. + # Some models (e.g., Whisper) have embeddings tied to max length. + self.max_num_batched_tokens = min( + self.max_num_seqs * self.max_model_len, + self.max_num_batched_tokens) + self.max_num_encoder_input_tokens = self.max_num_batched_tokens self.encoder_cache_size = self.max_num_batched_tokens @@ -2090,6 +2105,13 @@ def _verify_args(self) -> None: "be greater than or equal to max_num_seqs " f"({self.max_num_seqs}).") + if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len: + logger.warning( + "max_num_batched_tokens (%d) exceeds max_num_seqs" + "* max_model_len (%d). This may lead to unexpected behavior.", + self.max_num_batched_tokens, + self.max_num_seqs * self.max_model_len) + if self.num_lookahead_slots < 0: raise ValueError( "num_lookahead_slots " @@ -2694,8 +2716,8 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 """Maximum size of extra vocabulary that can be present in a LoRA adapter (added to the base model vocabulary).""" - # This is a constant. - lora_vocab_padding_size: ClassVar[int] = 256 + lora_vocab_padding_size: ClassVar[int] = current_platform\ + .get_lora_vocab_padding_size() long_lora_scaling_factors: Optional[tuple[float, ...]] = None """Specify multiple scaling factors (which can be different from base model scaling factor - see eg. Long LoRA) to allow for multiple LoRA adapters @@ -2723,6 +2745,7 @@ def compute_hash(self) -> str: factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) factors.append(self.lora_extra_vocab_size) + factors.append(self.lora_vocab_padding_size) factors.append(self.long_lora_scaling_factors) factors.append(self.bias_enabled) hash_str = hashlib.md5(str(factors).encode(), @@ -2826,8 +2849,8 @@ def verify_with_model_config(self, model_config: ModelConfig): class MultiModalConfig: """Controls the behavior of multimodal models.""" - limit_per_prompt: dict[str, int] = get_field(ModelConfig, - "limit_mm_per_prompt") + limit_per_prompt: dict[str, int] = \ + cast(dict[str, int], get_field(ModelConfig, "limit_mm_per_prompt")) """ The maximum number of input items allowed per prompt for each modality. Defaults to 1 (V0) or 999 (V1) for each modality. @@ -2973,6 +2996,7 @@ def _get_and_verify_dtype( if isinstance(dtype, str): dtype = dtype.lower() if dtype == "auto": + # Set default dtype from model config if config_dtype == torch.float32: # Following common practice, we use float16 for float32 models torch_dtype = torch.float16 @@ -2980,37 +3004,33 @@ def _get_and_verify_dtype( torch_dtype = config_dtype if config.model_type == "plamo2": - logger.info( + logger.warning( "For PLaMo2, we cast models to bfloat16 instead of using " "float16 by default. This is because float16 does not work." ) torch_dtype = torch.bfloat16 + # Deal with torch dtype fallback for device compatibility. from vllm.platforms import current_platform - if (current_platform.is_cpu() - and current_platform.get_cpu_architecture() - == CpuArchEnum.POWERPC - and (config_dtype == torch.float16 - or config_dtype == torch.float32)): - logger.info( - "For POWERPC, we cast models to bfloat16 instead of " - "using float16 by default. Float16 is not currently " - "supported for POWERPC.") - torch_dtype = torch.bfloat16 + if torch_dtype not in current_platform.supported_dtypes: + device_name = current_platform.get_device_name() - # TODO: change this condition to check if the platform support bf16 - # instead of checking the OS. For instance M2 shall supports bf16 - # already. But we need to modify `cpu_extension.cmake` to activate - # the feature in the build. - if (current_platform.is_cpu() and sys.platform.startswith("darwin") - and current_platform.get_cpu_architecture() - == CpuArchEnum.ARM and config_dtype == torch.bfloat16): - logger.info("For macOS with Apple Silicon, currently bfloat16 " - "is not supported. Setting dtype to float16.") - torch_dtype = torch.float16 + if ((capability := current_platform.get_device_capability()) + is None): + compute_str = "" + else: + version_str = capability.as_version_str() + compute_str = f" (with compute capability {version_str})" + fallback_dtype = current_platform.supported_dtypes[0] + logger.warning( + "Your %s device%s doesn't support %s. " \ + "Falling back to %s for compatibility.", + device_name, compute_str, torch_dtype, fallback_dtype + ) + torch_dtype = fallback_dtype - if current_platform.is_hpu() and config_dtype == torch.float16: - logger.info( + if current_platform.is_hpu() and torch_dtype == torch.float16: + logger.warning( "For HPU, we cast models to bfloat16 instead of " "using float16 by default. Please specify `dtype` if you " "want to use float16.") @@ -3404,41 +3424,49 @@ def _parse_collect_detailed_traces(self): self.collect_detailed_traces[0].split(",")) -class KVTransferConfig(BaseModel): +KVProducer = Literal["kv_producer", "kv_both"] +KVConsumer = Literal["kv_consumer", "kv_both"] +KVRole = Literal[KVProducer, KVConsumer] + + +@config +@dataclass +class KVTransferConfig: """Configuration for distributed KV cache transfer.""" - # The KV connector for vLLM to transmit KV caches between vLLM instances. kv_connector: Optional[str] = None + """The KV connector for vLLM to transmit KV caches between vLLM instances. + """ - # The device used by kv connector to buffer the KV cache. - # Currently only support 'cuda'. kv_buffer_device: Optional[str] = "cuda" + """The device used by kv connector to buffer the KV cache. + Currently only support 'cuda'.""" - # The buffer size for TorchDistributedConnector. Measured in number of - # bytes. Recommended value: 1e9 (about 1GB). kv_buffer_size: float = 1e9 + """The buffer size for TorchDistributedConnector. Measured in number of + bytes. Recommended value: 1e9 (about 1GB).""" - # Whether this vLLM instance produces, consumes KV cache, or both. Choices - # are 'kv_producer', 'kv_consumer', and 'both'. - kv_role: Optional[str] = None + kv_role: Optional[KVRole] = None + """Whether this vLLM instance produces, consumes KV cache, or both. Choices + are 'kv_producer', 'kv_consumer', and 'both'.""" - # The rank of this vLLM instance in the KV cache transfer. Typical value: - # 0 for prefill instance, 1 for decode instance. - # Currently only 1P1D is supported. kv_rank: Optional[int] = None + """The rank of this vLLM instance in the KV cache transfer. Typical value: + 0 for prefill instance, 1 for decode instance. + Currently only 1P1D is supported.""" - # The number of parallel instances for KV cache transfer. For - # PyNcclConnector, this should be 2. kv_parallel_size: int = 1 + """The number of parallel instances for KV cache transfer. For + PyNcclConnector, this should be 2.""" - # The KV connector ip, used to build distributed connection kv_ip: str = "127.0.0.1" + """The KV connector ip, used to build distributed connection.""" - # The KV connector port, used to build distributed connection kv_port: int = 14579 + """The KV connector port, used to build distributed connection.""" - # any extra config that the connector may need - kv_connector_extra_config: dict[str, Any] = {} + kv_connector_extra_config: dict[str, Any] = field(default_factory=dict) + """any extra config that the connector may need.""" def compute_hash(self) -> str: """ @@ -3459,46 +3487,37 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - @classmethod - def from_cli(cls, cli_value: str) -> "KVTransferConfig": - """Parse the CLI value for the kv cache transfer config.""" - return KVTransferConfig.model_validate_json(cli_value) - - def model_post_init(self, __context: Any) -> None: - - if self.kv_role is not None and self.kv_role not in [ - "kv_producer", "kv_consumer", "kv_both" - ]: - raise ValueError( - f"Unsupported kv_role: {self.kv_role}. " - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") + def __post_init__(self) -> None: + if self.kv_role is not None and self.kv_role not in get_args(KVRole): + raise ValueError(f"Unsupported kv_role: {self.kv_role}. " + f"Supported roles are {get_args(KVRole)}") if self.kv_connector is not None and self.kv_role is None: raise ValueError("Please specify kv_disagg_role when kv_connector " - "is set, supported roles are `kv_producer`, " - "`kv_consumer`, and `kv_both`") + f"is set, supported roles are {get_args(KVRole)}") @property def is_kv_transfer_instance(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_producer", "kv_consumer", "kv_both"] + self.kv_role in get_args(KVRole) @property def is_kv_producer(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_producer", "kv_both"] + self.kv_role in get_args(KVProducer) @property def is_kv_consumer(self) -> bool: return self.kv_connector is not None and \ - self.kv_role in ["kv_consumer", "kv_both"] + self.kv_role in get_args(KVConsumer) def get_from_extra_config(self, key, default) -> Any: return self.kv_connector_extra_config.get(key, default) -class KVEventsConfig(BaseModel): +@config +@dataclass +class KVEventsConfig: """Configuration for KV event publishing.""" enable_kv_cache_events: bool = False @@ -3537,11 +3556,6 @@ class KVEventsConfig(BaseModel): this topic to receive events. """ - @classmethod - def from_cli(cls, cli_value: str) -> "KVEventsConfig": - """Parse the CLI value for the event publisher config.""" - return KVEventsConfig.model_validate_json(cli_value) - class CompilationLevel: # constants for the levels of the compilation process @@ -3551,76 +3565,72 @@ class CompilationLevel: PIECEWISE = 3 -class CompilationConfig(BaseModel): - """ - Configuration for compilation. - It has three parts: +@config +@dataclass +class PassConfig: + """Configuration for custom Inductor passes. + + This is separate from general `CompilationConfig` so that inductor passes + don't all have access to full configuration - that would create a cycle as + the `PassManager` is set as a property of config.""" + + dump_graph_stages: list[str] = field(default_factory=list) + """List of stages for which we want to dump the graph. Each pass defines + its own stages (before, after, maybe in-between).""" + dump_graph_dir: Path = Path(".") + """Directory to dump the graphs.""" + # TODO(luka) better pass enabling system. + enable_fusion: bool = True + """Whether to enable the custom fusion pass.""" + enable_noop: bool = True + """Whether to enable the custom no-op elimination pass.""" + enable_sequence_parallelism: bool = False + """Whether to enable sequence parallelism.""" + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + include = { + "enable_fusion", "enable_noop", "enable_sequence_parallelism" + } + dict_ = {k: v for k, v in asdict(self).items() if k in include} + return InductorPass.hash_dict(dict_) + + def __post_init__(self) -> None: + if not self.enable_noop and self.enable_fusion: + logger.warning_once( + "Fusion enabled but reshape elimination disabled. " + "RMSNorm + quant (fp8) fusion might not work") + + +@config +@dataclass +class CompilationConfig: + """Configuration for compilation. It has three parts: + - Top-level Compilation control: - - level: the level of compilation. - - 0: no compilation. - - 1: dynamo as is. - - 2: dynamo once. - - 3: piecewise compilation. - - debug_dump_path: the path to dump the debug information. - - cache_dir: the directory to store the compiled graph, to - accelerate Inductor compilation. By default, it will use - model-related information to generate a cache directory. - - backend: the backend for compilation. It needs to be a string. - - "" (empty string): use the default backend. - - "eager"/"openxla"/...: use the specified backend registered in PyTorch. - - "full.module.name": a qualified name which can be used to import the backend function. - We use string to avoid serialization issues when using compilation in a distributed setting. - When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph). - When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph). - - custom_ops: fine-grained control over which custom ops to enable/disable. - Use 'all' to enable all, 'none' to disable all. - Also specify a list of custom op names to enable (prefixed with a '+'), - or disable (prefixed with a '-'). - Examples: - - 'all,-op1' to enable all except op1 - - 'none,+op1,+op2' to enable only op1 and op2 - By default, all custom ops are enabled when running without Inductor - and disabled when running with Inductor (compile_level >= Inductor). - - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. + - {attr}`level` + - {attr}`debug_dump_path` + - {attr}`cache_dir` + - {attr}`backend` + - {attr}`custom_ops` + - {attr}`splitting_ops` - CudaGraph capture: - - use_cudagraph: whether to use cudagraph inside compilation. - - False: cudagraph inside compilation is not used. - - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses, and all - splitting ops write their outputs to input buffers. - Note that this is orthogonal to the cudagraph capture logic - outside of compilation. - TODO: move outside cudagraph logic into compilation. - torch.compile will handle cudagraph capture logic in the future. - - cudagraph_capture_sizes: sizes to capture cudagraph. - - None (default): capture sizes are inferred from vllm config. - - list[int]: capture sizes are specified as given. - - cudagraph_num_of_warmups: number of warmup runs for cudagraph. - It means the first several runs will be treated as warmup runs. - Only after that, the execution will be recorded, and the recorded - cudagraph will be used for subsequent runs. - - cudagraph_copy_inputs: whether to copy input tensors for - cudagraph. If the caller can guarantee that the same input buffers - are always used, it can set this to False. Otherwise, it should - set this to True, and the compiler will copy the input to an - internally managed buffer. Default is False. + - {attr}`use_cudagraph` + - {attr}`cudagraph_capture_sizes` + - {attr}`cudagraph_num_of_warmups` + - {attr}`cudagraph_copy_inputs` + - {attr}`full_cuda_graph` - Inductor compilation: - - use_inductor: whether to use inductor compilation. - - False: inductor compilation is not used. graph runs in eager. - - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for compile_sizes, - using configurations in inductor_compile_config. - - compile_sizes: sizes to compile for inductor. In addition - to integers, it also supports "cudagraph_capture_sizes" to - specify the sizes for cudagraph capture. - - inductor_compile_config: additional configurations for inductor. - - None: use default configurations. - - inductor_passes: additional passes for inductor. It is a dictionary - from pass name to pass function qualified name. We use function - name because the config uses json format. If we pass the config - from Python, functions can also be passed directly via Python object - constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - custom inductor passes: see PassConfig for more details + - {attr}`use_inductor` + - {attr}`compile_sizes` + - {attr}`inductor_compile_config` + - {attr}`inductor_passes` + - custom inductor passes Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -3631,82 +3641,135 @@ class CompilationConfig(BaseModel): static shapes. However, we find the general shape compilation is sufficient for most cases. It might be beneficial to compile for certain small batchsizes, where inductor is good at optimizing. - """ # noqa + """ + # Top-level Compilation control level: int = 0 + """The level of compilation: + + - 0: no compilation. + - 1: dynamo as is. + - 2: dynamo once. + - 3: piecewise compilation.""" debug_dump_path: str = "" + """The path to dump the debug information.""" cache_dir: str = "" + """The directory to store the compiled graph, to accelerate Inductor + compilation. By default, it will use model-related information to generate + a cache directory.""" backend: str = "" - custom_ops: list[str] = Field(default_factory=list) - splitting_ops: list[str] = Field(default=None) # type: ignore - + """The backend for compilation. It needs to be a string: + + - "" (empty string): use the default backend. + - "eager"/"openxla"/...: use the specified backend registered in PyTorch. + - "full.module.name": a qualified name which can be used to import the + + backend function. + We use string to avoid serialization issues when using compilation in a + distributed setting. When the compilation level is 1 or 2, the backend is + used for the compilation directly (it sees the whole graph). When the + compilation level is 3, the backend is used for the piecewise compilation + (it sees a part of the graph).""" + custom_ops: list[str] = field(default_factory=list) + """Fine-grained control over which custom ops to enable/disable. Use 'all' + to enable all, 'none' to disable all. Also specify a list of custom op + names to enable (prefixed with a '+'), or disable (prefixed with a '-'). + Examples: + + - 'all,-op1' to enable all except op1 + - 'none,+op1,+op2' to enable only op1 and op2 + + By default, all custom ops are enabled when running without Inductor and + disabled when running with Inductor (compile_level >= Inductor).""" + splitting_ops: list[str] = field(default_factory=list) + """A list of ops to split the full graph into subgraphs, used in piecewise + compilation.""" + + # Inductor capture use_inductor: bool = True - compile_sizes: Optional[list[Union[int, str]]] = Field(default=None) - inductor_compile_config: dict = Field(default_factory=dict) - inductor_passes: dict[str, str] = Field(default_factory=dict) - + """Whether to use inductor compilation: + + - False: inductor compilation is not used. graph runs in eager. + - True: inductor compilation is used. one graph for symbolic shape + is compiled. In addition, compile for compile_sizes, + using configurations in inductor_compile_config.""" + compile_sizes: Optional[list[Union[int, str]]] = None + """Sizes to compile for inductor. In addition + to integers, it also supports "cudagraph_capture_sizes" to + specify the sizes for cudagraph capture.""" + inductor_compile_config: dict = field(default_factory=dict) + """Additional configurations for inductor. + - None: use default configurations.""" + inductor_passes: dict[str, str] = field(default_factory=dict) + """Additional passes for inductor. It is a dictionary + from pass name to pass function qualified name. We use function + name because the config uses JSON format. If we pass the config + from Python, functions can also be passed directly via Python object + constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" + + # CudaGraph compilation use_cudagraph: bool = False + """Whether to use cudagraph inside compilation. + - False: cudagraph inside compilation is not used. + - True: cudagraph inside compilation is used. It requires + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. + TODO: move outside cudagraph logic into compilation. + torch.compile will handle cudagraph capture logic in the future.""" cudagraph_num_of_warmups: int = 0 + """Number of warmup runs for cudagraph. + It means the first several runs will be treated as warmup runs. + Only after that, the execution will be recorded, and the recorded + cudagraph will be used for subsequent runs.""" cudagraph_capture_sizes: Optional[list[int]] = None + """Sizes to capture cudagraph. + - None (default): capture sizes are inferred from vllm config. + - list[int]: capture sizes are specified as given.""" cudagraph_copy_inputs: bool = False - - class PassConfig(BaseModel): - """ - Configuration for custom Inductor passes. - This is separate from general CompilationConfig so that inductor passes - don't all have access to full configuration - that would create a cycle - as the PassManager is set as a property of config. - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graphs. Default is . - - enable_fusion: whether to enable the custom fusion pass. - - enable_noop: whether to enable the custom no-op elimination pass. - TODO(luka) better pass enabling system. - - enable_sequence_parallelism: whether to enable sequence parallelism. - """ - dump_graph_stages: list[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True - enable_noop: bool = True - enable_sequence_parallelism: bool = False - - def uuid(self): - """ - Produces a hash unique to the pass configuration. - Any new fields that affect compilation should be added to the hash. - Do not include dump_graph_* in the hash - they don't affect - compilation. - """ - dict_ = self.model_dump(include={"enable_fusion", "enable_noop", \ - "enable_sequence_parallelism"}) - return InductorPass.hash_dict(dict_) - - def model_post_init(self, __context: Any) -> None: - if not self.enable_noop and self.enable_fusion: - logger.warning_once( - "Fusion enabled but reshape elimination disabled. " - "RMSNorm + quant (fp8) fusion might not work") - - pass_config: PassConfig = Field(default_factory=PassConfig) - - # not configurable, computed after init - max_capture_size: int = PrivateAttr - local_cache_dir: str = PrivateAttr # local cache dir for each rank - # optimization: - # Intuitively, bs_to_padded_graph_size should be dict[int, int]. - # since we know all keys are in a range [0, max_capture_size], - # we can optimize it to list[int] for better lookup performance. - bs_to_padded_graph_size: list[int] = PrivateAttr + """Whether to copy input tensors for + cudagraph. If the caller can guarantee that the same input buffers + are always used, it can set this to False. Otherwise, it should + set this to True, and the compiler will copy the input to an + internally managed buffer. Default is False.""" + full_cuda_graph: bool = False + """whether to use a full cuda graph for the entire forward pass rather than + splitting certain operations such as attention into subgraphs. Thus this + flag cannot be used together with splitting_ops. This may provide + performance benefits for smaller models.""" + + pass_config: PassConfig = field(default_factory=PassConfig) + """Custom inductor passes, see PassConfig for more details""" + + max_capture_size: int = field(default=None, init=False) # type: ignore + """not configurable, computed after init""" + local_cache_dir: str = field(default=None, init=False) # type: ignore + """local cache dir for each rank""" + bs_to_padded_graph_size: list[int] = field( + default=None, # type: ignore + init=False) + """optimization: + Intuitively, bs_to_padded_graph_size should be dict[int, int]. + since we know all keys are in a range [0, max_capture_size], + we can optimize it to list[int] for better lookup performance.""" # keep track of enabled and disabled custom ops - enabled_custom_ops: Counter[str] = PrivateAttr - disabled_custom_ops: Counter[str] = PrivateAttr - traced_files: set[str] = PrivateAttr - compilation_time: float = PrivateAttr - - # Per-model forward context - # Map from layer name to layer objects that need to be accessed outside - # model code, e.g., Attention, FusedMOE when dp_size>1. - static_forward_context: dict[str, Any] = PrivateAttr + enabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are enabled""" + disabled_custom_ops: Counter[str] = field(default_factory=Counter, + init=False) + """custom ops that are disabled""" + traced_files: set[str] = field(default_factory=set, init=False) + """files that are traced for compilation""" + compilation_time: float = field(default=0.0, init=False) + """time taken for compilation""" + + static_forward_context: dict[str, Any] = field(default_factory=dict, + init=False) + """Per-model forward context + Map from layer name to layer objects that need to be accessed outside + model code, e.g., Attention, FusedMOE when dp_size>1.""" def compute_hash(self) -> str: """ @@ -3741,7 +3804,17 @@ def __repr__(self) -> str: "pass_config", "traced_files", } - return self.model_dump_json(exclude=exclude, exclude_unset=True) + include = dict() + for k, v in asdict(self).items(): + if k in exclude: + continue + f = get_field(CompilationConfig, k) + if (d := f.default) is not MISSING and d == v: + continue + if (df := f.default_factory) is not MISSING and df() == v: + continue + include[k] = v + return json.dumps(include) __str__ = __repr__ @@ -3750,12 +3823,9 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - # do not use `eval`, it is dangerous and can execute arbitrary code - dict_value = ast.literal_eval(cli_value) - return CompilationConfig.model_validate(dict_value) - - def model_post_init(self, __context: Any) -> None: + return cls(**json.loads(cli_value)) + def __post_init__(self) -> None: count_none = self.custom_ops.count("none") count_all = self.custom_ops.count("all") assert count_none + count_all <= 1, "Can only specify 'none' or 'all'" @@ -3773,9 +3843,6 @@ def model_post_init(self, __context: Any) -> None: if KEY not in self.inductor_compile_config: self.inductor_compile_config[KEY] = False - if self.splitting_ops is None: - self.splitting_ops = [] - for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( @@ -3792,11 +3859,8 @@ def model_post_init(self, __context: Any) -> None: self.inductor_compile_config[k] = func if isinstance( func, InductorPass) else CallableInductorPass(func) - self.enabled_custom_ops = Counter() - self.disabled_custom_ops = Counter() - self.traced_files = set() - self.static_forward_context = {} - self.compilation_time = 0.0 + if isinstance(self.pass_config, dict): + self.pass_config = PassConfig(**self.pass_config) def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: @@ -3870,48 +3934,80 @@ def init_with_cudagraph_sizes(self, self.max_capture_size] = self.max_capture_size def set_splitting_ops_for_v1(self): - # If default, override splitting ops for piecewise cudagraph on V1. # NOTE: this function needs to be called + if self.splitting_ops and self.full_cuda_graph: + raise ValueError("full_cuda_graph cannot be used together with " + "splitting_ops, as Full CUDA graph will override " + f"the splitting_ops: {self.splitting_ops}") + if not self.splitting_ops: - self.splitting_ops = [ + self.splitting_ops = [] if self.full_cuda_graph else [ "vllm.unified_attention", "vllm.unified_attention_with_output", ] +@config @dataclass class VllmConfig: """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. """ - model_config: ModelConfig = field(default=None, init=True) # type: ignore - cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default_factory=ParallelConfig, - init=True) - scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, - init=True) - device_config: DeviceConfig = field(default=None, - init=True) # type: ignore - load_config: LoadConfig = field(default=None, init=True) # type: ignore + model_config: ModelConfig = field(default_factory=ModelConfig) + """Model configuration.""" + cache_config: CacheConfig = field(default_factory=CacheConfig) + """Cache configuration.""" + parallel_config: ParallelConfig = field(default_factory=ParallelConfig) + """Parallel configuration.""" + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig) + """Scheduler configuration.""" + device_config: DeviceConfig = field(default_factory=DeviceConfig) + """Device configuration.""" + load_config: LoadConfig = field(default_factory=LoadConfig) + """Load configuration.""" lora_config: Optional[LoRAConfig] = None - speculative_config: SpeculativeConfig = field(default=None, - init=True) # type: ignore + """LoRA configuration.""" + speculative_config: Optional[SpeculativeConfig] = None + """Speculative decoding configuration.""" decoding_config: Optional[DecodingConfig] = None + """Decoding configuration.""" observability_config: Optional[ObservabilityConfig] = None + """Observability configuration.""" prompt_adapter_config: Optional[PromptAdapterConfig] = None + """Prompt adapter configuration.""" quant_config: Optional[QuantizationConfig] = None - compilation_config: CompilationConfig = field(default=None, - init=True) # type: ignore - kv_transfer_config: KVTransferConfig = field(default=None, - init=True) # type: ignore + """Quantization configuration.""" + compilation_config: CompilationConfig = field( + default_factory=CompilationConfig) + """`torch.compile` configuration for the model. + + When it is a number (0, 1, 2, 3), it will be interpreted as the + optimization level. + + NOTE: level 0 is the default level without any optimization. level 1 and 2 + are for internal testing only. level 3 is the recommended level for + production. + + Following the convention of traditional compilers, using `-O` without space + is also supported. `-O3` is equivalent to `-O 3`. + + You can specify the full compilation config like so: + `{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}` + """ + kv_transfer_config: Optional[KVTransferConfig] = None + """The configurations for distributed KV cache transfer.""" kv_events_config: Optional[KVEventsConfig] = None + """The configurations for event publishing.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. - additional_config: SupportsHash = field(default=None, - init=True) # type: ignore + additional_config: Union[dict, SupportsHash] = field(default_factory=dict) + """Additional config for specified platform. Different platforms may + support different configs. Make sure the configs are valid for the platform + you are using. Contents must be hashable.""" instance_id: str = "" + """The ID of the vLLM instance.""" def compute_hash(self) -> str: """ @@ -3992,7 +4088,14 @@ def compute_hash(self) -> str: else: vllm_factors.append("None") if self.additional_config: - vllm_factors.append(self.additional_config.compute_hash()) + if isinstance(additional_config := self.additional_config, dict): + additional_config_hash = hashlib.md5( + json.dumps(additional_config, sort_keys=True).encode(), + usedforsecurity=False, + ).hexdigest() + else: + additional_config_hash = additional_config.compute_hash() + vllm_factors.append(additional_config_hash) else: vllm_factors.append("None") factors.append(vllm_factors) @@ -4150,6 +4253,12 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.compilation_config.full_cuda_graph and \ + not self.model_config.disable_cascade_attn: + logger.warning_once( + "full_cuda_graph is not supported with " + "cascade attention. Disabling cascade attention.") + self.model_config.disable_cascade_attn = True if self.model_config and self.model_config.use_mla and \ not (current_platform.is_cuda() or current_platform.is_rocm()): diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 1f4b4faf1190..296f5f2b424b 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -22,7 +22,8 @@ def __init__(self, super().__init__(cpu_group, device, device_group, unique_name) self.dist_module = torch.distributed - if current_platform.get_cpu_architecture() == CpuArchEnum.X86: + if (current_platform.get_cpu_architecture() == CpuArchEnum.X86) \ + and hasattr(torch.ops._C, "init_shm_manager"): self.dist_module = _CPUSHMDistributed(self) def all_reduce(self, input_): diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 960913858527..1141a8e53c3b 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -5,6 +5,7 @@ import time from abc import ABC, abstractmethod from collections import deque +from dataclasses import asdict from itertools import count from queue import Queue from typing import Any, Callable, Optional, Union @@ -284,7 +285,7 @@ def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher: if not config: return NullEventPublisher() - config_dict = config.model_dump() + config_dict = asdict(config) kind = config_dict.pop("publisher", "null") config_dict.pop("enable_kv_cache_events") diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5d8b5d1f618a..0ff6a6fbbc1c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,10 +7,10 @@ import re import threading import warnings -from dataclasses import MISSING, dataclass, fields +from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (Any, Callable, Dict, List, Literal, Optional, Type, - TypeVar, Union, cast, get_args, get_origin) +from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional, + Type, TypeVar, Union, cast, get_args, get_origin) import torch from typing_extensions import TypeIs, deprecated @@ -36,7 +36,8 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor +from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build, + is_in_ray_actor) # yapf: enable @@ -48,12 +49,9 @@ TypeHintT = Union[type[T], object] -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: +def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _optional_type(val: str) -> Optional[T]: - if val == "" or val == "None": - return None + def _parse_type(val: str) -> T: try: if return_type is json.loads and not re.match("^{.*}$", val): return cast(T, nullable_kvs(val)) @@ -62,14 +60,24 @@ def _optional_type(val: str) -> Optional[T]: raise argparse.ArgumentTypeError( f"Value {val} cannot be converted to {return_type}.") from e + return _parse_type + + +def optional_type( + return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: + + def _optional_type(val: str) -> Optional[T]: + if val == "" or val == "None": + return None + return parse_type(return_type)(val) + return _optional_type def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]: if not re.match("^{.*}$", val): return str(val) - else: - return optional_type(json.loads)(val) + return optional_type(json.loads)(val) @deprecated( @@ -144,10 +152,25 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} for field in fields(cls): + # Get the set of possible types for the field + type_hints: set[TypeHint] = set() + if get_origin(field.type) in {Union, Annotated}: + type_hints.update(get_args(field.type)) + else: + type_hints.add(field.type) + + # If the field is a dataclass, we can use the model_validate_json + generator = (th for th in type_hints if is_dataclass(th)) + dataclass_cls = next(generator, None) + # Get the default value of the field - default = field.default - if field.default_factory is not MISSING: - default = field.default_factory() + if field.default is not MISSING: + default = field.default + elif field.default_factory is not MISSING: + if is_dataclass(field.default_factory) and is_in_doc_build(): + default = {} + else: + default = field.default_factory() # Get the help text for the field name = field.name @@ -158,16 +181,17 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Initialise the kwargs dictionary for the field kwargs[name] = {"default": default, "help": help} - # Get the set of possible types for the field - type_hints: set[TypeHint] = set() - if get_origin(field.type) is Union: - type_hints.update(get_args(field.type)) - else: - type_hints.add(field.type) - # Set other kwargs based on the type hints json_tip = "\n\nShould be a valid JSON string." - if contains_type(type_hints, bool): + if dataclass_cls is not None: + dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) + # Special case for configs with a from_cli method + if hasattr(dataclass_cls, "from_cli"): + from_cli = dataclass_cls.from_cli + dataclass_init = lambda x, f=from_cli: f(x) + kwargs[name]["type"] = dataclass_init + kwargs[name]["help"] += json_tip + elif contains_type(type_hints, bool): # Creates --no- and -- flags kwargs[name]["action"] = argparse.BooleanOptionalAction elif contains_type(type_hints, Literal): @@ -202,7 +226,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): # Dict arguments will always be optional - kwargs[name]["type"] = optional_type(json.loads) + kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += json_tip elif (contains_type(type_hints, str) or any(is_not_builtin(th) for th in type_hints)): @@ -771,68 +795,26 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: scheduler_group.add_argument("--scheduler-cls", **scheduler_kwargs["scheduler_cls"]) - # Compilation arguments - # compilation_kwargs = get_kwargs(CompilationConfig) - compilation_group = parser.add_argument_group( - title="CompilationConfig", - description=CompilationConfig.__doc__, - ) - compilation_group.add_argument( - "--compilation-config", - "-O", - type=CompilationConfig.from_cli, - default=None, - help="torch.compile configuration for the model. " - "When it is a number (0, 1, 2, 3), it will be " - "interpreted as the optimization level.\n" - "NOTE: level 0 is the default level without " - "any optimization. level 1 and 2 are for internal " - "testing only. level 3 is the recommended level " - "for production.\n" - "To specify the full compilation config, " - "use a JSON string, e.g. ``{\"level\": 3, " - "\"cudagraph_capture_sizes\": [1, 2, 4, 8]}``\n" - "Following the convention of traditional " - "compilers, using ``-O`` without space is also " - "supported. ``-O3`` is equivalent to ``-O 3``.") - - # KVTransfer arguments - # kv_transfer_kwargs = get_kwargs(KVTransferConfig) - kv_transfer_group = parser.add_argument_group( - title="KVTransferConfig", - description=KVTransferConfig.__doc__, - ) - kv_transfer_group.add_argument( - "--kv-transfer-config", - type=KVTransferConfig.from_cli, - default=None, - help="The configurations for distributed KV cache " - "transfer. Should be a JSON string.") - kv_transfer_group.add_argument( - '--kv-events-config', - type=KVEventsConfig.from_cli, - default=None, - help='The configurations for event publishing.') - # vLLM arguments - # vllm_kwargs = get_kwargs(VllmConfig) + vllm_kwargs = get_kwargs(VllmConfig) vllm_group = parser.add_argument_group( title="VllmConfig", description=VllmConfig.__doc__, ) - vllm_group.add_argument( - "--additional-config", - type=json.loads, - default=None, - help="Additional config for specified platform in JSON format. " - "Different platforms may support different configs. Make sure the " - "configs are valid for the platform you are using. The input format" - " is like '{\"config_key\":\"config_value\"}'") + vllm_group.add_argument("--kv-transfer-config", + **vllm_kwargs["kv_transfer_config"]) + vllm_group.add_argument('--kv-events-config', + **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument("--compilation-config", "-O", + **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--additional-config", + **vllm_kwargs["additional_config"]) # Other arguments parser.add_argument('--use-v2-block-manager', action='store_true', default=True, + deprecated=True, help='[DEPRECATED] block manager v1 has been ' 'removed and SelfAttnBlockSpaceManager (i.e. ' 'block manager v2) is now the default. ' @@ -1318,6 +1300,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASHMLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "ROCM_AITER_MLA", ] if (envs.is_set("VLLM_ATTENTION_BACKEND") and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 3d7b73f97a19..a5dcf9e2d945 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -284,7 +284,7 @@ def handle_new_input(self): except Exception as e: self._set_errored(e) self._send_unhealthy(e) - raise e + raise e from None def _handle_process_request(self, request: RPCProcessRequest): """Handle RPCProcessRequest by adding it to the LLMEngine.""" @@ -447,4 +447,4 @@ def run_mp_engine(vllm_config: VllmConfig, usage_context: UsageContext, except BaseException as e: logger.exception(e) engine_alive.value = False - raise e + raise e from None diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 23dded7f226f..38fe98572178 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -38,6 +38,10 @@ from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.utils import MediaConnector +# yapf: disable +from vllm.transformers_utils.chat_templates import ( + get_chat_template_fallback_path) +# yapf: enable from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -325,11 +329,10 @@ def resolve_mistral_chat_template( return None def resolve_hf_chat_template( + model_config: ModelConfig, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], - *, - trust_remote_code: bool, ) -> Optional[str]: # 1st priority: The given chat template if chat_template is not None: @@ -342,7 +345,7 @@ def resolve_hf_chat_template( tokenizer.name_or_path, processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin), - trust_remote_code=trust_remote_code, + trust_remote_code=model_config.trust_remote_code, ) if isinstance(processor, ProcessorMixin) and \ processor.chat_template is not None: @@ -358,22 +361,34 @@ def resolve_hf_chat_template( logger.debug("Failed to load AutoTokenizer chat template for %s", tokenizer.name_or_path, exc_info=True) - return None + # 4th priority: Predefined fallbacks + path = get_chat_template_fallback_path( + model_type=model_config.hf_config.model_type, + tokenizer_name_or_path=model_config.tokenizer, + ) + if path is not None: + logger.info("Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", tokenizer.name_or_path) + chat_template = load_chat_template(path) + else: + logger.debug("There is no chat template fallback for %s", + tokenizer.name_or_path) + + return chat_template def _resolve_chat_template_content_format( + model_config: ModelConfig, chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, - *, - trust_remote_code: bool, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): hf_chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=chat_template, - trust_remote_code=trust_remote_code, tools=tools, ) else: @@ -413,19 +428,18 @@ def _log_chat_template_content_format( def resolve_chat_template_content_format( + model_config: ModelConfig, chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, - *, - trust_remote_code: bool = False, ) -> _ChatTemplateContentFormat: detected_format = _resolve_chat_template_content_format( + model_config, chat_template, tools, given_format, tokenizer, - trust_remote_code=trust_remote_code, ) _log_chat_template_content_format( @@ -1177,20 +1191,20 @@ def parse_chat_messages_futures( def apply_hf_chat_template( + model_config: ModelConfig, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: list[ConversationMessage], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], *, - trust_remote_code: bool = False, tokenize: bool = False, # Different from HF's default **kwargs: Any, ) -> str: hf_chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=chat_template, tools=tools, - trust_remote_code=trust_remote_code, ) if hf_chat_template is None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a04ab885a72b..cebddcc8e6aa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,7 +13,8 @@ from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.config import CompilationConfig, ModelDType, TokenizerMode +from vllm.config import (CompilationConfig, ModelDType, TokenizerMode, + is_init_field) from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -204,9 +205,13 @@ def __init__( kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) if compilation_config is not None: - if isinstance(compilation_config, (int, dict)): - compilation_config_instance = CompilationConfig.from_cli( - str(compilation_config)) + if isinstance(compilation_config, int): + compilation_config_instance = CompilationConfig( + level=compilation_config) + elif isinstance(compilation_config, dict): + predicate = lambda x: is_init_field(CompilationConfig, x[0]) + compilation_config_instance = CompilationConfig( + **dict(filter(predicate, compilation_config.items()))) else: compilation_config_instance = compilation_config else: @@ -726,11 +731,11 @@ def chat( tokenizer = self.get_tokenizer(lora_request) model_config = self.llm_engine.get_model_config() resolved_content_format = resolve_chat_template_content_format( + model_config, chat_template, tools, chat_template_content_format, tokenizer, - trust_remote_code=model_config.trust_remote_code, ) _chat_template_kwargs: dict[str, Any] = dict( @@ -762,8 +767,8 @@ def chat( ) else: prompt_str = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=model_config.trust_remote_code, conversation=conversation, **_chat_template_kwargs, ) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 9746d9697a66..e034eacb24ef 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -41,7 +41,8 @@ resolve_mistral_chat_template) from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.cli_args import (make_arg_parser, +from vllm.entrypoints.openai.cli_args import (log_non_default_args, + make_arg_parser, validate_parsed_serve_args) # yapf conflicts with isort for this block # yapf: disable @@ -936,10 +937,11 @@ async def init_app_state( chat_template=resolved_chat_template) else: hf_chat_template = resolve_hf_chat_template( + vllm_config.model_config, tokenizer, chat_template=None, tools=None, - trust_remote_code=model_config.trust_remote_code) + ) if hf_chat_template != resolved_chat_template: logger.warning( @@ -1040,7 +1042,7 @@ def create_server_socket(addr: tuple[str, int]) -> socket.socket: async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) - logger.info("args: %s", args) + log_non_default_args(args) if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a2639d374791..d8cec2202134 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -17,8 +17,11 @@ from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.logger import init_logger from vllm.utils import FlexibleArgumentParser +logger = init_logger(__name__) + class LoRAParserAction(argparse.Action): @@ -285,6 +288,15 @@ def validate_parsed_serve_args(args: argparse.Namespace): "--tool-call-parser") +def log_non_default_args(args: argparse.Namespace): + non_default_args = {} + parser = make_arg_parser(FlexibleArgumentParser()) + for arg, default in vars(parser.parse_args([])).items(): + if default != getattr(args, arg): + non_default_args[arg] = getattr(args, arg) + logger.info("non-default args: %s", non_default_args) + + def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( prog="-m vllm.entrypoints.openai.api_server") diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 40e477f03194..aa01e785f21a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,10 +5,11 @@ import json import re import time +from http import HTTPStatus from typing import Annotated, Any, ClassVar, Literal, Optional, Union import torch -from fastapi import UploadFile +from fastapi import HTTPException, UploadFile from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, ValidationInfo, field_validator, model_validator) from typing_extensions import TypeAlias @@ -409,7 +410,7 @@ class ChatCompletionRequest(OpenAIBaseModel): "repetition_penalty": 1.0, "temperature": 1.0, "top_p": 1.0, - "top_k": -1, + "top_k": 0, "min_p": 0.0, } @@ -853,7 +854,7 @@ class CompletionRequest(OpenAIBaseModel): "repetition_penalty": 1.0, "temperature": 1.0, "top_p": 1.0, - "top_k": -1, + "top_k": 0, "min_p": 0.0, } @@ -1679,7 +1680,7 @@ class TranscriptionRequest(OpenAIBaseModel): "repetition_penalty": 1.0, "temperature": 1.0, "top_p": 1.0, - "top_k": -1, + "top_k": 0, "min_p": 0.0, } @@ -1727,7 +1728,13 @@ def to_sampling_params( @model_validator(mode="before") @classmethod - def validate_stream_options(cls, data): + def validate_transcription_request(cls, data): + if isinstance(data.get("file"), str): + raise HTTPException( + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + detail="Expected 'file' to be a file-like object, not 'str'.", + ) + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 25069c28a0a2..bb11650815ec 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -394,11 +394,11 @@ async def _preprocess_chat( model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( + model_config, chat_template, tool_dicts, chat_template_content_format, tokenizer, - trust_remote_code=model_config.trust_remote_code, ) conversation, mm_data_future = parse_chat_messages_futures( messages, @@ -425,8 +425,8 @@ async def _preprocess_chat( ) else: request_prompt = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=model_config.trust_remote_code, conversation=conversation, **_chat_template_kwargs, ) diff --git a/vllm/envs.py b/vllm/envs.py index ea40bfff11b5..d7f332cb0a73 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -84,6 +84,7 @@ VLLM_ROCM_FP8_PADDING: bool = True VLLM_ROCM_MOE_PADDING: bool = True VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_QUARK_EMU_MEM_OPT: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False @@ -110,6 +111,7 @@ VLLM_USE_DEEP_GEMM: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 + VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False def get_default_cache_root(): @@ -261,6 +263,10 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + # Internal flag to enable/disable Inductor standalone compile + "VLLM_TEST_STANDALONE_COMPILE": + lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0", + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": @@ -583,6 +589,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")), + # If set, when running in Quark emulation mode, do not dequantize the + # weights at load time. Instead, dequantize weights on-the-fly during + # kernel execution. + # This allows running larger models at the cost of slower inference. + # This flag has no effect when not running in Quark emulation mode. + "VLLM_QUARK_EMU_MEM_OPT": + lambda: bool(int(os.getenv("VLLM_QUARK_EMU_MEM_OPT", "0"))), + # Divisor for dynamic query scale factor calculation for FP8 KV Cache "Q_SCALE_CONSTANT": lambda: int(os.getenv("Q_SCALE_CONSTANT", "200")), @@ -727,6 +741,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: # limit will actually be zero-copy decoded. "VLLM_MSGPACK_ZERO_COPY_THRESHOLD": lambda: int(os.getenv("VLLM_MSGPACK_ZERO_COPY_THRESHOLD", "256")), + + # If set, allow insecure serialization using pickle. + # This is useful for environments where it is deemed safe to use the + # insecure method and it is needed for some reason. + "VLLM_ALLOW_INSECURE_SERIALIZATION": + lambda: bool(int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))), } # end-env-vars-definition @@ -789,6 +809,7 @@ def factorize(name: str): "VLLM_USE_TRITON_AWQ", "VLLM_DP_RANK", "VLLM_DP_SIZE", + "VLLM_TEST_STANDALONE_COMPILE", ] for key in environment_variables_to_hash: if key in environment_variables: diff --git a/vllm/logging_utils/dump_input.py b/vllm/logging_utils/dump_input.py new file mode 100644 index 000000000000..169e24794095 --- /dev/null +++ b/vllm/logging_utils/dump_input.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import enum +import json +from typing import Optional + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.metrics.stats import SchedulerStats +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) + + +def prepare_object_to_dump(obj) -> str: + if isinstance(obj, str): + return "'{obj}'" # Double quotes + elif isinstance(obj, dict): + dict_str = ', '.join({f'{str(k)}: {prepare_object_to_dump(v)}' \ + for k, v in obj.items()}) + return f'{{{dict_str}}}' + elif isinstance(obj, list): + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" + elif isinstance(obj, set): + return f"[{', '.join([prepare_object_to_dump(v) for v in list(obj)])}]" + # return [prepare_object_to_dump(v) for v in list(obj)] + elif isinstance(obj, tuple): + return f"[{', '.join([prepare_object_to_dump(v) for v in obj])}]" + elif isinstance(obj, enum.Enum): + return repr(obj) + elif isinstance(obj, torch.Tensor): + # We only print the 'draft' of the tensor to not expose sensitive data + # and to get some metadata in case of CUDA runtime crashed + return (f"Tensor(shape={obj.shape}, " + f"device={obj.device}," + f"dtype={obj.dtype})") + elif hasattr(obj, 'anon_repr'): + return obj.anon_repr() + elif hasattr(obj, '__dict__'): + items = obj.__dict__.items() + dict_str = ','.join([f'{str(k)}={prepare_object_to_dump(v)}' \ + for k, v in items]) + return (f"{type(obj).__name__}({dict_str})") + else: + # Hacky way to make sure we can serialize the object in JSON format + try: + return json.dumps(obj) + except (TypeError, OverflowError): + return repr(obj) + + +def dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): + # NOTE: ensure we can log extra info without risking raises + # unexpected errors during logging + with contextlib.suppress(BaseException): + _dump_engine_exception(config, scheduler_output, scheduler_stats) + + +def _dump_engine_exception(config: VllmConfig, + scheduler_output: SchedulerOutput, + scheduler_stats: Optional[SchedulerStats]): + logger.error("Dumping input data") + + logger.error( + "V1 LLM engine (v%s) with config: %s, ", + VLLM_VERSION, + config, + ) + + try: + dump_obj = prepare_object_to_dump(scheduler_output) + logger.error("Dumping scheduler output for model execution:") + logger.error(dump_obj) + if scheduler_stats: + logger.error(scheduler_stats) + except BaseException as exception: + logger.error("Error preparing object to dump") + logger.error(repr(exception)) diff --git a/vllm/logits_process.py b/vllm/logits_process.py index e3faf20029ec..29a73656bf65 100644 --- a/vllm/logits_process.py +++ b/vllm/logits_process.py @@ -4,11 +4,12 @@ import torch -from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer -LogitsProcessor = Union[Callable[[list[int], torch.Tensor], torch.Tensor], - Callable[[list[int], list[int], torch.Tensor], - torch.Tensor]] +LogitsProcessor = Union[ + Callable[[list[int], torch.Tensor], torch.Tensor], + Callable[[list[int], list[int], torch.Tensor], torch.Tensor], +] """LogitsProcessor is a function that takes a list of previously generated tokens, the logits tensor for the next token and, optionally, prompt tokens as a @@ -29,12 +30,8 @@ def get_bad_words_logits_processors( prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - if isinstance(tokenizer, MistralTokenizer): - # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(text=prompt) - else: - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 41e1ec94145d..e195f8cf5e8e 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -16,6 +16,7 @@ MergedQKVParallelLinearWithLoRA, QKVParallelLinearWithLoRA, RowParallelLinearWithLoRA) +from vllm.platforms import current_platform if TYPE_CHECKING: pass @@ -57,15 +58,25 @@ def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA): device=x.device, ) - layer.punica_wrapper.add_shrink(buffers, x, layer.lora_a_stacked, 1.0) + shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink( + buffers, x, layer.lora_a_stacked, 1.0) + + if not current_platform.can_update_inplace(): + buffers = shrunk_buffers + buffers = tensor_model_parallel_all_gather(buffers) - layer.punica_wrapper.add_expand(output, - buffers, - layer.lora_b_stacked, - layer.lora_bias_stacked, - layer.output_slices, - offset_start=0, - add_input=True) + + lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand( + output, + buffers, + layer.lora_b_stacked, + layer.lora_bias_stacked, + layer.output_slices, + offset_start=0, + add_input=True) + + if not current_platform.can_update_inplace(): + output = lora_output output = output.view(*out_orig_shape) # now have column partitioned and packed output @@ -292,7 +303,11 @@ def apply(self, device=x.device, ) - self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink( + buffer, x, self.lora_a_stacked, 1.0) + if not current_platform.can_update_inplace(): + buffer = shrunk_buffer + buffer = tensor_model_parallel_all_reduce(buffer) # following S-LoRA, allows the fusing of all_gather and all_reduce @@ -304,7 +319,7 @@ def apply(self, # NOTE offset are based on the rank. shard_size = self.lora_b_stacked[0].shape[2] offset_start = self.tp_rank * shard_size - self.punica_wrapper.add_expand( + lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand( output, buffer, self.lora_b_stacked, @@ -313,6 +328,10 @@ def apply(self, offset_start=offset_start, add_input=True, ) + + if not current_platform.can_update_inplace(): + output = lora_output + output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index d9de0f3cfeb3..6749ec16a097 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -261,10 +261,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_lora_a_embeddings.shape[1], -1, ) - self.punica_wrapper.add_lora_embedding(full_output, - full_lora_a_embeddings, - self.lora_b_stacked, - add_input=True) + + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_embedding( + full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + + if not current_platform.can_update_inplace(): + full_output = lora_output + return full_output.view_as(full_output_org) @classmethod @@ -410,10 +417,13 @@ def apply(self, output = output.flatten(0, 1) x = x.flatten(0, 1) - self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, - self.lora_b_stacked, - self.lora_bias_stacked, 1.0, - self.output_slices) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_linear( + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.lora_bias_stacked, 1.0, self.output_slices) + if not current_platform.can_update_inplace(): + output = lora_output + return output @property @@ -1133,15 +1143,23 @@ def _get_logits( torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) - lora_logits[-1] = float("-inf") + + neg_inf, pos_inf = current_platform.get_infinity_values( + lora_logits.dtype) + + lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded + + if current_platform.is_tpu(): + indices_padded = indices_padded[:logits.size(0)] + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], - ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), - posinf=float("inf"), - neginf=float("-inf"))) + ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, + posinf=pos_inf, + neginf=neg_inf)) # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): @@ -1151,10 +1169,13 @@ def _get_logits( self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits - # LogitsProcessorWithLoRA always using bgmv - self.punica_wrapper.add_lora_logits(logits, hidden_states, - self.lora_a_stacked, - self.lora_b_stacked, 1.0) + lora_output: Optional[ + torch.Tensor] = self.punica_wrapper.add_lora_logits( + logits, hidden_states, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + if not current_platform.can_update_inplace(): + logits = lora_output # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 81e0741a03cf..9f9d808679d7 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -344,7 +344,7 @@ def __init__( self.supported_lora_modules = get_supported_lora_modules(self.model) assert self.supported_lora_modules, "No supported LoRA modules found in" - f"{self.model.__class__.__name__}." + f" {self.model.__class__.__name__}." if lora_config.long_lora_scaling_factors: # We need to replace rotary emb layer to do batch computation # for long lora. diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py new file mode 100644 index 000000000000..94062b05d916 --- /dev/null +++ b/vllm/lora/ops/xla_ops/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 + +from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice, + bgmv_shrink) + +__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py new file mode 100644 index 000000000000..acbec0cfab9c --- /dev/null +++ b/vllm/lora/ops/xla_ops/lora_ops.py @@ -0,0 +1,106 @@ +# SPDX-License-Identifier: Apache-2.0 + +import torch + +# Required to register the custom ops +import vllm.lora.ops.xla_ops.pallas # noqa # pylint: disable=unused-import + + +def bgmv_expand(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + n_tokens = outputs.size(0) + + limit = output_tensor.shape[0] + if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: + limit = 1 + + outputs = torch.cat( + (outputs, + torch.zeros((n_tokens, output_tensor.shape[1] - outputs.shape[1]), + device=outputs.device)), + dim=1) + + if add_inputs: + return output_tensor + outputs[:limit, :] + else: + return outputs[:limit, :] + + +def bgmv_shrink(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + output_tensor (torch.Tensor): (Unused) output tensor (placeholder). + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + scaling (float, optional): Scalar multiplier applied to the output. + """ + + return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, + lora_indices_tensor) + + +def bgmv_expand_slice(inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True): + """ + Args: + inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. + + lora_b_weights (torch.Tensor): LoRA weights of shape + [num_loras, lora_rank, hidden_size]. + + output_tensor (torch.Tensor): output tensor of shape + [num_tokens, hidden_size * num_slices]. + + lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] + indicating which LoRA matrix to use for each token. + add_inputs (bool): Whether or not to add the input tensor to the output + tensor. + """ + outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) + n_tokens = outputs.size(0) + + outputs = torch.cat(( + torch.zeros((n_tokens, slice_offset), device=outputs.device), + outputs, + torch.zeros( + (n_tokens, output_tensor.shape[1] - (slice_offset + slice_size)), + device=outputs.device), + ), + dim=1) + + if add_inputs: + return output_tensor + outputs + else: + return outputs diff --git a/vllm/lora/ops/xla_ops/pallas.py b/vllm/lora/ops/xla_ops/pallas.py new file mode 100644 index 000000000000..35dc307539bf --- /dev/null +++ b/vllm/lora/ops/xla_ops/pallas.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +import functools + +import jax +import jax.numpy as jnp +import torch +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from torch.library import impl +from torch_xla.experimental.custom_kernel import (XLA_LIB, jax_import_guard, + make_kernel_from_pallas) + +# TODO: Tune these +TOKENS_BLOCK = 16 +LORA_RANK_BLOCK = 128 +DIM_BLOCK_SIZE = 128 + + +def _bgmv_kernel(bT: int, bL: int, idx_ref, inp_ref, lora_ref, out_ref, + acc_ref, mask_ref): + + @pl.when(pl.program_id(2) == 0) + def _(): + acc_ref[...] = jnp.zeros_like(acc_ref[...], dtype=jnp.float32) + + t = pl.program_id(0) + + for i in range(bT): + idx = idx_ref[i + bT * t] + mask_ref[...] = jnp.zeros_like(mask_ref[...], dtype=jnp.float32) + mask_ref[i, :] = jnp.ones((bL, ), dtype=jnp.float32) + + acc_ref[...] += jax.lax.dot_general( + inp_ref[...], + lora_ref[idx, ...], (((1, ), (1, )), ((), ())), + preferred_element_type=jnp.float32) * mask_ref[...] + + @pl.when(pl.program_id(2) == pl.num_programs(2) - 1) + def _(): + out_ref[...] = acc_ref[...].astype(out_ref.dtype) + + +@jax.jit +def _bgmv( + idxs: jax.Array, # (T, ) int32 + inputs: jax.Array, # (T, D) model dtype + loras: jax.Array # (N, L, D) model dtype +) -> jax.Array: # (T, L) model dtype + T, D = inputs.shape + N, L, _ = loras.shape + + return pl.pallas_call( + kernel=functools.partial(_bgmv_kernel, TOKENS_BLOCK, LORA_RANK_BLOCK), + out_shape=jax.ShapeDtypeStruct((T, L), dtype=inputs.dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(T // TOKENS_BLOCK, L // LORA_RANK_BLOCK, + D // DIM_BLOCK_SIZE), + in_specs=[ + pl.BlockSpec((TOKENS_BLOCK, DIM_BLOCK_SIZE), + lambda i, j, k, block_idx: (i, k)), + pl.BlockSpec((N, LORA_RANK_BLOCK, DIM_BLOCK_SIZE), + lambda i, j, k, block_idx: (0, j, k)), + ], + out_specs=pl.BlockSpec((TOKENS_BLOCK, LORA_RANK_BLOCK), + lambda i, j, k, block_idx: (i, j)), + scratch_shapes=[ + pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32), + pltpu.VMEM((TOKENS_BLOCK, LORA_RANK_BLOCK), jnp.float32) + ]), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + name="bgmv")(idxs, inputs, loras) + + +def bgmv_shape_function(idxs, inputs, loras): + T, _ = inputs.shape + _, L, _ = loras.shape + + return [((T, L), inputs.dtype)] + + +XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor", ) + + +@impl(XLA_LIB, "bgmv", "XLA") +def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): + inputs = inputs.to(dtype=loras.dtype) + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + jax_import_guard() + kernel = make_kernel_from_pallas(_bgmv, bgmv_shape_function) + + T, _ = inputs.shape + _, L, D = loras.shape + + # Pad the loras' rank if it's too low. This is to allow it to fit in a TPU + # register. This has to happen in pytorch, doing it in Jax will lead to NaNs + L1 = L + if LORA_RANK_BLOCK > L or L % LORA_RANK_BLOCK != 0: + L1 = (L // LORA_RANK_BLOCK + 1) * LORA_RANK_BLOCK + + D1 = D + if DIM_BLOCK_SIZE > D or D % DIM_BLOCK_SIZE != 0: + D1 = (D // DIM_BLOCK_SIZE + 1) * DIM_BLOCK_SIZE + + T1 = T + if TOKENS_BLOCK > T or T % TOKENS_BLOCK != 0: + T1 = (T // TOKENS_BLOCK + 1) * TOKENS_BLOCK + + if D1 != D or L1 != L: + loras = torch.nn.functional.pad(loras, (0, D1 - D, 0, L1 - L, 0, 0)) + if D1 != D or T1 != T: + inputs = torch.nn.functional.pad(inputs, (0, D1 - D, 0, T1 - T)) + if T1 != T: + idxs = torch.nn.functional.pad(idxs, ((0, T1 - T))) + + return kernel(idxs, inputs, loras)[:T, :L] + + +@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") +def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, + idxs: torch.IntTensor): + T, _ = inputs.shape + + if len(loras.shape) == 4: + loras = loras.squeeze(axis=1) + + _, L, _ = loras.shape + + return torch.empty((T, L), device=inputs.device) diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 94fa3f27ab60..78866c51895b 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -48,7 +48,7 @@ def add_shrink( lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. """ @@ -66,7 +66,7 @@ def add_expand( offset_start: int = 0, add_inputs=True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. """ @@ -80,7 +80,7 @@ def add_lora_embedding( lora_b_stacked: torch.Tensor, add_inputs: bool = True, **kwargs, - ) -> None: + ) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA, and this layer only requires the expand operation. @@ -98,7 +98,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. """ @@ -114,7 +114,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. """ @@ -207,7 +207,8 @@ def _update_base_metadata( self._long_lora_indices.zero_() self.indices_len[:] = indices_len - def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, batch_size, max_length, token_nums, @@ -334,7 +335,7 @@ def update_metadata( long_lora_context) if mapping.is_prefill: # Update metadata required for prefill-related operators. - self._update_prefill_metada(self.token_lora_indices) + self._update_prefill_metadata(self.token_lora_indices) self.is_prefill = True else: self.is_prefill = False @@ -342,7 +343,7 @@ def update_metadata( @abstractmethod def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], - scale: float, **kwargs) -> None: + scale: float, **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM for multiple slices of lora_a. @@ -369,7 +370,7 @@ def add_expand(self, output_slices: Tuple[int, ...], offset_start: int = 0, add_inputs=True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Performs GEMM and bias addition for multiple slices of lora_b. @@ -401,7 +402,7 @@ def add_lora_embedding(self, x: torch.Tensor, lora_b_stacked: torch.Tensor, add_inputs: bool = True, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for VocabParallelEmbeddingWithLoRA. and this layer only requires the expand operation. @@ -428,7 +429,7 @@ def add_lora_linear(self, output_slices: Tuple[int, ...], *, buffer: Optional[Tuple[torch.Tensor, ...]] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applicable to linear-related lora. @@ -463,7 +464,7 @@ def add_lora_logits(self, scale, *, buffer: Optional[torch.Tensor] = None, - **kwargs) -> None: + **kwargs) -> Optional[torch.Tensor]: """ Applies lora specifically for LogitsProcessorWithLoRA. diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py new file mode 100644 index 000000000000..37544c755d90 --- /dev/null +++ b/vllm/lora/punica_wrapper/punica_tpu.py @@ -0,0 +1,325 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F + +from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink + +from .punica_base import PunicaWrapperBase + + +class PunicaWrapperTPU(PunicaWrapperBase): + """ + PunicaWrapperTPU is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the pytorch punica ops. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + + # PunicaWrapperBase defines some tensors with dtype=torch.int64, which + # isn't supported by the TPU. So convert those tensors to int32. + # Not all of them are used by the TPU so only convert the useful ones. + self._token_lora_indices = self._token_lora_indices.to( + dtype=torch.int32) + self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) + self._sampler_indices_padded = self._sampler_indices_padded.to( + dtype=torch.int32) + + torch._dynamo.mark_dynamic(self._token_lora_indices, 0) + torch._dynamo.mark_dynamic(self._embeddings_indices, 1) + torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) + + def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: + return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + return self._embeddings_indices[:] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + return self._sampler_indices_padded[:] + + def shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + if self.no_lora: + return y + return bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), + scale) + + def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + add_inputs: bool): + return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), + add_inputs) + + def expand_slice(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, y_offset: int, y_slice_size: int, + y_total_size: int, add_inputs: bool) -> torch.Tensor: + return bgmv_expand_slice(x, w_t_all, y, + self._get_token_lora_indices(x), y_offset, + y_slice_size, add_inputs) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs) -> Optional[torch.Tensor]: + """ + Performs GEMM for multiple slices of lora_a. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + torch.ops.xla.dynamo_set_buffer_donor_(y, True) + x = x.view(-1, x.shape[-1]) + + for slice_idx in range(len(lora_a_stacked)): + y_s = y[slice_idx] + lora_s = lora_a_stacked[slice_idx] + y_s = self.shrink(y_s, x, lora_s, scale) + y[slice_idx, :, :] = y_s # type: ignore[index] + return y + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> torch.Tensor: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = 0 + + if lora_bias_stacked is not None: + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + y = self.expand_slice( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + y_total_size=sum(output_slices), + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + return y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only needs the expand op + return self.expand(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> torch.Tensor: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will not be changed in-place. + x (torch.Tensor): Input tensor (T, E) + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self._get_token_lora_indices(y), y, + output_slices, lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default, consistent with the + # triton op + T = x.size(0) + buffer = torch.zeros( + (len(output_slices), T, r), + dtype=torch.float32, + device=x.device, + ) + buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + return self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + if self.no_lora: + return y + + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default, consistent with the + # triton op + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + buffer = bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, + scale) + y = bgmv_expand(buffer, + lora_b_stacked, + y, + self.sampler_indices, + add_inputs=True) + return y.view_as(y_org) + + def _apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + lora_bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + lora_bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = lora_bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias = torch.where(indices[:, None] == -1, 0, bias) + + bias = F.pad(bias, (offset_left, output.shape[1] - + (offset_left + slice), 0, 0)) + + output += bias + offset_left += slice + + return output.view_as(org_output) + + def _update_prefill_metadata(self, + token_lora_tensor: torch.Tensor) -> None: + self.batch_size = 1 + self._lora_indices_per_batch[:self.batch_size].copy_( + token_lora_tensor[:self.batch_size]) + # TODO: .item() is extremely inefficient on TPU, so find a way around it + self.no_lora = torch.all(token_lora_tensor == -1).item() diff --git a/vllm/lora/punica_wrapper/utils.py b/vllm/lora/punica_wrapper/utils.py index dbc2d27c597f..f4e5542b177d 100644 --- a/vllm/lora/punica_wrapper/utils.py +++ b/vllm/lora/punica_wrapper/utils.py @@ -125,11 +125,13 @@ def convert_mapping( indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), ]) - embeddings_indices[embeddings_indices == -1] = max_loras - 1 + embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1, + embeddings_indices) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() - sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.where(sampler_indices_padded == -1, + max_loras - 1, sampler_indices_padded) sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 883ca938ea1a..01064e5d007e 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -117,16 +117,18 @@ def parse_fine_tuned_lora_name( # LoRA weight qualified name usually starts with `base_model.model.`, # so we remove the prefix `base_model.model.` to make the following # mapping correctly. - if "base_model.model." in name: + if name.startswith("base_model.model."): name = name.replace("base_model.model.", "") name = weights_mapper._map_name(name) if weights_mapper else name # recover the prefix `base_model.model.` name = "base_model.model." + name + else: + name = weights_mapper._map_name(name) if weights_mapper else name # In some situations, we may not start with `base_model.model.`. # If we don't (e.g., ibm-granite/granite-speech-3.3-8b), # we should keep the prefix intact. - start_index = 2 if "base_model.model." in name else 0 + start_index = 2 if name.startswith("base_model.model.") else 0 parts = name.split(".") if parts[-1] == "weight" and (parts[-2] == "lora_A" diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 9829ccdb384f..53e7769b2042 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -36,7 +36,7 @@ def get_config() -> Optional[Dict[str, Any]]: import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.cutlass_moe import ( - cutlass_moe_fp8) + cutlass_moe_fp4, cutlass_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk) @@ -48,4 +48,5 @@ def get_config() -> Optional[Dict[str, Any]]: "get_config_file_name", "grouped_topk", "cutlass_moe_fp8", + "cutlass_moe_fp4", ] diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json new file mode 100644 index 000000000000..e1c4cac9c826 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json b/vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json new file mode 100644 index 000000000000..5a9910a4d37e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index 960c7f834857..1b34e952208a 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -"""Fused MoE kernel.""" +""" CUTLASS based Fused MoE kernels.""" from typing import Optional import torch from vllm import _custom_ops as ops +from vllm.scalar_type import scalar_types #TODO make the grouped gemm kernel consistent with scaled gemm kernel @@ -178,3 +179,126 @@ def cutlass_moe_fp8( if not apply_router_weight_on_input: c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype) return c2.sum(dim=1) + + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max +MAX_TOKENS_PER_EXPERT = 65536 + + +def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, topk_weights: torch.Tensor, + topk_ids: torch.Tensor, m: int, n: int, k: int, e: int, + device: torch.device): + """ + MoE implementation for FP4 Inputs + + # Gemm 1 + a: Input tensor: [m, k] (half/bfloat16) + a1_gscale: Activation scale per expert: [e] (float32) + w1(gate up) (not an argument to cutlass_moe_fp4): [e, 2 * n, k] + w1_fp4: [e, 2 * n, k // 2], dtype: torch.uint8 (stacked fp4: E2M1) + (Note: `n` is the up projection output dim, `k` is the input dim in + full precision) + w1_blockscale: [e, 2 * n, k // block_size] (float8_e4m3) + (Block size = 16 for NVFP4) + + # Gemm 2 + a2_gscale: Activation scale per expert: [e] + w2(down projection) (not an argument to cutlass_moe_fp4): [e, k, n] + w2_fp4: [e, k, n // 2], dtype: torch.uint8 (stacked E2M1) + w2_blockscale: [e, k, n // block_size], dtype: float8_e4m3 + + topk_weights: [m, topk] dtype: float8 + topk_ids: [m, topk] dtype: float8 + + m, n, k: Unquantized weight shapes, dtype: int + e: number of experts, dtype: int + + assumes that topk < k < n to satisfy - up/down projection expectations. + """ + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert w1_fp4.dtype == torch.uint8, "weight 1 must be uint8" + assert w2_fp4.dtype == torch.uint8, "weight 2 must be uint8" + assert (w1_fp4.ndim == 3 and w2_fp4.ndim == 3 and w1_blockscale.ndim == 3 + and w2_blockscale.ndim + == 3), ("All Weights must be of rank 3 for cutlass_moe_fp4") + m_a, k_a = a.shape + e_w1, nx2_w1, half_k_w1 = w1_fp4.shape + e_w2, k_w2, half_n_w2 = w2_fp4.shape + + assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match", + " between weights.") + assert (k_a // 2 == half_k_w1 + and k == k_w2), ("Hidden size mismatch between a, w1 and w2") + assert (nx2_w1 == n * 2 and half_n_w2 == n // 2), ("mismatch in " + "expected `n`") + assert (m == m_a), "input shape mismatch" + assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" + assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" + assert (topk_weights.shape[0] == m and topk_ids.shape[0] + == m), ("topk must be provided for each row of a") + assert (m <= MAX_TOKENS_PER_EXPERT), ( + f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})" + f" for cutlass_moe_fp4, observed m = {m}") + out_dtype = a.dtype + num_topk = topk_ids.shape[1] + + expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,2n,k)) + problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device) + # Problem size: (num_experts, (m,n,k)) + problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device) + + a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) + + # problem shapes should have [m, n, k] + # Note that problem sizes are based on logical number of elements. + ops.get_cutlass_moe_mm_data(topk_ids, expert_offsets, problem_sizes1, + problem_sizes2, a_map, c_map, e, n, k) + + tokens_per_expert = problem_sizes1[:, 0] + rounded_tokens_per_expert = (tokens_per_expert + (128 - 1)) // 128 * 128 + blockscale_offsets = torch.zeros(e + 1, dtype=torch.int32, device=device) + blockscale_offsets[1:] = torch.cumsum(rounded_tokens_per_expert, dim=0) + + rep_a_fp4, rep_a_blockscale = ops.scaled_fp4_experts_quant( + a, + a1_gscale, + expert_offsets, + blockscale_offsets, + num_topk, + expert_map=a_map, + MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT) + + c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, + w1_blockscale, w1_alphas, problem_sizes1, + expert_offsets[:-1], blockscale_offsets[:-1], + out_dtype, device) + del rep_a_fp4, rep_a_blockscale + # hidden size dimension is split to one halfpytho sized tensor. + intermediate = torch.empty((m * num_topk, w1_fp4.shape[1] // 2), + device=device, + dtype=out_dtype) + + torch.ops._C.silu_and_mul(intermediate, c1) + + int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( + intermediate, + a2_gscale, + expert_offsets, + blockscale_offsets, + num_topk, + MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT) + + c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, + w2_alphas, problem_sizes2, expert_offsets[:-1], + blockscale_offsets[:-1], out_dtype, device) + del int_fp4, int_blockscale + out = (c2[c_map].view(m, num_topk, k) * + topk_weights.view(m, num_topk, 1).half()).sum(dim=1) + return out.to(dtype=out_dtype) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 35994c8ac6af..5337ff0037da 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -643,7 +643,7 @@ def weight_loader(self, param: torch.nn.Parameter, expert_id = self._map_global_expert_id_to_local_expert_id(expert_id) if expert_id == -1: return - + quant_method_name = self.quant_method.__class__.__name__ # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -697,8 +697,9 @@ def weight_loader(self, param: torch.nn.Parameter, # this is needed for compressed-tensors only loaded_weight = loaded_weight.to(param.data.device) - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: + if ("compressed" in quant_method_name.lower() + and param.data[expert_id] != 1 + and (param.data[expert_id] - loaded_weight).abs() > 1e-5): raise ValueError( "input_scales of w1 and w3 of a layer " f"must be equal. But got {param.data[expert_id]} " @@ -718,6 +719,22 @@ def weight_loader(self, param: torch.nn.Parameter, tp_rank=self.tp_rank) return + if "ModelOpt" in quant_method_name: + if ('weight_scale_2' in weight_name + or 'input_scale' in weight_name): + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + elif "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) + return + # Case weight scales, zero_points and offset if ("scale" in weight_name or "zero" in weight_name or "offset" in weight_name): diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index acaa93f5a23e..7d7bce9ec6ab 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -145,7 +145,7 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake( block_shape: List[int], smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - return torch.empty_like(a1, dtype=torch.bf16) + return torch.empty_like(a1, dtype=hidden_states_dtype) def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ae16a20cfaab..4a3fc2a1a6b9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,7 +2,7 @@ import enum from enum import Enum -from typing import Callable, List, Optional +from typing import Callable, Optional import torch from compressed_tensors import CompressionFormat @@ -14,9 +14,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa + WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP) from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_moe_marlin_supports_layer, marlin_make_workspace_new, + marlin_moe_permute_scales) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs @@ -54,18 +57,19 @@ def get_moe_method( "input_activations") if quant_config._is_wNa16_group_channel(weight_quant, input_quant): - # Prefer to use the non-marlin kernel when: - # 1. Many experts (MarlinMoE gives poor performance when >= 16) - # 2. Non-FP16 dtype (MarlinMoE only supports FP16) - # 3. Actorder is not group/dynamic (g_idx is unsupported) - # 4. Scaled are grouped (channelwise is unsupported) - if ((layer.local_num_experts >= 16 - or layer.params_dtype != torch.float16) and - weight_quant.actorder not in (ActivationOrdering.GROUP, - ActivationOrdering.DYNAMIC) - and weight_quant.strategy in QuantizationStrategy.GROUP): + # Prefer to use the MarlinMoE kernel when it is supported. + if not check_moe_marlin_supports_layer(layer, + weight_quant.group_size): + if (weight_quant.strategy in QuantizationStrategy.GROUP and + weight_quant.actorder in (ActivationOrdering.GROUP, + ActivationOrdering.DYNAMIC)): + raise ValueError( + "WNA16MoE is not supported with actorder=group/dynamic." + ) + logger.info_once("Using CompressedTensorsWNA16MoEMethod") return CompressedTensorsWNA16MoEMethod(quant_config) else: + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod(quant_config) elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) and layer.activation == "silu"): @@ -705,15 +709,12 @@ def __init__( f"{CompressionFormat.pack_quantized.value} ", "is supported for the following bits: ", f"{WNA16_SUPPORTED_BITS}") + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): - assert params_dtype == torch.float16, ( - "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501 - ) - intermediate_size_full = extra_weight_attrs.pop( "intermediate_size_full") @@ -837,50 +838,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.marlin_state = GPTQMarlinState.REPACK def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - - def replace_tensor(name, new_t): - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def get_scale_perms(num_bits: int): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, - group_size: int, num_bits: int): - scale_perm, scale_perm_single = get_scale_perms(num_bits) - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, - scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - return s - - def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, - size_n: int, group_size: int, - num_bits: int): - num_experts = s.shape[0] - output = torch.empty((num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype) - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, - group_size, num_bits) - return output - - size_k2 = layer.w2_weight_packed.shape[2] - size_k13 = layer.w13_weight_packed.shape[2] - num_experts = layer.w13_weight_g_idx.shape[0] device = layer.w13_weight_g_idx.device @@ -938,7 +895,7 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, layer.w13_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w13_weight_packed", marlin_w13_qweight) + replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_weight_packed, layer.w2_g_idx_sort_indices, @@ -946,25 +903,25 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, layer.w2_weight_packed.shape[2], self.num_bits, ) - replace_tensor("w2_weight_packed", marlin_w2_qweight) + replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( - layer.w13_weight_scale, - size_k13, - layer.w13_weight_scale.shape[2], - self.group_size, - self.num_bits, + s=layer.w13_weight_scale, + size_k=layer.w13_weight_packed.shape[2], + size_n=layer.w13_weight_scale.shape[2], + group_size=self.group_size, ) - replace_tensor("w13_weight_scale", marlin_w13_scales) + replace_parameter(layer, "w13_weight_scale", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( - layer.w2_weight_scale, - layer.w2_weight_scale.shape[1] * + s=layer.w2_weight_scale, + size_k=layer.w2_weight_scale.shape[1] * (self.group_size if self.group_size != -1 else self.packed_factor), - size_k2, - self.group_size, - self.num_bits, + size_n=layer.w2_weight_scale.shape[2], + group_size=self.group_size, ) - replace_tensor("w2_weight_scale", marlin_w2_scales) + replace_parameter(layer, "w2_weight_scale", marlin_w2_scales) + + layer.workspace = marlin_make_workspace_new(device, 4) def apply( self, @@ -985,10 +942,6 @@ def apply( activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " - "fused Marlin MoE method.") if apply_router_weight_on_input: raise NotImplementedError( "Apply router weight on input is not supported for " @@ -1015,11 +968,14 @@ def apply( router_logits, topk_weights, topk_ids, + quant_type_id=self.quant_type.id, + global_num_experts=global_num_experts, + expert_map=expert_map, g_idx1=layer.w13_weight_g_idx, g_idx2=layer.w2_weight_g_idx, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, - num_bits=self.num_bits, + workspace=layer.workspace, is_k_full=self.is_k_full) @@ -1203,7 +1159,7 @@ def apply( activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts - assert activation == "silu", "Only SiLU activation is supported." + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -1223,6 +1179,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_int4_w4a16=self.num_bits == 4, use_int8_w8a16=self.num_bits == 8, global_num_experts=global_num_experts, diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 5dff8b09693c..67723c7c91cc 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -124,11 +124,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # These are used in the final Attention.forward() layer._q_scale.copy_(q_scale) layer._prob_scale.copy_(prob_scale) - if q_scale == 1.0 or prob_scale == 1.0: + if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0 + or prob_scale == 1.0): logger.warning_once( - f"Using Q scale {q_scale} and prob scale {prob_scale} " - "with fp8 attention. This may cause accuracy issues. " - "Please make sure Q/prob scaling factors are " + f"Using uncalibrated q_scale {q_scale} and/or prob_scale " + f"{prob_scale} with fp8 attention. This may cause accuracy " + "issues. Please make sure q/prob scaling factors are " "available in the fp8 checkpoint.") del layer.k_scale diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 828447dd1019..e9b16b8a0acd 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import Module @@ -9,6 +9,8 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm, cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods @@ -210,25 +212,37 @@ def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config": "`hf_quant_config.json` file for your model's " "quant configuration.") is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) - kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] - group_size = quant_config["group_size"] - exclude_modules = quant_config["exclude_modules"] - if not (group_size and kv_cache_quant_algo and exclude_modules): + if ("group_size" and "kv_cache_quant_algo" + and "exclude_modules") not in quant_config: raise ValueError("NVFP4 quantization requires group size and " "kv_cache_quant_algo specified in " "hf_quant_config.json") + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + group_size = quant_config["group_size"] + exclude_modules = quant_config["exclude_modules"] return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, exclude_modules, group_size) + def is_layer_excluded(self, prefix: str, exclude_modules: List): + import re + for pattern in exclude_modules: + regex_str = pattern.replace('.', r'\.').replace('*', r'.*') + if re.fullmatch(regex_str, prefix): + return True + return False + def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.exclude_modules): + if (is_layer_skipped(prefix, self.exclude_modules) + or self.is_layer_excluded(prefix, self.exclude_modules)): return UnquantizedLinearMethod() return ModelOptNvFp4LinearMethod(self) elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) + elif isinstance(layer, FusedMoE): + return ModelOptNvFp4FusedMoE(self) return None @@ -409,3 +423,235 @@ def apply( if bias is not None: out = out + bias return out.view(*output_shape) + + +class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): + """ + MoE Method for FP4 Quantization. + Args: + quant_config: NVFP4 Quant Config + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + + layer.quant_config = self.quant_config + weight_dtype = torch.uint8 + weight_scale_dtype = torch.float8_e4m3fn + weight_loader = extra_weight_attrs.get("weight_loader") + # GEMM 1 + w13_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight", w13_weight) + + # GEMM 2 + w2_weight = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=weight_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight", w2_weight) + + w13_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = ModelWeightParameter( + data=torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // + self.quant_config.group_size, + dtype=weight_scale_dtype), + input_dim=1, + output_dim=2, + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) + + w13_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) + + w2_weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + + w13_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, 2, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w13_input_scale", w13_input_scale) + + w2_input_scale = PerTensorScaleParameter(data=torch.empty( + num_experts, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("w2_input_scale", w2_input_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # GEMM 1 + + assert torch.allclose( + layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1]), ( + "Expected w1_weight_scale_2 to equal w3_weight_scale_2") + + w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0] + layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, + requires_grad=False) + + w13_input_scale = layer.w13_input_scale.max(dim=1).values.to( + torch.float32) + layer.g1_alphas = Parameter( + (w13_input_scale * w13_weight_scale_2).to(torch.float32), + requires_grad=False) + + assert (layer.w13_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w13_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w13_blockscale_swizzled = self.swizzle_blockscale( + layer.w13_weight_scale) + + layer.w13_blockscale_swizzled = Parameter(w13_blockscale_swizzled, + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w13_input_scale_quant = Parameter( + (1 / w13_input_scale).to(torch.float32), requires_grad=False) + + # GEMM 2 + layer.g2_alphas = Parameter( + (layer.w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False) + + # This is for quantization, so we need to invert it. + layer.w2_input_scale_quant = Parameter( + (1 / layer.w2_input_scale).to(torch.float32), requires_grad=False) + + assert (layer.w2_weight_scale.shape[2] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.w2_weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Blockscale must be represented as FP8-E4M3") + w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) + + layer.w2_blockscale_swizzled = Parameter(w2_blockscale_swizzled, + requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + ): + assert activation == "silu", "Only SiLU activation is supported." + assert not apply_router_weight_on_input, ( + "Router weight on input is not " + "supported for ModelOptNvFp4FusedMoE.") + assert expert_map is None, ("Expert Parallelism /expert_map " + "is currently not supported for " + "ModelOptNvFp4FusedMoE.") + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp4) + + # Cutlass moe takes in activations in BF16/Half precision + # and fp4 quantized weights loaded from the checkpoint + return cutlass_moe_fp4(a=x, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_blockscale_swizzled, + w1_alphas=layer.g1_alphas, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_blockscale_swizzled, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=x.shape[0], + n=layer.w2_weight.shape[2] * 2, + k=x.shape[1], + e=layer.w13_weight.shape[0], + a1_gscale=layer.w13_input_scale_quant, + a2_gscale=layer.w2_input_scale_quant, + device=x.device).to(x.dtype) diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index da2312190084..66e677f56ffd 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -5,6 +5,7 @@ import torch +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -15,13 +16,15 @@ from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 QuarkMoEMethod) from vllm.model_executor.layers.quantization.quark.schemes import ( - QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) + QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8) from vllm.model_executor.layers.quantization.quark.utils import ( deep_compare, should_ignore_layer) from vllm.platforms import current_platform __all__ = ["QuarkLinearMethod"] +logger = init_logger(__name__) + class QuarkConfig(QuantizationConfig): @@ -67,6 +70,7 @@ def get_quant_method(self, layer: torch.nn.Module, return QuarkLinearMethod(self) if isinstance(layer, Attention): return QuarkKVCacheMethod(self) + if isinstance(layer, FusedMoE): return QuarkMoEMethod.get_moe_method(self, module=layer, @@ -205,6 +209,54 @@ def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], # Only symmetric weight quantization supported. return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + def _is_mx_fp4(self, weight_quant: Optional[Dict[str, Any]], + input_quant: Optional[Dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + logger.debug("Quark model is not in MX-FP4 format: " + "weight_quant or input_quant not set") + return False + + # Input and weight dtype needs to be fp4. + if weight_quant.get("dtype") != "fp4" or input_quant.get( + "dtype") != "fp4": + logger.debug("Quark model is not in MX-FP4 format: dtype not fp4") + return False + + # Input and weight qscheme needs to be per group. + if weight_quant.get("qscheme") != "per_group" or input_quant.get( + "qscheme") != "per_group": + logger.debug("Quark model is not in MX-FP4 format: not per_group") + return False + + # Input and weight group size needs to be 32. + if weight_quant.get("group_size") != 32 or input_quant.get( + "group_size") != 32: + logger.debug( + "Quark model is not in MX-FP4 format: not group_size=32") + return False + + # Weights need to use static quantization. + if weight_quant.get("is_dynamic") is True: + logger.debug( + "Quark model is not in MX-FP4 format: not weight static") + return False + + # Activations need to use dynamic quantization. + if input_quant.get("is_dynamic") is False: + logger.debug( + "Quark model is not in MX-FP4 format: not activation dynamic") + return False + + # Activations and weight scales need to be in e8m0 format. + if weight_quant.get("scale_format") != "e8m0" or input_quant.get( + "scale_format") != "e8m0": + logger.debug( + "Quark model is not in MX-FP4 format: not scale_format e8m0") + return False + + return True + def _find_matched_config(self, layer_name: str, module: torch.nn.Module) -> Dict[str, Any]: @@ -269,6 +321,8 @@ def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme": return QuarkW8A8Int8(qscheme=weight_qscheme, is_static_input_scheme=True, input_symmetric=input_config.get("symmetric")) + elif self._is_mx_fp4(weight_config, input_config): + return QuarkW4A4MXFP4(weight_config, input_config) raise NotImplementedError("No quark compatible scheme was found. " f"Weight config: {weight_config}, " diff --git a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py index 9069b5a0d515..d7dac17574ff 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 from .quark_scheme import QuarkScheme +from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4 from .quark_w8a8_fp8 import QuarkW8A8Fp8 from .quark_w8a8_int8 import QuarkW8A8Int8 -__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"] +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"] diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py new file mode 100644 index 000000000000..9da52a732fc4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, List, Optional + +import torch +import torch.nn.functional as F + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( + OCP_MX_BLOCK_SIZE, per_token_group_quant_mxfp4) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.platforms import current_platform + +__all__ = ["QuarkW4A4MXFP4"] + + +class QuarkW4A4MXFP4(QuarkScheme): + + def __init__(self, weight_quant_spec: Dict[str, Any], + input_quant_spec: Dict[str, Any]): + self.out_dtype = torch.get_default_dtype() + self.qscheme = "per_group" + self.weight_quant_spec = weight_quant_spec + self.input_quant_spec = input_quant_spec + self.emulate = not current_platform.supports_mx() + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + + if self.emulate: + try: + from quark.torch.export.nn.modules import realquantizer + from quark.torch.quantization.config.config import ( + QuantizationSpec) + except ImportError as err: + raise ImportError( + "The package `amd-quark` is required to use AMD Quark " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + weight_quant_spec = QuantizationSpec.from_dict( + self.weight_quant_spec) + + weight_quantizer = realquantizer.get_real_quantizer( + qspec=weight_quant_spec, + quantizer=None, + real_quantized=True, + reorder=False, + float_dtype=self.out_dtype, + scale_shape=layer.weight_scale.shape, + zero_point_shape=None, + ) + weight_quantizer.scale.data = layer.weight_scale.data + + if not envs.VLLM_QUARK_EMU_MEM_OPT: + layer.weight = torch.nn.Parameter( + weight_quantizer(layer.weight.data).to(self.out_dtype), + requires_grad=False, + ) + else: + self.weight_quantizer = weight_quantizer + layer.weight_scale = None + + # This call is necessary to release the scales memory. + torch.cuda.empty_cache() + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=2, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.emulate: + if envs.VLLM_QUARK_EMU_MEM_OPT: + dq_w = self.weight_quantizer(layer.weight).to(self.out_dtype) + else: + dq_w = layer.weight + qdq_x, _ = per_token_group_quant_mxfp4(x, OCP_MX_BLOCK_SIZE) + return F.linear(qdq_x, dq_w, bias) + else: + raise NotImplementedError() diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index afd4bb722dad..f8eb3611592e 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -34,21 +34,24 @@ def process_weights_after_loading(self, layer) -> None: # tensor scales (thus N scales being passed to the kernel), # requantize so we can always run per tensor if self.qscheme == "per_tensor": - max_w_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, - logical_widths=layer.logical_widths, - ) - - if current_platform.is_fp8_fnuz(): + if current_platform.is_rocm(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( - weight=weight, - weight_scale=max_w_scale, + weight=layer.weight, + weight_scale=layer.weight_scale, input_scale=input_scale) if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) + else: + max_w_scale = layer.weight_scale + weight = layer.weight + + max_w_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=max_w_scale, + logical_widths=layer.logical_widths, + ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 064cbb8cf52d..3bb42e737f10 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear( or br not in (1, weight.shape[0])): shape_supported_by_cutlass = False if cutlass_block_fp8_supported and shape_supported_by_cutlass: + rows, cols = input_2d.shape + # Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for + # optimal tensor core usage. Can be removed when targeting platforms + # without this constraint. + should_pad = current_platform.has_device_capability( + 100) and rows % 4 != 0 + if should_pad: + input_2d = torch.nn.functional.pad(input_2d, + (0, 0, 0, 4 - (rows % 4)), + value=0).contiguous() q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], column_major_scales=True) @@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear( out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale.T) + if should_pad: + output = output[:rows, :] else: q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py new file mode 100644 index 000000000000..6312c3934fd4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Tuple + +import torch + +OCP_MX_BLOCK_SIZE = 32 + + +def per_token_group_quant_mxfp4(x: torch.Tensor, + block_k: int, + scale_calculation_mode: str = "even" + ) -> Tuple[torch.Tensor, torch.Tensor]: + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_fp4_fp6_per_group_with_scale) + from quark.torch.quantization.utils import (even_round, + reshape_to_blocks) + except ImportError as err: + raise ImportError("The package `amd-quark` is required to use " + "MX-FP4 models. Please install it with `pip install " + "amd-quark`.") from err + + axis = -1 + block_x = reshape_to_blocks(x, block_k, axis) + amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) + amax = amax.squeeze(-1) + + # TODO: there are other rounding strategies supported in quark and in the + # config.json that we do not check for here! + if scale_calculation_mode != "even": + raise NotImplementedError( + f"Scale calculation mode {scale_calculation_mode} is not yet " + "supported in MX-FP4 quantization") + scale = even_round(amax, "fp4") + + # Apply dequantize(quantize(x)). + x = fake_quantize_fp4_fp6_per_group_with_scale( + x, + scale.to(x.device), + axis=axis, + group_size=block_k, + quant_dtype="fp4", + ) + + return x, scale diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 32c2a2859b49..f8392eb679d2 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -808,8 +808,9 @@ def forward( query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) + if self.cos_sin_cache.device != positions.device: + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index adb966c4b1c0..751b86787c7b 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -84,7 +84,7 @@ def rocm_unquantized_gemm(x: torch.Tensor, m = weight.shape[0] cu_count = current_platform.get_cu_count() - if m > 8 and 0 < n < 4: + if m > 8 and 0 < n <= 4: out = ops.wvSplitK(weight, x_view, cu_count) return out.view(*x.shape[:-1], weight.shape[0]) elif m % 4 == 0 and n == 1 and k <= 8192: diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 42528cd7e433..ddc857aebdc8 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -220,7 +220,7 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = [ - "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark" ] if (model_config.quantization is not None diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index ce86b9b2c4f0..0366895ef02e 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -453,7 +453,6 @@ def __init__( qk_rope_head_dim=self.qk_rope_head_dim, qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, - rotary_emb=self.rotary_emb, kv_b_proj=self.kv_b_proj, ) @@ -475,6 +474,13 @@ def forward( [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + # Add head dim of 1 to k_pe + k_pe = k_pe.unsqueeze(1) + + q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( + positions, q[..., self.qk_nope_head_dim:], k_pe) + attn_out = self.mla_attn( q, kv_c_normed, diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index c48cb157084d..6f56eb2d5e38 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -504,9 +504,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) diff --git a/vllm/model_executor/models/nemotron_nas.py b/vllm/model_executor/models/nemotron_nas.py index 264999496876..988b994b7689 100644 --- a/vllm/model_executor/models/nemotron_nas.py +++ b/vllm/model_executor/models/nemotron_nas.py @@ -334,14 +334,6 @@ class DeciLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, HasNoOps): } # LoRA specific attributes - supported_lora_modules = [ - "qkv_proj", - "o_proj", - "gate_up_proj", - "down_proj", - "embed_tokens", - "lm_head", - ] embedding_modules = { "embed_tokens": "input_embeddings", "lm_head": "output_embeddings", diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 6035994f4336..e5ff9ceddef7 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -955,11 +955,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - padding_size=( - DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size), + padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, ) if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/qwen2_5_omni_thinker.py b/vllm/model_executor/models/qwen2_5_omni_thinker.py index 039f528db13b..d8e178f9cd47 100644 --- a/vllm/model_executor/models/qwen2_5_omni_thinker.py +++ b/vllm/model_executor/models/qwen2_5_omni_thinker.py @@ -145,9 +145,11 @@ def get_hf_processor( kwargs["fps"] = fps processor = self.ctx.get_hf_processor( Qwen2_5OmniProcessor, - image_processor=self.get_image_processor(min_pixels=min_pixels, - max_pixels=max_pixels, - size=size), + image_processor=self.get_image_processor( + min_pixels=min_pixels, + max_pixels=max_pixels, + size=size, + use_fast=kwargs.get("use_fast")), **kwargs, ) if not hasattr(processor, "audio_token"): diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 5bef4129bfa8..8728de95134d 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -758,9 +758,11 @@ def get_hf_processor( return self.ctx.get_hf_processor( Qwen2_5_VLProcessor, - image_processor=self.get_image_processor(min_pixels=min_pixels, - max_pixels=max_pixels, - size=size), + image_processor=self.get_image_processor( + min_pixels=min_pixels, + max_pixels=max_pixels, + size=size, + use_fast=kwargs.get("use_fast")), **kwargs, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index a00b756ecec0..ac0a6de523df 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -759,9 +759,11 @@ def get_hf_processor( ) -> Qwen2VLProcessor: return self.ctx.get_hf_processor( Qwen2VLProcessor, - image_processor=self.get_image_processor(min_pixels=min_pixels, - max_pixels=max_pixels, - size=size), + image_processor=self.get_image_processor( + min_pixels=min_pixels, + max_pixels=max_pixels, + size=size, + use_fast=kwargs.get("use_fast")), **kwargs, ) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index d76c75d9e6ce..888ca3e5009e 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -416,7 +416,7 @@ def from_sampling_metadata( # k should not be greater than the vocab size. top_k = min(sampling_params.top_k, vocab_size) - top_k = vocab_size if top_k == -1 else top_k + top_k = vocab_size if top_k < 1 else top_k if temperature < _SAMPLING_EPS: # NOTE: Zero temperature means deterministic sampling # (i.e., greedy sampling or beam search). diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 6d875a1c651e..72e9b65d763c 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -81,7 +81,8 @@ def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: total_frames_num = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) full_read = num_frames == -1 or total_frames_num < num_frames if full_read: - frame_idx = list(range(0, total_frames_num)) + num_frames = total_frames_num + frame_idx = list(range(0, num_frames)) else: uniform_sampled_frames = np.linspace(0, total_frames_num - 1, @@ -104,7 +105,8 @@ def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray: frames[i] = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) i += 1 # we expect all frames loaded - assert i == num_frames + assert i == num_frames, (f"Expected reading {num_frames} frames, " + f"but only loaded {i} frames from video.") return frames diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index e45522a4c407..d286c8939512 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -10,7 +10,7 @@ from vllm.logger import init_logger -from .interface import Platform, PlatformEnum, _Backend +from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend logger = init_logger(__name__) @@ -26,6 +26,20 @@ class CpuPlatform(Platform): device_type: str = "cpu" dispatch_key: str = "CPU" + @property + def supported_dtypes(self) -> list: + if self.get_cpu_architecture() == CpuArchEnum.POWERPC: + return [torch.bfloat16, torch.float32] + elif sys.platform.startswith( + "darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM: + # TODO: change this condition to check if the platform support bf16 + # instead of checking the OS. For instance M2 shall supports bf16 + # already. But we need to modify `cpu_extension.cmake` to activate + # the feature in the build. + return [torch.bfloat16, torch.float32] + # x86/aarch64 CPU has supported both bf16 and fp16 natively. + return [torch.bfloat16, torch.float16, torch.float32] + @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "cpu" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index ab03dece8c13..f116285870ec 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -73,6 +73,19 @@ class CudaPlatformBase(Platform): ray_device_key: str = "GPU" device_control_env_var: str = "CUDA_VISIBLE_DEVICES" + @property + def supported_dtypes(self) -> List[torch.dtype]: + if self.has_device_capability(80): + # Ampere and Hopper or later NVIDIA GPUs. + return [torch.bfloat16, torch.float16, torch.float32] + elif (not self.has_device_capability(80) + ) and self.has_device_capability(60): + # Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported + return [torch.float16, torch.float32] + # Kepler and Maxwell NVIDIA GPUs, only FP32 is supported, + # though vLLM doesn't support these GPUs. + return [torch.float32] + @classmethod def get_device_capability(cls, device_id: int = 0 diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 5df0e9d3d072..68b90796ece2 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -39,7 +39,8 @@ class _Backend(enum.Enum): TRITON_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() - ROCM_AITER_MLA = enum.auto() + ROCM_AITER_MLA = enum.auto() # Supported by V1 + ROCM_AITER_MLA_VLLM_V1 = enum.auto() TORCH_SDPA = enum.auto() FLASHINFER = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 @@ -121,6 +122,14 @@ class Platform: additional_env_vars: list[str] = [] + @property + def supported_dtypes(self) -> list[torch.dtype]: + """Returns the supported dtypes for the current platform.""" + # Be careful with the order of the dtypes. The first dtype will + # be used as the default dtype fallback for the current platform, + # when encountering unsupported dtypes in "auto" dtype. + return [torch.bfloat16, torch.float16, torch.float32] + def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA @@ -332,6 +341,27 @@ def get_punica_wrapper(cls) -> str: """ raise NotImplementedError + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + """ + Return the platform specific values for (-inf, inf) + """ + return float("-inf"), float("inf") + + @classmethod + def can_update_inplace(cls) -> bool: + """ + Checks if the platform allows inplace memory updates + """ + return True + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + """ + Returns how much padding the LoRA logits need for kernels + """ + return 256 + @classmethod def get_device_communicator_cls(cls) -> str: """ @@ -339,6 +369,13 @@ def get_device_communicator_cls(cls) -> str: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def supports_mx(cls) -> bool: + """ + Returns whether the current platform supports MX types. + """ + return False + @classmethod def supports_fp8(cls) -> bool: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index ff63f9656c01..ea028e13fc4d 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -104,6 +104,7 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +@cache def on_mi250_mi300() -> bool: GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName return any(arch in GPU_ARCH for arch in ["gfx90a", "gfx942"]) @@ -167,10 +168,15 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, raise ValueError( f" The selected backend, {selected_backend.name}," f"does not support block size {block_size}.") - elif selected_backend == _Backend.ROCM_AITER_MLA: + elif selected_backend == _Backend.ROCM_AITER_MLA \ + or selected_backend == _Backend.ROCM_AITER_MLA_VLLM_V1: if block_size == 1: - logger.info("Using AITER MLA backend.") - return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + if use_v1: + logger.info("Using AITER MLA backend on V1 engine.") + return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501 + else: + logger.info("Using AITER MLA backend") + return "vllm.attention.backends.rocm_aiter_mla.AiterMLABackend" # noqa: E501 else: raise ValueError( f" The selected backend, {selected_backend.name}," @@ -327,6 +333,11 @@ def get_current_memory_usage(cls, def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + @classmethod + def supports_mx(cls) -> bool: + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + return any(gfx in gcn_arch for gfx in ["gfx95"]) + @classmethod def supports_fp8(cls) -> bool: gcn_arch = torch.cuda.get_device_properties(0).gcnArchName diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 8c968e7df3ef..d0a5af3587c4 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union, cast import torch from tpu_info import device @@ -13,9 +13,10 @@ from .interface import Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import ModelConfig, VllmConfig + from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams else: + BlockSize = None ModelConfig = None VllmConfig = None PoolingParams = None @@ -67,6 +68,22 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool: return not envs.VLLM_USE_V1 + @classmethod + def get_punica_wrapper(cls) -> str: + return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" + + @classmethod + def get_infinity_values(cls, dtype: torch.dtype) -> Tuple[float, float]: + return torch.finfo(dtype).min, torch.finfo(dtype).max + + @classmethod + def can_update_inplace(cls): + return False + + @classmethod + def get_lora_vocab_padding_size(cls) -> int: + return 1 + @classmethod def inference_mode(cls): return torch.no_grad() @@ -78,7 +95,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config # For v0, the default block size is 16. if cache_config and cache_config.block_size is None: - cache_config.block_size = 16 + cache_config.block_size = cast(BlockSize, 16) compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_ONCE compilation level @@ -102,7 +119,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm.v1.attention.backends.pallas import ( PallasAttentionBackend) cache_config.block_size = PallasAttentionBackend.get_page_size( - vllm_config) + vllm_config) # type: ignore[assignment] min_page_size = PallasAttentionBackend.get_min_page_size( vllm_config) if min_page_size > cache_config.block_size: @@ -112,7 +129,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config.block_size, min_page_size, ) - cache_config.block_size = min_page_size + cache_config.block_size = min_page_size # type: ignore[assignment] parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 66a77681be9a..dc38daa388ce 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -13,7 +13,6 @@ from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer logger = init_logger(__name__) @@ -150,7 +149,7 @@ class SamplingParams( top_p: Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to 1 to consider all tokens. top_k: Integer that controls the number of top tokens to consider. Set - to -1 to consider all tokens. + to 0 (or -1) to consider all tokens. min_p: Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. @@ -210,7 +209,7 @@ class SamplingParams( repetition_penalty: float = 1.0 temperature: float = 1.0 top_p: float = 1.0 - top_k: int = -1 + top_k: int = 0 min_p: float = 0.0 seed: Optional[int] = None stop: Optional[Union[str, list[str]]] = None @@ -257,7 +256,7 @@ def from_optional( repetition_penalty: Optional[float] = 1.0, temperature: Optional[float] = 1.0, top_p: Optional[float] = 1.0, - top_k: int = -1, + top_k: int = 0, min_p: float = 0.0, seed: Optional[int] = None, stop: Optional[Union[str, list[str]]] = None, @@ -377,7 +376,7 @@ def __post_init__(self) -> None: if self.temperature < _SAMPLING_EPS: # Zero temperature means greedy sampling. self.top_p = 1.0 - self.top_k = -1 + self.top_k = 0 self.min_p = 0.0 self._verify_greedy_sampling() @@ -405,8 +404,9 @@ def _verify_args(self) -> None: f"temperature must be non-negative, got {self.temperature}.") if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") - if self.top_k < -1 or self.top_k == 0: - raise ValueError(f"top_k must be -1 (disable), or at least 1, " + # quietly accept -1 as disabled, but prefer 0 + if self.top_k < -1: + raise ValueError(f"top_k must be 0 (disable), or at least 1, " f"got {self.top_k}.") if not isinstance(self.top_k, int): raise TypeError( @@ -491,13 +491,8 @@ def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: for add_prefix_space in [False, True]: prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - - if isinstance(tokenizer, MistralTokenizer): - # Mistral tokenizers should not add special tokens - prompt_token_ids = tokenizer.encode(text=prompt) - else: - prompt_token_ids = tokenizer.encode( - text=prompt, add_special_tokens=False) + prompt_token_ids = tokenizer.encode(text=prompt, + add_special_tokens=False) # If no space at the beginning # or if prefix space produces a new word token diff --git a/vllm/transformers_utils/chat_templates/__init__.py b/vllm/transformers_utils/chat_templates/__init__.py new file mode 100644 index 000000000000..fe2bd3ca4125 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .registry import get_chat_template_fallback_path + +__all__ = ["get_chat_template_fallback_path"] diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py new file mode 100644 index 000000000000..853fed5d4409 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path +from typing import Callable, Optional, Union + +from vllm.logger import init_logger + +logger = init_logger(__file__) + +CHAT_TEMPLATES_DIR = Path(__file__).parent + +ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]] + + +def _get_qwen_chat_template_fallback( + tokenizer_name_or_path: str) -> Optional[Path]: + if tokenizer_name_or_path.endswith("-Chat"): + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + return CHAT_TEMPLATES_DIR / "template_basic.jinja" + + +# yapf: disable +_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { + "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", + "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", + "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "qwen": _get_qwen_chat_template_fallback, +} +# yapf: enable + + +def register_chat_template_fallback_path( + model_type: str, + chat_template: ChatTemplatePath, +) -> None: + if model_type in _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: + logger.warning( + "Model type %s already has a chat template registered. " + "It will be overwritten by the new chat template %s.", model_type, + chat_template) + + _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK[model_type] = chat_template + + +def get_chat_template_fallback_path( + model_type: str, + tokenizer_name_or_path: str, +) -> Optional[Path]: + chat_template = _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK.get(model_type) + if callable(chat_template): + chat_template = chat_template(tokenizer_name_or_path) + + if chat_template is None: + return None + + return chat_template diff --git a/examples/template_chameleon.jinja b/vllm/transformers_utils/chat_templates/template_basic.jinja similarity index 100% rename from examples/template_chameleon.jinja rename to vllm/transformers_utils/chat_templates/template_basic.jinja diff --git a/examples/template_blip2.jinja b/vllm/transformers_utils/chat_templates/template_blip2.jinja similarity index 100% rename from examples/template_blip2.jinja rename to vllm/transformers_utils/chat_templates/template_blip2.jinja diff --git a/examples/template_qwen_vl_chat.jinja b/vllm/transformers_utils/chat_templates/template_chatml.jinja similarity index 100% rename from examples/template_qwen_vl_chat.jinja rename to vllm/transformers_utils/chat_templates/template_chatml.jinja diff --git a/examples/template_deepseek_vl2.jinja b/vllm/transformers_utils/chat_templates/template_deepseek_vl2.jinja similarity index 100% rename from examples/template_deepseek_vl2.jinja rename to vllm/transformers_utils/chat_templates/template_deepseek_vl2.jinja diff --git a/examples/template_fuyu.jinja b/vllm/transformers_utils/chat_templates/template_fuyu.jinja similarity index 100% rename from examples/template_fuyu.jinja rename to vllm/transformers_utils/chat_templates/template_fuyu.jinja diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f6c2b35535b6..2fbd996dbb0b 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -686,9 +686,54 @@ def recurse_elems(elem: Any): config_dict["hidden_act"] = config_dict.get("activation", "silu") config_dict["tie_word_embeddings"] = config_dict.get( "tie_embeddings", False) - config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000) - config_dict["max_position_embeddings"] = config_dict.get( - "max_position_embeddings", 128_000) + # Check if max_position_embeddings is in params.json + mpe_from_params = config_dict.get("max_position_embeddings") + final_mpe_to_set = mpe_from_params + + if final_mpe_to_set is None: + # Not found in params.json, try to get from standard HF AutoConfig + hf_config_for_defaults = None + try: + trust_remote_code_val = kwargs.get("trust_remote_code", False) + token_val = kwargs.get("token") # Passed from get_config + + hf_config_for_defaults = AutoConfig.from_pretrained( + model, + revision=revision, + trust_remote_code=trust_remote_code_val, + token=token_val) + except Exception as e: + error_message = ( + "Invalid repository ID or local directory specified:" + " '{model}'.\nPlease verify the following requirements:\n" + "1. Provide a valid Hugging Face repository ID.\n" + "2. Specify a local directory that contains a recognized " + "configuration file.\n").format(model=model) + + raise ValueError(error_message) from e + + if hf_config_for_defaults: + # Try to get from text_config first, then top-level + mpe_from_hf_config = None + text_config_obj = getattr(hf_config_for_defaults, "text_config", + None) + if text_config_obj and hasattr(text_config_obj, + "max_position_embeddings"): + mpe_from_hf_config = getattr(text_config_obj, + "max_position_embeddings", None) + + if mpe_from_hf_config is None and hasattr( + hf_config_for_defaults, "max_position_embeddings"): + mpe_from_hf_config = getattr(hf_config_for_defaults, + "max_position_embeddings", None) + + if mpe_from_hf_config is not None: + final_mpe_to_set = mpe_from_hf_config + + if final_mpe_to_set is None: # Still not found, use ultimate fallback + final_mpe_to_set = 128_000 + + config_dict["max_position_embeddings"] = final_mpe_to_set if config_dict.get("quantization") is not None: quantization = config_dict.get("quantization", {}) diff --git a/vllm/utils.py b/vllm/utils.py index 212138e4ba6e..6779c5b3f8d9 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -41,7 +41,6 @@ from concurrent.futures.process import ProcessPoolExecutor from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps -from gettext import gettext as _gettext from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, @@ -1333,31 +1332,10 @@ def add_arguments(self, actions): super().add_arguments(actions) -class _FlexibleArgumentGroup(_ArgumentGroup): - - def __init__(self, parser: FlexibleArgumentParser, *args, **kwargs): - self._parser = parser - super().__init__(*args, **kwargs) - - def add_argument(self, *args: Any, **kwargs: Any): - if sys.version_info < (3, 13): - deprecated = kwargs.pop('deprecated', False) - action = super().add_argument(*args, **kwargs) - object.__setattr__(action, 'deprecated', deprecated) - if deprecated and action.dest not in \ - self._parser.__class__._deprecated: - self._parser._deprecated.add(action) - return action - - # python>3.13 - return super().add_argument(*args, **kwargs) - - class FlexibleArgumentParser(ArgumentParser): """ArgumentParser that allows both underscore and dash in names.""" _deprecated: set[Action] = set() - _seen: set[str] = set() def __init__(self, *args, **kwargs): # Set the default 'formatter_class' to SortedHelpFormatter @@ -1366,39 +1344,36 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if sys.version_info < (3, 13): + # Enable the deprecated kwarg for Python 3.12 and below - def parse_known_args( # type: ignore[override] - self, - args: Sequence[str] | None = None, - namespace: Namespace | None = None, - ) -> tuple[Namespace | None, list[str]]: + def parse_known_args(self, args=None, namespace=None): namespace, args = super().parse_known_args(args, namespace) for action in FlexibleArgumentParser._deprecated: - if action.dest not in FlexibleArgumentParser._seen and getattr( - namespace, action.dest, - None) != action.default: # noqa: E501 - self._warning( - _gettext("argument '%(argument_name)s' is deprecated") - % {'argument_name': action.dest}) - FlexibleArgumentParser._seen.add(action.dest) + if (hasattr(namespace, dest := action.dest) + and getattr(namespace, dest) != action.default): + logger.warning_once("argument '%s' is deprecated", dest) return namespace, args - def add_argument(self, *args: Any, **kwargs: Any): - # add a deprecated=True compatibility - # for python < 3.13 - deprecated = kwargs.pop('deprecated', False) + def add_argument(self, *args, **kwargs): + deprecated = kwargs.pop("deprecated", False) action = super().add_argument(*args, **kwargs) - object.__setattr__(action, 'deprecated', deprecated) - if deprecated and \ - action not in FlexibleArgumentParser._deprecated: - self._deprecated.add(action) - + if deprecated: + FlexibleArgumentParser._deprecated.add(action) return action - def _warning(self, message: str): - self._print_message( - _gettext('warning: %(message)s\n') % {'message': message}, - sys.stderr) + class _FlexibleArgumentGroup(_ArgumentGroup): + + def add_argument(self, *args, **kwargs): + deprecated = kwargs.pop("deprecated", False) + action = super().add_argument(*args, **kwargs) + if deprecated: + FlexibleArgumentParser._deprecated.add(action) + return action + + def add_argument_group(self, *args, **kwargs): + group = self._FlexibleArgumentGroup(self, *args, **kwargs) + self._action_groups.append(group) + return group def parse_args( # type: ignore[override] self, @@ -1575,15 +1550,6 @@ def _load_config_file(self, file_path: str) -> list[str]: return processed_args - def add_argument_group( - self, - *args: Any, - **kwargs: Any, - ) -> _FlexibleArgumentGroup: - group = _FlexibleArgumentGroup(self, self, *args, **kwargs) - self._action_groups.append(group) - return group - async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): @@ -1854,6 +1820,14 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(zmq, _MockModule) + except ModuleNotFoundError: + return False + + def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): """ Import a Python file according to its file path. diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index db7926902154..605dff3749fb 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -291,6 +291,7 @@ class FlashAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner"): model_config = runner.model_config + compilation_config = runner.vllm_config.compilation_config self.runner = runner self.num_heads_q = model_config.get_num_attention_heads( @@ -300,7 +301,14 @@ def __init__(self, runner: "GPUModelRunner"): self.headdim = model_config.get_head_size() self.page_size = self.runner.block_size - self.aot_schedule = (get_flash_attn_version() == 3) + if get_flash_attn_version() == 3: + self.aot_schedule = not compilation_config.full_cuda_graph + if not self.aot_schedule: + logger.warning( + "AOT Schedule is disabled when using full_cuda_graph") + else: + self.aot_schedule = False + # Sliding window size to be used with the AOT scheduler will be # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None @@ -317,8 +325,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, seq_lens = common_attn_metadata.seq_lens block_table = ( self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) - slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( - self.runner.device, non_blocking=True).long() + slot_mapping = self.runner.slot_mapping[:num_actual_tokens] if self.aot_sliding_window is None: self.aot_sliding_window = (-1, -1) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 0d18a5639c2a..0c740fbcc6b7 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -204,7 +204,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, UnquantizedLinearMethod) -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform from vllm.utils import cdiv, round_down from vllm.v1.attention.backends.utils import CommonAttentionMetadata @@ -269,9 +268,6 @@ class ChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor query_start_loc: torch.Tensor max_query_len: int @@ -280,9 +276,6 @@ class ChunkedContextMetadata: @dataclass class MLACommonDecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor block_table: torch.Tensor seq_lens: torch.Tensor @@ -443,10 +436,8 @@ def reorder_batch(self, input_batch: "InputBatch", return modified_batch - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, seq_lens: torch.Tensor): + def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor): return MLACommonDecodeMetadata( - input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, ) @@ -464,8 +455,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, self.runner.input_batch.block_table.get_device_tensor()[:num_reqs]) slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - input_positions = self.runner.positions_cpu[:num_actual_tokens].to( - device, non_blocking=True).long() query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -473,7 +462,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, prefill_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start - tokens_start = self._num_decode_tokens context_lens_cpu = self.runner.input_batch.\ num_computed_tokens_cpu_tensor[reqs_start:num_reqs] @@ -496,11 +484,12 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, max_context_chunk = (self.chunked_prefill_workspace_size // num_prefills_with_context_cpu) - # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, - self.page_size) + if self.aot_schedule: + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) assert max_context_chunk > 0 num_chunks = cdiv(max_context_len_cpu, max_context_chunk) @@ -541,7 +530,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, self.chunked_prefill_workspace_size prefill_metadata = MLACommonPrefillMetadata( - input_positions=input_positions[tokens_start:], block_table=block_table[reqs_start:, ...], query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, @@ -551,7 +539,6 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, decode_metadata = None if self._num_decodes > 0: decode_metadata = self._build_decode( - input_positions=input_positions[:self._num_decode_tokens], block_table=block_table[:self._num_decodes, ...], seq_lens=seq_lens[:self._num_decodes], ) @@ -598,7 +585,6 @@ def __init__( qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, - rotary_emb: RotaryEmbedding, kv_b_proj: ColumnParallelLinear, ) -> None: self.num_heads = num_heads @@ -613,15 +599,6 @@ def __init__( self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim - - # Hack for V1 for now to avoid torch library overhead (since we are - # already inside an attention custom op), pull out the forward - # method from the rotary embedding and call it directly - # TODO(lucas): we should probably find a cleaner way to do this - self.rotary_emb = rotary_emb.forward_native - if current_platform.is_cuda(): - self.rotary_emb = rotary_emb.forward_cuda - self.kv_b_proj = kv_b_proj self.vllm_flash_attn_version = get_flash_attn_version() @@ -893,9 +870,6 @@ def forward( k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] - # Restore head dim (for rotary embedding) - k_pe = k_pe.unsqueeze(1) - assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ attn_metadata.num_decode_tokens is not None @@ -904,35 +878,12 @@ def forward( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - q = q.view(-1, self.num_heads, self.qk_head_dim) decode_q = q[:num_decode_tokens] - decode_k_pe = k_pe[:num_decode_tokens] prefill_q = q[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] - if has_decode: - assert attn_metadata.decode is not None - decode_q_nope, decode_q_pe = decode_q.split( - [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - # Convert from (B, N, P) to (N, B, P) - decode_q_nope = decode_q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - decode_ql_nope = decode_ql_nope.transpose(0, 1) - decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) - - if has_prefill: - assert attn_metadata.prefill is not None - prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] - - prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - attn_metadata.prefill.input_positions, prefill_q_pe, - prefill_k_pe) - # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( @@ -950,6 +901,16 @@ def forward( attn_metadata) if has_decode: + assert attn_metadata.decode is not None + decode_q_nope, decode_q_pe = decode_q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + # Convert from (B, N, P) to (N, B, P) + decode_q_nope = decode_q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + decode_ql_nope = decode_ql_nope.transpose(0, 1) + output[:num_decode_tokens] = self._forward_decode( decode_ql_nope, decode_q_pe, kv_cache, attn_metadata) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index f18c9c8b6462..2f35f9b0a54f 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -58,8 +58,7 @@ def __init__(self, runner): self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - def _build_decode(self, input_positions: torch.Tensor, - block_table: torch.Tensor, + def _build_decode(self, block_table: torch.Tensor, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( @@ -69,7 +68,6 @@ def _build_decode(self, input_positions: torch.Tensor, ) return FlashMLADecodeMetadata( - input_positions=input_positions, block_table=block_table, seq_lens=seq_lens, tile_scheduler_metadata=tile_scheduler_metadata, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py new file mode 100644 index 000000000000..37b72c08d52b --- /dev/null +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +import vllm.envs as envs +from vllm.attention.ops.rocm_aiter_mla import aiter_mla_decode_fwd +# yapf conflicts with isort for this docstring +# yapf: disable +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) + +# yapf: enable + + +def is_aiter_mla_enabled() -> bool: + return envs.VLLM_ROCM_USE_AITER \ + and envs.VLLM_ROCM_USE_AITER_MLA + + +class AiterMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "ROCM_AITER_MLA_VLLM_V1" + + @staticmethod + def get_impl_cls() -> type["AiterMLAImpl"]: + return AiterMLAImpl + + @staticmethod + def get_metadata_cls() -> type["AiterMLAMetadata"]: + return AiterMLAMetadata + + @staticmethod + def get_builder_cls() -> type["AiterMLAMetadataBuilder"]: + return AiterMLAMetadataBuilder + + +@dataclass +class AiterMLADecodeMetadata(MLACommonDecodeMetadata): + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + + +class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): + pass + + +class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): + + def __init__(self, runner): + super().__init__(runner) + max_model_len = self.runner.model_config.max_model_len + assert max_model_len == 32768,\ + "AITER MLA requires max_model_len=32768" + assert self.runner.block_size == 1, "AITER MLA" \ + "only supports block size 1." + + def _get_paged_kv_tensors( + self, block_table: torch.Tensor, + seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]: + page_size = self.runner.block_size + block_table_bounds = (seq_lens + page_size - 1) // page_size + + mask = (torch.arange(block_table.size(1), + dtype=block_table.dtype, + device=block_table.device).unsqueeze(0) + < block_table_bounds.unsqueeze(1)) + paged_kv_indices = block_table[mask] + + paged_kv_indptr = torch.cat([ + torch.zeros(1, + dtype=block_table_bounds.dtype, + device=block_table_bounds.device), + block_table_bounds.cumsum(dim=0, dtype=torch.int32) + ]) + + paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, + page_size, paged_kv_last_page_len) + return ( + paged_kv_indices, + paged_kv_indptr, + paged_kv_last_page_len, + ) + + def _build_decode(self, block_table: torch.Tensor, + seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + + ( + paged_kv_indices, + paged_kv_indptr, + paged_last_page_len, + ) = self._get_paged_kv_tensors(block_table, seq_lens) + + attn_metadata = AiterMLADecodeMetadata( + block_table=block_table, + seq_lens=seq_lens, + paged_kv_indptr=paged_kv_indptr, + paged_kv_indices=paged_kv_indices, + paged_kv_last_page_len=paged_last_page_len) + + return attn_metadata + + +class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]], + logits_soft_cap: Optional[float], + attn_type: str, + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap, attn_type, + **mla_args) + + unsupported_features = [ + alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap + ] + if any(unsupported_features): + raise NotImplementedError( + "Aiter MLA does not support one of the following: " + "alibi_slopes, sliding_window, blocksparse_params, " + "logits_soft_cap") + + from aiter import flash_attn_varlen_func + self.flash_attn_varlen_func = flash_attn_varlen_func + + def _flash_attn_varlen_diff_headdims(self, + q, + k, + v, + return_softmax_lse=False, + softmax_scale=None, + **kwargs): + output = self.flash_attn_varlen_func( + q=q, + k=k, + v=v, + softmax_scale=softmax_scale, + return_lse=return_softmax_lse, + **kwargs, + ) + + return output + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: AiterMLAMetadata, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + B = q_nope.shape[0] + + q = torch.cat([q_nope, q_pe], dim=-1) + o = torch.zeros(B, + self.num_heads, + self.kv_lora_rank, + dtype=q.dtype, + device=q.device) + + kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2) + + aiter_mla_decode_fwd(q, kv_buffer, o, self.scale, + attn_metadata.decode.paged_kv_indptr, + attn_metadata.decode.paged_kv_indices, + attn_metadata.decode.paged_kv_last_page_len) + + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 79ec67b89e97..8187e457d9e6 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -95,7 +95,7 @@ class PallasMetadata: block_tables: torch.Tensor context_lens: torch.Tensor query_start_loc: torch.Tensor - num_seqs: int + num_seqs: torch.Tensor class PallasAttentionBackendImpl(AttentionImpl): diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 9e172b6bdb00..c4ed127ece60 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -1,17 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 from collections import defaultdict -from collections.abc import Iterable from dataclasses import dataclass from typing import Optional from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger -from vllm.utils import cdiv, sha256 +from vllm.utils import sha256 from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock, hash_request_tokens) -from vllm.v1.core.specialized_manager import get_specialized_manager +from vllm.v1.core.specialized_manager import get_manager_for_kv_cache_spec from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request, RequestStatus @@ -56,7 +55,6 @@ def __init__( self.block_size = kv_cache_spec.block_size self.num_gpu_blocks = kv_cache_config.num_blocks self.max_model_len = max_model_len - self.max_num_blocks_per_req = cdiv(max_model_len, self.block_size) self.enable_caching = enable_caching self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash @@ -68,30 +66,20 @@ def __init__( self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching, enable_kv_cache_events) - self.specialized_manager = get_specialized_manager( + self.single_type_manager = get_manager_for_kv_cache_spec( kv_cache_spec=kv_cache_spec, block_pool=self.block_pool, use_eagle=self.use_eagle, + num_kv_cache_groups=1, + caching_hash_fn=self.caching_hash_fn, ) - # Mapping from request ID to blocks to track the blocks allocated - # for each request, so that we can free the blocks when the request - # is finished. - self.req_to_blocks: defaultdict[str, - list[KVCacheBlock]] = defaultdict(list) - # Mapping from request ID to kv block hashes. # This is to avoid recomputing the block hashes for each call of # `get_computed_blocks` or `allocate_slots`. self.req_to_block_hashes: defaultdict[ str, list[BlockHashType]] = defaultdict(list) - # {req_id: The number of cached blocks for this given request} - # This is used to track the number of cached blocks for each request. - # This is only used to track the RUNNING requests, we do not track the - # data for reempted ones. - self.num_cached_block: dict[str, int] = {} - @property def usage(self) -> float: """Get the KV cache usage. @@ -126,8 +114,11 @@ def get_computed_blocks(self, - A list of blocks that are computed for the request. - The number of computed tokens. """ - if not self.enable_caching: - # Prefix caching is disabled. + + # Prefix caching is disabled or + # When the request requires prompt logprobs, we skip prefix caching. + if (not self.enable_caching + or request.sampling_params.prompt_logprobs is not None): return KVCacheBlocks.create_empty(), 0 # The block hashes for the request may already be computed @@ -141,9 +132,6 @@ def get_computed_blocks(self, if self.log_stats: assert self.prefix_cache_stats is not None self.prefix_cache_stats.requests += 1 - # When the request requires prompt logprobs, we skip prefix caching. - if request.sampling_params.prompt_logprobs is not None: - return KVCacheBlocks.create_empty(), 0 if len(block_hashes) * self.block_size == request.num_tokens: # When prompt length is divisible by the block size and all @@ -159,7 +147,7 @@ def get_computed_blocks(self, last_block_hash = None computed_blocks = ( - self.specialized_manager.find_longest_cache_hit(block_hashes)) + self.single_type_manager.find_longest_cache_hit(block_hashes)) if self.log_stats: assert self.prefix_cache_stats is not None @@ -181,7 +169,7 @@ def get_computed_blocks(self, def allocate_slots( self, request: Request, - num_tokens: int, + num_new_tokens: int, new_computed_blocks: Optional[KVCacheBlocks] = None, num_lookahead_tokens: int = 0, ) -> Optional[KVCacheBlocks]: @@ -189,7 +177,7 @@ def allocate_slots( Args: request: The request to allocate slots. - num_tokens: The number of tokens to allocate, including external + num_new_tokens: The number of tokens to allocate, including external tokens. Note that this does not include tokens that have already been computed locally (i.e. new_computed_blocks). new_computed_blocks: The new computed blocks just hitting the @@ -215,44 +203,38 @@ def allocate_slots( Returns: A list of new allocated blocks. """ - if num_tokens == 0: - raise ValueError("num_tokens must be greater than 0") + if num_new_tokens == 0: + raise ValueError("num_new_tokens must be greater than 0") if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks else: new_computed_block_list = [] - req_blocks = self.req_to_blocks[request.request_id] - # Free the blocks that are skipped during the attention computation # (e.g., tokens outside the sliding window). # We can do this even if we cannot schedule this request due to # insufficient free blocks. # Should call this function before allocating new blocks to reduce # the number of evicted blocks. - removed_blocks = self.specialized_manager.remove_skipped_blocks( - req_blocks, request.num_computed_tokens) - self.block_pool.free_blocks(removed_blocks) + self.single_type_manager.remove_skipped_blocks( + request.request_id, request.num_computed_tokens) # The number of computed tokens is the number of computed tokens plus # the new prefix caching hits num_computed_tokens = (request.num_computed_tokens + len(new_computed_block_list) * self.block_size) - num_required_blocks = cdiv( - num_computed_tokens + num_tokens + num_lookahead_tokens, - self.block_size) - num_new_blocks = (num_required_blocks - len(req_blocks) - - len(new_computed_block_list)) - - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = sum(1 - for blk in new_computed_block_list - if blk.ref_cnt == 0) - if (num_new_blocks > self.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks): + num_tokens_need_slot = min( + num_computed_tokens + num_new_tokens + num_lookahead_tokens, + self.max_model_len) + num_blocks_to_allocate = ( + self.single_type_manager.get_num_blocks_to_allocate( + request_id=request.request_id, + num_tokens=num_tokens_need_slot, + new_computed_blocks=new_computed_block_list, + )) + + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): # Cannot allocate new blocks return None @@ -266,74 +248,33 @@ def allocate_slots( # Append the new computed blocks to the request blocks until now to # avoid the case where the new blocks cannot be allocated. - req_blocks.extend(new_computed_block_list) - - # Start to handle new blocks + self.single_type_manager.save_new_computed_blocks( + request.request_id, new_computed_block_list) - if num_new_blocks <= 0: - # No new block is needed. - new_blocks = [] - else: - # Get new blocks from the free block pool. - num_new_blocks = min( - num_new_blocks, - self.block_pool.get_num_free_blocks(), - # Should not exceed the maximum number of blocks per request. - # This is especially because the block table has the shape - # [..., max_num_blocks_per_req]. - self.max_num_blocks_per_req - len(req_blocks), - ) - assert num_new_blocks > 0 - - # Concatenate the computed block IDs and the new block IDs. - new_blocks = self.block_pool.get_new_blocks(num_new_blocks) - req_blocks.extend(new_blocks) + new_blocks = self.single_type_manager.allocate_new_blocks( + request.request_id, num_tokens_need_slot) if not self.enable_caching: return KVCacheBlocks(new_blocks) - # Use `new_computed_block_list` for a new request, and - # `num_cached_block` for a running request. - num_cached_blocks = self.num_cached_block.get( - request.request_id, len(new_computed_block_list)) # Speculated tokens might be rejected in the future, so we does # not cache any speculated tokens. We only cache blocks with # generated (accepted) tokens. - num_full_blocks_after_append = (num_computed_tokens + num_tokens - len( - request.spec_token_ids)) // self.block_size - - self.block_pool.cache_full_blocks( - request=request, - blocks=req_blocks, - block_hashes=self.req_to_block_hashes[request.request_id], - num_cached_blocks=num_cached_blocks, - num_full_blocks=num_full_blocks_after_append, - block_size=self.block_size, - hash_fn=self.caching_hash_fn, - ) + self.single_type_manager.cache_blocks( + request, self.req_to_block_hashes[request.request_id], + num_computed_tokens + num_new_tokens - len(request.spec_token_ids)) - self.num_cached_block[ - request.request_id] = num_full_blocks_after_append return KVCacheBlocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. - When caching is enabled, we free the blocks in reverse order so that - the tail blocks are evicted first. + We free the blocks in reverse order so that he tail blocks are evicted + first when caching is enabled. Args: request: The request to free the blocks. """ - # Default to [] in case a request is freed (aborted) before alloc. - blocks = self.req_to_blocks.pop(request.request_id, []) - ordered_blocks: Iterable[KVCacheBlock] = blocks - if self.enable_caching: - # Free blocks in reverse order so that the tail blocks are - # freed first. - ordered_blocks = reversed(blocks) - - self.block_pool.free_blocks(ordered_blocks) - self.num_cached_block.pop(request.request_id, None) + self.single_type_manager.free(request.request_id) def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF @@ -390,14 +331,8 @@ def get_num_common_prefix_blocks( int: The number of common prefix blocks. """ assert request.status == RequestStatus.RUNNING - blocks = self.req_to_blocks[request.request_id] - num_common_blocks = 0 - for block in blocks: - if block.ref_cnt == num_running_requests: - num_common_blocks += 1 - else: - break - return num_common_blocks + return self.single_type_manager.get_num_common_prefix_blocks( + request.request_id, num_running_requests) def free_block_hashes(self, request: Request) -> None: """Discard the block hashes for the request. diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 928fb231a1f2..24032498e50b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -48,6 +48,33 @@ def from_request( lora_request=request.lora_request, ) + def __repr__(self): + return (f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids={self.prompt_token_ids}," + f"mm_inputs={self.mm_inputs}," + f"mm_hashes={self.mm_hashes}," + f"mm_positions={self.mm_positions}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}" + ")") + + # Version of __repr__ with the prompt data obfuscated + def anon_repr(self): + return (f"NewRequestData(" + f"req_id={self.req_id}," + f"prompt_token_ids_len={len(self.prompt_token_ids)}," + f"mm_inputs={self.mm_inputs}," + f"mm_hashes={self.mm_hashes}," + f"mm_positions={self.mm_positions}," + f"sampling_params={self.sampling_params}," + f"block_ids={self.block_ids}," + f"num_computed_tokens={self.num_computed_tokens}," + f"lora_request={self.lora_request}" + ")") + @dataclass class CachedRequestData: diff --git a/vllm/v1/core/specialized_manager.py b/vllm/v1/core/specialized_manager.py index f04eedf42662..3fd3cb2841e0 100644 --- a/vllm/v1/core/specialized_manager.py +++ b/vllm/v1/core/specialized_manager.py @@ -1,17 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Callable from vllm.utils import cdiv from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, SlidingWindowSpec) +from vllm.v1.request import Request -class SpecializedManager(ABC): +class SingleTypeKVCacheManager(ABC): """ - An abstract base class for specialized managers that handle the kv - cache management logic of different attention layers. + An abstract base class for a manager that handle the kv cache management + logic of one specific type of attention layer. """ def __init__( @@ -19,12 +22,18 @@ def __init__( kv_cache_spec: KVCacheSpec, block_pool: BlockPool, use_eagle: bool, + num_kv_cache_groups: int, + caching_hash_fn: Callable, ) -> None: """ Initializes the SpecializedManager. Args: kv_cache_spec: The kv_cache_spec for this manager. block_pool: The block pool. + use_eagle: Whether to use eagle. + num_kv_cache_groups: The number of kv cache groups managed by this + manager. + caching_hash_fn: The caching hash function. """ self.block_size = kv_cache_spec.block_size @@ -34,6 +43,149 @@ def __init__( # Needs special handling for find_longest_cache_hit if eagle is enabled self.use_eagle = use_eagle + # Mapping from request ID to blocks to track the blocks allocated + # for each request, so that we can free the blocks when the request + # is finished. + self.req_to_blocks: defaultdict[str, + list[KVCacheBlock]] = defaultdict(list) + + # {req_id: The number of cached blocks for this given request} + # This is used to track the number of cached blocks for each request. + # This is only used to track the RUNNING requests, we do not track the + # data for reempted ones. + self.num_cached_block: dict[str, int] = {} + + self.num_kv_cache_groups = num_kv_cache_groups + self.caching_hash_fn = caching_hash_fn + + def get_num_blocks_to_allocate( + self, request_id: str, num_tokens: int, + new_computed_blocks: list[KVCacheBlock]) -> int: + """ + Get the number of blocks needed to be allocated for the request. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + new_computed_blocks: The new computed blocks just hitting the + prefix caching. + + Returns: + The number of blocks. + """ + + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = (num_required_blocks - len(new_computed_blocks) - + len(self.req_to_blocks[request_id])) + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it will be changed from a free block + # to a computed block when the request is allocated, so we also count + # it as needed to be allocated. + num_evictable_computed_blocks = sum(blk.ref_cnt == 0 + for blk in new_computed_blocks) + return ((num_new_blocks + num_evictable_computed_blocks) * + self.num_kv_cache_groups) + + def save_new_computed_blocks( + self, request_id: str, + new_computed_blocks: list[KVCacheBlock]) -> None: + """ + Add the new computed blocks to the request. + + Args: + request_id: The request ID. + new_computed_blocks: The new computed blocks just hitting the + prefix cache. + """ + if request_id not in self.num_cached_block: + # A new request. + req_blocks = self.req_to_blocks[request_id] + assert len(req_blocks) == 0 + req_blocks.extend(new_computed_blocks) + self.num_cached_block[request_id] = len(new_computed_blocks) + else: + # A running request. Should not have new computed blocks. + assert len(new_computed_blocks) == 0 + + def allocate_new_blocks(self, request_id: str, + num_tokens: int) -> list[KVCacheBlock]: + """ + Allocate new blocks for the request to give it at least `num_tokens` + token slots. + + Args: + request_id: The request ID. + num_tokens: The total number of tokens that need a slot (including + tokens that are already allocated). + + Returns: + The new allocated blocks. + """ + req_blocks = self.req_to_blocks[request_id] + num_required_blocks = cdiv(num_tokens, self.block_size) + num_new_blocks = num_required_blocks - len(req_blocks) + if num_new_blocks <= 0: + return [] + else: + new_blocks = self.block_pool.get_new_blocks( + num_new_blocks * self.num_kv_cache_groups) + req_blocks.extend(new_blocks) + return new_blocks + + def cache_blocks(self, request: Request, block_hashes: list[BlockHashType], + num_tokens: int) -> None: + """ + Cache the blocks for the request. + + Args: + request: The request. + block_hashes: The block hashes of the request. + num_tokens: The total number of tokens that need to be cached + (including tokens that are already cached). + """ + num_cached_blocks = self.num_cached_block[request.request_id] + num_full_blocks = num_tokens // self.block_size + + self.block_pool.cache_full_blocks( + request=request, + blocks=self.req_to_blocks[request.request_id], + block_hashes=block_hashes, + num_cached_blocks=num_cached_blocks, + num_full_blocks=num_full_blocks, + block_size=self.block_size, + hash_fn=self.caching_hash_fn, + ) + + self.num_cached_block[request.request_id] = num_full_blocks + + def free(self, request_id: str) -> None: + # Default to [] in case a request is freed (aborted) before alloc. + req_blocks = self.req_to_blocks.pop(request_id, []) + + # Free blocks in reverse order so that the tail blocks are + # freed first. + ordered_blocks = reversed(req_blocks) + + self.block_pool.free_blocks(ordered_blocks) + self.num_cached_block.pop(request_id, None) + + @abstractmethod + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + Get the number of common prefix blocks for a request. + + Args: + request_id: The request ID. + block_hashes: The block hashes of the request. + + Returns: + The number of common prefix blocks. + """ + + raise NotImplementedError + @abstractmethod def find_longest_cache_hit( self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: @@ -41,7 +193,8 @@ def find_longest_cache_hit( Get the longest cache hit prefix of the blocks. If no cache hit is found, return an empty list. if eagle is enabled, drop the last matched block to force recompute the last block to get the required hidden - states for eagle drafting head. + states for eagle drafting head. Need to be customized for each attention + type. Args: block_hashes: The block hashes of the request. @@ -55,24 +208,23 @@ def find_longest_cache_hit( raise NotImplementedError @abstractmethod - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: """ Remove the blocks that are no longer needed from `blocks`. The removed blocks should be replaced by null_block. Return the removed blocks in eviction order, where the first returned block should be evicted first. - Don't free the removed blocks in this function. + Don't free the removed blocks in this function. Need to be customized + for each attention type. Args: - blocks: The list of blocks to be updated. + request_id: The request ID. num_computed_tokens: The number of tokens that have been computed. - Returns: - The removed blocks in eviction order. """ raise NotImplementedError -class FullAttentionManager(SpecializedManager): +class FullAttentionManager(SingleTypeKVCacheManager): def find_longest_cache_hit( self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]: @@ -89,17 +241,28 @@ def find_longest_cache_hit( computed_blocks.pop() return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # No need to remove blocks for full attention. - return [] + pass + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + blocks = self.req_to_blocks[request_id] + num_common_blocks = 0 + for block in blocks: + if block.ref_cnt == num_running_requests: + num_common_blocks += 1 + else: + break + return num_common_blocks -class SlidingWindowManager(SpecializedManager): +class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, - use_eagle: bool): - super().__init__(kv_cache_spec, block_pool, use_eagle) + use_eagle: bool, **kwargs) -> None: + super().__init__(kv_cache_spec, block_pool, use_eagle, **kwargs) self.sliding_window = kv_cache_spec.sliding_window # The number of contiguous blocks needed for prefix cache hit. # -1 since the input token itself is also included in the window @@ -148,13 +311,13 @@ def find_longest_cache_hit( computed_blocks.pop() return computed_blocks - def remove_skipped_blocks(self, blocks: list[KVCacheBlock], - num_computed_tokens: int) -> list[KVCacheBlock]: + def remove_skipped_blocks(self, request_id: str, + num_computed_tokens: int) -> None: # Remove the blocks that are no longer be in the sliding window and # skipped during the attention computation. last_useful_token = num_computed_tokens - self.sliding_window + 1 last_useful_block = last_useful_token // self.block_size - + blocks = self.req_to_blocks[request_id] removed_blocks: list[KVCacheBlock] = [] for i in range(last_useful_block - 1, -1, -1): if blocks[i] == self._null_block: @@ -164,17 +327,27 @@ def remove_skipped_blocks(self, blocks: list[KVCacheBlock], break removed_blocks.append(blocks[i]) blocks[i] = self._null_block - return removed_blocks + self.block_pool.free_blocks(removed_blocks) + + def get_num_common_prefix_blocks(self, request_id: str, + num_running_requests: int) -> int: + """ + NOTE(Chen): The prefix blocks are null blocks for sliding window layers. + So it's not correct to count ref_cnt like FullAttentionManager. Return + 0 here for correctness. Need to support cascade attention + sliding + window in the future. + """ + return 0 -spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = { +spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, } -def get_specialized_manager(kv_cache_spec: KVCacheSpec, - **kwargs) -> SpecializedManager: +def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec, + **kwargs) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] manager = manager_class(kv_cache_spec, **kwargs) return manager diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 14ce820cc39e..00ceb7d3d0c4 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -120,8 +120,9 @@ def __init__( executor_class=executor_class, log_stats=self.log_stats, ) - for stat_logger in self.stat_loggers[0]: - stat_logger.log_engine_initialized() + if self.stat_loggers: + for stat_logger in self.stat_loggers[0]: + stat_logger.log_engine_initialized() self.output_handler: Optional[asyncio.Task] = None try: # Start output handler eagerly if we are in the asyncio eventloop. diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index e772615b7861..d9dd4957cff2 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -19,6 +19,7 @@ from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.executor.multiproc_worker_utils import _add_prefix from vllm.logger import init_logger +from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -56,6 +57,7 @@ def __init__(self, executor_fail_callback: Optional[Callable] = None): assert vllm_config.model_config.runner_type != "pooling" + self.vllm_config = vllm_config logger.info("Initializing a V1 LLM engine (v%s) with config: %s", VLLM_VERSION, vllm_config) @@ -191,6 +193,16 @@ def abort_requests(self, request_ids: list[str]): self.scheduler.finish_requests(request_ids, RequestStatus.FINISHED_ABORTED) + def execute_model(self, scheduler_output: SchedulerOutput): + try: + return self.model_executor.execute_model(scheduler_output) + except BaseException as err: + # NOTE: This method is exception-free + dump_engine_exception(self.vllm_config, scheduler_output, + self.scheduler.make_stats()) + # Re-raise exception + raise err + def step(self) -> EngineCoreOutputs: """Schedule, execute, and make output.""" @@ -202,9 +214,9 @@ def step(self) -> EngineCoreOutputs: scheduler_stats=self.scheduler.make_stats(), ) scheduler_output = self.scheduler.schedule() - output = self.model_executor.execute_model(scheduler_output) + model_output = self.execute_model(scheduler_output) engine_core_outputs = self.scheduler.update_from_output( - scheduler_output, output) # type: ignore + scheduler_output, model_output) # type: ignore return engine_core_outputs diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 0d5d92f72537..91a0a75a3081 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -442,9 +442,10 @@ def _wait_for_engine_startup(self): logger.info("Core engine process %d ready.", eng_id) identities.discard(eng_id) # Setup KV cache config with initialization state from - # engine core process. - self.vllm_config.cache_config.num_gpu_blocks = message_dict[ - 'num_gpu_blocks'] + # engine core process. Sum values from all engines in DP case. + num_gpu_blocks = self.vllm_config.cache_config.num_gpu_blocks or 0 + num_gpu_blocks += message_dict['num_gpu_blocks'] + self.vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks def _init_core_engines( self, diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index ff449901030c..74b226b45424 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -258,9 +258,10 @@ def shutdown(self): self.io_thread_pool.shutdown(wait=False, cancel_futures=True) self.io_thread_pool = None - for w in self.workers: - w.worker_response_mq = None - self._ensure_worker_termination([w.proc for w in self.workers]) + if workers := getattr(self, 'workers', None): + for w in workers: + w.worker_response_mq = None + self._ensure_worker_termination([w.proc for w in workers]) self.rpc_broadcast_mq = None diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 9109bdcf42f2..7455f1813cd7 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -128,9 +128,7 @@ def log(self): scheduler_stats.gpu_cache_usage * 100, self.prefix_caching_metrics.hit_rate * 100, ) - - if scheduler_stats.spec_decoding_stats is not None: - self.spec_decoding_logging.log(log_fn=log_fn) + self.spec_decoding_logging.log(log_fn=log_fn) def log_engine_initialized(self): logger.info( diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index e00ecde66af0..0dcf02113f5a 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -14,6 +14,7 @@ from msgspec import msgpack from vllm import envs +from vllm.logger import init_logger from vllm.multimodal.inputs import (BaseMultiModalField, MultiModalBatchedField, MultiModalFieldConfig, MultiModalFieldElem, @@ -21,6 +22,8 @@ MultiModalKwargsItem, MultiModalSharedField, NestedTensors) +logger = init_logger(__name__) + CUSTOM_TYPE_PICKLE = 1 CUSTOM_TYPE_CLOUDPICKLE = 2 CUSTOM_TYPE_RAW_VIEW = 3 @@ -37,6 +40,11 @@ bytestr = Union[bytes, bytearray, memoryview, zmq.Frame] +def _log_insecure_serialization_warning(): + logger.warning_once("Allowing insecure serialization using pickle due to " + "VLLM_ALLOW_INSECURE_SERIALIZATION=1") + + class MsgpackEncoder: """Encoder with custom torch tensor and numpy array serialization. @@ -47,9 +55,7 @@ class MsgpackEncoder: via dedicated messages. Note that this is a per-tensor limit. """ - def __init__(self, - size_threshold: Optional[int] = None, - allow_pickle: bool = True): + def __init__(self, size_threshold: Optional[int] = None): if size_threshold is None: size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) @@ -58,7 +64,8 @@ def __init__(self, # pass custom data to the hook otherwise. self.aux_buffers: Optional[list[bytestr]] = None self.size_threshold = size_threshold - self.allow_pickle = allow_pickle + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + _log_insecure_serialization_warning() def encode(self, obj: Any) -> Sequence[bytestr]: try: @@ -89,6 +96,12 @@ def enc_hook(self, obj: Any) -> Any: if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'): return self._encode_ndarray(obj) + if isinstance(obj, slice): + # We are assuming only int-based values will be used here. + return tuple( + int(v) if v is not None else None + for v in (obj.start, obj.stop, obj.step)) + if isinstance(obj, MultiModalKwargs): mm: MultiModalKwargs = obj if not mm.modalities: @@ -108,8 +121,10 @@ def enc_hook(self, obj: Any) -> Any: for itemlist in mm._items_by_modality.values() for item in itemlist] - if not self.allow_pickle: - raise TypeError(f"Object of type {type(obj)} is not serializable") + if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + raise TypeError(f"Object of type {type(obj)} is not serializable" + "Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow " + "fallback to pickle-based serialization.") if isinstance(obj, FunctionType): # `pickle` is generally faster than cloudpickle, but can have @@ -185,13 +200,14 @@ class MsgpackDecoder: not thread-safe when encoding tensors / numpy arrays. """ - def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True): + def __init__(self, t: Optional[Any] = None): args = () if t is None else (t, ) self.decoder = msgpack.Decoder(*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook) self.aux_buffers: Sequence[bytestr] = () - self.allow_pickle = allow_pickle + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: + _log_insecure_serialization_warning() def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): @@ -212,6 +228,8 @@ def dec_hook(self, t: type, obj: Any) -> Any: return self._decode_ndarray(obj) if issubclass(t, torch.Tensor): return self._decode_tensor(obj) + if t is slice: + return slice(*obj) if issubclass(t, MultiModalKwargs): if isinstance(obj, list): return MultiModalKwargs.from_items( @@ -253,6 +271,12 @@ def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: factory_meth_name, *field_args = v["field"] factory_meth = getattr(MultiModalFieldConfig, factory_meth_name) + + # Special case: decode the union "slices" field of + # MultiModalFlatField + if factory_meth_name == "flat": + field_args[0] = self._decode_nested_slices(field_args[0]) + v["field"] = factory_meth(None, *field_args).field elems.append(MultiModalFieldElem(**v)) decoded_items.append(MultiModalKwargsItem.from_elems(elems)) @@ -269,11 +293,17 @@ def _decode_nested_tensors(self, obj: Any) -> NestedTensors: return self._decode_tensor(obj) return [self._decode_nested_tensors(x) for x in obj] + def _decode_nested_slices(self, obj: Any) -> Any: + assert isinstance(obj, (list, tuple)) + if obj and not isinstance(obj[0], (list, tuple)): + return slice(*obj) + return [self._decode_nested_slices(x) for x in obj] + def ext_hook(self, code: int, data: memoryview) -> Any: if code == CUSTOM_TYPE_RAW_VIEW: return data - if self.allow_pickle: + if envs.VLLM_ALLOW_INSECURE_SERIALIZATION: if code == CUSTOM_TYPE_PICKLE: return pickle.loads(data) if code == CUSTOM_TYPE_CLOUDPICKLE: diff --git a/vllm/v1/spec_decode/metrics.py b/vllm/v1/spec_decode/metrics.py index 33ce98284e20..f71a59908ef3 100644 --- a/vllm/v1/spec_decode/metrics.py +++ b/vllm/v1/spec_decode/metrics.py @@ -67,13 +67,17 @@ def observe(self, spec_decoding_stats: SpecDecodingStats): spec_decoding_stats.num_accepted_tokens_per_pos) def log(self, log_fn=logger.info): + if not self.num_drafts: + return num_drafts = np.sum(self.num_drafts) num_draft_tokens = np.sum(self.num_draft_tokens) num_accepted_tokens = np.sum(self.num_accepted_tokens) draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens * 100 if num_draft_tokens > 0 else float("nan")) - mean_acceptance_length = (num_accepted_tokens / num_drafts) + + # Conventionally, mean acceptance length includes the bonus token + mean_acceptance_length = 1 + (num_accepted_tokens / num_drafts) pos_matrix = np.array(self.accepted_tokens_per_pos_lists) acceptance_rates = np.sum(pos_matrix, axis=0) / num_drafts @@ -103,10 +107,12 @@ class SpecDecodingProm: rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / rate(vllm:spec_decode_num_draft_tokens_total[$interval]) - The mean acceptance length can be calculated using: + The mean acceptance length (conventionally including bonus tokens) + can be calculated using: + 1 + ( rate(vllm:spec_decode_num_accepted_tokens_total[$interval]) / - rate(vllm:spec_decode_num_drafts[$interval]) + rate(vllm:spec_decode_num_drafts[$interval])) A per-position acceptance rate vector can be computed using diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e0c3d05c7976..bd8c87fd9efc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,6 +12,7 @@ from vllm.attention import AttentionType, get_attn_backend from vllm.attention.layer import Attention +from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.config import (CompilationLevel, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.kv_transfer import (get_kv_transfer_group, @@ -139,6 +140,16 @@ def __init__( raise NotImplementedError( "Non-Attention backend is not supported by V1 GPUModelRunner.") + if self.vllm_config.compilation_config.full_cuda_graph: + attn_backend_name = self.attn_backend.__name__ + flash_attn_version = get_flash_attn_version() + if attn_backend_name != "FlashAttentionBackend" or \ + flash_attn_version != 3: + raise ValueError( + f"full_cuda_graph is only supported with " + f"FA3. Current attention backend is {attn_backend_name}, " + f"FlashAttention version is {flash_attn_version}.") + self.attn_metadata_builder = self.attn_backend.get_builder_cls()( weakref.proxy(self)) self.cascade_attn_enabled = not self.model_config.disable_cascade_attn @@ -219,6 +230,16 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + self.query_start_loc = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + self.seq_lens = torch.zeros(self.max_num_reqs, + dtype=torch.int32, + device=self.device) + self.slot_mapping = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + # None in the first PP rank. The rest are set after load_model. self.intermediate_tensors: Optional[IntermediateTensors] = None @@ -271,7 +292,7 @@ def __init__( pin_memory=self.pin_memory) self.positions_np = self.positions_cpu.numpy() self.slot_mapping_cpu = torch.zeros(self.max_num_tokens, - dtype=torch.int32, + dtype=torch.int64, device="cpu", pin_memory=self.pin_memory) self.slot_mapping_np = self.slot_mapping_cpu.numpy() @@ -589,10 +610,22 @@ def _prepare_inputs( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) - query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to( - self.device, non_blocking=True) - seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device, - non_blocking=True) + self.query_start_loc[:num_reqs + 1].copy_( + self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True) + self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs], + non_blocking=True) + self.slot_mapping[:total_num_scheduled_tokens].copy_( + self.slot_mapping_cpu[:total_num_scheduled_tokens], + non_blocking=True) + + # Fill unused with -1. Needed for reshape_and_cache + self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) + self.seq_lens[num_reqs:].fill_(0) + self.query_start_loc[num_reqs + 1:].fill_(-1) + + query_start_loc = self.query_start_loc[:num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, seq_lens=seq_lens) @@ -1478,6 +1511,7 @@ def _get_prompt_logprobs_dict( def _dummy_run( self, num_tokens: int, + skip_attn: bool = True, ) -> torch.Tensor: # Set num_scheduled_tokens based on num_tokens and max_num_seqs @@ -1494,6 +1528,23 @@ def _dummy_run( num_scheduled_tokens = np.array(num_scheduled_tokens_list, dtype=np.int32) + if skip_attn: + attn_metadata = None + else: + query_start_loc = self.query_start_loc[:num_reqs + 1] + seq_lens = self.seq_lens[:num_reqs] + + common_attn_metadata = CommonAttentionMetadata( + query_start_loc=query_start_loc, seq_lens=seq_lens) + + attn_metadata = self.attn_metadata_builder.build( + num_reqs=num_tokens, + num_actual_tokens=num_tokens, + max_query_len=num_tokens, + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens): model = self.model @@ -1522,7 +1573,7 @@ def _dummy_run( for k, v in self.intermediate_tensors.items() }) - with set_forward_context(None, + with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_tokens): outputs = model( @@ -1708,11 +1759,12 @@ def capture_model(self) -> None: # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with graph_capture(device=self.device): + skip_attn = not self.vllm_config.compilation_config.full_cuda_graph for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(num_tokens) - self._dummy_run(num_tokens) + self._dummy_run(num_tokens, skip_attn=skip_attn) + self._dummy_run(num_tokens, skip_attn=skip_attn) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index f5626abb2a12..be059c30435c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -39,6 +39,7 @@ from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import sanity_check_mm_encoder_outputs @@ -90,7 +91,7 @@ # The dummy_run should be comprehensive, ensuring all potential input shapes and # branch predictions are included as subgraph inputs to facilitate # pre-compilation. -class TPUModelRunner: +class TPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -568,6 +569,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): self.device) seq_lens = self.seq_lens_cpu[:self.max_num_reqs].to(self.device) + if self.lora_config is not None: + # We need to respect padding when activating LoRA adapters + padded_num_scheduled_tokens_per_req = np.copy( + num_scheduled_tokens_per_req + ) # Copying to avoid accidental state corruption bugs + padded_num_scheduled_tokens_per_req[-1] += \ + padded_total_num_scheduled_tokens - total_num_scheduled_tokens + + self.set_active_loras(self.input_batch, + padded_num_scheduled_tokens_per_req) + attn_metadata = PallasMetadata( slot_mapping=slot_mapping, block_tables=block_tables, @@ -907,6 +919,11 @@ def load_model(self) -> None: "get_tensor_model_parallel_rank", return_value=xm_tp_rank): model = get_model(vllm_config=self.vllm_config) + if self.lora_config is not None: + model = self.load_lora_model(model, self.model_config, + self.scheduler_config, + self.lora_config, self.device) + # Sync all pending XLA execution during model initialization and weight # loading. xm.mark_step() @@ -970,7 +987,10 @@ def _dummy_run(self, num_tokens: int) -> None: for layer_name in layer_names } - with set_forward_context(per_layer_attn_metadata, self.vllm_config, 0): + with self.maybe_dummy_run_with_lora( + self.lora_config, + np.array([num_tokens], dtype=np.int32)), set_forward_context( + per_layer_attn_metadata, self.vllm_config, 0): out = self.model(input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds) diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index de676541effa..9eea26d85249 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -15,6 +15,7 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import SchedulerOutput @@ -82,6 +83,10 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 + if vllm_config.lora_config is not None: + raise NotImplementedError( + "The V1 TPU backend doesn't support LoRA serving") + def init_device(self): os.environ["PJRT_DEVICE"] = "TPU" # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D @@ -211,6 +216,9 @@ def profile(self, is_start: bool = True): else: xp.stop_trace() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + def load_model(self) -> None: self.model_runner.load_model() diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index e25864349e28..a343e2fedb23 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -14,7 +14,7 @@ import os import time from array import array -from enum import IntEnum +from enum import Enum, IntEnum from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type, TypeVar, Union) @@ -75,6 +75,12 @@ DUMMY_TOKEN_ID = -1 +class PhaseType(Enum): + PREFILL = 'prefill' + PREFIX_PREFILL = 'prefix_prefill' + DECODE = 'decode' + + def subtuple(obj: object, typename: str, to_copy: List[str], @@ -213,20 +219,40 @@ def _compile_region(self, model, name, module): def _set_attn_bias(self, attn_metadata, batch_size, seq_len, device, dtype): - prefill_metadata = attn_metadata - if prefill_metadata is None or self.prefill_use_fusedsdpa: + if (attn_metadata is None + or (self.prefill_use_fusedsdpa \ + and attn_metadata.block_list is None) + or not attn_metadata.is_prompt): return attn_metadata + prefill_metadata = attn_metadata + seq_lens_t = prefill_metadata.seq_lens_tensor + context_lens_t = prefill_metadata.context_lens_tensor + query_lens_t = seq_lens_t - context_lens_t + + block_list = attn_metadata.block_list + max_context_len = (block_list.size(-1) // + batch_size if block_list is not None else 0) + max_context_len = max_context_len * self.block_size + past_mask = torch.arange(0, + max_context_len, + dtype=torch.int32, + device=device) + past_mask = (past_mask.view(1, -1).expand(batch_size, -1).ge( + context_lens_t.view(-1, 1)).view(batch_size, 1, -1).expand( + batch_size, seq_len, -1).view(batch_size, 1, seq_len, -1)) + len_mask = (torch.arange(0, seq_len, device=device, dtype=torch.int32).view(1, seq_len).ge( - seq_lens_t.unsqueeze(-1)).view( + query_lens_t.unsqueeze(-1)).view( batch_size, 1, 1, seq_len)) causal_mask = torch.triu(torch.ones((batch_size, 1, seq_len, seq_len), device=device, dtype=torch.bool), diagonal=1) mask = causal_mask.logical_or(len_mask) + mask = torch.concat((past_mask, mask), dim=-1) attn_bias = (torch.zeros_like(mask, dtype=dtype).masked_fill_( mask, -math.inf)) attn_metadata = prefill_metadata._replace(attn_bias=attn_bias) @@ -517,6 +543,11 @@ def __init__( False, self.max_model_len) self.graphed_buckets: Set[Any] = set() self._set_gc_threshold() + if self.vllm_config.cache_config.enable_prefix_caching: + os.environ.setdefault("VLLM_CONTIGUOUS_PA", "False") + assert os.environ.get( + "VLLM_CONTIGUOUS_PA", + "").lower() != "true", "Contiguous PA doesn't support APC" self.use_contiguous_pa = envs.VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH # For multi-step scheduling @@ -702,6 +733,10 @@ def _prepare_prompt( computed_block_nums) > 0 and self.sliding_window is None: # Prefix is not supported with sliding_window context_len = len(computed_block_nums) * self.block_size + if context_len == seq_len \ + and self.vllm_config.cache_config.enable_prefix_caching: + # Fully cached prompt - compute only last token + context_len = context_len - 1 prompt_tokens = prompt_tokens[context_len:] prefix_block_tables.append(computed_block_nums) elif self.scheduler_config.chunked_prefill_enabled: @@ -779,12 +814,33 @@ def _prepare_prompt( if lora_id > 0: lora_requests.add(seq_group_metadata.lora_request) - lora_index_mapping += [lora_id] * (max_prompt_len - context_len) + lora_index_mapping += [lora_id] * max_prompt_len lora_prompt_mapping.extend( [lora_id] * - (max_prompt_len - context_len + (max_prompt_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if any(context_lens): + assert not self.scheduler_config.chunked_prefill_enabled + # prefix caching + + max_num_block = max(len(bt) for bt in prefix_block_tables) + prefix_block_list = list( + itertools.chain.from_iterable( + bt if len(bt) == max_num_block else bt + + ([_PAD_BLOCK_ID] * (max_num_block - len(bt))) + for bt in prefix_block_tables)) + + pad_len = len(prefix_block_list) + prefix_block_list = pad_list(prefix_block_list, pad_len, + _PAD_BLOCK_ID) + + prefix_block_list_tensor = torch.tensor(prefix_block_list, + dtype=torch.long, + device=self.device) + else: + prefix_block_list_tensor = None + input_tokens = make_tensor_with_pad(input_tokens, max_len=max_prompt_len, pad=0, @@ -807,11 +863,15 @@ def _prepare_prompt( dtype=torch.long, device=self.device) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.long, + device=self.device) + block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, True) attn_metadata = self.attn_backend.make_metadata( is_prompt=True, - block_list=None, + block_list=prefix_block_list_tensor, block_mapping=None, block_usage=None, block_indices=block_indices, @@ -819,6 +879,7 @@ def _prepare_prompt( block_groups=None, attn_bias=None, seq_lens_tensor=seq_lens_tensor, + context_lens_tensor=context_lens_tensor, num_prefills=real_num_seqs, num_prefill_tokens=sum_query_len, num_decode_tokens=0, @@ -987,6 +1048,7 @@ def _prepare_decode( block_groups=block_groups, attn_bias=None, seq_lens_tensor=None, + context_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, @@ -1091,7 +1153,7 @@ def prepare_input_tensors( # FIXME: We need to adjust selected_token_indices to accommodate # for padding max_len = input_tokens.size(1) - paddings = [max_len - s for s in seq_lens] + paddings = [max_len - q for q in query_lens] paddings = [0] + paddings[:-1] paddings = list(itertools.accumulate(paddings)) paddings_prompt_logprobs = [] @@ -1187,9 +1249,17 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'attn_bias', 'seq_lens_tensor', 'block_list', 'block_mapping', - 'block_usage', 'slot_mapping', 'is_prompt', 'block_indices', - 'block_offsets', 'block_groups' + 'attn_bias', + 'seq_lens_tensor', + 'context_lens_tensor', + 'block_list', + 'block_mapping', + 'block_usage', + 'slot_mapping', + 'is_prompt', + 'block_indices', + 'block_offsets', + 'block_groups', ]) return attention_metadata @@ -1733,14 +1803,44 @@ def finish_measurements(self): from neural_compressor.torch.quantization import finalize_calibration finalize_calibration(self.model.model) - def _check_config(self, batch_size, seq_len, is_prompt, warmup_mode): - cfg = (batch_size, seq_len, is_prompt) + def _num_blocks(self, attn_metadata): + if attn_metadata.block_list is None: + return 0 + return attn_metadata.block_list.numel() + + def _phase(self, attn_metadata): + phase_type: PhaseType + is_prompt = attn_metadata.is_prompt + is_prefix_prefill = is_prompt and attn_metadata.block_list is not None + if is_prompt and is_prefix_prefill: + phase_type = PhaseType.PREFIX_PREFILL + elif is_prompt and not is_prefix_prefill: + phase_type = PhaseType.PREFILL + elif not is_prompt: + phase_type = PhaseType.DECODE + else: + raise ValueError("Unrecognized pass type, likely due to malformed " + "attention metadata") + return phase_type + + def _check_config(self, batch_size, seq_len, attn_metadata, warmup_mode): + is_prefix_caching = self.vllm_config.cache_config.enable_prefix_caching + cfg: Optional[tuple] = None + assert cfg is None, "Configs changed between 2D and 3D" + if is_prefix_caching: + phase = self._phase(attn_metadata) + num_blocks = self._num_blocks(attn_metadata) + cfg = (batch_size, seq_len, num_blocks, phase) + else: + phase = 'prompt' if attn_metadata.is_prompt else 'decode' + cfg = (batch_size, seq_len, phase) seen = cfg in self.seen_configs self.seen_configs.add(cfg) if not seen and not warmup_mode: - phase = 'prompt' if is_prompt else 'decode' - logger.warning("Configuration: (%s, %s, %s) was not warmed-up!", - phase, batch_size, seq_len) + logger.warning("Configuration: %s was not warmed-up!", + (phase.value, batch_size, seq_len, + num_blocks) if is_prefix_caching else + (phase, batch_size, seq_len)) def create_lora_mask(self, input_tokens: torch.Tensor, lora_ids: List[int], is_prompt: bool): @@ -1912,7 +2012,7 @@ def execute_model( batch_size = input_tokens.size(0) seq_len = self._seq_len(attn_metadata) use_graphs = self._use_graphs(batch_size, seq_len, is_prompt) - self._check_config(batch_size, seq_len, is_prompt, warmup_mode) + self._check_config(batch_size, seq_len, attn_metadata, warmup_mode) lora_mask: torch.Tensor = None lora_logits_mask: torch.Tensor = None diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index c80b69e78dc0..e97adf757cc1 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -348,7 +348,7 @@ def _convert_to_neuron_sampling_params( if temperature == 0.0: # Enable greedy sampling on zero temperature return (1, 1.0, 1.0) - if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: + if top_k < 1 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: top_k = self._MAX_NEURON_SAMPLING_TOP_K return (top_k, top_p, temperature) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 53541a2579ed..e0cca9072745 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -525,7 +525,7 @@ def _prepare_sample( "Top-p sampling is currently disabled for the TPU backend " "due to performance issues.") p.append(sampling_params.top_p) - if sampling_params.top_k != -1: + if sampling_params.top_k > 0: raise NotImplementedError( "Top-k sampling is currently disabled for the TPU backend " "due to performance issues.") diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index bbcc4d59ae1c..4bb9bea022f9 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -54,6 +54,10 @@ def __init__( if self.model_config.seed is None: self.model_config.seed = 0 + if vllm_config.lora_config is not None: + raise NotImplementedError( + "The V0 TPU backend doesn't support LoRA serving") + def init_device(self) -> None: os.environ["PJRT_DEVICE"] = "TPU" torch.set_grad_enabled(False)