diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml new file mode 100644 index 000000000000..78347f63fa79 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 +model_name: "neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.356 + - name: "exact_match,flexible-extract" + value: 0.358 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 64a0f428587a..6057229ac50f 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,6 +1,6 @@ Meta-Llama-3-8B-Instruct.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml -Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 98592ea7948f..3b7fa0f2d94b 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -3,7 +3,7 @@ steps: agents: queue: cpu_queue commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" # rename the files to change linux -> manylinux1 @@ -22,7 +22,7 @@ steps: agents: queue: cpu_queue commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" # rename the files to change linux -> manylinux1 diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index c2818c38965e..c331a9c49c0d 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -32,10 +32,10 @@ docker exec cpu-test bash -c " --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # Run compressed-tensor test -# docker exec cpu-test bash -c " -# pytest -s -v \ -# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ -# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token" +docker exec cpu-test bash -c " + pytest -s -v \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" # Run AWQ test docker exec cpu-test bash -c " diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4c2fe41c739b..8c98aa36ac0f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -77,8 +77,8 @@ steps: - vllm/ - tests/basic_correctness/test_chunked_prefill commands: - - VLLM_ATTENTION_BACKEND=XFORMERS VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min mirror_hardwares: [amd] @@ -88,11 +88,7 @@ steps: - vllm/distributed - tests/core commands: - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core/test_scheduler.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/test_chunked_prefill_scheduler.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s core core/block/e2e/test_correctness_sliding_window.py - - pytest -v -s core --ignore=core/block/e2e/test_correctness.py --ignore=core/test_scheduler.py --ignore=core/test_chunked_prefill_scheduler.py --ignore=core/block/e2e/test_correctness.py --ignore=core/block/e2e/test_correctness_sliding_window.py + - pytest -v -s core - label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" @@ -184,6 +180,7 @@ steps: - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py + - python3 offline_profile.py --model facebook/opt-125m - label: Prefix Caching Test # 9min #mirror_hardwares: [amd] @@ -191,8 +188,7 @@ steps: - vllm/ - tests/prefix_caching commands: - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s prefix_caching/test_prefix_caching.py - - pytest -v -s prefix_caching --ignore=prefix_caching/test_prefix_caching.py + - pytest -v -s prefix_caching - label: Samplers Test # 36min source_file_dependencies: @@ -216,8 +212,7 @@ steps: - tests/spec_decode commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - - VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest -v -s spec_decode/e2e/test_compatibility.py - - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_compatibility.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - label: LoRA Test %N # 15min each mirror_hardwares: [amd] @@ -235,14 +230,12 @@ steps: commands: - pytest -v -s compile/test_basic_correctness.py -# TODO: re-write in comparison tests, and fix symbolic shape -# for quantization ops. -# - label: "PyTorch Fullgraph Test" # 18min -# source_file_dependencies: -# - vllm/ -# - tests/compile -# commands: -# - pytest -v -s compile/test_full_graph.py +- label: "PyTorch Fullgraph Test" # 18min + source_file_dependencies: + - vllm/ + - tests/compile + commands: + - pytest -v -s compile/test_full_graph.py - label: Kernels Test %N # 1h each mirror_hardwares: [amd] @@ -317,13 +310,22 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/*.py --ignore=models/test_oot_registration.py -- label: Decoder-only Language Models Test # 1h36min +- label: Decoder-only Language Models Test (Standard) # 35min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/decoder_only/language commands: - - pytest -v -s models/decoder_only/language + - pytest -v -s models/decoder_only/language/test_models.py + - pytest -v -s models/decoder_only/language/test_big_models.py + +- label: Decoder-only Language Models Test (Extended) # 1h20min + nightly: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/language + commands: + - pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py --ignore=models/decoder_only/language/test_big_models.py - label: Decoder-only Multi-Modal Models Test # 1h31min #mirror_hardwares: [amd] @@ -340,10 +342,12 @@ steps: source_file_dependencies: - vllm/ - tests/models/embedding/language + - tests/models/embedding/vision_language - tests/models/encoder_decoder/language - tests/models/encoder_decoder/vision_language commands: - pytest -v -s models/embedding/language + - pytest -v -s models/embedding/vision_language - pytest -v -s models/encoder_decoder/language - pytest -v -s models/encoder_decoder/vision_language @@ -402,7 +406,7 @@ steps: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - - TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus # Avoid importing model tests that cause CUDA reinitialization error - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus diff --git a/.dockerignore b/.dockerignore index 575f087f3ef6..3863656915d0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,3 @@ -/.github/ /.venv /build dist diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml index 2a0e3239f58d..b80749aaa8fe 100644 --- a/.github/workflows/actionlint.yml +++ b/.github/workflows/actionlint.yml @@ -34,4 +34,5 @@ jobs: - name: "Run actionlint" run: | + echo "::add-matcher::.github/workflows/matchers/actionlint.json" tools/actionlint.sh -color diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml new file mode 100644 index 000000000000..c9d6d4259df9 --- /dev/null +++ b/.github/workflows/add_label_automerge.yml @@ -0,0 +1,21 @@ +name: Add label on auto-merge enabled +on: + pull_request_target: + types: + - auto_merge_enabled +jobs: + add-label-on-auto-merge: + runs-on: ubuntu-latest + steps: + - name: Add label + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + github.rest.issues.addLabels({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + labels: ['ready'] + }) + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index 064af291009f..68d60d7365ed 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -17,9 +17,9 @@ jobs: matrix: python-version: ["3.11"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -38,4 +38,4 @@ jobs: ) find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ - | xargs clang-format --dry-run --Werror \ No newline at end of file + | xargs clang-format --dry-run --Werror diff --git a/.github/workflows/matchers/mypy.json b/.github/workflows/matchers/mypy.json new file mode 100644 index 000000000000..f048fce52894 --- /dev/null +++ b/.github/workflows/matchers/mypy.json @@ -0,0 +1,16 @@ +{ + "problemMatcher": [ + { + "owner": "mypy", + "pattern": [ + { + "regexp": "^(.+):(\\d+):\\s(error|warning):\\s(.+)$", + "file": 1, + "line": 2, + "severity": 3, + "message": 4 + } + ] + } + ] +} diff --git a/.github/workflows/matchers/ruff.json b/.github/workflows/matchers/ruff.json new file mode 100644 index 000000000000..f6d4479ee199 --- /dev/null +++ b/.github/workflows/matchers/ruff.json @@ -0,0 +1,17 @@ +{ + "problemMatcher": [ + { + "owner": "ruff", + "pattern": [ + { + "regexp": "^(.+?):(\\d+):(\\d+): (\\w+): (.+)$", + "file": 1, + "line": 2, + "column": 3, + "code": 4, + "message": 5 + } + ] + } + ] + } diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 22e3564779ad..5f1e5f8eeaf7 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -17,9 +17,9 @@ jobs: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -32,4 +32,5 @@ jobs: pip install types-setuptools - name: Mypy run: | - tools/mypy.sh + echo "::add-matcher::.github/workflows/matchers/mypy.json" + tools/mypy.sh 1 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 96549b3f9918..f959a1cacf86 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -21,7 +21,7 @@ jobs: upload_url: ${{ steps.create_release.outputs.upload_url }} steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Extract branch info shell: bash @@ -30,7 +30,7 @@ jobs: - name: Create Release id: create_release - uses: "actions/github-script@v7" + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 env: RELEASE_TAG: ${{ env.release_tag }} with: @@ -54,10 +54,10 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Setup ccache - uses: hendrikmuhs/ccache-action@v1.2 + uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 with: create-symlink: true key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} @@ -68,7 +68,7 @@ jobs: bash -x .github/workflows/scripts/env.sh - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} @@ -92,7 +92,7 @@ jobs: echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - name: Upload Release Asset - uses: actions/upload-release-asset@v1 + uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml new file mode 100644 index 000000000000..df62539c0b3d --- /dev/null +++ b/.github/workflows/reminder_comment.yml @@ -0,0 +1,21 @@ +name: PR Reminder Comment Bot +on: + pull_request_target: + types: [opened] + +jobs: + pr_reminder: + runs-on: ubuntu-latest + steps: + - name: Remind to run full CI on PR + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: context.issue.number, + body: 'šŸ‘‹ Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\nšŸš€' + }) + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index be73fb85ed1f..9cc8a9e91447 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -17,9 +17,9 @@ jobs: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -28,7 +28,8 @@ jobs: pip install -r requirements-lint.txt - name: Analysing the code with ruff run: | - ruff check . + echo "::add-matcher::.github/workflows/matchers/ruff.json" + ruff check --output-format github . - name: Spelling check with codespell run: | codespell --toml pyproject.toml diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 9e0a698990b3..122e4e101e20 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -eux python_executable=python$1 cuda_home=/usr/local/cuda-$2 @@ -15,5 +16,8 @@ export MAX_JOBS=1 # Make sure release wheels are built for the following architectures export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real" + +bash tools/check_repo.sh + # Build $python_executable setup.py bdist_wheel --dist-dir=dist diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 000000000000..2418c61bdcf6 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,51 @@ +name: 'Close inactive issues and PRs' + +on: + schedule: + # Daily at 1:30 AM UTC + - cron: '30 1 * * *' + +jobs: + close-issues-and-pull-requests: + permissions: + issues: write + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/stale@28ca1036281a5e5922ead5184a1bbf96e5fc984e # v9.0.0 + with: + # Increasing this value ensures that changes to this workflow + # propagate to all issues and PRs in days rather than months + operations-per-run: 1000 + + exempt-draft-pr: true + exempt-issue-labels: 'keep-open' + exempt-pr-labels: 'keep-open' + + labels-to-add-when-unstale: 'unstale' + labels-to-remove-when-stale: 'unstale' + + days-before-issue-stale: 90 + days-before-issue-close: 30 + stale-issue-label: 'stale' + stale-issue-message: > + This issue has been automatically marked as stale because it has not + had any activity within 90 days. It will be automatically closed if no + further activity occurs within 30 days. Leave a comment if + you feel this issue should remain open. Thank you! + close-issue-message: > + This issue has been automatically closed due to inactivity. Please + feel free to reopen if you feel it is still relevant. Thank you! + + days-before-pr-stale: 90 + days-before-pr-close: 30 + stale-pr-label: 'stale' + stale-pr-message: > + This pull request has been automatically marked as stale because it + has not had any activity within 90 days. It will be automatically + closed if no further activity occurs within 30 days. Leave a comment + if you feel this pull request should remain open. Thank you! + close-pr-message: > + This pull request has been automatically closed due to inactivity. + Please feel free to reopen if you intend to continue working on it. + Thank you! diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index eb728ae04dfc..9f06b35c19e3 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -16,9 +16,9 @@ jobs: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # v5.2.0 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a424ad7b110..d1956f3d409b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -83,24 +83,6 @@ endif() # find_package(Torch REQUIRED) -# -message(STATUS "Enabling core extension.") - -# Define _core_C extension -# built for (almost) every target platform, (excludes TPU and Neuron) - -set(VLLM_EXT_SRC - "csrc/core/torch_bindings.cpp") - -define_gpu_extension_target( - _core_C - DESTINATION vllm - LANGUAGE CXX - SOURCES ${VLLM_EXT_SRC} - COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - USE_SABI 3 - WITH_SOABI) - # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -187,12 +169,12 @@ endif() # # Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. -# Configure it to place files in vllm/.deps, in order to play nicely with sccache. +# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache. +# Each dependency that produces build artifacts should override its BINARY_DIR to avoid +# conflicts between build types. It should instead be set to ${CMAKE_BINARY_DIR}/. # include(FetchContent) -get_filename_component(PROJECT_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" ABSOLUTE) -file(MAKE_DIRECTORY "${FETCHCONTENT_BASE_DIR}") -set(FETCHCONTENT_BASE_DIR "${PROJECT_ROOT_DIR}/.deps") +file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") # @@ -270,7 +252,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") else() message(STATUS "Not building Marlin kernels as no compatible archs found" - "in CUDA target architectures") + " in CUDA target architectures") endif() # @@ -286,10 +268,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1") message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") else() - # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't - # build any 3x kernels - set(SCALED_MM_3X_ARCHS) - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " "not >= 12.0, we recommend upgrading to CUDA 12.0 or " @@ -299,13 +277,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Not building scaled_mm_c3x as no compatible archs found " "in CUDA target architectures") endif() + + # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't + # build any 3x kernels + set(SCALED_MM_3X_ARCHS) endif() # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.6;8.9;9.0;9.0a" "${CUDA_ARCHS}") + "7.5;8.0;8.6;8.9;9.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) @@ -450,7 +432,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") else() message(STATUS "Not building Marlin MOE kernels as no compatible archs found" - "in CUDA target architectures") + " in CUDA target architectures") endif() endif() @@ -527,6 +509,8 @@ else() GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd GIT_PROGRESS TRUE + # Don't share the vllm-flash-attn build between build types + BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn ) endif() diff --git a/Dockerfile b/Dockerfile index d527868bc4c2..0a562253c537 100644 --- a/Dockerfile +++ b/Dockerfile @@ -70,8 +70,10 @@ COPY requirements-build.txt requirements-build.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-build.txt -# files and directories related to build wheels COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # max jobs used by Ninja to build extensions ARG max_jobs=2 diff --git a/Dockerfile.cpu b/Dockerfile.cpu index b9134d4ae41c..f1a21d6bd13f 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -33,19 +33,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip install --upgrade pip && \ pip install -r requirements-build.txt -# install oneDNN -RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git - -RUN --mount=type=cache,target=/root/.cache/ccache \ - cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ - -DONEDNN_BUILD_DOC=OFF \ - -DONEDNN_BUILD_EXAMPLES=OFF \ - -DONEDNN_BUILD_TESTS=OFF \ - -DONEDNN_BUILD_GRAPH=OFF \ - -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ - -DONEDNN_ENABLE_PRIMITIVE=MATMUL && \ - cmake --build ./oneDNN/build --target install --config Release - FROM cpu-test-1 AS build WORKDIR /workspace/vllm @@ -55,7 +42,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \ pip install -v -r requirements-cpu.txt -COPY ./ ./ +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... ARG VLLM_CPU_DISABLE_AVX512 diff --git a/Dockerfile.neuron b/Dockerfile.neuron index adae6db87ba8..3d9d8e7da487 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -17,7 +17,7 @@ RUN apt-get update && \ # When launching the container, mount the code directory to /app ARG APP_MOUNT=/app VOLUME [ ${APP_MOUNT} ] -WORKDIR ${APP_MOUNT} +WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas @@ -25,17 +25,17 @@ RUN python3 -m pip install sentencepiece transformers==4.36.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -COPY . /app/vllm +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi -RUN cd /app/vllm \ - && python3 -m pip install -U \ +RUN python3 -m pip install -U \ cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron RUN --mount=type=bind,source=.git,target=.git \ - cd /app/vllm \ - && pip install --no-build-isolation -v -e . \ - && cd .. + pip install --no-build-isolation -v -e . \ CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino index d65bfa08ccd9..a05ff452cd36 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -10,13 +10,16 @@ RUN apt-get update -y && \ WORKDIR /workspace COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # install build requirements -RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/requirements-build.txt # build vLLM with OpenVINO backend -RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace -COPY examples/ /workspace/vllm/examples -COPY benchmarks/ /workspace/vllm/benchmarks +COPY examples/ /workspace/examples +COPY benchmarks/ /workspace/benchmarks CMD ["/bin/bash"] diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 1f374b01b9bc..cd5fcf481f07 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -14,6 +14,9 @@ RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p COPY ./ /workspace/vllm WORKDIR /workspace/vllm +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi # These packages will be in rocketce eventually RUN --mount=type=cache,target=/root/.cache/pip \ @@ -30,4 +33,4 @@ WORKDIR /workspace/ RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks -ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] +ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 496e6bed7c02..d35889f053e2 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -117,6 +117,9 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ FROM base AS final # Import the vLLM development directory from the build context COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # Package upgrades for useful functionality or to avoid dependency issues RUN --mount=type=cache,target=/root/.cache/pip \ diff --git a/Dockerfile.tpu b/Dockerfile.tpu index d8f1a42c4517..bdfab3f61910 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -2,7 +2,7 @@ ARG NIGHTLY_DATE="20240828" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE -WORKDIR /workspace +WORKDIR /workspace/vllm # Install some basic utilities RUN apt-get update && apt-get install -y \ @@ -16,14 +16,17 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html # Build vLLM. -COPY . /workspace/vllm +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi + ENV VLLM_TARGET_DEVICE="tpu" RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ - cd /workspace/vllm && \ python3 -m pip install \ cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ -r requirements-tpu.txt -RUN cd /workspace/vllm && python3 setup.py develop +RUN python3 setup.py develop CMD ["/bin/bash"] diff --git a/Dockerfile.ubi b/Dockerfile.ubi index d1a29ac7ed61..51bea5e54e25 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -207,7 +207,7 @@ USER root RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \ - uv pip install $(echo dist/*.whl)'[tensorizer]' vllm-tgis-adapter==0.5.3 + uv pip install $(echo dist/*.whl)'[tensorizer]' git+https://github.com/opendatahub-io/vllm-tgis-adapter.git@ibm-20241024-adapter ENV GRPC_PORT=8033 \ PORT=8000 \ diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 83db341556ea..0ecb46df6256 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -33,7 +33,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \ -r requirements-xpu.txt -COPY ./ /workspace/vllm +COPY . . +ARG GIT_REPO_CHECK +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi ENV VLLM_TARGET_DEVICE=xpu diff --git a/README.md b/README.md index 72c3273edc61..0836d872358f 100644 --- a/README.md +++ b/README.md @@ -127,5 +127,6 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs * For technical questions and feature requests, please use Github issues or discussions. * For discussing with fellow users, please use Discord. +* For coordinating contributions and development, please use Slack. * For security disclosures, please use Github's security advisory feature. * For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 79a48b2a1a84..0a14aedd5feb 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,5 +1,6 @@ """Benchmark the latency of processing a single batch of requests.""" import argparse +import dataclasses import json import time from pathlib import Path @@ -10,44 +11,19 @@ from tqdm import tqdm from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs +from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser def main(args: argparse.Namespace): print(args) + engine_args = EngineArgs.from_cli_args(args) + # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM( - model=args.model, - speculative_model=args.speculative_model, - num_speculative_tokens=args.num_speculative_tokens, - speculative_draft_tensor_parallel_size=\ - args.speculative_draft_tensor_parallel_size, - tokenizer=args.tokenizer, - quantization=args.quantization, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - max_model_len=args.max_model_len, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - quantization_param_path=args.quantization_param_path, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - use_v2_block_manager=args.use_v2_block_manager, - enable_chunked_prefill=args.enable_chunked_prefill, - download_dir=args.download_dir, - block_size=args.block_size, - gpu_memory_utilization=args.gpu_memory_utilization, - load_format=args.load_format, - distributed_executor_backend=args.distributed_executor_backend, - otlp_traces_endpoint=args.otlp_traces_endpoint, - enable_prefix_caching=args.enable_prefix_caching, - ) + llm = LLM(**dataclasses.asdict(engine_args)) sampling_params = SamplingParams( n=args.n, @@ -126,19 +102,6 @@ def run_to_completion(profile_dir: Optional[str] = None): parser = FlexibleArgumentParser( description='Benchmark the latency of processing a single batch of ' 'requests till completion.') - parser.add_argument('--model', type=str, default='facebook/opt-125m') - parser.add_argument('--speculative-model', type=str, default=None) - parser.add_argument('--num-speculative-tokens', type=int, default=None) - parser.add_argument('--speculative-draft-tensor-parallel-size', - '-spec-draft-tp', - type=int, - default=None) - parser.add_argument('--tokenizer', type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--batch-size', type=int, default=8) @@ -155,45 +118,6 @@ def run_to_completion(profile_dir: Optional[str] = None): type=int, default=30, help='Number of iterations to run.') - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--enforce-eager', - action='store_true', - help='enforce eager mode and disable CUDA graph') - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') parser.add_argument( '--profile', action='store_true', @@ -204,81 +128,12 @@ def run_to_completion(profile_dir: Optional[str] = None): default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) - parser.add_argument("--device", - type=str, - default="auto", - choices=DEVICE_OPTIONS, - help='device type for vLLM execution') - parser.add_argument('--block-size', - type=int, - default=16, - help='block size of key/value cache') - parser.add_argument( - '--enable-chunked-prefill', - action='store_true', - help='If True, the prefill requests can be chunked based on the ' - 'max_num_batched_tokens') - parser.add_argument("--enable-prefix-caching", - action='store_true', - help="Enable automatic prefix caching") - parser.add_argument('--use-v2-block-manager', - action='store_true', - default=EngineArgs.use_v2_block_manager) - parser.add_argument( - "--ray-workers-use-nsight", - action='store_true', - help="If specified, use nsight to profile ray workers", - ) - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') parser.add_argument( '--output-json', type=str, default=None, help='Path to save the latency results in JSON format.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], - default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--otlp-traces-endpoint', - type=str, - default=None, - help='Target URL to which OpenTelemetry traces will be sent.') + + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index f14092d34734..1aac029992db 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -25,6 +25,7 @@ --input-length-range 128:256 """ +import dataclasses import json import random import time @@ -130,13 +131,9 @@ def main(args): filtered_datasets = [(PROMPT, prompt_len, args.output_len) ] * args.num_prompts - llm = LLM(model=args.model, - tokenizer_mode='auto', - trust_remote_code=True, - enforce_eager=True, - use_v2_block_manager=args.use_v2_block_manager, - tensor_parallel_size=args.tensor_parallel_size, - enable_prefix_caching=args.enable_prefix_caching) + engine_args = EngineArgs.from_cli_args(args) + + llm = LLM(**dataclasses.asdict(engine_args)) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) @@ -164,22 +161,11 @@ def main(args): parser = FlexibleArgumentParser( description= 'Benchmark the performance with or without automatic prefix caching.') - parser.add_argument('--model', - type=str, - default='baichuan-inc/Baichuan2-13B-Chat') parser.add_argument("--dataset-path", type=str, default=None, help="Path to the dataset.") - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - default=EngineArgs.use_v2_block_manager, - help='Use BlockSpaceMangerV2') parser.add_argument('--num-prompts', type=int, default=1, @@ -196,9 +182,7 @@ def main(args): default='128:256', help='Range of input lengths for sampling prompts,' 'specified as "min:max" (e.g., "128:256").') - parser.add_argument("--seed", - type=int, - default=0, - help='Random seed for reproducibility') + + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 8843e3a927a0..e0c9e6a6db50 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -1,5 +1,6 @@ """Benchmark offline prioritization.""" import argparse +import dataclasses import json import random import time @@ -7,7 +8,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser def sample_requests( @@ -62,46 +64,11 @@ def sample_requests( def run_vllm( requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, n: int, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - gpu_memory_utilization: float = 0.9, - download_dir: Optional[str] = None, + engine_args: EngineArgs, ) -> float: from vllm import LLM, SamplingParams - llm = LLM( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - disable_log_stats=False, - ) + llm = LLM(**dataclasses.asdict(engine_args)) # Add the requests to the engine. prompts = [] @@ -142,16 +109,8 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.model, args.tokenizer, - args.quantization, args.tensor_parallel_size, - args.seed, args.n, args.trust_remote_code, - args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, - args.enable_chunked_prefill, - args.max_num_batched_tokens, - args.gpu_memory_utilization, args.download_dir) + elapsed_time = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args)) else: raise ValueError(f"Unknown backend: {args.backend}") total_num_tokens = sum(prompt_len + output_len @@ -173,7 +132,7 @@ def main(args: argparse.Namespace): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument("--backend", type=str, choices=["vllm", "hf", "mii"], @@ -191,13 +150,6 @@ def main(args: argparse.Namespace): default=None, help="Output length for each request. Overrides the " "output length from the dataset.") - parser.add_argument("--model", type=str, default="facebook/opt-125m") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, @@ -206,81 +158,13 @@ def main(args: argparse.Namespace): type=int, default=200, help="Number of prompts to process.") - parser.add_argument("--seed", type=int, default=0) - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument("--enforce-eager", - action="store_true", - help="enforce eager execution") - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') - parser.add_argument( - "--device", - type=str, - default="cuda", - choices=["cuda", "cpu"], - help='device type for vLLM execution, supporting CUDA and CPU.') - parser.add_argument( - "--enable-prefix-caching", - action='store_true', - help="enable automatic prefix caching for vLLM backend.") - parser.add_argument("--enable-chunked-prefill", - action='store_true', - help="enable chunked prefill for vLLM backend.") - parser.add_argument('--max-num-batched-tokens', - type=int, - default=None, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') parser.add_argument( '--output-json', type=str, default=None, help='Path to save the throughput results in JSON format.') + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 04999518b713..0d205014b15b 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -53,6 +53,8 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser +MILLISECONDS_TO_SECONDS_CONVERSION = 1000 + @dataclass class BenchmarkMetrics: @@ -60,6 +62,7 @@ class BenchmarkMetrics: total_input: int total_output: int request_throughput: float + request_goodput: float output_throughput: float total_token_throughput: float mean_ttft_ms: float @@ -202,6 +205,7 @@ def sample_hf_requests( dataset_split: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, + random_seed: int, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: dataset = load_dataset(dataset_path, @@ -210,8 +214,8 @@ def sample_hf_requests( streaming=True) assert "conversations" in dataset.features, ( "HF Dataset must have 'conversations' column.") - filtered_dataset = dataset.shuffle().filter( - lambda x: len(x["conversations"]) >= 2) + filter_func = lambda x: len(x["conversations"]) >= 2 + filtered_dataset = dataset.shuffle(seed=random_seed).filter(filter_func) sampled_requests: List[Tuple[str, int, int, Dict[str, Collection[str]]]] = [] for data in filtered_dataset: @@ -315,12 +319,15 @@ def calculate_metrics( tokenizer: PreTrainedTokenizerBase, selected_percentile_metrics: List[str], selected_percentiles: List[float], + gootput_config_dict: Dict[str, float], ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 completed = 0 + good_completed = 0 itls: List[float] = [] tpots: List[float] = [] + all_tpots: List[float] = [] ttfts: List[float] = [] e2els: List[float] = [] for i in range(len(outputs)): @@ -334,9 +341,13 @@ def calculate_metrics( add_special_tokens=False).input_ids) actual_output_lens.append(output_len) total_input += input_requests[i][1] + tpot = 0 if output_len > 1: - tpots.append( - (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) + tpot = (outputs[i].latency - outputs[i].ttft) / (output_len - + 1) + tpots.append(tpot) + # Note: if output_len <= 1, we regard tpot as 0 for goodput + all_tpots.append(tpot) itls += outputs[i].itl ttfts.append(outputs[i].ttft) e2els.append(outputs[i].latency) @@ -344,6 +355,28 @@ def calculate_metrics( else: actual_output_lens.append(0) + if gootput_config_dict: + valid_metrics = [] + slo_values = [] + + if "ttft" in gootput_config_dict: + valid_metrics.append(ttfts) + slo_values.append(gootput_config_dict["ttft"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "tpot" in gootput_config_dict: + valid_metrics.append(all_tpots) + slo_values.append(gootput_config_dict["tpot"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + if "e2el" in gootput_config_dict: + valid_metrics.append(e2els) + slo_values.append(gootput_config_dict["e2el"] / + MILLISECONDS_TO_SECONDS_CONVERSION) + + for req_metric in zip(*valid_metrics): + is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) + if is_good_req: + good_completed += 1 + if completed == 0: warnings.warn( "All requests failed. This is likely due to a misconfiguration " @@ -354,6 +387,7 @@ def calculate_metrics( total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, + request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * @@ -397,6 +431,8 @@ async def benchmark( selected_percentile_metrics: List[str], selected_percentiles: List[str], ignore_eos: bool, + gootput_config_dict: Dict[str, float], + max_concurrency: Optional[int], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -431,42 +467,56 @@ async def benchmark( if profile: print("Starting profiler...") - profile_input = RequestFuncInput( - model=model_id, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - best_of=best_of, - multi_modal_content=test_mm_content, - ) + profile_input = RequestFuncInput(model=model_id, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: print("Profiler started") print(f"Traffic request rate: {request_rate}") + print(f"Maximum request concurrency: {max_concurrency}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + # This can be used once the minimum Python version is 3.10 or higher, + # and it will simplify the code in limited_request_func. + # semaphore = (asyncio.Semaphore(max_concurrency) + # if max_concurrency else contextlib.nullcontext()) + semaphore = (asyncio.Semaphore(max_concurrency) + if max_concurrency else None) + + async def limited_request_func(request_func_input, pbar): + if semaphore is None: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + async with semaphore: + return await request_func(request_func_input=request_func_input, + pbar=pbar) + benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): prompt, prompt_len, output_len, mm_content = request - request_func_input = RequestFuncInput( - model=model_id, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - best_of=best_of, - multi_modal_content=mm_content, - ) + request_func_input = RequestFuncInput(model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=mm_content, + ignore_eos=ignore_eos) tasks.append( asyncio.create_task( - request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, + pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -496,6 +546,7 @@ async def benchmark( tokenizer=tokenizer, selected_percentile_metrics=selected_percentile_metrics, selected_percentiles=selected_percentiles, + gootput_config_dict=gootput_config_dict, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -507,6 +558,9 @@ async def benchmark( metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) + if gootput_config_dict: + print("{:<40} {:<10.2f}".format("Request goodput (req/s):", + metrics.request_goodput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", @@ -518,6 +572,8 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, + "request_goodput:": + metrics.request_goodput if gootput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -571,6 +627,41 @@ def process_one_metric( return result +def check_goodput_args(args): + # Check and parse goodput arguments + gootput_config_dict = {} + VALID_NAMES = ["ttft", "tpot", "e2el"] + if args.goodput: + gootput_config_dict = parse_goodput(args.goodput) + for slo_name, slo_val in gootput_config_dict.items(): + if slo_name not in VALID_NAMES: + raise ValueError( + f"Invalid metric name found, {slo_name}: {slo_val}. " + "The service level objective name should be one of " + f"{str(VALID_NAMES)}. ") + if slo_val < 0: + raise ValueError( + f"Invalid value found, {slo_name}: {slo_val}. " + "The service level objective value should be " + "non-negative.") + return gootput_config_dict + + +def parse_goodput(slo_pairs): + gootput_config_dict = {} + try: + for slo_pair in slo_pairs: + slo_name, slo_val = slo_pair.split(":") + gootput_config_dict[slo_name] = float(slo_val) + except ValueError as err: + raise argparse.ArgumentTypeError( + "Invalid format found for service level objectives. " + "Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is a " + "number in milliseconds.") from err + return gootput_config_dict + + def main(args: argparse.Namespace): print(args) random.seed(args.seed) @@ -648,6 +739,7 @@ def main(args: argparse.Namespace): dataset_split=args.hf_split, num_requests=args.num_prompts, tokenizer=tokenizer, + random_seed=args.seed, fixed_output_len=args.hf_output_len, ) @@ -664,6 +756,8 @@ def main(args: argparse.Namespace): else: raise ValueError(f"Unknown dataset: {args.dataset_name}") + gootput_config_dict = check_goodput_args(args) + benchmark_result = asyncio.run( benchmark( backend=backend, @@ -682,6 +776,8 @@ def main(args: argparse.Namespace): float(p) for p in args.metric_percentiles.split(",") ], ignore_eos=args.ignore_eos, + gootput_config_dict=gootput_config_dict, + max_concurrency=args.max_concurrency, )) # Save config and results to json @@ -711,13 +807,16 @@ def main(args: argparse.Namespace): # Traffic result_json["request_rate"] = ( args.request_rate if args.request_rate < float("inf") else "inf") + result_json["max_concurrency"] = args.max_concurrency # Merge with benchmark result result_json = {**result_json, **benchmark_result} # Save to file base_model_id = model_id.split("/")[-1] - file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa + max_concurrency_str = (f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None else "") + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa if args.result_filename: file_name = args.result_filename if args.result_dir: @@ -768,6 +867,19 @@ def main(args: argparse.Namespace): default=None, help="Path to the sharegpt/sonnet dataset. " "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument( + "--max-concurrency", + type=int, + default=None, + help="Maximum number of concurrent requests. This can be used " + "to help simulate an environment where a higher level component " + "is enforcing a maximum number of concurrent requests. While the " + "--request-rate argument controls the rate at which requests are " + "initiated, this argument will control how many are actually allowed " + "to execute at a time. This means that when used in combination, the " + "actual request rate may be lower than specified with --request-rate, " + "if the server is not processing requests fast enough to keep up.") + parser.add_argument( "--model", type=str, @@ -881,6 +993,17 @@ def main(args: argparse.Namespace): "Default value is \"99\". " "Use \"--percentile-metrics\" to select metrics.", ) + parser.add_argument( + "--goodput", + nargs="+", + required=False, + help="Specify service level objectives for goodput as \"KEY:VALUE\" " + "pairs, where the key is a metric name, and the value is in " + "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + "separated by spaces. Allowed request level metric names are " + "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " + "and the blog: https://hao-ai-lab.github.io/blogs/distserve") # group for dataset specific arguments sonnet_group = parser.add_argument_group("sonnet dataset options") diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index b7bc2a640237..ee41c8ea3838 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,5 +1,6 @@ """Benchmark offline inference throughput.""" import argparse +import dataclasses import json import random import time @@ -11,10 +12,9 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( build_async_engine_client_from_engine_args) -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser, merge_async_iterators @@ -67,55 +67,11 @@ def sample_requests( def run_vllm( requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, n: int, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - distributed_executor_backend: Optional[str], - gpu_memory_utilization: float = 0.9, - num_scheduler_steps: int = 1, - use_v2_block_manager: bool = False, - download_dir: Optional[str] = None, - load_format: str = EngineArgs.load_format, - disable_async_output_proc: bool = False, + engine_args: EngineArgs, ) -> float: from vllm import LLM, SamplingParams - llm = LLM( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - load_format=load_format, - num_scheduler_steps=num_scheduler_steps, - use_v2_block_manager=use_v2_block_manager, - disable_async_output_proc=disable_async_output_proc, - ) + llm = LLM(**dataclasses.asdict(engine_args)) # Add the requests to the engine. prompts: List[str] = [] @@ -157,58 +113,11 @@ def run_vllm( async def run_vllm_async( requests: List[Tuple[str, int, int]], - model: str, - tokenizer: str, - quantization: Optional[str], - tensor_parallel_size: int, - seed: int, n: int, - trust_remote_code: bool, - dtype: str, - max_model_len: Optional[int], - enforce_eager: bool, - kv_cache_dtype: str, - quantization_param_path: Optional[str], - device: str, - enable_prefix_caching: bool, - enable_chunked_prefill: bool, - max_num_batched_tokens: int, - distributed_executor_backend: Optional[str], - gpu_memory_utilization: float = 0.9, - num_scheduler_steps: int = 1, - use_v2_block_manager: bool = False, - download_dir: Optional[str] = None, - load_format: str = EngineArgs.load_format, - disable_async_output_proc: bool = False, + engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, ) -> float: from vllm import SamplingParams - engine_args = AsyncEngineArgs( - model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - quantization_param_path=quantization_param_path, - device=device, - enable_prefix_caching=enable_prefix_caching, - download_dir=download_dir, - enable_chunked_prefill=enable_chunked_prefill, - max_num_batched_tokens=max_num_batched_tokens, - distributed_executor_backend=distributed_executor_backend, - load_format=load_format, - num_scheduler_steps=num_scheduler_steps, - use_v2_block_manager=use_v2_block_manager, - disable_async_output_proc=disable_async_output_proc, - worker_use_ray=False, - disable_log_requests=True, - ) async with build_async_engine_client_from_engine_args( engine_args, disable_frontend_multiprocessing) as llm: @@ -324,7 +233,16 @@ def main(args: argparse.Namespace): args.tokenizer, trust_remote_code=args.trust_remote_code) if args.dataset is None: # Synthesize a prompt with the given input length. - prompt = "hi" * (args.input_len - 1) + # As tokenizer may add additional tokens like BOS, we need to try + # different lengths to get the desired input length. + for i in range(-10, 10): + prompt = "hi " * (args.input_len + i) + tokenized_prompt = tokenizer(prompt).input_ids + if len(tokenized_prompt) == args.input_len: + break + else: + raise ValueError( + f"Failed to synthesize a prompt with {args.input_len} tokens.") requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)] else: @@ -332,24 +250,17 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - run_args = [ - requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, - args.trust_remote_code, args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.distributed_executor_backend, - args.gpu_memory_utilization, args.num_scheduler_steps, - args.use_v2_block_manager, args.download_dir, args.load_format, - args.disable_async_output_proc - ] - if args.async_engine: - run_args.append(args.disable_frontend_multiprocessing) - elapsed_time = uvloop.run(run_vllm_async(*run_args)) + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + )) else: - elapsed_time = run_vllm(*run_args) + elapsed_time = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args)) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -361,8 +272,10 @@ def main(args: argparse.Namespace): raise ValueError(f"Unknown backend: {args.backend}") total_num_tokens = sum(prompt_len + output_len for _, prompt_len, output_len in requests) + total_output_tokens = sum(output_len for _, _, output_len in requests) print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} tokens/s") + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") # Output JSON results if specified if args.output_json: @@ -396,13 +309,6 @@ def main(args: argparse.Namespace): default=None, help="Output length for each request. Overrides the " "output length from the dataset.") - parser.add_argument("--model", type=str, default="facebook/opt-125m") - parser.add_argument("--tokenizer", type=str, default=None) - parser.add_argument('--quantization', - '-q', - choices=[*QUANTIZATION_METHODS, None], - default=None) - parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", type=int, default=1, @@ -411,127 +317,15 @@ def main(args: argparse.Namespace): type=int, default=1000, help="Number of prompts to process.") - parser.add_argument("--seed", type=int, default=0) parser.add_argument("--hf-max-batch-size", type=int, default=None, help="Maximum batch size for HF backend.") - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument( - '--max-model-len', - type=int, - default=None, - help='Maximum length of a sequence (including prompt and output). ' - 'If None, will be derived from the model.') - parser.add_argument( - '--dtype', - type=str, - default='auto', - choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=0.9, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument("--enforce-eager", - action="store_true", - help="enforce eager execution") - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default="auto", - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=str, - default=None, - help='Path to the JSON file containing the KV cache scaling factors. ' - 'This should generally be supplied, when KV cache dtype is FP8. ' - 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' - 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' - 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' - 'instead supported for common inference criteria.') - parser.add_argument("--device", - type=str, - default="auto", - choices=DEVICE_OPTIONS, - help='device type for vLLM execution') - parser.add_argument( - "--num-scheduler-steps", - type=int, - default=1, - help="Maximum number of forward steps per scheduler call.") - parser.add_argument("--use-v2-block-manager", - action='store_true', - default=EngineArgs.use_v2_block_manager, - help="Enable block manager v2.") - parser.add_argument( - "--enable-prefix-caching", - action='store_true', - help="Enable automatic prefix caching for vLLM backend.") - parser.add_argument("--enable-chunked-prefill", - action='store_true', - help="enable chunked prefill for vLLM backend.") - parser.add_argument('--max-num-batched-tokens', - type=int, - default=None, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--download-dir', - type=str, - default=None, - help='directory to download and load the weights, ' - 'default to the default cache dir of huggingface') parser.add_argument( '--output-json', type=str, default=None, help='Path to save the throughput results in JSON format.') - parser.add_argument( - '--distributed-executor-backend', - choices=['ray', 'mp'], - default=None, - help='Backend to use for distributed serving. When more than 1 GPU ' - 'is used, will be automatically set to "ray" if installed ' - 'or "mp" (multiprocessing) otherwise.') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=[ - 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer', - 'bitsandbytes' - ], - help='The format of the model weights to load.\n\n' - '* "auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available.\n' - '* "pt" will load the weights in the pytorch bin format.\n' - '* "safetensors" will load the weights in the safetensors format.\n' - '* "npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading.\n' - '* "dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.\n' - '* "tensorizer" will load the weights using tensorizer from ' - 'CoreWeave. See the Tensorize vLLM Model script in the Examples' - 'section for more information.\n' - '* "bitsandbytes" will load the weights using bitsandbytes ' - 'quantization.\n') - parser.add_argument( - "--disable-async-output-proc", - action='store_true', - default=False, - help="Disable async output processor for vLLM backend.") parser.add_argument("--async-engine", action='store_true', default=False, @@ -540,6 +334,7 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable decoupled async engine frontend.") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 73fc9e9dbf46..784b1cf9844e 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -31,7 +31,7 @@ def benchmark_rope_kernels_multi_lora( # batched RoPE can take multiple scaling factors batched_rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": tuple(scaling_factors) }) # non-batched RoPE takes only one scaling factor, we create multiple @@ -41,7 +41,7 @@ def benchmark_rope_kernels_multi_lora( non_batched_ropes.append( get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": (scaling_factor, ) })) diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index 203699e9a8d0..d16d6f9fba44 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -16,7 +16,6 @@ def main(args): enforce_eager=True, enable_prefix_caching=True, tensor_parallel_size=args.tensor_parallel_size, - use_v2_block_manager=args.use_v2_block_manager, ) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) @@ -56,8 +55,5 @@ def main(args): parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index bc5f24d3f591..7237d246ddf5 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,5 +1,8 @@ +include(FetchContent) + +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 17) # # Define environment variables for special configurations @@ -82,15 +85,40 @@ else() message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") endif() +# +# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) +# +if (AVX512_FOUND AND NOT AVX512_DISABLED) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.5.3 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + FetchContent_MakeAvailable(oneDNN) + + list(APPEND LIBS dnnl) +endif() + message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") list(APPEND LIBS numa) -# Appending the dnnl library for the AVX2 and AVX512, as it is not utilized by Power architecture. -if (AVX2_FOUND OR AVX512_FOUND) - list(APPEND LIBS dnnl) -endif() - # # _C extension # diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 5ed1dc3b8f79..839dc36ba4e2 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -89,6 +89,48 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] namespace vllm { +template +__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { + const float f = (float)x; + return (T)(f > threshold ? f : 0.0f); +} + +template +__global__ void act_and_mul_kernel_with_param( + scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, + const float param) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = ACT_FN(x, param) * y; + } +} + +} // namespace vllm + +#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, \ + PARAM); \ + }); + +void fatrelu_and_mul(torch::Tensor& out, // [..., d], + torch::Tensor& input, // [..., 2 * d] + double threshold) { + LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); +} +namespace vllm { + // Element-wise activation kernel template. template __global__ void activation_kernel( diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index 0e1f360d74bd..408e736d5bc0 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -1,6 +1,7 @@ #pragma once -#include +// For TORCH_CHECK +#include namespace vllm { @@ -9,12 +10,7 @@ namespace vllm { // in particular it can be used to represent sub-byte data types (something // that torch.dtype currently does not support). // -// ScalarTypeTorch is a subclass of ScalarType that is compatible with -// TORCH_LIBRARY, making it accessible from Python as well meaning this class -// can be used as a argument for custom operators, helping to simplify these -// interfaces. -// -// The type definitions on the Python side can be found in: vllm/_core_ext.pyi +// The type definitions on the Python side can be found in: vllm/scalar_type.py // these type definitions should be kept up to date with any Python API changes // here. // @@ -308,204 +304,7 @@ class ScalarType { } }; -// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from -// torch::CustomClassHolder), we use multiple inheritance here since we cannot -// have ScalarType inherit from torch::CustomClassHolder and have a constexpr -// constructor at the same time (torch::CustomClassHolder does not have a -// constexpr destructor) -// See also: -// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA -class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { - public: - ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, - bool _signed) - : ScalarType(exponent, mantissa, bias, _signed){}; - - ScalarTypeTorch(ScalarType type) : ScalarType(type){}; - - using Base = ScalarType; - using Self = ScalarTypeTorch; - using SelfPtr = c10::intrusive_ptr; - - static void check_size_bits(int64_t size_bits, bool signed_) { - TORCH_CHECK( - size_bits <= - std::numeric_limits().mantissa)>::max(), - "size_bits bit width is too large to be represented"); - } - - static void check_bias(int64_t bias) { - using Bias = decltype(std::declval().bias); - TORCH_CHECK(bias <= std::numeric_limits::max() && - bias >= std::numeric_limits::min(), - "bias too large or small to be represented"); - } - - static void check_exponent(int64_t exponent) { - TORCH_CHECK( - exponent <= - std::numeric_limits().exponent)>::max(), - "exponent bit width is too large to be represented"); - } - - static void check_mantissa(int64_t mantissa) { - TORCH_CHECK( - mantissa <= - std::numeric_limits().mantissa)>::max(), - "mantissa bit width is too large to be represented"); - } - - static SelfPtr int_(int64_t size_bits, c10::optional bias) { - check_size_bits(size_bits, true); - check_bias(bias.value_or(0)); - return c10::make_intrusive( - ScalarType::int_(size_bits, bias.value_or(0))); - } - - static SelfPtr uint(int64_t size_bits, c10::optional bias) { - check_size_bits(size_bits, true); - check_bias(bias.value_or(0)); - return c10::make_intrusive( - ScalarType::uint(size_bits, bias.value_or(0))); - } - - static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { - check_mantissa(mantissa); - check_exponent(exponent); - return c10::make_intrusive( - ScalarType::float_IEEE754(exponent, mantissa)); - } - - static SelfPtr float_(int64_t exponent, int64_t mantissa, - bool finite_values_only, int64_t nan_repr) { - check_mantissa(mantissa); - check_exponent(exponent); - return c10::make_intrusive(ScalarType::float_( - exponent, mantissa, finite_values_only, NanRepr(nan_repr))); - } - - // This needs to be implemented and throw a TypeError in order for - // PyTorch's opcheck to work on ops that use ScalarTypes. - int64_t len() const { - throw c10::TypeError({__func__, __FILE__, static_cast(__LINE__)}, - "__len__ not implemented"); - return 0; - } - - // Serialize a ScalarType into a tuple of pairs. Where each pair - // is a (fieldname, value). - // For simplicity, we are just going to convert to a ScalarTypeId. - std::tuple> obj_flatten() const { - return {{"ScalarType", id()}}; - } - - // Deserialize a scalar type that has been serialized by obj_flatten, - // ostensibly from a tuple of (member name, value) pairs, but in reality - // just a ScalarTypeId. - static SelfPtr obj_unflatten( - std::tuple> const& flat_type) { - return c10::make_intrusive( - from_id(std::get<1>(std::get<0>(flat_type)))); - } - - template - static void bind_readonly_property(torch::class_& cls, - std::string const& name, T Base::*field) { - auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) { - if constexpr (std::is_member_function_pointer_v) { - return (self.get()->*field)(); - } else { - return self.get()->*field; - } - }; - - auto getter_func = [field = std::move(field), - getter_func_helper = std::move(getter_func_helper)]( - SelfPtr const& self) { - auto val = getter_func_helper(self); - // upconvert uint8_t, int32_t etc. to int64_t for python - if constexpr (std::is_integral_v) { - return static_cast(val); - } else { - return val; - } - }; - - cls.def_property(name, getter_func); - } - - template - static void bind_function(torch::class_& cls, const std::string& name, - MemberFunc Cls::*member) { - cls.def(name, [member = std::move(member)](SelfPtr const& self) { - return (self.get()->*member)(); - }); - } - - template - static void bind_function(torch::class_& cls, const std::string& name, - Func func) { - cls.def(name, func); - } - - template - static void bind_static_function(torch::class_& cls, - const std::string& name, Func func) { - cls.def_static(name, func); - } - - static void bind_class(torch::Library& lib) { - auto cls = lib.class_("ScalarType") - .def(torch::init()); - - // Bind Properties - bind_readonly_property(cls, "mantissa", &Base::mantissa); - bind_readonly_property(cls, "exponent", &Base::exponent); - bind_readonly_property(cls, "bias", &Base::bias); - bind_readonly_property(cls, "signed", &Base::is_signed); - bind_readonly_property(cls, "size_bits", &Base::size_bits); - - // Bind member functions - bind_function(cls, "is_signed", &Base::is_signed); - bind_function(cls, "is_integer", &Base::is_integer); - bind_function(cls, "is_floating_point", &Base::is_floating_point); - bind_function(cls, "is_ieee_754", &Base::is_ieee_754); - bind_function(cls, "has_nans", &Base::has_nans); - bind_function(cls, "has_infs", &Base::has_infs); - bind_function(cls, "has_bias", &Base::has_bias); - - bind_function(cls, "max", [](SelfPtr const& self) { - return std::visit([](auto arg) { return c10::IValue(arg); }, - self.get()->max()); - }); - bind_function(cls, "min", [](SelfPtr const& self) { - return std::visit([](auto arg) { return c10::IValue(arg); }, - self.get()->min()); - }); - - bind_function(cls, "__len__", &ScalarTypeTorch::len); - bind_function(cls, "__str__", &Base::str); - bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { - return *self == *other; - }); - bind_function(cls, "__repr__", [](SelfPtr const& self) { - return "ScalarType." + self.get()->str(); - }); - - bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten); - bind_static_function(cls, "__obj_unflatten__", - &ScalarTypeTorch::obj_unflatten); - - // Bind static functions (convenience constructors) - bind_static_function(cls, "int_", &ScalarTypeTorch::int_); - bind_static_function(cls, "uint", &ScalarTypeTorch::uint); - bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754); - bind_static_function(cls, "float_", &ScalarTypeTorch::float_); - } -}; - -using ScalarTypeId = int64_t; -using ScalarTypeTorchPtr = c10::intrusive_ptr; +using ScalarTypeId = ScalarType::Id; // "rust style" names generally following: // https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 diff --git a/csrc/core/torch_bindings.cpp b/csrc/core/torch_bindings.cpp deleted file mode 100644 index f60254189a2f..000000000000 --- a/csrc/core/torch_bindings.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include - -#include "scalar_type.hpp" -#include "registration.h" - -// Note the CORE exstension will be built for (almost) all hardware targets so -// new additions must account for this. (currently not built for TPU and Neuron) - -TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) { - // ScalarType, a custom class for representing data types that supports - // quantized types, declared here so it can be used when creating interfaces - // for custom ops. - vllm::ScalarTypeTorch::bind_class(lib); -} - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 5b1d3d6442b2..a325153b470c 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -265,6 +265,30 @@ struct FP32Vec8 : public Vec { void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } }; +#ifdef __AVX512F__ +struct INT32Vec16: public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512i reg; + int32_t values[VEC_ELEM_NUM]; + }; + + __m512i reg; + + explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {} + + void save(int32_t* ptr) const { + _mm512_storeu_epi32(ptr, reg); + } + + void save(int32_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_epi32(ptr, mask, reg); + } +}; +#endif + #ifdef __AVX512F__ struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; @@ -283,8 +307,6 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(__m512 data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} - explicit FP32Vec16(const FP32Vec4 &data) : reg((__m512)_mm512_inserti32x4( _mm512_inserti32x4( @@ -303,6 +325,9 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const INT32Vec16 &v) + : reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {} + FP32Vec16 operator*(const FP32Vec16 &b) const { return FP32Vec16(_mm512_mul_ps(reg, b.reg)); } @@ -333,6 +358,16 @@ struct FP32Vec16 : public Vec { return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); } + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(_mm512_min_ps(reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); + } + FP32Vec16 abs() const { return FP32Vec16(_mm512_abs_ps(reg)); } @@ -341,6 +376,8 @@ struct FP32Vec16 : public Vec { float reduce_max() const { return _mm512_reduce_max_ps(reg); } + float reduce_min() const { return _mm512_reduce_min_ps(reg); } + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 2d7abe6145fe..b493fd793818 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -5,25 +5,29 @@ namespace { template struct KernelVecType { using load_vec_type = void; + using azp_adj_load_vec_type = void; using cvt_vec_type = void; }; template <> struct KernelVecType { using load_vec_type = vec_op::FP32Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; using cvt_vec_type = vec_op::FP32Vec16; }; template <> struct KernelVecType { using load_vec_type = vec_op::BF16Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; using cvt_vec_type = vec_op::FP32Vec16; }; #ifdef __AVX512F__ -template +template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int num_tokens, + const float* scale, const int32_t* azp, + const int num_tokens, const int hidden_size) { using load_vec_t = typename KernelVecType::load_vec_type; using cvt_vec_t = typename KernelVecType::cvt_vec_type; @@ -37,62 +41,110 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const cvt_vec_t i8_min_vec(i8_min); const cvt_vec_t i8_max_vec(i8_max); + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + #pragma omp parallel for for (int i = 0; i < num_tokens; ++i) { int j = 0; for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); vec_op::INT8Vec16 elems_int8(elems_fp32); elems_int8.save(output + i * hidden_size + j); } load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); - elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec); - vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_fp32 = elems_fp32 * inv_scale; - if (j + vec_elem_num == hidden_size) { - elems_int8.save(output + i * hidden_size + j); - } else { - elems_int8.save(output + i * hidden_size + j, hidden_size - j); + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); } } -template +template void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, const int num_tokens, + float* scale, int32_t* azp, + const int num_tokens, const int hidden_size) { using load_vec_t = typename KernelVecType::load_vec_type; using cvt_vec_t = typename KernelVecType::cvt_vec_type; constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + #pragma omp parallel for for (int i = 0; i < num_tokens; ++i) { - cvt_vec_t max_abs(0.0); + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); { int j = 0; for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); - max_abs = max_abs.max(elems_fp32.abs()); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } } load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); if (j + vec_elem_num == hidden_size) { - max_abs = max_abs.max(elems_fp32.abs()); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } } else { - max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } } } - float scale_val = max_abs.reduce_max() / 127.0f; - scale[i] = scale_val; + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); { int j = 0; @@ -100,6 +152,11 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); vec_op::INT8Vec16 elems_int8(elems_fp32); elems_int8.save(output + i * hidden_size + j); } @@ -107,34 +164,111 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, load_vec_t elems(input + i * hidden_size + j); cvt_vec_t elems_fp32(elems); elems_fp32 = (elems_fp32 * inv_scale); - vec_op::INT8Vec16 elems_int8(elems_fp32); - if (j + vec_elem_num == hidden_size) { - elems_int8.save(output + i * hidden_size + j); - } else { - elems_int8.save(output + i * hidden_size + j, hidden_size - j); + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); } } } -template -void dynamic_output_scale_impl(const float* input, scalar_t* output, - const float* scale, const scalar_t* bias, - const int num_tokens, const int hidden_size) { +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; using cvt_vec_t = typename KernelVecType::cvt_vec_type; constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; #pragma omp parallel for for (int i = 0; i < num_tokens; ++i) { int j = 0; - cvt_vec_t token_scale_vec(scale[i]); + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { cvt_vec_t elems_fp32(input + i * hidden_size + j); elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + if constexpr (Bias) { load_vec_t bias_vec(bias + j); cvt_vec_t bias_vec_fp32(bias_vec); @@ -148,6 +282,19 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output, cvt_vec_t elems_fp32(input + i * hidden_size + j); elems_fp32 = elems_fp32 * token_scale_vec; + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + if constexpr (Bias) { load_vec_t bias_vec(bias + j); cvt_vec_t bias_vec_fp32(bias_vec); @@ -155,32 +302,41 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output, } load_vec_t elems_out(elems_fp32); - - if (j + vec_elem_num == hidden_size) { - elems_out.save(output + i * hidden_size + j); - } else { - elems_out.save(output + i * hidden_size + j, hidden_size - j); - } + elems_out.save(output + i * hidden_size + j, hidden_size - j); } } #else template void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - const float* scale, const int num_tokens, + const float* scale, const int32_t* azp, + const int num_tokens, const int hidden_size) { TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") } template void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, - float* scale, const int num_tokens, + float* scale, int32_t* azp, + const int num_tokens, const int hidden_size) { TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") } +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") +} + template -void dynamic_output_scale_impl() { - TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.") +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_with_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") } #endif } // namespace @@ -214,39 +370,52 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major bias->dim() == 1); } - VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] { + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { if (a_scales.numel() != 1) { // per-token // Note: oneDNN doesn't support per-token activation quantization + // Ideally we want to fuse the GEMM and the scale procedure with oneDNN + // JIT, the intermediate data is cached in registers or L1. But for now + // the oneDNN GEMM code generation only supports two quantization + // patterns: per-tensor or per-output-channel of weight. + // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * + // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN + // GEMM, then the per-token scale (and bias) is applied with the epilogue + // C=s_a * C_inter + bias. torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); - DNNLPrimitiveHelper::gemm_s8s8_jit( + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( a.data_ptr(), b.data_ptr(), - tmp_fp32_out.data_ptr(), (void*)(0), a.size(0), b.size(1), - a.size(1), (float*)(0), b_scales.data_ptr(), 0, - b_scales.numel()); + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); if (bias.has_value()) { - dynamic_output_scale_impl( + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), bias->data_ptr(), c.size(0), - c.size(1)); + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); } else { - dynamic_output_scale_impl( + // Compute C=s_a * C_inter + dynamic_quant_epilogue( tmp_fp32_out.data_ptr(), c.data_ptr(), - a_scales.data_ptr(), (scalar_t*)(0), c.size(0), c.size(1)); + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); } } else { // per-tensor if (bias.has_value()) { + // Compute C=s_a * s_b * (A@B) + bias DNNLPrimitiveHelper::gemm_s8s8_jit( a.data_ptr(), b.data_ptr(), c.data_ptr(), bias->data_ptr(), a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); } else { - DNNLPrimitiveHelper::gemm_s8s8_jit( + // Compute C=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( a.data_ptr(), b.data_ptr(), c.data_ptr(), - (void*)(0), a.size(0), b.size(1), a.size(1), + nullptr, a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); } @@ -254,6 +423,127 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major }); } +void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const torch::Tensor& azp_adj, // [OC] + const c10::optional& azp, // [1] or [M] + const c10::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_azp only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); + } + if (azp) { + TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); + } + TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); + + // azp & bias types + TORCH_CHECK(azp_adj.dtype() == torch::kInt32); + TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); + TORCH_CHECK(!bias || bias->dtype() == c.dtype(), + "currently bias dtype must match output dtype ", c.dtype()); + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } + } + } else { + // per-tensor + if (bias.has_value()) { + // Compute C_inter=s_a * s_b * (A@B) + bias + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), bias->data_ptr(), + a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), + b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); + } else { + // Compute C_inter=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + + // Compute C=C_inter - s_a * s_b * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } else { + // Per-Tensor + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } + } + }); +} + // static-per-tensor quantization. void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const torch::Tensor& input, // [..., hidden_size] @@ -263,15 +553,22 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); - TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); const int hidden_size = input.size(-1); const int num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_impl", [&] { - static_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), num_tokens, hidden_size); + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } }); } @@ -284,14 +581,20 @@ void dynamic_scaled_int8_quant( CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); - TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU."); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { - dynamic_scaled_int8_quant_impl( - input.data_ptr(), out.data_ptr(), - scale.data_ptr(), num_tokens, hidden_size); + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } }); } diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index ab697e3e6aef..03beefbc6de7 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, const torch::Tensor& b_scales, const c10::optional& bias); +void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const torch::Tensor& azp_adj, + const c10::optional& azp, + const c10::optional& bias); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif } diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 30831efdfa1a..3a464c5f327a 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -55,6 +55,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, const at::Tensor out, const c10::optional& bias, bool silu_activation, + int64_t pad_slot_id, const c10::optional& query_start_loc = std::nullopt, const c10::optional& cache_indices = std::nullopt, const c10::optional& has_initial_state = std::nullopt) { @@ -66,6 +67,7 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, params.dim = dim; params.seqlen = seqlen; params.width = width; + params.pad_slot_id = pad_slot_id; params.silu_activation = silu_activation; @@ -90,14 +92,16 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, } -at::Tensor -causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, +void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &bias_, const c10::optional &conv_states, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - bool silu_activation) { + bool silu_activation, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -153,12 +157,13 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, CHECK_SHAPE(cache_indices_, batch_size); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, silu_activation, + pad_slot_id, query_start_loc, cache_indices, has_initial_state @@ -183,18 +188,19 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { causal_conv1d_fwd_cuda(params, stream); }); - return out; } -at::Tensor -causal_conv1d_update(const at::Tensor &x, +void causal_conv1d_update(const at::Tensor &x, const at::Tensor &conv_state, const at::Tensor &weight, const c10::optional &bias_, bool silu_activation, const c10::optional &cache_seqlens_, - const c10::optional &conv_state_indices_) { + const c10::optional &conv_state_indices_, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -227,12 +233,13 @@ causal_conv1d_update(const at::Tensor &x, CHECK_SHAPE(bias, dim); } - at::Tensor out = torch::empty_like(x); + at::Tensor out = x; ConvParamsBase params; set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, bias_, - silu_activation); + silu_activation, + pad_slot_id); params.conv_state_ptr = conv_state.data_ptr(); params.conv_state_len = conv_state_len; // All stride are in elements, not bytes. @@ -274,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); }); - return out; } template @@ -340,7 +346,10 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; - + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; @@ -528,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr ? batch_id : params.conv_state_indices_ptr[batch_id]; + // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early + if (conv_state_batch_coord == params.pad_slot_id){ + return; + } input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + conv_state_batch_coord * params.conv_state_batch_stride + channel_id * params.conv_state_c_stride; diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 49e37ee4528b..e26684a2b98b 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -13,6 +13,7 @@ struct ConvParamsBase { using index_t = uint32_t; int batch, dim, seqlen, width; + int64_t pad_slot_id; bool silu_activation; index_t x_batch_stride; diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 580d0b2e17e7..563d2fe4ef65 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -21,6 +21,7 @@ struct SSMParamsBase { int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; + int64_t pad_slot_id; bool delta_softplus; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 6b225b41d295..71624696338d 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -115,6 +115,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr : reinterpret_cast(params.cache_indices_ptr); const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride @@ -387,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const size_t seqlen, const size_t dstate, const size_t n_groups, - const size_t n_chunks, const bool is_variable_B, const bool is_variable_C, // device pointers @@ -407,7 +410,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - bool varlen) { + bool varlen, + int64_t pad_slot_id) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -417,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.seqlen = seqlen; params.dstate = dstate; params.n_groups = n_groups; - params.n_chunks = n_chunks; params.dim_ngroups_ratio = dim / n_groups; + params.pad_slot_id = pad_slot_id; params.delta_softplus = delta_softplus; @@ -507,7 +511,10 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const c10::optional &query_start_loc, const c10::optional &cache_indices, const c10::optional &has_initial_state, - const torch::Tensor &ssm_states) { + const torch::Tensor &ssm_states, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -618,18 +625,14 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, out_z = z; - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout at::Tensor out = delta; TORCH_CHECK(ssm_states.scalar_type() == input_type); TORCH_CHECK(ssm_states.is_cuda()); TORCH_CHECK(ssm_states.stride(-1) == 1); - CHECK_SHAPE(ssm_states, batch_size, dim, dstate); SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, D_, delta_bias_, @@ -639,7 +642,8 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, query_start_loc, cache_indices, has_initial_state, - varlen + varlen, + pad_slot_id ); diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index e2db4e4196b6..5f12483e951e 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -484,21 +484,22 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& topk_ids, const torch::Tensor& b_scales, torch::Tensor& b_zeros, const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, bool replicate_input, bool apply_weights) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); bool has_zp = b_zeros.size(1) != 0; if (has_zp) { TORCH_CHECK( - *b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); + b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK( - *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); } - int pack_factor = 32 / b_q_type->size_bits(); + int pack_factor = 32 / b_q_type.size_bits(); int max_par = 4; @@ -575,7 +576,7 @@ torch::Tensor marlin_gemm_moe( topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, + b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, num_experts, topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 18fbc57ac783..019c6cedd3d8 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -13,8 +13,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " - "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int b_q_type, SymInt size_m, " + "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " + "topk, " "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); // conditionally compiled so impl registration is in source file diff --git a/csrc/ops.h b/csrc/ops.h index fce545f95a7c..11a297069554 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -48,6 +48,9 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, + double threshold); + void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input); @@ -157,21 +160,23 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, const c10::optional& query_start_loc, const c10::optional& cache_indices, const c10::optional& has_initial_state, - const torch::Tensor& ssm_states); - -at::Tensor causal_conv1d_update( - const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, - const c10::optional& bias_, bool silu_activation, - const c10::optional& cache_seqlens_, - const c10::optional& conv_state_indices_); - -at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, - const c10::optional& bias_, - const c10::optional& conv_states, - const c10::optional& query_start_loc, - const c10::optional& cache_indices, - const c10::optional& has_initial_state, - bool silu_activation); + const torch::Tensor& ssm_states, int64_t pad_slot_id); + +void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation, + const c10::optional& cache_seqlens_, + const c10::optional& conv_state_indices_, + int64_t pad_slot_id); + +void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& conv_states, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + bool silu_activation, int64_t pad_slot_id); #ifndef USE_ROCM using fptr_t = int64_t; diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aec9fa002f96..e9987535bd3e 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) / scale); + out[i] = float_to_int8_rn(static_cast(input[i]) / scale); } } @@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel( scale_type const* scale_ptr, azp_type const* azp_ptr, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; azp_type const azp = *azp_ptr; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const val = static_cast(input[i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); - out[token_idx * hidden_size + i] = quant_val; + out[i] = quant_val; } } @@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; float absmax_val = 0.0f; float const zero = 0.0f; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[token_idx * hidden_size + i]); + float val = static_cast(input[i]); val = val > zero ? val : -val; absmax_val = val > absmax_val ? val : absmax_val; } @@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( float const tmp_scale = 127.0f / block_absmax_val; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); } } @@ -159,13 +169,17 @@ template __global__ void dynamic_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, azp_type* azp, const int hidden_size) { - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; // Scan for the min and max value for this token float max_val = std::numeric_limits::min(); float min_val = std::numeric_limits::max(); for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto val = static_cast(input[token_idx * hidden_size + i]); + auto val = static_cast(input[i]); max_val = std::max(max_val, val); min_val = std::min(min_val, val); } @@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( // Quantize the values for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const val = static_cast(input[i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); - out[token_idx * hidden_size + i] = quant_val; + out[i] = quant_val; } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 1657f7d0b16e..97a969cf5e3e 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -137,9 +137,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, return; } - // Turing - TORCH_CHECK(version_num >= 75); - cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); + if (version_num >= 75) { + // Turing + cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); + return; + } #endif TORCH_CHECK_NOT_IMPLEMENTED( diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 7e23f9225776..f2c609c1b68c 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -204,8 +204,10 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( int const tid = threadIdx.x; int const token_idx = blockIdx.x; - scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; - FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size]; + // Use int64 to avoid overflowing an int32 when calculating this offset + int64_t offset = static_cast(token_idx) * hidden_size; + scalar_t const* __restrict__ token_input = &input[offset]; + FP8_TYPE* __restrict__ token_output = &out[offset]; // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively. diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 5efe15d2b2f6..6dbf9594e849 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -80,7 +80,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp) { TORCH_CHECK_NOT_IMPLEMENTED(false, @@ -2132,22 +2132,23 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_scales, torch::Tensor& b_zeros, torch::Tensor& g_idx, torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, bool has_zp, bool use_fp32_reduce) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); if (has_zp) { - TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", - b_q_type->str()); + TORCH_CHECK( + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { TORCH_CHECK( - *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - b_q_type->str()); + b_q_type.str()); } - int pack_factor = 32 / b_q_type->size_bits(); + int pack_factor = 32 / b_q_type.size_bits(); // Verify A TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), @@ -2279,7 +2280,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else if (a.scalar_type() == at::ScalarType::BFloat16) { @@ -2288,7 +2289,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, c.data_ptr(), c_tmp.data_ptr(), b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp, + workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); } else { @@ -2302,4 +2303,4 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("gptq_marlin_gemm", &gptq_marlin_gemm); -} \ No newline at end of file +} diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index ff037756f55a..9f9073ded619 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -38,9 +38,10 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { // Interface // -std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { +std::vector supported_schedules(ScalarTypeId const btype_id) { #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - return scalar_type_dispatch(*btype, [&](auto BType) { + vllm::ScalarType b_type = ScalarType::from_id(btype_id); + return scalar_type_dispatch(b_type, [&](auto BType) { return GemmDispatcher::supported_schedules(); }); #else @@ -49,7 +50,7 @@ std::vector supported_schedules(ScalarTypeTorchPtr const& btype) { } torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, - ScalarTypeTorchPtr const& btype, + ScalarTypeId const btype_id, c10::optional const& scales, c10::optional const& zeros, c10::optional group_size, @@ -57,6 +58,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, c10::optional alpha, c10::optional beta, c10::optional schedule) { #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 + ScalarType const btype = ScalarType::from_id(btype_id); auto args = PyTorchArguments{.A = A, .B = B, .scales = scales, @@ -67,7 +69,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, .beta = beta, .schedule = schedule}; - return scalar_type_dispatch(*btype, [&](auto BType) { + return scalar_type_dispatch(btype, [&](auto BType) { return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( A.scalar_type(), "machete_gemm", [&] { using ComputeType = equivalent_cutlass_type_t; @@ -79,9 +81,9 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, #endif } -torch::Tensor prepack_B(torch::Tensor const& B, - vllm::ScalarTypeTorchPtr const& btype) { - return scalar_type_dispatch(*btype, [&](auto BType) { +torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) { + ScalarType const btype = ScalarType::from_id(btype_id); + return scalar_type_dispatch(btype, [&](auto BType) { return PrepackBDispatcher::dispatch(B); }); } diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 908e4f70ab1e..a33e2660d760 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -89,7 +89,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -1029,13 +1029,14 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, + vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k) { + vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); // Verify num_bits - TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, - "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type->str()); - int pack_factor = 32 / b_q_type->size_bits(); + TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, + "num_bits must be uint4b8 or uint8b128. Got = ", b_q_type.str()); + int pack_factor = 32 / b_q_type.size_bits(); // Verify M TORCH_CHECK(size_m == a.size(0), @@ -1130,8 +1131,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, marlin_24::marlin_cuda_2_4( a.data_ptr(), b_q_weight.data_ptr(), b_meta.data_ptr(), c.data_ptr(), b_scales.data_ptr(), size_n, size_m, size_k, workspace.data_ptr(), - b_q_type->size_bits(), groupsize, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_m, sms, max_par); + b_q_type.size_bits(), groupsize, dev, at::cuda::getCurrentCUDAStream(dev), + thread_k, thread_m, sms, max_par); return c; } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index a0100b4a85ed..826f918c82e7 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -60,6 +60,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + // FATReLU implementation. + ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); + ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); + // GELU implementation used in GPT-2. ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCUDA, &gelu_new); @@ -140,13 +144,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantized GEMM for AWQ. ops.def( "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, int split_k_iters) -> Tensor"); + "Tensor _zeros, SymInt split_k_iters) -> Tensor"); ops.impl("awq_gemm", torch::kCUDA, &awq_gemm); // Dequantization for AWQ. ops.def( "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, " - "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor"); + "Tensor _zeros, SymInt split_k_iters, int thx, int thy) -> Tensor"); ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize); // Note about marlin kernel 'workspace' arguments: @@ -166,32 +170,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Marlin (Dense) Optimized Quantized GEMM for GPTQ. ops.def( "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"); + "Tensor! workspace, SymInt size_m, SymInt size_n, SymInt size_k) -> " + "Tensor"); // conditionally compiled so impl in source file // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, " "Tensor b_scales, Tensor workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, " - "int size_m, int size_n, int size_k) -> Tensor"); + "int b_q_type, " + "SymInt size_m, SymInt size_n, SymInt size_k) -> Tensor"); // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. + ops.def("machete_supported_schedules(int btype) -> str[]"); ops.def( - "machete_supported_schedules(" - " __torch__.torch.classes._core_C.ScalarType btype" - ") -> str[]"); - ops.def( - "machete_gemm(Tensor A, Tensor B," - " __torch__.torch.classes._core_C.ScalarType btype," - " Tensor? scales, Tensor? zeros, int? group_size," + "machete_gemm(Tensor A, Tensor B, int btype, " + " Tensor? scales, Tensor? zeros, int? group_size, " " Tensor? C, float? alpha, float? beta, str? schedule)" "-> Tensor"); - ops.def( - "machete_prepack_B(Tensor B," - " __torch__.torch.classes._core_C.ScalarType btype)" - "-> Tensor"); + ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor"); // conditionally compiled so impl registration is in source file ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); @@ -201,8 +199,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " - "__torch__.torch.classes._core_C.ScalarType b_q_type, " - "int size_m, int size_n, int size_k, bool is_k_full, " + "int b_q_type, " + "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " "bool has_zp, bool use_fp32_reduce) -> Tensor"); // conditionally compiled so impl registration is in source file @@ -219,32 +217,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl registrations are in source file // Dequantization for GGML. - ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor"); + ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"); ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); // mmvq kernel for GGML. ops.def( - "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) " + "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) " "-> Tensor"); ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8); // mmq kernel for GGML. - ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor"); + ops.def( + "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def( "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, int num_bits, int size_m, int size_n, " - "int size_k) -> Tensor"); + "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, " + "SymInt size_k) -> Tensor"); // conditionally compiled so impl registration is in source file // marlin_qqq_gemm for QQQ. ops.def( "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " "Tensor s_tok, Tensor s_ch, Tensor s_group, " - "Tensor! workspace, int size_m, int size_n, " - "int size_k) -> Tensor"); + "Tensor! workspace, SymInt size_m, SymInt size_n, " + "SymInt size_k) -> Tensor"); // conditionally compiled so impl registration is in source file // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column @@ -278,7 +277,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "Tensor! ssm_states) -> ()"); + "Tensor! ssm_states," + "int pad_slot_id) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( @@ -288,7 +288,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? bias_," "bool silu_activation," "Tensor? cache_seqlens_," - "Tensor? conv_state_indices) -> Tensor"); + "Tensor? conv_state_indices," + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( @@ -298,7 +299,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor? query_start_loc," "Tensor? cache_indices," "Tensor? has_initial_state," - "bool silu_activation) -> Tensor"); + "bool silu_activation," + "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/docs/source/dev/input_processing/model_inputs_index.rst b/docs/source/dev/input_processing/model_inputs_index.rst index 5d895837590b..f0ec1fea15dd 100644 --- a/docs/source/dev/input_processing/model_inputs_index.rst +++ b/docs/source/dev/input_processing/model_inputs_index.rst @@ -25,7 +25,7 @@ Module Contents LLM Engine Inputs ----------------- -.. autoclass:: vllm.inputs.LLMInputs +.. autoclass:: vllm.inputs.DecoderOnlyInputs :members: :show-inheritance: diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index c8947beb3494..d12aeebbbc18 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -3,7 +3,13 @@ Installation with CPU ======================== -vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. +vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32 and BF16. vLLM CPU backend supports the following vLLM features: + +- Tensor Parallel (``-tp = N``) +- Quantization (``INT8 W8A8, AWQ``) + +.. note:: + FP16 data type and more advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. Table of contents: @@ -59,20 +65,6 @@ Build from source $ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu -- Third, build and install oneDNN library from source: - -.. code-block:: console - - $ git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git - $ cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ - -DONEDNN_BUILD_DOC=OFF \ - -DONEDNN_BUILD_EXAMPLES=OFF \ - -DONEDNN_BUILD_TESTS=OFF \ - -DONEDNN_BUILD_GRAPH=OFF \ - -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ - -DONEDNN_ENABLE_PRIMITIVE=MATMUL - $ cmake --build ./oneDNN/build --target install --config Release - - Finally, build and install vLLM CPU backend: .. code-block:: console @@ -155,5 +147,20 @@ Performance tips - If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access. +CPU Backend Considerations +-------------------------- + +- The CPU backend significantly differs from the GPU backend since the vLLM architecture was originally optimized for GPU use. A number of optimizations are needed to enhance its performance. + +- Decouple the HTTP serving components from the inference components. In a GPU backend configuration, the HTTP serving and tokenization tasks operate on the CPU, while inference runs on the GPU, which typically does not pose a problem. However, in a CPU-based setup, the HTTP serving and tokenization can cause significant context switching and reduced cache efficiency. Therefore, it is strongly recommended to segregate these two components for improved performance. + +- On CPU based setup with NUMA enabled, the memory access performance may be largely impacted by the `topology `_. For NUMA architecture, two optimizations are to recommended: Tensor Parallel or Data Parallel. + + * Using Tensor Parallel for a latency constraints deployment: following GPU backend design, a Megatron-LM's parallel algorithm will be used to shard the model, based on the number of NUMA nodes (e.g. TP = 2 for a two NUMA node system). With `TP feature on CPU `_ merged, Tensor Parallel is supported for serving and offline inferencing. In general each NUMA node is treated as one GPU card. Below is the example script to enable Tensor Parallel = 2 for serving: + + .. code-block:: console + + $ VLLM_CPU_KVCACHE_SPACE=40 VLLM_CPU_OMP_THREADS_BIND="0-31|32-63" vllm serve meta-llama/Llama-2-7b-chat-hf -tp=2 --distributed-executor-backend mp + * Using Data Parallel for maximum throughput: to launch an LLM serving endpoint on each NUMA node along with one additional load balancer to dispatch the requests to those endpoints. Common solutions like `Nginx <../serving/deploying_with_nginx.html>`_ or HAProxy are recommended. Anyscale Ray project provides the feature on LLM `serving `_. Here is the example to setup a scalable LLM serving with `Ray Serve `_. \ No newline at end of file diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index cfd2dcb3bd5d..91978065faf4 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -107,15 +107,15 @@ If GPU/CPU communication cannot be established, you can use the following Python If you are testing with a single node, adjust ``--nproc-per-node`` to the number of GPUs you want to use: -.. code-block:: shell +.. code-block:: console - NCCL_DEBUG=TRACE torchrun --nproc-per-node= test.py + $ NCCL_DEBUG=TRACE torchrun --nproc-per-node= test.py If you are testing with multi-nodes, adjust ``--nproc-per-node`` and ``--nnodes`` according to your setup and set ``MASTER_ADDR`` to the correct IP address of the master node, reachable from all nodes. Then, run: -.. code-block:: shell +.. code-block:: console - NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR test.py + $ NCCL_DEBUG=TRACE torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=$MASTER_ADDR test.py If the script runs successfully, you should see the message ``sanity check is successful!``. diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index 99c695ac4ddb..a706b285eded 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -7,14 +7,14 @@ Installation vLLM is a Python library that also contains pre-compiled C++ and CUDA (12.1) binaries. Requirements -=========================== +============ * OS: Linux -* Python: 3.8 -- 3.12 +* Python: 3.8 - 3.12 * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) Install released versions -=========================== +========================= You can install vLLM using pip: @@ -51,9 +51,9 @@ You can install vLLM using pip: .. _install-the-latest-code: Install the latest code -========================= +======================= -LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on x86 platform with cuda 12 for every commit since v0.5.3. You can download and install the latest one with the following command: +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on a x86 platform with CUDA 12 for every commit since ``v0.5.3``. You can download and install it with the following command: .. code-block:: console @@ -66,7 +66,7 @@ If you want to access the wheels for previous commits, you can specify the commi $ export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl -Note that the wheels are built with Python 3.8 abi (see `PEP 425 `_ for more details about abi), so **they are compatible with Python 3.8 and later**. The version string in the wheel file name (``1.0.0.dev``) is just a placeholder to have a unified URL for the wheels. The actual versions of wheels are contained in the wheel metadata. +Note that the wheels are built with Python 3.8 ABI (see `PEP 425 `_ for more details about ABI), so **they are compatible with Python 3.8 and later**. The version string in the wheel file name (``1.0.0.dev``) is just a placeholder to have a unified URL for the wheels. The actual versions of wheels are contained in the wheel metadata. Another way to access the latest code is to use the docker images: @@ -77,17 +77,17 @@ Another way to access the latest code is to use the docker images: These docker images are used for CI and testing only, and they are not intended for production use. They will be expired after several days. -Latest code can contain bugs and may not be stable. Please use it with caution. +The latest code can contain bugs and may not be stable. Please use it with caution. .. _build_from_source: Build from source -================== +================= .. _python-only-build: Python-only build (without compilation) ----------------------------------------- +--------------------------------------- If you only need to change Python code, you can simply build vLLM without compilation. @@ -116,28 +116,28 @@ The script will: Now, you can edit the Python code in the current directory, and the changes will be reflected when you run vLLM. -Once you have finished editing or want to install another vLLM wheel, you should exit the development environment using `the same script `_ with the ``--quit-dev``(or ``-q`` for short) flag: +Once you have finished editing or want to install another vLLM wheel, you should exit the development environment using `the same script `_ with the ``--quit-dev`` (or ``-q`` for short) flag: .. code-block:: console $ python python_only_dev.py --quit-dev -The script with ``--quit-dev`` flag will: +The ``--quit-dev`` flag will: * Remove the symbolic link from the current directory to the vLLM package. * Restore the original vLLM package from the backup. -If you update the vLLM wheel and want to rebuild from the source and make further edits, you will need to start `all above <#python-only-build>`_ over again. +If you update the vLLM wheel and rebuild from the source to make further edits, you will need to repeat the `Python-only build <#python-only-build>`_ steps again. .. note:: There is a possibility that your source code may have a different commit ID compared to the latest vLLM wheel, which could potentially lead to unknown errors. - It is recommended to use the same commit ID for the source code as the vLLM wheel you have installed. Please refer to `the above section <#install-the-latest-code>`_ for instructions on how to install a specified wheel. + It is recommended to use the same commit ID for the source code as the vLLM wheel you have installed. Please refer to `the section above <#install-the-latest-code>`_ for instructions on how to install a specified wheel. Full build (with compilation) ---------------------------------- +----------------------------- -If you want to modify C++ or CUDA code, you'll need to build vLLM from source. This can take several minutes: +If you want to modify C++ or CUDA code, you'll need to build vLLM from source. This can take several minutes: .. code-block:: console @@ -153,7 +153,7 @@ If you want to modify C++ or CUDA code, you'll need to build vLLM from source. T Use an existing PyTorch installation -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ There are scenarios where the PyTorch dependency cannot be easily installed via pip, e.g.: * Building vLLM with PyTorch nightly or a custom PyTorch build. @@ -171,7 +171,7 @@ To build vLLM using an existing PyTorch installation: Troubleshooting -~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~ To avoid your system being overloaded, you can limit the number of compilation jobs to be run simultaneously, via the environment variable ``MAX_JOBS``. For example: @@ -207,7 +207,7 @@ Here is a sanity check to verify that the CUDA Toolkit is correctly installed: Unsupported OS build ----------------------- +-------------------- vLLM can fully run only on Linux but for development purposes, you can still build it on other systems (for example, macOS), allowing for imports and a more convenient development environment. The binaries will not be compiled and won't work on non-Linux systems. diff --git a/docs/source/index.rst b/docs/source/index.rst index d20e46b4a365..c328c049b430 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -80,6 +80,7 @@ Documentation serving/openai_compatible_server serving/deploying_with_docker serving/deploying_with_k8s + serving/deploying_with_nginx serving/distributed_serving serving/metrics serving/env_vars diff --git a/docs/source/models/spec_decode.rst b/docs/source/models/spec_decode.rst index 50468f25b922..b02c80aebec6 100644 --- a/docs/source/models/spec_decode.rst +++ b/docs/source/models/spec_decode.rst @@ -30,7 +30,6 @@ The following code configures vLLM in an offline mode to use speculative decodin tensor_parallel_size=1, speculative_model="facebook/opt-125m", num_speculative_tokens=5, - use_v2_block_manager=True, ) outputs = llm.generate(prompts, sampling_params) @@ -44,10 +43,10 @@ To perform the same with an online mode launch the server: .. code-block:: bash python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000 --model facebook/opt-6.7b \ - --seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \ - --num_speculative_tokens 5 --gpu_memory_utilization 0.8 + --seed 42 -tp 1 --speculative_model facebook/opt-125m --use-v2-block-manager \ + --num_speculative_tokens 5 --gpu_memory_utilization 0.8 - Then use a client: +Then use a client: .. code-block:: python @@ -104,7 +103,6 @@ matching n-grams in the prompt. For more information read `this thread. `_. -The following is the list of model architectures that are currently supported by vLLM. +vLLM supports a variety of generative and embedding models from `HuggingFace (HF) Transformers `_. +This page lists the model architectures that are currently supported by vLLM. Alongside each architecture, we include some popular models that use it. +For other models, you can check the :code:`config.json` file inside the model repository. +If the :code:`"architectures"` field contains a model architecture listed below, then it should be supported in theory. + +.. tip:: + The easiest way to check if your model is really supported at runtime is to run the program below: + + .. code-block:: python + + from vllm import LLM + + llm = LLM(model=...) # Name or path of your model + output = llm.generate("Hello, my name is") + print(output) + + If vLLM successfully generates text, it indicates that your model is supported. + +Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` +for instructions on how to implement your model in vLLM. +Alternatively, you can `open an issue on GitHub `_ to request vLLM support. + +.. note:: + To use models from `ModelScope `_ instead of HuggingFace Hub, set an environment variable: + + .. code-block:: shell + + $ export VLLM_USE_MODELSCOPE=True + + And use with :code:`trust_remote_code=True`. + + .. code-block:: python + + from vllm import LLM + + llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model + output = llm.generate("Hello, my name is") + print(output) + Text-only Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -19,7 +56,7 @@ Text Generation * - Architecture - Models - - Example HuggingFace Models + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`AquilaForCausalLM` @@ -87,6 +124,11 @@ Text Generation - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - - āœ…ļøŽ + * - :code:`FalconMambaForCausalLM` + - FalconMamba + - :code:`tiiuae/falcon-mamba-7b`, :code:`tiiuae/falcon-mamba-7b-instruct`, etc. + - āœ…ļøŽ + - * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. @@ -139,7 +181,7 @@ Text Generation - āœ…ļøŽ * - :code:`JAISLMHeadModel` - Jais - - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. + - :code:`inceptionai/jais-13b`, :code:`inceptionai/jais-13b-chat`, :code:`inceptionai/jais-30b-v3`, :code:`inceptionai/jais-30b-chat-v3`, etc. - - āœ…ļøŽ * - :code:`JambaForCausalLM` @@ -155,11 +197,11 @@ Text Generation * - :code:`MambaForCausalLM` - Mamba - :code:`state-spaces/mamba-130m-hf`, :code:`state-spaces/mamba-790m-hf`, :code:`state-spaces/mamba-2.8b-hf`, etc. - - āœ…ļøŽ + - - * - :code:`MiniCPMForCausalLM` - MiniCPM - - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. + - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc. - āœ…ļøŽ - āœ…ļøŽ * - :code:`MiniCPM3ForCausalLM` @@ -280,7 +322,7 @@ Text Embedding * - Architecture - Models - - Example HuggingFace Models + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`Gemma2Model` @@ -294,6 +336,10 @@ Text Embedding - - āœ…ļøŽ +.. important:: + Some model architectures support both generation and embedding tasks. + In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. + Reward Modeling --------------- @@ -303,7 +349,7 @@ Reward Modeling * - Architecture - Models - - Example HuggingFace Models + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`Qwen2ForRewardModel` @@ -316,7 +362,22 @@ Reward Modeling As an interim measure, these models are supported via Embeddings API. See `this RFC `_ for upcoming changes. Multimodal Language Models -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The following modalities are supported depending on the model: + +- **T**\ ext +- **I**\ mage +- **V**\ ideo +- **A**\ udio + +Any combination of modalities joined by :code:`+` are supported. + +- e.g.: :code:`T + I` means that the model supports text-only, image-only, and text-with-image inputs. + +On the other hand, modalities separated by :code:`/` are mutually exclusive. + +- e.g.: :code:`T / I` means that the model supports text-only and image-only inputs, but not text-with-image inputs. .. _supported_vlms: @@ -324,126 +385,132 @@ Text Generation --------------- .. list-table:: - :widths: 25 25 25 25 5 5 + :widths: 25 25 15 25 5 5 :header-rows: 1 * - Architecture - Models - - Modalities - - Example HuggingFace Models + - Inputs + - Example HF Models - :ref:`LoRA ` - :ref:`PP ` * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - - Image\ :sup:`E` + - T + I\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - - āœ…ļøŽ * - :code:`ChameleonForConditionalGeneration` - Chameleon - - Image + - T + I - :code:`facebook/chameleon-7b` etc. - - āœ…ļøŽ * - :code:`FuyuForCausalLM` - Fuyu - - Image + - T + I - :code:`adept/fuyu-8b` etc. - - āœ…ļøŽ * - :code:`ChatGLMModel` - GLM-4V - - Image + - T + I - :code:`THUDM/glm-4v-9b` etc. - - āœ…ļøŽ * - :code:`InternVLChatModel` - InternVL2 - - Image\ :sup:`E+` - - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. + - T + I\ :sup:`E+` + - :code:`OpenGVLab/Mono-InternVL-2B`, :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - - āœ…ļøŽ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - - āœ…ļøŽ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - - āœ…ļøŽ * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - - Video + - T + V - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - - āœ…ļøŽ * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - - Image\ :sup:`+` / Video + - T + I\ :sup:`+` + V - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - - āœ…ļøŽ * - :code:`MiniCPMV` - MiniCPM-V - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - āœ…ļøŽ - āœ…ļøŽ * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - - Image + - T + I - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - - * - :code:`MolmoForCausalLM` - Molmo - - Image + - T + I - :code:`allenai/Molmo-7B-D-0924`, :code:`allenai/Molmo-72B-0924`, etc. - - āœ…ļøŽ * - :code:`NVLM_D_Model` - NVLM-D 1.0 - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`nvidia/NVLM-D-72B`, etc. - - āœ…ļøŽ * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - - Image\ :sup:`E` + - T + I\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - - āœ…ļøŽ * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - āœ…ļøŽ * - :code:`PixtralForConditionalGeneration` - Pixtral - - Image\ :sup:`+` - - :code:`mistralai/Pixtral-12B-2409` + - T + I\ :sup:`+` + - :code:`mistralai/Pixtral-12B-2409`, :code:`mistral-community/pixtral-12b` etc. - - āœ…ļøŽ * - :code:`QWenLMHeadModel` - Qwen-VL - - Image\ :sup:`E+` + - T + I\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - - āœ…ļøŽ + * - :code:`Qwen2AudioForConditionalGeneration` + - Qwen2-Audio + - T + A\ :sup:`+` + - :code:`Qwen/Qwen2-Audio-7B-Instruct` + - + - āœ…ļøŽ * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - - Image\ :sup:`E+` / Video\ :sup:`+` + - T + I\ :sup:`E+` + V\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - - āœ…ļøŽ * - :code:`UltravoxModel` - Ultravox - - Audio\ :sup:`E+` + - T + A\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - - āœ…ļøŽ @@ -455,43 +522,35 @@ Text Generation For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 ----- - -If your model uses one of the above model architectures, you can seamlessly run your model with vLLM. -Otherwise, please refer to :ref:`Adding a New Model ` and :ref:`Enabling Multimodal Inputs ` -for instructions on how to implement support for your model. -Alternatively, you can raise an issue on our `GitHub `_ project. - -.. tip:: - The easiest way to check if your model is supported is to run the program below: - - .. code-block:: python - - from vllm import LLM - - llm = LLM(model=...) # Name or path of your model - output = llm.generate("Hello, my name is") - print(output) - - If vLLM successfully generates text, it indicates that your model is supported. - -.. tip:: - To use models from `ModelScope `_ instead of HuggingFace Hub, set an environment variable: - - .. code-block:: shell - - $ export VLLM_USE_MODELSCOPE=True - - And use with :code:`trust_remote_code=True`. - - .. code-block:: python +Multimodal Embedding +-------------------- - from vllm import LLM +.. list-table:: + :widths: 25 25 15 25 5 5 + :header-rows: 1 - llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model - output = llm.generate("Hello, my name is") - print(output) + * - Architecture + - Models + - Inputs + - Example HF Models + - :ref:`LoRA ` + - :ref:`PP ` + * - :code:`LlavaNextForConditionalGeneration` + - LLaVA-NeXT-based + - T / I + - :code:`royokong/e5-v` + - + - āœ…ļøŽ + * - :code:`Phi3VForCausalLM` + - Phi-3-Vision-based + - T + I + - :code:`TIGER-Lab/VLM2Vec-Full` + - 🚧 + - āœ…ļøŽ +.. important:: + Some model architectures support both generation and embedding tasks. + In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. Model Support Policy ===================== diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index a3ee5da04422..a47902ab4fc9 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -181,8 +181,8 @@ Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruc .. code-block:: bash - vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ - --trust-remote-code --limit-mm-per-prompt image=2 + vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 .. important:: Since OpenAI Vision API is based on `Chat Completions `_ API, @@ -241,15 +241,15 @@ To consume the server, you can use the OpenAI client like in the example below: print("Chat completion output:", chat_response.choices[0].message.content) -A full code example can be found in `examples/openai_vision_api_client.py `_. +A full code example can be found in `examples/openai_api_client_for_multimodal.py `_. .. note:: By default, the timeout for fetching images through http url is ``5`` seconds. You can override this by setting the environment variable: - .. code-block:: shell + .. code-block:: console - export VLLM_IMAGE_FETCH_TIMEOUT= + $ export VLLM_IMAGE_FETCH_TIMEOUT= .. note:: There is no need to format the prompt in the API request since it will be handled by the server. diff --git a/docs/source/serving/deploying_with_nginx.rst b/docs/source/serving/deploying_with_nginx.rst new file mode 100644 index 000000000000..b5dff02b6bae --- /dev/null +++ b/docs/source/serving/deploying_with_nginx.rst @@ -0,0 +1,142 @@ +.. _nginxloadbalancer: + +Deploying with Nginx Loadbalancer +================================= + +This document shows how to launch multiple vLLM serving containers and use Nginx to act as a load balancer between the servers. + +Table of contents: + +#. :ref:`Build Nginx Container ` +#. :ref:`Create Simple Nginx Config file ` +#. :ref:`Build vLLM Container ` +#. :ref:`Create Docker Network ` +#. :ref:`Launch vLLM Containers ` +#. :ref:`Launch Nginx ` +#. :ref:`Verify That vLLM Servers Are Ready ` + +.. _nginxloadbalancer_nginx_build: + +Build Nginx Container +--------------------- + +This guide assumes that you have just cloned the vLLM project and you're currently in the vllm root directory. + +.. code-block:: console + + export vllm_root=`pwd` + +Create a file named ``Dockerfile.nginx``: + +.. code-block:: console + + FROM nginx:latest + RUN rm /etc/nginx/conf.d/default.conf + EXPOSE 80 + CMD ["nginx", "-g", "daemon off;"] + +Build the container: + +.. code-block:: console + + docker build . -f Dockerfile.nginx --tag nginx-lb + +.. _nginxloadbalancer_nginx_conf: + +Create Simple Nginx Config file +------------------------------- + +Create a file named ``nginx_conf/nginx.conf``. Note that you can add as many servers as you'd like. In the below example we'll start with two. To add more, add another ``server vllmN:8000 max_fails=3 fail_timeout=10000s;`` entry to ``upstream backend``. + +.. code-block:: console + + upstream backend { + least_conn; + server vllm0:8000 max_fails=3 fail_timeout=10000s; + server vllm1:8000 max_fails=3 fail_timeout=10000s; + } + server { + listen 80; + location / { + proxy_pass http://backend; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } + } + +.. _nginxloadbalancer_nginx_vllm_container: + +Build vLLM Container +-------------------- + +.. code-block:: console + + cd $vllm_root + docker build -f Dockerfile . --tag vllm + + +If you are behind proxy, you can pass the proxy settings to the docker build command as shown below: + +.. code-block:: console + + cd $vllm_root + docker build -f Dockerfile . --tag vllm --build-arg http_proxy=$http_proxy --build-arg https_proxy=$https_proxy + +.. _nginxloadbalancer_nginx_docker_network: + +Create Docker Network +--------------------- + +.. code-block:: console + + docker network create vllm_nginx + + +.. _nginxloadbalancer_nginx_launch_container: + +Launch vLLM Containers +---------------------- + +Notes: + +* If you have your HuggingFace models cached somewhere else, update ``hf_cache_dir`` below. +* If you don't have an existing HuggingFace cache you will want to start ``vllm0`` and wait for the model to complete downloading and the server to be ready. This will ensure that ``vllm1`` can leverage the model you just downloaded and it won't have to be downloaded again. +* The below example assumes GPU backend used. If you are using CPU backend, remove ``--gpus all``, add ``VLLM_CPU_KVCACHE_SPACE`` and ``VLLM_CPU_OMP_THREADS_BIND`` environment variables to the docker run command. +* Adjust the model name that you want to use in your vLLM servers if you don't want to use ``Llama-2-7b-chat-hf``. + +.. code-block:: console + + mkdir -p ~/.cache/huggingface/hub/ + hf_cache_dir=~/.cache/huggingface/ + docker run -itd --ipc host --privileged --network vllm_nginx --gpus all --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8081:8000 --name vllm0 vllm --model meta-llama/Llama-2-7b-chat-hf + docker run -itd --ipc host --privileged --network vllm_nginx --gpus all --shm-size=10.24gb -v $hf_cache_dir:/root/.cache/huggingface/ -p 8082:8000 --name vllm1 vllm --model meta-llama/Llama-2-7b-chat-hf + +.. note:: + If you are behind proxy, you can pass the proxy settings to the docker run command via ``-e http_proxy=$http_proxy -e https_proxy=$https_proxy``. + +.. _nginxloadbalancer_nginx_launch_nginx: + +Launch Nginx +------------ + +.. code-block:: console + + docker run -itd -p 8000:80 --network vllm_nginx -v ./nginx_conf/:/etc/nginx/conf.d/ --name nginx-lb nginx-lb:latest + +.. _nginxloadbalancer_nginx_verify_nginx: + +Verify That vLLM Servers Are Ready +---------------------------------- + +.. code-block:: console + + docker logs vllm0 | grep Uvicorn + docker logs vllm1 | grep Uvicorn + +Both outputs should look like this: + +.. code-block:: console + + INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 9132e12a36ba..8ee83a34adb8 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -103,6 +103,23 @@ vllm serve --chat-template ./path-to-chat-template.jinja vLLM community provides a set of chat templates for popular models. You can find them in the examples directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) +With the inclusion of multi-modal chat APIs, the OpenAI spec now accepts chat messages in a new format which specifies +both a `type` and a `text` field. An example is provided below: +```python +completion = client.chat.completions.create( + model="NousResearch/Meta-Llama-3-8B-Instruct", + messages=[ + {"role": "user", "content": [{"type": "text", "text": "Classify this sentiment: vLLM is wonderful!"}]} + ] +) +``` +Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like +`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which +format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify +between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match +this, unless explicitly specified. + + ## Command line arguments for the server ```{argparse} @@ -157,7 +174,7 @@ vLLM will use guided decoding to ensure the response matches the tool parameter To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers +* `--tool-call-parser` -- select the tool parser to use (listed below). Additional tool parsers will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. * `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages @@ -168,7 +185,7 @@ from HuggingFace; and you can find an example of this in a `tokenizer_config.jso If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! -#### Hermes Models +#### Hermes Models (`hermes`) All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. * `NousResearch/Hermes-2-Pro-*` * `NousResearch/Hermes-2-Theta-*` @@ -180,7 +197,7 @@ step in their creation_. Flags: `--tool-call-parser hermes` -#### Mistral Models +#### Mistral Models (`mistral`) Supported models: * `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) * Additional mistral function-calling models are compatible as well. @@ -199,7 +216,7 @@ when tools are provided, that results in much better reliability when working wi Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` -#### Llama Models +#### Llama Models (`llama3_json`) Supported models: * `meta-llama/Meta-Llama-3.1-8B-Instruct` * `meta-llama/Meta-Llama-3.1-70B-Instruct` @@ -219,16 +236,60 @@ it works better with vLLM. Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` -#### Internlm Models +#### IBM Granite + +Supported models: +* `ibm-granite/granite-20b-functioncalling` + +Flags: `--tool-call-parser granite-20b-fc` +`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + +* `ibm-granite/granite-8b-instruct` + +Flags: `--tool-call-parser granite` +`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. + + +#### IBM Granite +Supported models: +* `ibm-granite/granite-20b-functioncalling` + +Flags: `--tool-call-parser granite-20b-fc` +`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + +* `ibm-granite/granite-8b-instruct` + +Flags: `--tool-call-parser granite` +`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. + + +#### InternLM Models (`internlm`) Supported models: * `internlm/internlm2_5-7b-chat` (confirmed) * Additional internlm2.5 function-calling models are compatible as well Known issues: -* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model. +* Although this implementation also supports InternLM2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model. Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja` +#### Jamba Models (`jamba`) +AI21's Jamba-1.5 models are supported. +* `ai21labs/AI21-Jamba-1.5-Mini` +* `ai21labs/AI21-Jamba-1.5-Large` + + +Flags: `--tool-call-parser jamba` + + +#### IBM Granite + +Supported models: +* `ibm-granite/granite-20b-functioncalling` + +Flags: `--tool-call-parser granite-20b-fc` +`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. + ### How to write a tool parser plugin @@ -287,5 +348,5 @@ Then you can use this plugin in the command line like this. --tool-parser-plugin --tool-call-parser example \ --chat-template \ -``` +``` diff --git a/docs/source/serving/tensorizer.rst b/docs/source/serving/tensorizer.rst index a44696507fb9..96a93db94871 100644 --- a/docs/source/serving/tensorizer.rst +++ b/docs/source/serving/tensorizer.rst @@ -9,4 +9,7 @@ shorter Pod startup times and CPU memory usage. Tensor encryption is also suppor For more information on CoreWeave's Tensorizer, please refer to `CoreWeave's Tensorizer documentation `_. For more information on serializing a vLLM model, as well a general usage guide to using Tensorizer with vLLM, see -the `vLLM example script `_. \ No newline at end of file +the `vLLM example script `_. + +.. note:: + Note that to use this feature you will need to install `tensorizer` by running `pip install vllm[tensorizer]`. diff --git a/examples/florence2_inference.py b/examples/florence2_inference.py new file mode 100644 index 000000000000..b58ac2e1f7ed --- /dev/null +++ b/examples/florence2_inference.py @@ -0,0 +1,44 @@ +''' +Demonstrate prompting of text-to-text +encoder/decoder models, specifically Florence-2 +''' +# TODO(Isotr0py): +# Move to offline_inference_vision_language.py after porting vision backbone +from vllm import LLM, SamplingParams + +dtype = "float" + +# Create a Florence-2 encoder/decoder model instance +llm = LLM( + model="microsoft/Florence-2-base", + tokenizer="facebook/bart-base", + dtype=dtype, + trust_remote_code=True, +) + +prompts = [ + "", "", "", + "", "", "", + "", "", "" +] +# Create a sampling params object. +sampling_params = SamplingParams( + temperature=0, + top_p=1.0, + min_tokens=0, + max_tokens=20, +) + +# Generate output tokens from the prompts. The output is a list of +# RequestOutput objects that contain the prompt, generated +# text, and other information. +outputs = llm.generate(prompts, sampling_params) + +# Print the outputs. +for output in outputs: + prompt = output.prompt + encoder_prompt = output.encoder_prompt + generated_text = output.outputs[0].text + print(f"Encoder prompt: {encoder_prompt!r}, " + f"Decoder prompt: {prompt!r}, " + f"Generated text: {generated_text!r}") diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 1c6ac06123bb..37ec667d96a7 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -12,14 +12,15 @@ from vllm.utils import FlexibleArgumentParser audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] -question_per_audio_count = [ - "What is recited in the audio?", - "What sport and what nursery rhyme are referenced?" -] +question_per_audio_count = { + 0: "What is 1+1?", + 1: "What is recited in the audio?", + 2: "What sport and what nursery rhyme are referenced?" +} # Ultravox 0.3 -def run_ultravox(question, audio_count): +def run_ultravox(question: str, audio_count: int): model_name = "fixie-ai/ultravox-v0_3" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -42,9 +43,29 @@ def run_ultravox(question, audio_count): return llm, prompt, stop_token_ids -model_example_map = { - "ultravox": run_ultravox, -} +# Qwen2-Audio +def run_qwen2_audio(question: str, audio_count: int): + model_name = "Qwen/Qwen2-Audio-7B-Instruct" + + llm = LLM(model=model_name, + max_model_len=4096, + max_num_seqs=5, + limit_mm_per_prompt={"audio": audio_count}) + + audio_in_prompt = "".join([ + f"Audio {idx+1}: " + f"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count) + ]) + + prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + f"{audio_in_prompt}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + stop_token_ids = None + return llm, prompt, stop_token_ids + + +model_example_map = {"ultravox": run_ultravox, "qwen2_audio": run_qwen2_audio} def main(args): @@ -54,7 +75,7 @@ def main(args): audio_count = args.num_audios llm, prompt, stop_token_ids = model_example_map[model]( - question_per_audio_count[audio_count - 1], audio_count) + question_per_audio_count[audio_count], audio_count) # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. @@ -62,16 +83,17 @@ def main(args): max_tokens=64, stop_token_ids=stop_token_ids) - assert args.num_prompts > 0 - inputs = { - "prompt": prompt, - "multi_modal_data": { + mm_data = {} + if audio_count > 0: + mm_data = { "audio": [ asset.audio_and_sample_rate for asset in audio_assets[:audio_count] ] - }, - } + } + + assert args.num_prompts > 0 + inputs = {"prompt": prompt, "multi_modal_data": mm_data} if args.num_prompts > 1: # Batch inference inputs = [inputs] * args.num_prompts @@ -100,7 +122,7 @@ def main(args): parser.add_argument("--num-audios", type=int, default=1, - choices=[1, 2], + choices=[0, 1, 2], help="Number of audio items per prompt.") args = parser.parse_args() diff --git a/examples/offline_inference_mlpspeculator.py b/examples/offline_inference_mlpspeculator.py index 5dec4a76afb2..8f0eb65e47f6 100644 --- a/examples/offline_inference_mlpspeculator.py +++ b/examples/offline_inference_mlpspeculator.py @@ -50,8 +50,6 @@ def time_generation(llm: LLM, prompts: List[str], llm = LLM( model="meta-llama/Llama-2-13b-chat-hf", speculative_model="ibm-fms/llama-13b-accelerator", - # These are currently required for MLPSpeculator decoding - use_v2_block_manager=True, ) print("With speculation") diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 4c88dcc2f087..83d2548a506e 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -1,6 +1,6 @@ """ -This example shows how to use vLLM for running offline inference -with the correct prompt format on vision language models. +This example shows how to use vLLM for running offline inference with +the correct prompt format on vision language models for text generation. For most models, the prompt format should follow corresponding examples on HuggingFace model repository. @@ -267,6 +267,11 @@ def run_qwen2_vl(question: str, modality: str): model=model_name, max_model_len=8192, max_num_seqs=5, + # Note - mm_processor_kwargs can also be passed to generate/chat calls + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + }, ) prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" @@ -277,6 +282,22 @@ def run_qwen2_vl(question: str, modality: str): return llm, prompt, stop_token_ids +# Pixtral HF-format +def run_pixtral_hf(question: str, modality: str): + assert modality == "image" + + model_name = "mistral-community/pixtral-12b" + + llm = LLM( + model=model_name, + max_model_len=8192, + ) + + prompt = f"[INST]{question}\n[IMG][/INST]" + stop_token_ids = None + return llm, prompt, stop_token_ids + + # LLama 3.2 def run_mllama(question: str, modality: str): assert modality == "image" @@ -347,6 +368,7 @@ def run_glm4v(question: str, modality: str): "NVLM_D": run_nvlm_d, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "pixtral_hf": run_pixtral_hf, "mllama": run_mllama, "molmo": run_molmo, "glm4v": run_glm4v, @@ -433,7 +455,7 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' - 'vision language models') + 'vision language models for text generation') parser.add_argument('--model-type', '-m', type=str, diff --git a/examples/offline_inference_vision_language_embedding.py b/examples/offline_inference_vision_language_embedding.py new file mode 100644 index 000000000000..e1732d045f94 --- /dev/null +++ b/examples/offline_inference_vision_language_embedding.py @@ -0,0 +1,170 @@ +""" +This example shows how to use vLLM for running offline inference with +the correct prompt format on vision language models for multimodal embedding. + +For most models, the prompt format should follow corresponding examples +on HuggingFace model repository. +""" +from argparse import Namespace +from typing import Literal, NamedTuple, Optional, TypedDict, Union, get_args + +from PIL.Image import Image + +from vllm import LLM +from vllm.multimodal.utils import fetch_image +from vllm.utils import FlexibleArgumentParser + + +class TextQuery(TypedDict): + modality: Literal["text"] + text: str + + +class ImageQuery(TypedDict): + modality: Literal["image"] + image: Image + + +class TextImageQuery(TypedDict): + modality: Literal["text+image"] + text: str + image: Image + + +QueryModality = Literal["text", "image", "text+image"] +Query = Union[TextQuery, ImageQuery, TextImageQuery] + + +class ModelRequestData(NamedTuple): + llm: LLM + prompt: str + image: Optional[Image] + + +def run_e5_v(query: Query): + llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 + + if query["modality"] == "text": + text = query["text"] + prompt = llama3_template.format( + f"{text}\nSummary above sentence in one word: ") + image = None + elif query["modality"] == "image": + prompt = llama3_template.format( + "\nSummary above image in one word: ") + image = query["image"] + else: + modality = query['modality'] + raise ValueError(f"Unsupported query modality: '{modality}'") + + llm = LLM( + model="royokong/e5-v", + task="embedding", + max_model_len=4096, + ) + + return ModelRequestData( + llm=llm, + prompt=prompt, + image=image, + ) + + +def run_vlm2vec(query: Query): + if query["modality"] == "text": + text = query["text"] + prompt = f"Find me an everyday image that matches the given caption: {text}" # noqa: E501 + image = None + elif query["modality"] == "image": + prompt = "<|image_1|> Find a day-to-day image that looks similar to the provided image." # noqa: E501 + image = query["image"] + elif query["modality"] == "text+image": + text = query["text"] + prompt = f"<|image_1|> Represent the given image with the following question: {text}" # noqa: E501 + image = query["image"] + else: + modality = query['modality'] + raise ValueError(f"Unsupported query modality: '{modality}'") + + llm = LLM( + model="TIGER-Lab/VLM2Vec-Full", + task="embedding", + trust_remote_code=True, + mm_processor_kwargs={"num_crops": 4}, + ) + + return ModelRequestData( + llm=llm, + prompt=prompt, + image=image, + ) + + +def get_query(modality: QueryModality): + if modality == "text": + return TextQuery(modality="text", text="A dog sitting in the grass") + + if modality == "image": + return ImageQuery( + modality="image", + image=fetch_image( + "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/American_Eskimo_Dog.jpg/360px-American_Eskimo_Dog.jpg" # noqa: E501 + ), + ) + + if modality == "text+image": + return TextImageQuery( + modality="text+image", + text="A cat standing in the snow.", + image=fetch_image( + "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b6/Felis_catus-cat_on_snow.jpg/179px-Felis_catus-cat_on_snow.jpg" # noqa: E501 + ), + ) + + msg = f"Modality {modality} is not supported." + raise ValueError(msg) + + +def run_encode(model: str, modality: QueryModality): + query = get_query(modality) + req_data = model_example_map[model](query) + + mm_data = {} + if req_data.image is not None: + mm_data["image"] = req_data.image + + outputs = req_data.llm.encode({ + "prompt": req_data.prompt, + "multi_modal_data": mm_data, + }) + + for output in outputs: + print(output.outputs.embedding) + + +def main(args: Namespace): + run_encode(args.model_name, args.modality) + + +model_example_map = { + "e5_v": run_e5_v, + "vlm2vec": run_vlm2vec, +} + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models for multimodal embedding') + parser.add_argument('--model-name', + '-m', + type=str, + default="vlm2vec", + choices=model_example_map.keys(), + help='The name of the embedding model.') + parser.add_argument('--modality', + type=str, + default="image", + choices=get_args(QueryModality), + help='Modality of the input.') + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index c4e4cdc0db95..e28514bf403f 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -1,7 +1,7 @@ """ This example shows how to use vLLM for running offline inference with -multi-image input on vision language models, using the chat template defined -by the model. +multi-image input on vision language models for text generation, +using the chat template defined by the model. """ from argparse import Namespace from typing import List, NamedTuple, Optional @@ -234,12 +234,35 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: ) +def load_mllama(question, image_urls: List[str]) -> ModelRequestData: + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + # The configuration below has been confirmed to launch on a single L40 GPU. + llm = LLM( + model=model_name, + max_model_len=4096, + max_num_seqs=16, + enforce_eager=True, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + prompt = f"<|image|><|image|><|begin_of_text|>{question}" + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=None, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None, + ) + + model_example_map = { "phi3_v": load_phi3v, "internvl_chat": load_internvl, "NVLM_D": load_nvlm_d, "qwen2_vl": load_qwen2_vl, "qwen_vl_chat": load_qwenvl_chat, + "mllama": load_mllama, } @@ -311,7 +334,8 @@ def main(args: Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' - 'vision language models that support multi-image input') + 'vision language models that support multi-image input for text ' + 'generation') parser.add_argument('--model-type', '-m', type=str, diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py index 3b3e0ae64a03..67b755a15596 100644 --- a/examples/offline_inference_with_prefix.py +++ b/examples/offline_inference_with_prefix.py @@ -1,4 +1,5 @@ from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory # NOTE: This is just a running example. For benchmarking purpose, # please see benchmarks/benchmark_prefix_caching.py @@ -28,12 +29,9 @@ # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0) -# Create an LLM. +# Create an LLM without prefix caching as a baseline. regular_llm = LLM(model="facebook/opt-125m", gpu_memory_utilization=0.4) -prefix_cached_llm = LLM(model="facebook/opt-125m", - enable_prefix_caching=True, - gpu_memory_utilization=0.4) print("Results without `enable_prefix_caching`") # Generate texts from the prompts. The output is a list of RequestOutput objects @@ -50,6 +48,15 @@ print("-" * 80) +# Destroy the LLM object and free up the GPU memory. +del regular_llm +cleanup_dist_env_and_memory() + +# Create an LLM with prefix caching enabled. +prefix_cached_llm = LLM(model="facebook/opt-125m", + enable_prefix_caching=True, + gpu_memory_utilization=0.4) + # Warmup so that the shared prompt's KV cache is computed. prefix_cached_llm.generate(generating_prompts[0], sampling_params) diff --git a/examples/offline_profile.py b/examples/offline_profile.py new file mode 100644 index 000000000000..1d415b82cddb --- /dev/null +++ b/examples/offline_profile.py @@ -0,0 +1,282 @@ +import inspect +import json +import os +import sys +from argparse import RawTextHelpFormatter +from dataclasses import asdict, dataclass +from typing import Optional + +import torch + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.profiler import layerwise_profile +from vllm.utils import FlexibleArgumentParser + +BATCH_SIZE_DEFAULT = 1 +PROMPT_LEN_DEFAULT = 256 +OUTPUT_LEN_DEFAULT = 2 + + +@dataclass +class ProfileContext: + engine_args: EngineArgs + prompt_len: int + output_len: int + batch_size: int + save_chrome_traces_folder: Optional[str] + + +def get_dtype(dtype: str): + if dtype == "torch.float": + return torch.float + else: + return dtype + + +def run_profile(context: ProfileContext, csv_output: Optional[str], + json_output: Optional[str]): + print("Run profile with:") + for key, value in asdict(context).items(): + print(f" {key} = {value}") + + # Create sampling params + sampling_params = SamplingParams(temperature=0.8, + top_p=0.95, + max_tokens=args.output_len, + ignore_eos=True) + + # Create LLM + llm = LLM(**asdict(context.engine_args)) + batch_size = context.batch_size + prompt_len = context.prompt_len + output_len = context.output_len + + scheduler_config = llm.llm_engine.scheduler_config + max_model_len = llm.llm_engine.model_config.max_model_len + max_num_batched_tokens = scheduler_config.max_num_batched_tokens + max_num_seqs = scheduler_config.max_num_seqs + + if batch_size * prompt_len > max_num_batched_tokens: + print(f"ERROR: chosen batch_size * prompt_len " + f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is " + f"larger than max_num_batched_tokens ({max_num_batched_tokens}) " + f"and therefore cannot be run in a single profile step, please " + f"choose a smaller batch size or prompt length, or increase " + f"--max-num-batched-tokens") + sys.exit(-1) + if batch_size >= max_num_seqs: + print( + f"ERROR: chosen batch_size ({batch_size}) is larger than " + f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " + f"single profile step, please choose a smaller batch size") + sys.exit(-1) + print("llm.llm_engine.model_config.max_model_len: ", + llm.llm_engine.model_config.max_model_len) + if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: + print( + f"ERROR: chosen prompt_len + output_len ({prompt_len} + " + f"{output_len} = {prompt_len + output_len}) is larger than the " + f"model's max_model_len ({max_model_len}), please choose a smaller " + f"prompt_len or output_len, or increase --max-model-len") + sys.exit(-1) + + def add_requests(): + for i in range(batch_size): + prompt_token_ids = torch.randint( + llm.llm_engine.model_config.get_vocab_size(), + size=(prompt_len, )).tolist() + + llm.llm_engine.add_request( + request_id=f"seq{i}", + prompt={'prompt_token_ids': prompt_token_ids}, + params=sampling_params) + + def abort_requests(): + for i in range(batch_size): + llm.llm_engine.abort_request(f"seq{i}") + + # Warm up run + print("Warm up run ...") + add_requests() + llm.llm_engine.step() # Prefill + llm.llm_engine.step() # Decode + abort_requests() + + print("Profile run ...") + add_requests() + + with layerwise_profile() as prefill_prof: + llm.llm_engine.step() # First step is prefill + + decode_profs = [] + for x in range(args.output_len - 1): + with layerwise_profile() as decode_prof: + llm.llm_engine.step() + decode_profs.append(decode_prof) + + decode_results_list = [prof.results for prof in decode_profs] + prefill_results = prefill_prof.results + has_decode = len(decode_results_list) > 0 + + LINE_WIDTH = 80 + print("=" * LINE_WIDTH) + print(f"= Prefill Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + prefill_results.print_model_table() + + if has_decode: + print() + print("=" * LINE_WIDTH) + print(f"= First Decode Step Model Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + decode_results_list[0].print_model_table() + + print() + print("=" * LINE_WIDTH) + print(f"= Prefill Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + prefill_results.print_summary_table() + + if has_decode: + print() + print("=" * LINE_WIDTH) + print(f"= First Decode Step Summary Table " + f"(prompt_len={prompt_len}, batch_size={batch_size})") + print("=" * LINE_WIDTH) + print() + decode_results_list[0].print_summary_table() + + if csv_output: + csv_filename_base = csv_output.rstrip(".csv") + prefill_results.export_model_stats_table_csv( + csv_filename_base + "_prefill_model_table.csv") + prefill_results.export_summary_stats_table_csv( + csv_filename_base + "_prefill_summary_table.csv") + + if has_decode: + decode_results_list[0].export_model_stats_table_csv(\ + csv_filename_base + "_decode_model_table.csv") + decode_results_list[0].export_summary_stats_table_csv( + csv_filename_base + "_decode_summary_table.csv") + + if json_output: + cuda_devices = [ + torch.cuda.get_device_properties(dev_idx) + for dev_idx in range(torch.cuda.device_count()) + ] + + json_dict = { + "context": { + "python_version": f"{sys.version}", + "torch_version": f"{torch.__version__}", + "torch_cuda_version": f"{torch.version.cuda}", + "cuda_devices": f"{cuda_devices}", + **asdict(context) + }, + "prefill": prefill_results.convert_stats_to_dict(), + } + + if has_decode: + for idx, dr in enumerate(decode_results_list): + json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() + + for idx, dr in enumerate(decode_results_list[1:]): + json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() + + with open(json_output.rstrip(".json") + ".json", "w+") as f: + json.dump(json_dict, f, indent=2) + pass + + if context.save_chrome_traces_folder is not None: + os.makedirs(context.save_chrome_traces_folder, exist_ok=True) + prefill_prof.profiler.export_chrome_trace( + context.save_chrome_traces_folder + "/prefill.json") + for idx, decode_prof in enumerate(decode_profs): + decode_prof.profiler.export_chrome_trace( + context.save_chrome_traces_folder + f"/decode_{idx + 1}.json") + print("Traces saved as prefill.json and decode_1.json, etc." + f" in folder {context.save_chrome_traces_folder}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description=""" +Profile a model + + example: + ``` + python examples/offline_profile.py \\ + --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ + --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ + --enforce-eager + ``` + + then you can use various tools to analyze the json output + terminal ascii tables: + ``` + python tools/profiler/print_layerwise_table.py \\ + --json-trace Llama31-8b-FP8.json --phase prefill --table summary + ``` + or create matplotlib stacked bar charts: + ``` + python tools/profiler/visualize_layerwise_profile.py \\ + --json-trace Llama31-8b-FP8.json \\ + --output-directory profile_breakdown --plot-metric pct_cuda_time + ``` +""", + formatter_class=RawTextHelpFormatter) + parser.add_argument( + "--csv", + type=str, + default=None, + help="Export the results as multiple csv file. This should be the root " + "filename, will create _prefill_model_table.csv, " + "_prefill_summary_table.csv, " + "_decode_model_table.csv, and " + "_decode_summary_table.csv") + parser.add_argument( + "--json", + type=str, + default=None, + help="Export the results as a json file. This should be the filename") + parser.add_argument("--save-chrome-traces-folder", + type=str, + help="Save chrome traces for the prefill and decode " + "will save traces as prefill.json and decode_1.json, " + "etc. inside this folder") + parser.add_argument( + "--prompt-len", + type=int, + default=PROMPT_LEN_DEFAULT, + help=f"Length of the random prompt to use when profiling, all batched " + f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}") + parser.add_argument("--batch-size", + type=int, + default=BATCH_SIZE_DEFAULT, + help=f"Number of requests to run as a single batch, " + f"default={BATCH_SIZE_DEFAULT}") + parser.add_argument( + "--output-len", + type=int, + default=OUTPUT_LEN_DEFAULT, + help="Number of llm steps to run (includes prefill and decode) " + "- default={OUTPUT_LEN_DEFAULT}") + + EngineArgs.add_cli_args(parser) + + args = parser.parse_args() + + context = ProfileContext( + engine_args=EngineArgs.from_cli_args(args), + **{ + k: v + for k, v in vars(args).items() + if k in inspect.signature(ProfileContext).parameters + }) + run_profile(context, csv_output=args.csv, json_output=args.json) diff --git a/examples/openai_api_client_for_multimodal.py b/examples/openai_api_client_for_multimodal.py new file mode 100644 index 000000000000..beb83e494ed0 --- /dev/null +++ b/examples/openai_api_client_for_multimodal.py @@ -0,0 +1,236 @@ +"""An example showing how to use vLLM to serve multimodal models +and run online inference with OpenAI client. + +Launch the vLLM server with the following command: + +(single image inference with Llava) +vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja + +(multi-image inference with Phi-3.5-vision-instruct) +vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ + --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2 + +(audio inference with Ultravox) +vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096 +""" +import base64 + +import requests +from openai import OpenAI + +from vllm.assets.audio import AudioAsset +from vllm.utils import FlexibleArgumentParser + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + + +def encode_base64_content_from_url(content_url: str) -> str: + """Encode a content retrieved from a remote url to base64 format.""" + + with requests.get(content_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode('utf-8') + + return result + + +# Text-only inference +def run_text_only() -> None: + chat_completion = client.chat.completions.create( + messages=[{ + "role": "user", + "content": "What's the capital of France?" + }], + model=model, + max_tokens=64, + ) + + result = chat_completion.choices[0].message.content + print("Chat completion output:", result) + + +# Single-image input inference +def run_single_image() -> None: + + ## Use image url in the payload + image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url + }, + }, + ], + }], + model=model, + max_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output from image url:", result) + + ## Use base64 encoded image in the payload + image_base64 = encode_base64_content_from_url(image_url) + chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + }, + ], + }], + model=model, + max_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from base64 encoded image:", result) + + +# Multi-image input inference +def run_multi_image() -> None: + image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" + image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What are the animals in these images?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url_duck + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url_lion + }, + }, + ], + }], + model=model, + max_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output:", result) + + +# Audio input inference +def run_audio() -> None: + # Any format supported by librosa is supported + audio_url = AudioAsset("winning_call").url + + # Use audio url in the payload + chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + }, + ], + }], + model=model, + max_tokens=64, + ) + + result = chat_completion_from_url.choices[0].message.content + print("Chat completion output from audio url:", result) + + audio_base64 = encode_base64_content_from_url(audio_url) + chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + # Any format supported by librosa is supported + "url": f"data:audio/ogg;base64,{audio_base64}" + }, + }, + ], + }], + model=model, + max_tokens=64, + ) + + result = chat_completion_from_base64.choices[0].message.content + print("Chat completion output from base64 encoded audio:", result) + + +example_function_map = { + "text-only": run_text_only, + "single-image": run_single_image, + "multi-image": run_multi_image, + "audio": run_audio, +} + + +def main(args) -> None: + chat_type = args.chat_type + example_function_map[chat_type]() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using OpenAI client for online inference with ' + 'multimodal language models served with vLLM.') + parser.add_argument( + '--chat-type', + '-c', + type=str, + default="single-image", + choices=["text-only", "single-image", "multi-image", "audio"], + help='Conversation type with multimodal data.') + args = parser.parse_args() + main(args) diff --git a/examples/openai_audio_api_client.py b/examples/openai_audio_api_client.py deleted file mode 100644 index 80a972683871..000000000000 --- a/examples/openai_audio_api_client.py +++ /dev/null @@ -1,90 +0,0 @@ -"""An example showing how to use vLLM to serve VLMs. - -Launch the vLLM server with the following command: -vllm serve fixie-ai/ultravox-v0_3 -""" -import base64 - -import requests -from openai import OpenAI - -from vllm.assets.audio import AudioAsset - -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" - -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - -# Any format supported by librosa is supported -audio_url = AudioAsset("winning_call").url - -# Use audio url in the payload -chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this audio?" - }, - { - "type": "audio_url", - "audio_url": { - "url": audio_url - }, - }, - ], - }], - model=model, - max_tokens=64, -) - -result = chat_completion_from_url.choices[0].message.content -print(f"Chat completion output:{result}") - - -# Use base64 encoded audio in the payload -def encode_audio_base64_from_url(audio_url: str) -> str: - """Encode an audio retrieved from a remote url to base64 format.""" - - with requests.get(audio_url) as response: - response.raise_for_status() - result = base64.b64encode(response.content).decode('utf-8') - - return result - - -audio_base64 = encode_audio_base64_from_url(audio_url=audio_url) -chat_completion_from_base64 = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this audio?" - }, - { - "type": "audio_url", - "audio_url": { - # Any format supported by librosa is supported - "url": f"data:audio/ogg;base64,{audio_base64}" - }, - }, - ], - }], - model=model, - max_tokens=64, -) - -result = chat_completion_from_base64.choices[0].message.content -print(f"Chat completion output:{result}") diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py deleted file mode 100644 index 71ae03e4d148..000000000000 --- a/examples/openai_vision_api_client.py +++ /dev/null @@ -1,126 +0,0 @@ -"""An example showing how to use vLLM to serve VLMs. - -Launch the vLLM server with the following command: - -(single image inference with Llava) -vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja - -(multi-image inference with Phi-3.5-vision-instruct) -vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ - --trust-remote-code --limit-mm-per-prompt image=2 -""" -import base64 - -import requests -from openai import OpenAI - -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" - -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) - -models = client.models.list() -model = models.data[0].id - -# Single-image input inference -image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" - -## Use image url in the payload -chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url - }, - }, - ], - }], - model=model, - max_tokens=64, -) - -result = chat_completion_from_url.choices[0].message.content -print("Chat completion output:", result) - - -## Use base64 encoded image in the payload -def encode_image_base64_from_url(image_url: str) -> str: - """Encode an image retrieved from a remote url to base64 format.""" - - with requests.get(image_url) as response: - response.raise_for_status() - result = base64.b64encode(response.content).decode('utf-8') - - return result - - -image_base64 = encode_image_base64_from_url(image_url=image_url) -chat_completion_from_base64 = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What's in this image?" - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, - }, - ], - }], - model=model, - max_tokens=64, -) - -result = chat_completion_from_base64.choices[0].message.content -print(f"Chat completion output:{result}") - -# Multi-image input inference -image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" -image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" -chat_completion_from_url = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "What are the animals in these images?" - }, - { - "type": "image_url", - "image_url": { - "url": image_url_duck - }, - }, - { - "type": "image_url", - "image_url": { - "url": image_url_lion - }, - }, - ], - }], - model=model, - max_tokens=64, -) - -result = chat_completion_from_url.choices[0].message.content -print("Chat completion output:", result) diff --git a/examples/tool_chat_template_granite.jinja b/examples/tool_chat_template_granite.jinja new file mode 100644 index 000000000000..adf06600925b --- /dev/null +++ b/examples/tool_chat_template_granite.jinja @@ -0,0 +1,46 @@ +{%- if tools %} + {{- '<|start_of_role|>available_tools<|end_of_role|> +' }} + {%- for tool in tools %} + {{- tool | tojson(indent=4) }} + {%- if not loop.last %} + {{- ' + +' }} + {%- endif %} + {%- endfor %} + {{- '<|end_of_text|> +' }} +{%- endif %} + +{%- if messages[0]["role"] == "system" %} + {%- set sys_prompt = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} + {% set sys_prompt = 'You are a helpful assistant.' %} +{%- endif %} +{{- '<|start_of_role|>system<|end_of_role|>' + sys_prompt + '<|end_of_text|> +' }} +{%- for message in loop_messages %} + {%- if message['role'] == 'user' %} + {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {% for tc in message.tool_calls %} + {{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }} + {% endfor %} + {{- '<|end_of_text|> +' }} + {%- elif message['role'] == 'assistant' %} + {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- elif message['role'] == 'tool_response' or message['role'] == 'tool' %} + {{- '<|start_of_role|>tool_response<|end_of_role|>' + message['content'] + '<|end_of_text|> +' }} + {%- endif %} + {%- if loop.last and add_generation_prompt %} + {{- '<|start_of_role|>assistant<|end_of_role|>' }} + {%- endif %} +{%- endfor %} \ No newline at end of file diff --git a/examples/tool_chat_template_granite_20b_fc.jinja b/examples/tool_chat_template_granite_20b_fc.jinja new file mode 100644 index 000000000000..cb52188ec72d --- /dev/null +++ b/examples/tool_chat_template_granite_20b_fc.jinja @@ -0,0 +1,130 @@ +{%- macro json_to_python_type(json_spec) %} + {%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + + {%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} + {%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]" }} + {%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }} + {%- else %} + {{- "dict" }} + {%- endif %} + {%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} + {%- else %} + {{- "Any" }} + {%- endif %} +{%- endmacro %} + +{%- if not full_function_description is defined %} + {%- set full_function_description = false %} +{%- endif %} + +{%- macro full_description(tool) %} + {{- tool.name + '(' }} + {%- if tool.parameters is defined %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- if tool.parameters is defined %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- endif %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} +{%- endmacro %} + +{%- macro simple_description(tool) %} + {{- tool.description }} +{%- endmacro %} + +{%- macro function_description(tool) %} + {%- if full_function_description %} + {{- full_description(tool) }} + {%- else %} + {{- simple_description(tool) }} + {%- endif %} +{%- endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set sys_prompt = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} + {% set sys_prompt = 'You are a helpful assistant with access to the following function calls. Your task is to understand the given conversation with function calls and responses and generate natural language response as the ASSISTANT to continue the conversation. You may use the following function calls to understand how to respond to the user query.' %} +{%- endif %} + +{{ 'SYSTEM: ' + sys_prompt }} +{% if tools is iterable and tools | length > 0 %} +<|function_call_library|> + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + function_description(tool) }} + {{- ', "parameters": ' }} + {%- if not tool.parameters is defined or tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} + {%- endfor %} +If none of the functions are relevant or the given question lacks the parameters required by the function, please output \" {\"name\": \"no_function\", \"arguments\": {}}\". +{%- endif %} + + + +{% for message in messages %} + {% if message['role'] == 'user' %} + {{- '\nUSER: ' + message['content'] }} + {% elif message['role'] == 'assistant' and message.tool_calls is defined %} + {{- '\nASSISTANT:' }} + {% for tc in message.tool_calls %} + {{- ' ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }} + {% endfor %} + {{- '<|endoftext|>' }} + {% elif message['role'] == 'assistant' %} + {{- '\nASSISTANT: ' + message['content'] + ' <|endoftext|>' }} + {% elif message['role'] == 'tool' %} + {{- ' ' + message['content'] }} + {%- else %} + {{- raise_exception("Unexpected combination of role and message content") }} + {% endif %} + {% if loop.last and add_generation_prompt %} + {{- '\nASSISTANT: ' }} + {% endif %} +{% endfor %} diff --git a/format.sh b/format.sh index 1ac028d00e3a..be6ee0ce46dc 100755 --- a/format.sh +++ b/format.sh @@ -21,6 +21,20 @@ builtin cd "$(dirname "${BASH_SOURCE:-$0}")" ROOT="$(git rev-parse --show-toplevel)" builtin cd "$ROOT" || exit 1 +check_command() { + if ! command -v "$1" &> /dev/null; then + echo "ā“ā“$1 is not installed, please run \`pip install -r requirements-lint.txt\`" + exit 1 + fi +} + +check_command yapf +check_command ruff +check_command mypy +check_command codespell +check_command isort +check_command clang-format + YAPF_VERSION=$(yapf --version | awk '{print $2}') RUFF_VERSION=$(ruff --version | awk '{print $2}') MYPY_VERSION=$(mypy --version | awk '{print $2}') @@ -31,7 +45,7 @@ CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') # # params: tool name, tool version, required version tool_version_check() { if [[ $2 != $3 ]]; then - echo "Wrong $1 version installed: $3 is required, not $2." + echo "ā“ā“Wrong $1 version installed: $3 is required, not $2." exit 1 fi } @@ -281,10 +295,12 @@ tools/actionlint.sh -color echo 'vLLM actionlint: Done' if ! git diff --quiet &>/dev/null; then - echo 'Reformatted files. Please review and stage the changes.' - echo 'Changes not staged for commit:' - echo + echo + echo "šŸ”šŸ”There are files changed by the format checker or by you that are not added and committed:" git --no-pager diff --name-only + echo "šŸ”šŸ”Format checker passed, but please add, commit and push all the files above to include changes made by the format checker." exit 1 +else + echo "āœØšŸŽ‰ Format check passed! Congratulations! šŸŽ‰āœØ" fi diff --git a/python_only_dev.py b/python_only_dev.py index 72d4e78ee14f..4ab203bb6f9d 100644 --- a/python_only_dev.py +++ b/python_only_dev.py @@ -39,7 +39,6 @@ files_to_copy = [ "vllm/_C.abi3.so", - "vllm/_core_C.abi3.so", "vllm/_moe_C.abi3.so", "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", "vllm/vllm_flash_attn/flash_attn_interface.py", diff --git a/requirements-common.txt b/requirements-common.txt index aa165ff6d6a5..d72cc4476272 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -4,7 +4,7 @@ numpy < 2.0.0 requests >= 2.26.0 tqdm py-cpuinfo -transformers >= 4.45.0 # Required for Llama 3.2. +transformers >= 4.45.2 # Required for Llama 3.2 and Qwen2-VL. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' @@ -31,3 +31,4 @@ pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. +compressed-tensors == 0.7.1 # required for compressed-tensors diff --git a/requirements-test.txt b/requirements-test.txt index 997df9afac76..9787fa2a4a48 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -17,7 +17,6 @@ requests ray[adag]==2.35 sentence-transformers # required for embedding soundfile # required for audio test -compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test transformers_stream_generator # required for qwen-vl test matplotlib # required for qwen-vl test diff --git a/setup.py b/setup.py index 9ea4e85c0754..8abeb0ba739d 100644 --- a/setup.py +++ b/setup.py @@ -157,6 +157,14 @@ def configure(self, ext: CMakeExtension) -> None: # on subsequent calls to python. cmake_args += ['-DVLLM_PYTHON_PATH={}'.format(":".join(sys.path))] + # Override the base directory for FetchContent downloads to $ROOT/.deps + # This allows sharing dependencies between profiles, + # and plays more nicely with sccache. + # To override this, set the FETCHCONTENT_BASE_DIR environment variable. + fc_base_dir = os.path.join(ROOT_DIR, ".deps") + fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir) + cmake_args += ['-DFETCHCONTENT_BASE_DIR={}'.format(fc_base_dir)] + # # Setup parallelism and build tool # @@ -290,10 +298,6 @@ def _build_custom_ops() -> bool: return _is_cuda() or _is_hip() or _is_cpu() -def _build_core_ext() -> bool: - return not (_is_neuron() or _is_tpu() or _is_openvino() or _is_xpu()) - - def get_hipcc_rocm_version(): # Run the hipcc --version command result = subprocess.run(['hipcc', '--version'], @@ -456,9 +460,6 @@ def _read_requirements(filename: str) -> List[str]: ext_modules = [] -if _build_core_ext(): - ext_modules.append(CMakeExtension(name="vllm._core_C")) - if _is_cuda() or _is_hip(): ext_modules.append(CMakeExtension(name="vllm._moe_C")) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 1903a7582dc8..8a04693ba676 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -12,11 +12,11 @@ from vllm import SamplingParams from vllm.config import ParallelConfig +from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine from vllm.outputs import RequestOutput as RealRequestOutput from vllm.sampling_params import RequestOutputKind -from ..conftest import cleanup from ..utils import wait_for_gpu_memory_to_clear @@ -157,7 +157,7 @@ async def async_engine(): engine.shutdown_background_loop() del engine await asyncio.sleep(0.1) - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture() diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 0fe88e792520..3c2ca1bddd90 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -19,7 +19,7 @@ MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-3.2-1B", ] TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index e8819688c9e8..51aec8c873d1 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -12,20 +12,14 @@ import pytest from ..models.utils import check_logprobs_close, check_outputs_equal -from ..utils import check_deprecated_block_manager_usage, multi_gpu_test +from ..utils import multi_gpu_test MODELS = [ "facebook/opt-125m", - "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-3.2-1B", ] -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/basic_correctness/test_chunked_prefill.py') - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @@ -197,7 +191,6 @@ def test_models_with_fp8_kv_cache( @pytest.mark.parametrize("max_tokens", [16]) @pytest.mark.parametrize("enforce_eager", [False]) @pytest.mark.parametrize("chunk_size", [30, 32]) -@pytest.mark.parametrize("use_v2_block_manager", [False, True]) # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) @@ -206,7 +199,6 @@ def test_with_prefix_caching( max_tokens: int, enforce_eager: bool, chunk_size: int, - use_v2_block_manager: bool, tensor_parallel_size: int, ) -> None: """ @@ -234,7 +226,6 @@ def test_with_prefix_caching( enable_chunked_prefill=True, enable_prefix_caching=enable, tensor_parallel_size=tensor_parallel_size, - use_v2_block_manager=use_v2_block_manager, enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, ) as vllm_model: diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index a5df5639cf94..d7f36a781280 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -2,5 +2,5 @@ def test_cpu_offload(): - compare_two_settings("meta-llama/Llama-2-7b-hf", [], - ["--cpu-offload-gb", "4"]) + compare_two_settings("meta-llama/Llama-3.2-1B", [], + ["--cpu-offload-gb", "1"]) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index b6ec7413978f..77c56d91d0a8 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -13,8 +13,7 @@ @pytest.mark.parametrize( "model, model_args, pp_size, tp_size, attn_backend, method, fullgraph", [ - ("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate", - True), + ("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True), ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", ["--quantization", "compressed-tensors" ], 1, 1, "FLASH_ATTN", "generate", True), diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 5386eb0e3795..c69343b51ae0 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -69,11 +69,11 @@ def check_full_graph_support(model, os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level) os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1" - # Inductor doesn't support fp8/gptq_marlin_24 yet. + # Inductor doesn't support fp8 and the base meta llama uses too + # much memory. quantization = model_kwargs.get("quantization") - if (quantization == "fp8" or quantization == "gptq_marlin" - or quantization == "gptq_marlin_24" - ) and optimization_level >= CompilationLevel.INDUCTOR: + if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B") + and optimization_level >= CompilationLevel.INDUCTOR): return prompts = [ diff --git a/tests/conftest.py b/tests/conftest.py index baa6bae03a45..b11bbcb4ab7d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -import contextlib -import gc import json import os import sys @@ -25,19 +23,19 @@ from vllm import LLM, SamplingParams from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset -from vllm.config import TokenizerPoolConfig +from vllm.config import TaskOption, TokenizerPoolConfig from vllm.connections import global_http_connection -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel, +from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.sampling_params import BeamSearchParams from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - identity, is_cpu) + identity) logger = init_logger(__name__) @@ -45,10 +43,12 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] -PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]] -PromptAudioInput = Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]] -PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]] +_M = TypeVar("_M") +_PromptMultiModalInput = Union[List[_M], List[List[_M]]] + +PromptImageInput = _PromptMultiModalInput[Image.Image] +PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]] +PromptVideoInput = _PromptMultiModalInput[np.ndarray] def _read_prompts(filename: str) -> List[str]: @@ -140,17 +140,7 @@ def dist_init(): ) initialize_model_parallel(1, 1) yield - cleanup() - - -def cleanup(): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - gc.collect() - if not is_cpu(): - torch.cuda.empty_cache() + cleanup_dist_env_and_memory() @pytest.fixture() @@ -167,7 +157,7 @@ def should_do_global_cleanup_after_test(request) -> bool: def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield if should_do_global_cleanup_after_test: - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(autouse=True) @@ -249,7 +239,8 @@ class HfRunner: def wrap_device(self, input: _T, device: Optional[str] = None) -> _T: if device is None: - return self.wrap_device(input, "cpu" if is_cpu() else "cuda") + return self.wrap_device( + input, "cpu" if current_platform.is_cpu() else "cuda") if hasattr(input, "device") and input.device.type == device: return input @@ -263,6 +254,8 @@ def __init__( *, model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, + is_sentence_transformer: bool = False, + skip_tokenizer_init: bool = False, auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding] = identity, @@ -271,7 +264,7 @@ def __init__( self.model_name = model_name - if is_embedding_model: + if is_sentence_transformer: # Lazy init required for AMD CI from sentence_transformers import SentenceTransformer self.model = self.wrap_device( @@ -290,11 +283,12 @@ def __init__( **model_kwargs, )) - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ) + if not skip_tokenizer_init: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) # don't put this import at the top level # it will call torch.cuda.device_count() @@ -304,33 +298,64 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, ) + if skip_tokenizer_init: + self.tokenizer = self.processor.tokenizer self.postprocess_inputs = postprocess_inputs - def generate( + def get_inputs( self, prompts: List[str], images: Optional[PromptImageInput] = None, - videos: Optional[List[np.ndarray]] = None, - **kwargs: Any, - ) -> List[Tuple[List[List[int]], List[str]]]: - if images: + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[BatchEncoding]: + if images is not None: assert len(prompts) == len(images) - outputs: List[Tuple[List[List[int]], List[str]]] = [] + if videos is not None: + assert len(prompts) == len(videos) + + if audios is not None: + assert len(prompts) == len(audios) + + all_inputs: List[BatchEncoding] = [] for i, prompt in enumerate(prompts): processor_kwargs: Dict[str, Any] = { "text": prompt, "return_tensors": "pt", } - if images is not None and images[i] is not None: - processor_kwargs["images"] = images[i] - if videos is not None and videos[i] is not None: - processor_kwargs["videos"] = videos[i] + if images is not None and (image := images[i]) is not None: + processor_kwargs["images"] = image + if videos is not None and (video := videos[i]) is not None: + processor_kwargs["videos"] = video + if audios is not None and (audio_tuple := audios[i]) is not None: + audio, sr = audio_tuple + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr inputs = self.processor(**processor_kwargs) inputs = self.postprocess_inputs(inputs) + all_inputs.append(inputs) + + return all_inputs + + def generate( + self, + prompts: List[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + **kwargs: Any, + ) -> List[Tuple[List[List[int]], List[str]]]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] + for inputs in all_inputs: output_ids = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, @@ -350,12 +375,16 @@ def generate_greedy( prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, images=images, + videos=videos, + audios=audios, **kwargs) return [(output_ids[0], output_str[0]) @@ -387,23 +416,17 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, - videos: Optional[List[np.ndarray]] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[List[torch.Tensor]]: - all_logprobs: List[List[torch.Tensor]] = [] - for i, prompt in enumerate(prompts): - processor_kwargs: Dict[str, Any] = { - "text": prompt, - "return_tensors": "pt", - } - if images is not None and images[i] is not None: - processor_kwargs["images"] = images[i] - if videos is not None and videos[i] is not None: - processor_kwargs["videos"] = videos[i] - - inputs = self.processor(**processor_kwargs) - inputs = self.postprocess_inputs(inputs) + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + all_logprobs: List[List[torch.Tensor]] = [] + for inputs in all_inputs: output = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, @@ -472,31 +495,19 @@ def generate_greedy_logprobs_limit( num_logprobs: int, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, - videos: Optional[List[np.ndarray]] = None, + videos: Optional[PromptVideoInput] = None, **kwargs: Any, ) -> List[TokensTextLogprobs]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + all_logprobs: List[List[Dict[int, float]]] = [] all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] - for i, prompt in enumerate(prompts): - processor_kwargs: Dict[str, Any] = { - "text": prompt, - "return_tensors": "pt", - } - if images is not None and images[i] is not None: - processor_kwargs["images"] = images[i] - - if audios is not None: - audio, sr = audios[i] - processor_kwargs["audio"] = audio - processor_kwargs["sampling_rate"] = sr - - if videos is not None: - processor_kwargs["videos"] = videos[i] - inputs = self.processor(**processor_kwargs) - inputs = self.postprocess_inputs(inputs) - + for inputs in all_inputs: output = self.model.generate( **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, @@ -529,6 +540,7 @@ def generate_encoder_decoder_greedy_logprobs_limit( encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[TokensTextLogprobs]: ''' @@ -539,11 +551,17 @@ def generate_encoder_decoder_greedy_logprobs_limit( all_output_ids: List[List[int]] = [] all_output_strs: List[str] = [] - for (encoder_prompt, - decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts): + for i, (encoder_prompt, decoder_prompt) in enumerate( + to_enc_dec_tuple_list(encoder_decoder_prompts)): + processor_kwargs: Dict[str, Any] = { + "text": encoder_prompt, + "return_tensors": "pt", + } + if images is not None and images[i] is not None: + processor_kwargs["images"] = images[i] encoder_input_ids = self.wrap_device( - self.tokenizer(encoder_prompt, return_tensors="pt").input_ids, + self.processor(**processor_kwargs).input_ids, device=self.model.device.type, ) @@ -591,7 +609,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): del self.model - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(scope="session") @@ -604,6 +622,7 @@ class VllmRunner: def __init__( self, model_name: str, + task: TaskOption = "auto", tokenizer_name: Optional[str] = None, # Use smaller max model length, otherwise bigger model cannot run due # to kv cache size limit. @@ -619,6 +638,7 @@ def __init__( ) -> None: self.model = LLM( model=model_name, + task=task, tokenizer=tokenizer_name, trust_remote_code=True, dtype=dtype, @@ -632,19 +652,52 @@ def __init__( **kwargs, ) - def generate( + def get_inputs( self, prompts: List[str], - sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, - ) -> List[Tuple[List[List[int]], List[str]]]: + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[TextPrompt]: if images is not None: assert len(prompts) == len(images) + if videos is not None: + assert len(prompts) == len(videos) + + if audios is not None: + assert len(prompts) == len(audios) + inputs = [TextPrompt(prompt=prompt) for prompt in prompts] if images is not None: for i, image in enumerate(images): - inputs[i]["multi_modal_data"] = {"image": image} + if image is not None: + inputs[i]["multi_modal_data"] = {"image": image} + + if videos is not None: + for i, video in enumerate(videos): + if video is not None: + inputs[i]["multi_modal_data"] = {"video": video} + + if audios is not None: + for i, audio in enumerate(audios): + if audio is not None: + inputs[i]["multi_modal_data"] = {"audio": audio} + + return inputs + + def generate( + self, + prompts: List[str], + sampling_params: SamplingParams, + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) req_outputs = self.model.generate(inputs, sampling_params=sampling_params) @@ -687,24 +740,10 @@ def generate_w_logprobs( videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: - if images is not None: - assert len(prompts) == len(images) - - if videos is not None: - assert len(prompts) == len(videos) - - inputs = [TextPrompt(prompt=prompt) for prompt in prompts] - if images is not None: - for i, image in enumerate(images): - inputs[i]["multi_modal_data"] = {"image": image} - - if audios is not None: - for i, audio in enumerate(audios): - inputs[i]["multi_modal_data"] = {"audio": audio} - - if videos is not None: - for i, video in enumerate(videos): - inputs[i]["multi_modal_data"] = {"video": video} + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) req_outputs = self.model.generate(inputs, sampling_params=sampling_params) @@ -741,9 +780,15 @@ def generate_greedy( prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, greedy_params, images=images) + outputs = self.generate(prompts, + greedy_params, + images=images, + videos=videos, + audios=audios) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] @@ -809,20 +854,27 @@ def generate_beam_search( returned_outputs.append((token_ids, texts)) return returned_outputs - def encode(self, prompts: List[str]) -> List[List[float]]: - req_outputs = self.model.encode(prompts) - outputs = [] - for req_output in req_outputs: - embedding = req_output.outputs.embedding - outputs.append(embedding) - return outputs + def encode( + self, + prompts: List[str], + images: Optional[PromptImageInput] = None, + videos: Optional[PromptVideoInput] = None, + audios: Optional[PromptAudioInput] = None, + ) -> List[List[float]]: + inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + req_outputs = self.model.encode(inputs) + return [req_output.outputs.embedding for req_output in req_outputs] def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): del self.model - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(scope="session") diff --git a/tests/core/block/e2e/conftest.py b/tests/core/block/e2e/conftest.py index e870597b7a01..70577ec052a2 100644 --- a/tests/core/block/e2e/conftest.py +++ b/tests/core/block/e2e/conftest.py @@ -3,10 +3,9 @@ import pytest from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.utils import set_random_seed -from ....conftest import cleanup - @pytest.fixture def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs, @@ -37,7 +36,7 @@ def generator_inner(): yield llm del llm - cleanup() + cleanup_dist_env_and_memory() for llm in generator_inner(): yield llm diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index b3f626714d35..86502f613b18 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -2,18 +2,11 @@ import pytest -from tests.utils import check_deprecated_block_manager_usage from vllm import SamplingParams from .conftest import get_token_ids_from_llm_generator -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/core/block/e2e/test_correctness.py') - - @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -28,32 +21,32 @@ def check_deprecated_block_manager(): "num_gpu_blocks_override": 5 * (64 + 1), }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "use_v2_block_manager": False -}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ - "use_v2_block_manager": True, "preemption_mode": "swap" }, { - "use_v2_block_manager": True, "preemption_mode": "recompute" }]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) -def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify block manager v2 produces same outputs as block manager v1, even - when there is preemption. +def test_block_manager_with_preemption(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify block manager produces same outputs even when there is preemption. This constructs two LLM, each with limited number of GPU blocks. The limit is decided such that as the sequences in the batch grow, sequences must be preempted and removed from cache. If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted in the v2 block manager. + cache is not corrupted. NOTE: We want a significant number of generated tokens so that any incorrect KV mapping has time to build up error. + + NOTE(Kuntai): Though we have removed block manager v1, this test is still + useful as it asserts the behavior of block manager v2 (now it is called + SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we + keep this test. """ output_len = 1024 temperature = 0.0 @@ -77,11 +70,9 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids from block manager v1') baseline_token_ids = get_token_ids_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - print('Getting token ids from block manager v2') test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) @@ -104,9 +95,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, # skip cuda graph creation for fast test. "enforce_eager": True, - - # Lookahead scheduling only supported in v2 block manager. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -218,26 +206,22 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, "max_num_seqs": 10, }]) @pytest.mark.parametrize("baseline_llm_kwargs", [ - { - "use_v2_block_manager": False, - }, + {}, ]) @pytest.mark.parametrize("test_llm_kwargs", [ { - "use_v2_block_manager": True, "num_lookahead_slots": 0, }, { - "use_v2_block_manager": True, "num_lookahead_slots": 5, }, ]) @pytest.mark.parametrize("batch_size", [4]) @pytest.mark.parametrize("seed", [1]) -def test_chunked_prefill_block_manager_v2(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify that chunked prefill works with BlockManagerV2, with and without - lookahead scheduling. +def test_chunked_prefill_block_manager(baseline_llm_generator, + test_llm_generator, batch_size): + """Verify that chunked prefill works with SelfAttnBlockSpaceManager, + with and without lookahead scheduling. """ output_len = 32 temperature = 0.0 @@ -258,11 +242,11 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, temperature=temperature, ) - print('Getting token ids with BlockManagerV1') + print('Getting token ids with BlockManager') baseline_token_ids = get_token_ids_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - print('Getting token ids with BlockManagerV2') + print('Getting token ids with BlockManager, with lookahead slots.') test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) @@ -290,32 +274,32 @@ def test_chunked_prefill_block_manager_v2(baseline_llm_generator, "enable_prefix_caching": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "use_v2_block_manager": False -}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{ - "use_v2_block_manager": True, "preemption_mode": "swap" }, { - "use_v2_block_manager": True, "preemption_mode": "recompute" }]) @pytest.mark.parametrize("batch_size", [10]) @pytest.mark.parametrize("seed", [1]) -def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( +def test_block_manager_prefix_caching_enabled_with_preemption( baseline_llm_generator, test_llm_generator, batch_size): - """Verify block manager v2 produces same outputs as block manager v1, even - when there is preemption. + """Verify block manager produces same outputs even when there is preemption. This constructs two LLM, each with limited number of GPU blocks. The limit is decided such that as the sequences in the batch grow, sequences must be preempted and removed from cache. If the output token ids are equivalent, then we have confidence that the KV - cache is not corrupted in the v2 block manager. + cache is not corrupted. NOTE: We want a significant number of generated tokens so that any incorrect KV mapping has time to build up error. + + NOTE(Kuntai): Though we have removed block manager v1, this test is still + useful as it asserts the behavior of block manager v2 (now it is called + SelfAttnBlockSpaceManager) is the same when swapping / preemption, so we + keep this test. """ output_len = 1024 temperature = 0.0 @@ -339,11 +323,11 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( temperature=temperature, ) - print('Getting token ids from block manager v1') + print('Getting token ids from block manager') baseline_token_ids = get_token_ids_from_llm_generator( baseline_llm_generator, prompts, sampling_params) - print('Getting token ids from block manager v2') + print('Getting token ids from block manager, with preemption') test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, prompts, sampling_params) @@ -366,9 +350,6 @@ def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption( # Allow only 5 sequences of ~1024 tokens in worst case. "block_size": 16, "num_gpu_blocks_override": 5 * (64 + 1), - - # Test APC in v2 block - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ @@ -444,9 +425,6 @@ def test_auto_prefix_caching_with_preemption(baseline_llm_generator, "max_model_len": 48, "block_size": 16, "num_gpu_blocks_override": 3, - - # Test APC in v2 block - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{ diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 731131984b0e..9320a9ef6231 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -3,7 +3,6 @@ import pytest -from tests.utils import check_deprecated_block_manager_usage from vllm import LLM, SamplingParams from .conftest import get_text_from_llm_generator @@ -13,12 +12,6 @@ BLOCK_SIZE = 16 -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/core/block/e2e/test_correctness_sliding_window.py') - - @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -31,10 +24,8 @@ def check_deprecated_block_manager(): "num_gpu_blocks_override": 100000 // BLOCK_SIZE, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "use_v2_block_manager": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{"use_v2_block_manager": True}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, @@ -55,7 +46,6 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, prompts, answer, indices = prep_prompts(batch_size) - print('Getting token ids from block manager v1') baseline_texts = get_text_from_llm_generator(baseline_llm_generator, prompts, sampling_params, @@ -91,10 +81,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, "num_gpu_blocks_override": 100000 // BLOCK_SIZE, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "use_v2_block_manager": True, - "enable_chunked_prefill": True -}]) +@pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager.py similarity index 91% rename from tests/core/block/test_block_manager_v2.py rename to tests/core/block/test_block_manager.py index e67883367879..cfd749ad5869 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager.py @@ -2,7 +2,7 @@ from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, STR_NOT_IMPL_ENC_DEC_SWA) -from vllm.core.block_manager_v2 import BlockSpaceManagerV2 +from vllm.core.block_manager import SelfAttnBlockSpaceManager from vllm.core.interfaces import AllocStatus from vllm.sequence import Logprob, SequenceStatus from vllm.utils import chunk_list @@ -17,7 +17,7 @@ @pytest.mark.parametrize("watermark", [0.0, 0.5]) def test_can_allocate_seq_group(block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -63,7 +63,7 @@ def test_can_allocate_seq_group_encoder_decoder(block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -117,16 +117,16 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, ''' SWA short for Sliding Window Attention. - At time of writing block manager v2 does not support SWA. + At time of writing block manager does not support SWA. - However even when SWA is implemented for block manager v2, + However even when SWA is implemented for block manager, there will still most likely be a separate workstream required to enable SWA for encoder/decoder models. Therefore this test enforces that one of the following cases hold true: - 1. Block manager v2 does not support SWA at all (true at time of writing) - 2. Block manager v2 fails with NotImplementError when SWA is enabled + 1. Block manager does not support SWA at all (true at time of writing) + 2. Block manager fails with NotImplementError when SWA is enabled AND a SequenceGroup with an encoder sequence (i.e. in support of an encoder/decoder model) is passed into can_allocate() as an argument @@ -135,7 +135,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, ''' with pytest.raises((NotImplementedError, AssertionError)) as exc_info: - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -158,7 +158,7 @@ def test_can_allocate_encoder_decoder_fails_with_swa(block_size: int, block_manager.can_allocate(seq_group) # Assert that either - # 1. Block manager v2 constructor fails with assertion that sliding window + # 1. Block manager constructor fails with assertion that sliding window # is not yet supported (most likely near-term outcome at time of # writing), or # 2. can_allocate() fails with NotImplementedError due to combination of @@ -177,7 +177,7 @@ def test_can_allocate_encoder_decoder_fails_with_prefix_cache( block_size: int, num_seqs_per_group: int, num_gpu_blocks: int, watermark: float): - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=1024, @@ -217,7 +217,7 @@ def test_append_slots(block_size, prompt_len, num_slots_to_append, num_gpu_blocks = 1024 watermark = 0.1 - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, @@ -269,14 +269,15 @@ def test_swap(block_size, num_cpu_blocks, num_gpu_blocks, num_lookahead_slots, """Verify blocks number on src/desc device is correct after swapping in/out sequence group (not missing or extra blocks). """ - block_manager = BlockSpaceManagerV2(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) prompt.status = SequenceStatus.WAITING block_manager.allocate(seq_group) + # Emulate a forward pass by appending a single token. # The block manager then knows how many unprocessed # tokens will be written in the next forward pass. @@ -321,11 +322,11 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, can be swapped in/out. """ num_cpu_blocks = num_gpu_blocks - block_manager = BlockSpaceManagerV2(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) prompt, seq_group = create_dummy_prompt( "1", prompt_length=(num_gpu_blocks - 1) * block_size - 1) prompt.status = SequenceStatus.WAITING @@ -382,11 +383,11 @@ def test_swap_in_infeasible(num_lookahead_slots, enable_caching): block_size = 8 num_cpu_blocks = 1 num_gpu_blocks = 1 - block_manager = BlockSpaceManagerV2(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=enable_caching) + block_manager = SelfAttnBlockSpaceManager(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) prompt_length = block_size - 3 assert prompt_length > 0 prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length) @@ -434,7 +435,7 @@ def test_sliding_window(block_size, prompt_len, num_slots_to_append, num_gpu_blocks = 1024 watermark = 0.1 - block_manager = BlockSpaceManagerV2( + block_manager = SelfAttnBlockSpaceManager( block_size=block_size, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=0, @@ -474,7 +475,7 @@ def num_blocks(num_tokens): seq.data.update_num_computed_tokens(prompt_len) check_used(num_blocks(prompt_len)) - # this is how we compute it in BlockSpaceManagerV2.__init__ + # this is how we compute it in SelfAttnBlockSpaceManager.__init__ sliding_blocks = (sliding_window // block_size) + 2 # plus one block for null block sliding_blocks += 1 diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py deleted file mode 100644 index 2ee9f20824f2..000000000000 --- a/tests/core/test_block_manager.py +++ /dev/null @@ -1,637 +0,0 @@ -import time -from collections import defaultdict -from typing import List - -import pytest - -from vllm import SamplingParams -from vllm.block import PhysicalTokenBlock -from vllm.core.block.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, - STR_NOT_IMPL_ENC_DEC_SWA) -from vllm.core.block_manager_v1 import (BlockSpaceManagerV1, - UncachedBlockAllocator) -from vllm.core.interfaces import AllocStatus -from vllm.sequence import Logprob, Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -from .utils import create_dummy_prompt, create_dummy_prompt_encoder_decoder - - -def test_block_allocator_allocate(): - block_size = 4 - num_cpu_blocks = 4 - cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) - - # Allocate all available cpu blocks. - num_free = num_cpu_blocks - assert cpu_allocator.get_num_free_blocks() == num_free - for _ in range(num_cpu_blocks): - block = cpu_allocator.allocate() - num_free -= 1 - - assert block not in cpu_allocator.free_blocks - assert cpu_allocator.get_num_free_blocks() == num_free - - with pytest.raises(ValueError): - cpu_allocator.allocate() - - -def test_block_allocator_free(): - block_size = 4 - num_cpu_blocks = 4 - cpu_allocator = UncachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) - - # Allocate all available cpu blocks. - blocks: List[PhysicalTokenBlock] = [] - for _ in range(num_cpu_blocks): - block = cpu_allocator.allocate() - blocks.append(block) - assert block not in cpu_allocator.free_blocks - - # Free all allocated cpu blocks. - num_free = 0 - assert cpu_allocator.get_num_free_blocks() == num_free - for block in blocks: - cpu_allocator.free(block) - num_free += 1 - assert block in cpu_allocator.free_blocks - assert cpu_allocator.get_num_free_blocks() == num_free - - with pytest.raises(ValueError): - cpu_allocator.free(block) - - -def test_allocate(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same sequence group to all available gpu blocks. - for i in range(num_gpu_blocks): - _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - # Allocate same sequence group to all available gpu blocks. - # Use watermark to reserve one gpu block. - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=1 / num_gpu_blocks) - for i in range(num_gpu_blocks - 1): - _, seq_group = create_dummy_prompt(str(i), block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - -def test_allocate_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_req_per_seq_group = 2 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same sequence group to all available gpu blocks. - for i in range(num_gpu_blocks // block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - # Allocate same sequence group to all available gpu blocks. - # Use watermark to reserve one gpu block. - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=1 / num_gpu_blocks) - for i in range((num_gpu_blocks - 1) // block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder( - str(i), - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - assert block_manager.can_allocate(seq_group) == AllocStatus.OK - block_manager.allocate(seq_group) - assert block_manager.can_allocate(seq_group) != AllocStatus.OK - - -def test_allocate_encoder_decoder_fails_with_swa(): - # SWA short for sliding window attention - - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - sliding_window=5) # swa - - # Allocate same sequence group to all available gpu blocks. - _, _, seq_group = create_dummy_prompt_encoder_decoder( - "0", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - - # Assert that can_allocate() fails due to SWA - with pytest.raises(NotImplementedError) as exc_info: - block_manager.can_allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA - - # Assert that allocate() fails due to SWA - with pytest.raises(NotImplementedError) as exc_info: - block_manager.allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_SWA - - -def test_allocate_encoder_decoder_fails_with_prefix_caching(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0, - enable_caching=True) # Prefix cache - - # Allocate same sequence group to all available gpu blocks. - _, _, seq_group = create_dummy_prompt_encoder_decoder( - "0", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - - # Assert that can_allocate() fails due to prefix caching - with pytest.raises(NotImplementedError) as exc_info: - block_manager.can_allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE - - # Assert that allocate() fails due to prefix caching - with pytest.raises(NotImplementedError) as exc_info: - block_manager.allocate(seq_group) - - assert str(exc_info.value) == STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE - - -def test_append_slot_single_seq(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate single seq to gpu block. - prompt, seq_group = create_dummy_prompt("1", block_size) - block_manager.allocate(seq_group) - - # Nothing to append. Sequence has no new logical blocks. - assert block_manager.can_append_slots(seq_group) - before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slots(prompt) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert before_blocks == after_blocks - - # Add block_size number of new tokens and append slot. - for i in range(block_size): - token_id = i + 5 - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - assert block_manager.can_append_slots(seq_group) - before_blocks = block_manager.get_num_free_gpu_blocks() - assert not block_manager.append_slots(prompt) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert before_blocks - after_blocks == 1 - - -def test_append_slot_cow(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size=block_size, - num_cpu_blocks=num_cpu_blocks, - num_gpu_blocks=num_gpu_blocks, - watermark=0) - - # Allocate prompt to gpu block. There is one slot left in the block. - prompt = Sequence(seq_id=1, - inputs={ - "prompt": "one two three", - "prompt_token_ids": [1, 2, 3], - }, - block_size=block_size) - - # Fork the sequence, such that a COW will be required when we append a new - # token id. - child = prompt.fork(new_seq_id=2) - - # Allocate space for the sequence group. - seq_group = SequenceGroup(request_id="1", - seqs=[prompt, child], - arrival_time=time.time(), - sampling_params=SamplingParams()) - block_manager.allocate(seq_group) - - # Fork and append a new token id. We expect a COW to be scheduled. - token_id = 4 - child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.fork(prompt, child) - - assert block_manager.can_append_slots(seq_group) - before_blocks = block_manager.get_num_free_gpu_blocks() - - cows = block_manager.append_slots(child) - assert cows - dict_cows = defaultdict(list) - for src_block, dst_block in cows: - dict_cows[src_block].append(dst_block) - for src_block, dst_blocks in dict_cows.items(): - assert src_block not in dst_blocks - - after_blocks = block_manager.get_num_free_gpu_blocks() - assert before_blocks - after_blocks == 1 - - -def test_fork(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - prompt, seq_group = create_dummy_prompt("1", - block_size - 1, - block_size=block_size) - block_manager.allocate(seq_group) - - # Fork prompt and copy block tables. - child = prompt.fork(2) - block_manager.fork(prompt, child) - assert block_manager.get_block_table( - prompt) == block_manager.get_block_table(child) - token_id = 4 - # Append token to child. Block is shared so copy on write occurs. - child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slots(child) - assert block_manager.get_block_table( - prompt) != block_manager.get_block_table(child) - - -def test_swap(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - prompt, seq_group = create_dummy_prompt("1", prompt_length=block_size - 1) - prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - prompt.status = SequenceStatus.RUNNING - prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap seq group from GPU -> CPU. - gpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - assert [x[0] for x in mapping] == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - prompt.status = SequenceStatus.SWAPPED - - # Swap seq group from CPU -> GPU. - cpu_blocks = block_manager.get_block_table(prompt) - assert block_manager.can_swap_in(seq_group) == AllocStatus.OK - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_in(seq_group) - assert [x[0] for x in mapping] == cpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks - assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) - - -def test_swap_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - decoder_prompt, encoder_prompt, seq_group = \ - create_dummy_prompt_encoder_decoder( - "1", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - decoder_prompt.status = SequenceStatus.WAITING - encoder_prompt.status = SequenceStatus.WAITING - block_manager.allocate(seq_group) - - # Emulate a forward pass by appending a single token. - # The block manager then knows how many unprocessed - # tokens will be written in the next forward pass. - token_id = 0 - decoder_prompt.status = SequenceStatus.RUNNING - decoder_prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) - - # Swap encoder/decoder seq group from GPU -> CPU. - decoder_gpu_blocks = block_manager.get_block_table(decoder_prompt) - cross_gpu_blocks = block_manager.get_cross_block_table(seq_group) - gpu_blocks = decoder_gpu_blocks + cross_gpu_blocks - assert block_manager.can_swap_out(seq_group) - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_out(seq_group) - assert [x[0] for x in mapping] == gpu_blocks - #assert list(mapping.keys()) == gpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) - assert before_gpu_blocks + len(gpu_blocks) == after_gpu_blocks - decoder_prompt.status = SequenceStatus.SWAPPED - - # Swap encoder/decoder seq group from CPU -> GPU. - decoder_cpu_blocks = block_manager.get_block_table(decoder_prompt) - cross_cpu_blocks = block_manager.get_cross_block_table(seq_group) - cpu_blocks = decoder_cpu_blocks + cross_cpu_blocks - assert block_manager.can_swap_in(seq_group) == AllocStatus.OK - before_cpu_blocks = block_manager.get_num_free_cpu_blocks() - before_gpu_blocks = block_manager.get_num_free_gpu_blocks() - mapping = block_manager.swap_in(seq_group) - assert [x[0] for x in mapping] == cpu_blocks - after_cpu_blocks = block_manager.get_num_free_cpu_blocks() - after_gpu_blocks = block_manager.get_num_free_gpu_blocks() - assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks - assert before_gpu_blocks == after_gpu_blocks + len(cpu_blocks) - - -def test_free(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - prompt, seq_group = create_dummy_prompt("1", block_size) - block_manager.allocate(seq_group) - - # Free allocated seq. - prompt_blocks = len(block_manager.get_block_table(prompt)) - before_blocks = block_manager.get_num_free_gpu_blocks() - block_manager.free(prompt) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert after_blocks == before_blocks + prompt_blocks - - # Block table for freed seq is deleted. - with pytest.raises(KeyError): - block_manager.get_block_table(prompt) - - -def test_free_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - decoder_prompt, encoder_prompt, seq_group = \ - create_dummy_prompt_encoder_decoder( - "1", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - block_manager.allocate(seq_group) - - # Free allocated seq. - decoder_prompt_blocks = len(block_manager.get_block_table(decoder_prompt)) - encoder_prompt_blocks = len(block_manager.get_cross_block_table(seq_group)) - prompt_blocks = decoder_prompt_blocks + encoder_prompt_blocks - before_blocks = block_manager.get_num_free_gpu_blocks() - block_manager.free(decoder_prompt) - block_manager.free_cross(seq_group) - after_blocks = block_manager.get_num_free_gpu_blocks() - assert after_blocks == before_blocks + prompt_blocks - - # Block table for freed encoder & decoder seq's are deleted. - with pytest.raises(KeyError): - block_manager.get_block_table(decoder_prompt) - - # Block table for freed encoder & decoder seq's are deleted. - with pytest.raises(KeyError): - block_manager.get_block_table(encoder_prompt) - - -def test_reset(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same seq group on all available gpu blocks. - original_blocks = block_manager.get_num_free_gpu_blocks() - for i in range(num_gpu_blocks): - _, seq_group = create_dummy_prompt(str(i), block_size) - block_manager.allocate(seq_group) - assert block_manager.get_num_free_gpu_blocks() == 0 - - # Resetting block manager frees all allocated blocks. - block_manager.reset() - assert block_manager.get_num_free_gpu_blocks() == original_blocks - - -def test_reset_encoder_decoder(): - block_size = 4 - num_cpu_blocks = 4 - num_gpu_blocks = 4 - block_req_per_seq_group = 2 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - watermark=0) - - # Allocate same seq group on all available gpu blocks. - original_blocks = block_manager.get_num_free_gpu_blocks() - for i in range(num_gpu_blocks // block_req_per_seq_group): - _, _, seq_group = create_dummy_prompt_encoder_decoder( - f"{i}", - decoder_prompt_length=block_size, - encoder_prompt_length=block_size) - block_manager.allocate(seq_group) - assert block_manager.get_num_free_gpu_blocks() == 0 - - # Resetting block manager frees all allocated blocks. - block_manager.reset() - assert block_manager.get_num_free_gpu_blocks() == original_blocks - - -def test_sliding_window_multi_seq(): - """ - Tests that memory allocation and deallocation is handled - correctly with multiple sequences that exceed the sliding - window's capacity. - """ - block_size = 1 - num_cpu_blocks = 8 - num_gpu_blocks = 8 - sliding_window = 2 - block_manager = BlockSpaceManagerV1(block_size, - num_cpu_blocks, - num_gpu_blocks, - sliding_window=sliding_window, - watermark=0) - - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - - parent = Sequence(seq_id=1, - inputs={ - "prompt": "one two three", - "prompt_token_ids": [0, 1, 2], - }, - block_size=block_size) - seq_group = SequenceGroup(request_id="1", - seqs=[parent], - arrival_time=time.time(), - sampling_params=SamplingParams(), - lora_request=None) - block_manager.allocate(seq_group) - - # assert the number of blocks allocated is correct - # the parent seq has len 3, but since sliding_window is 2, - # we will use at most 2 blocks - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - - # Fork prompt and copy block tables. - child = parent.fork(2) - block_manager.fork(parent, child) - - # assert the number of blocks allocated is correct - # forking does not increase memory consumption - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - - # assert both parent and child share all blocks - assert block_manager.get_block_table( - parent) == block_manager.get_block_table(child) - - token_id = 4 - # Append token to child. Block is shared so copy on write occurs. - child.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slots(child) - - # assert the number of blocks allocated is correct - # we will use now one block more. Each seq will use 2 blocks, - # but only one can be shared - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - 1 - - token_id = 5 - parent.append_token_id(token_id, {token_id: Logprob(0.0)}) - block_manager.append_slots(parent) - - # assert the number of blocks allocated is correct - # no change, because both sequences are still just sharing one block - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - 1 - - block_table_parent = block_manager.get_block_table(parent) - block_table_child = block_manager.get_block_table(child) - - assert block_table_parent != block_table_child - - # assert both blocks are sharing the second-last block - assert block_table_parent[-2] == block_table_child[-2] - - # now let's clean up... - block_manager.free(parent) - - # assert the number of blocks allocated is correct - # We have freed one seq, reducing the ref count of two blocks by one. - # One of the two was only used by the parent seq, so this is now free. - # The child seq still consumes sliding_window blocks - assert block_manager.get_num_free_gpu_blocks( - ) == num_gpu_blocks - sliding_window - - # free all blocks - block_manager.free(child) - - # assert all blocks are free now - assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - - -def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill(): - """When prefix cache and chunked prefill are enabled, the block manager - should only mark a chunk of blocks as computed instead of all blocks. - """ - - block_size = 4 - num_cpu_blocks = 0 - num_gpu_blocks = 16 - block_manager = BlockSpaceManagerV1(block_size, - num_gpu_blocks, - num_cpu_blocks, - watermark=0, - enable_caching=True) - - # Set prompt size to have num_gpu_blocks - 1 full blocks. - prompt_length = block_size * num_gpu_blocks - 1 - - # Allocate (reserve) all blocks. - _, seq_group = create_dummy_prompt("0", - prompt_length, - block_size=block_size) - block_manager.allocate(seq_group) - assert seq_group.seqs[0].n_blocks == num_gpu_blocks - - # 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed. - token_chunk_size = int(block_size * 2.5) - block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) - computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) - assert len(computed_blocks) == 2 - - # Actual computed tokens. - seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size) - - # 2nd chunk: Complete 3rd block and additional 4 blocks. - token_chunk_size = int(block_size * 4.5) - block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) - computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) - assert len(computed_blocks) == 7 diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index c9495fd50d7c..acd82065ae45 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -4,11 +4,9 @@ import pytest # noqa from vllm.config import CacheConfig, SchedulerConfig -from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler from vllm.sequence import Logprob, SequenceGroup -from ..utils import check_deprecated_block_manager_usage from .utils import create_dummy_prompt @@ -28,25 +26,17 @@ def schedule_and_update_computed_tokens(scheduler): return metas, out -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/core/test_chunked_prefill_scheduler.py') - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_simple(use_v2_block_manager: bool): +def test_simple(): """Verify basic scheduling works.""" block_size = 4 num_seq_group = 4 max_model_len = 16 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig( - max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -81,19 +71,19 @@ def test_simple(use_v2_block_manager: bool): assert len(seq_group_meta) == num_seq_group -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_chunk(use_v2_block_manager: bool): +def test_chunk(): """Verify prefills are chunked properly.""" block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 32 cache_config.num_gpu_blocks = 32 @@ -131,18 +121,18 @@ def test_chunk(use_v2_block_manager: bool): assert out.num_batched_tokens == 57 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_complex(use_v2_block_manager: bool): +def test_complex(): block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 64 cache_config.num_gpu_blocks = 64 @@ -201,19 +191,19 @@ def test_complex(use_v2_block_manager: bool): assert running[2].is_prefill() -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_maximal_decoding(use_v2_block_manager: bool): +def test_maximal_decoding(): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 max_model_len = 8 max_num_batched_tokens = 2 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -295,19 +285,19 @@ def test_maximal_decoding(use_v2_block_manager: bool): assert out.num_batched_tokens == 2 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prompt_limit(use_v2_block_manager: bool): +def test_prompt_limit(): """Verify max_num_batched_tokens < max_model_len is possible.""" block_size = 4 max_seqs = 32 max_model_len = 64 max_num_batched_tokens = 32 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 @@ -330,13 +320,13 @@ def test_prompt_limit(use_v2_block_manager: bool): assert out.num_batched_tokens == 32 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prompt_limit_exceed(use_v2_block_manager: bool): +def test_prompt_limit_exceed(): block_size = 4 max_seqs = 64 max_model_len = 32 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, + scheduler_config = SchedulerConfig("generate", + max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True) @@ -356,171 +346,19 @@ def test_prompt_limit_exceed(use_v2_block_manager: bool): assert out.ignored_seq_groups[0] == seq_group -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_swap(use_v2_block_manager: bool): - """Verify swapping works with chunked prefill requests""" - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig( - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 16 - cache_config.num_gpu_blocks = 16 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - # The last request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - # The running prefill is now swapped. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 0 - assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != [] - assert out.blocks_to_swap_in == [] - - # Add 1 more task. Swap should be prioritized over new prefill. - _, seq_group = create_dummy_prompt("2", prompt_length=60) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != [] - assert out.blocks_to_swap_out == [] - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool): - block_size = 4 - max_seqs = 30 - max_model_len = 200 - max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig( - max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 32 - cache_config.num_gpu_blocks = 32 - scheduler = Scheduler(scheduler_config, cache_config, None) - - _, seq_group = create_dummy_prompt("1", - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - _, out = schedule_and_update_computed_tokens(scheduler) - # The request is chunked. - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 1 - assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() - assert out.num_batched_tokens == max_num_batched_tokens - - # The request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "1" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - # The running prefill is now swapped. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 0 - assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != [] - assert out.blocks_to_swap_in == [] - - # Add 1 more task. Swap is not possible, so prefill is running. - scheduler.block_manager.can_swap_in = MagicMock() - scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER - - _, seq_group2 = create_dummy_prompt("2", - prompt_length=60, - block_size=block_size) - scheduler.add_seq_group(seq_group2) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == [] - assert out.blocks_to_swap_out == [] - assert out.scheduled_seq_groups[0].seq_group == seq_group2 - - # Now although swap is possible, running prefill is prioritized. - scheduler.block_manager.can_swap_in.return_value = AllocStatus.OK - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == [] - assert out.blocks_to_swap_out == [] - assert not seq_group2.is_prefill() - assert out.scheduled_seq_groups[0].seq_group == seq_group2 - append_new_token(seq_group2, 1) - - # Decoding is prioritized. - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 1 - assert out.blocks_to_swap_in == [] - assert out.blocks_to_swap_out == [] - assert not seq_group2.is_prefill() - assert out.scheduled_seq_groups[0].seq_group == seq_group2 - append_new_token(seq_group2, 1) - - # Since we abort the sequence group, we can finally swap. - scheduler.abort_seq_group(seq_group2.request_id) - _, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 1 - assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != [] - assert out.blocks_to_swap_out == [] - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_chunked_prefill_preempt(use_v2_block_manager: bool): +def test_chunked_prefill_preempt(): """Verify preempt works with chunked prefill requests""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 @@ -575,18 +413,18 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): +def test_chunked_prefill_max_seqs(): block_size = 4 max_seqs = 2 max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 128 cache_config.num_gpu_blocks = 128 @@ -629,19 +467,19 @@ def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): assert not running[1].is_prefill() -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_perfix_caching(use_v2_block_manager: bool): +def test_perfix_caching(): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 max_model_len = 80 max_num_batched_tokens = 64 scheduler_config = SchedulerConfig( + "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index f3ec24e7bee3..bd4accab7f37 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -31,7 +31,6 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, # Make a vllm engine runner = VllmRunner(model_name=MODEL, gpu_memory_utilization=0.7, - use_v2_block_manager=True, num_scheduler_steps=num_scheduler_steps, enable_chunked_prefill=enable_chunked_prefill, enforce_eager=enforce_eager) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 5cdf743a4509..5ff32be61159 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -3,32 +3,28 @@ from typing import List, Set, Tuple from unittest.mock import MagicMock -import pytest +import pytest # noqa from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup, SequenceStatus +from vllm.sequence import SequenceGroup -from ..utils import check_deprecated_block_manager_usage from .utils import (append_new_token, append_new_token_seq_group, create_dummy_prompt, get_sequence_groups, schedule_and_update_computed_tokens) -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - "tests/core/test_chunked_prefill_scheduler.py") - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_add_seq_group(use_v2_block_manager: bool): +def test_scheduler_add_seq_group(): block_size = 4 scheduler_config = SchedulerConfig( - 100, 64, 1, use_v2_block_manager=use_v2_block_manager) + "generate", + max_num_batched_tokens=100, + max_num_seqs=64, + max_model_len=1, + ) cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -44,11 +40,14 @@ def test_scheduler_add_seq_group(use_v2_block_manager: bool): assert scheduler.get_num_unfinished_seq_groups() == i + 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_abort_seq_group(use_v2_block_manager: bool): +def test_scheduler_abort_seq_group(): block_size = 4 scheduler_config = SchedulerConfig( - 100, 64, 1, use_v2_block_manager=use_v2_block_manager) + "generate", + max_num_batched_tokens=100, + max_num_seqs=64, + max_model_len=1, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -68,16 +67,16 @@ def test_scheduler_abort_seq_group(use_v2_block_manager: bool): assert scheduler.get_num_unfinished_seq_groups() == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_schedule_simple(use_v2_block_manager: bool): +def test_scheduler_schedule_simple(): block_size = 4 num_seq_group = 4 max_model_len = 16 scheduler_config = SchedulerConfig( - 64, - num_seq_group, - max_model_len, - use_v2_block_manager=use_v2_block_manager) + "generate", + max_num_batched_tokens=64, + max_num_seqs=num_seq_group, + max_model_len=max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -112,17 +111,17 @@ def test_scheduler_schedule_simple(use_v2_block_manager: bool): append_new_token(out, 1) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): +def test_scheduler_prefill_prioritized(): """Verify running batched tokens are not applied to prefill requests.""" block_size = 4 max_model_len = 30 max_batched_num_tokens = 30 scheduler_config = SchedulerConfig( - max_batched_num_tokens, - 2, - max_model_len, - use_v2_block_manager=use_v2_block_manager) + "generate", + max_num_batched_tokens=max_batched_num_tokens, + max_num_seqs=2, + max_model_len=max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 cache_config.num_gpu_blocks = 16 @@ -146,12 +145,15 @@ def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): assert get_sequence_groups(out) == [seq_group_b] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool): +def test_scheduler_schedule_preempt_abort(): block_size = 4 max_model_len = 16 scheduler_config = SchedulerConfig( - 64, 2, max_model_len, use_v2_block_manager=use_v2_block_manager) + "generate", + max_num_batched_tokens=64, + max_num_seqs=2, + max_model_len=max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 @@ -201,17 +203,17 @@ def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool): assert scheduler.get_num_unfinished_seq_groups() == 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_max_seqs(use_v2_block_manager: bool): +def test_scheduler_max_seqs(): block_size = 4 num_seq_group = 4 max_seq_group = 2 max_model_len = 16 scheduler_config = SchedulerConfig( - 64, - max_seq_group, - max_model_len, - use_v2_block_manager=use_v2_block_manager) + "generate", + max_num_batched_tokens=64, + max_num_seqs=max_seq_group, + max_model_len=max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -249,15 +251,15 @@ def test_scheduler_max_seqs(use_v2_block_manager: bool): assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_scheduler_delay_factor(use_v2_block_manager: bool): +def test_scheduler_delay_factor(): block_size = 4 scheduler_config = SchedulerConfig( - 100, - 64, - 16, + "generate", + max_num_batched_tokens=100, + max_num_seqs=64, + max_model_len=16, delay_factor=0.5, - use_v2_block_manager=use_v2_block_manager) + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -294,74 +296,23 @@ def test_scheduler_delay_factor(use_v2_block_manager: bool): append_new_token(out, 1) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_swapped_out_prioritized(use_v2_block_manager: bool): - block_size = 4 - scheduler = initialize_scheduler(max_num_seqs=6, - block_size=block_size, - use_v2_block_manager=use_v2_block_manager, - num_cpu_blocks=64, - num_gpu_blocks=64) - # best_of=2 * 3 == 6 sequences. - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - # prefill scheduled now. - assert len(out.scheduled_seq_groups) == 3 - append_new_token(out, 1) - - # The last request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "2" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - assert len(out.scheduled_seq_groups) == 2 - assert out.num_batched_tokens == 2 - assert out.blocks_to_swap_out != [] - assert out.blocks_to_swap_in == [] - append_new_token(out, 1) - - # Add 1 more task. Swap should be prioritized over prefill. - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler.add_seq_group(seq_group) - seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) - append_new_token(out, 1) - assert len(out.scheduled_seq_groups) == 3 - # 3 decodes. It is swapped in. - assert out.num_batched_tokens == 3 - assert out.blocks_to_swap_in != [] - assert out.blocks_to_swap_out == [] - - def initialize_scheduler( *, max_num_seqs=1000, max_token_budget=1000, max_model_len=1000, lora_config=None, - use_v2_block_manager=False, block_size=4, num_cpu_blocks=8, num_gpu_blocks=8, ): block_size = block_size scheduler_config = SchedulerConfig( - max_token_budget, - max_num_seqs, - max_model_len, - use_v2_block_manager=use_v2_block_manager) + "generate", + max_num_batched_tokens=max_token_budget, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = num_cpu_blocks cache_config.num_gpu_blocks = num_gpu_blocks @@ -386,15 +337,12 @@ def add_token_budget(budget: SchedulingBudget, budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool): +def test_prefill_schedule_max_prompt_len(): """ Test prompt longer than max_prompt_len is aborted. """ block_size = 4 - scheduler = initialize_scheduler(max_model_len=30, - use_v2_block_manager=use_v2_block_manager, - block_size=block_size) + scheduler = initialize_scheduler(max_model_len=30, block_size=block_size) _, seq_group = create_dummy_prompt("0", prompt_length=60, block_size=block_size) @@ -409,14 +357,12 @@ def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool): assert len(remaining_waiting) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_token_budget(use_v2_block_manager: bool): +def test_prefill_schedule_token_budget(): """ Test token budget respected. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) budget = create_token_budget(token_budget=0) @@ -446,8 +392,7 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool): assert len(remaining_waiting) == 1 # Test when current_batched_tokens respected. - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=16, num_gpu_blocks=16) budget = create_token_budget(token_budget=60) @@ -474,14 +419,12 @@ def test_prefill_schedule_token_budget(use_v2_block_manager: bool): assert len(remaining_waiting) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_max_seqs(use_v2_block_manager: bool): +def test_prefill_schedule_max_seqs(): """ Test max seq respected. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) budget = create_token_budget(max_num_seqs=2) @@ -515,15 +458,13 @@ def test_prefill_schedule_max_seqs(use_v2_block_manager: bool): assert len(remaining_waiting) == 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_max_lora(use_v2_block_manager: bool): +def test_prefill_schedule_max_lora(): """ Test max lora is respected and prioritized. """ block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config, - use_v2_block_manager=use_v2_block_manager, block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) @@ -570,14 +511,12 @@ def test_prefill_schedule_max_lora(use_v2_block_manager: bool): assert budget.num_batched_tokens == 60 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager): +def test_prefill_schedule_no_block_manager_capacity(): """ Test sequence cannot be scheduled due to block manager has no capacity. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_gpu_blocks=128, num_cpu_blocks=128) budget = create_token_budget() @@ -614,14 +553,12 @@ def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager): assert len(remaining_waiting) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_decode_schedule_preempted(use_v2_block_manager: bool): +def test_decode_schedule_preempted(): """ Test decodes cannot be scheduled and preempted. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=64, num_gpu_blocks=64) curr_loras = None @@ -660,70 +597,12 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_decode_swap_beam_search(use_v2_block_manager: bool): - """ - Test best_of > 1 swap out blocks - """ - block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, - num_gpu_blocks=64, - num_cpu_blocks=64) - curr_loras = None - budget = create_token_budget() - for i in range(3): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - best_of=2, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - scheduler._add_seq_group_to_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) - budget.add_num_batched_tokens( - seq_group.request_id, seq_group.num_seqs(SequenceStatus.RUNNING)) - - # The last request should be swapped out. - scheduler.block_manager.can_append_slots = MagicMock() - - def cannot_append_second_group(seq_group, num_lookahead_slots): - return seq_group.request_id != "2" - - scheduler.block_manager.can_append_slots.side_effect = ( - cannot_append_second_group) - scheduler.block_manager.swap_out = MagicMock() - expected_swap_mapping = [("5", "7")] - scheduler.block_manager.swap_out.return_value = expected_swap_mapping - - output = scheduler._schedule_running(budget, curr_loras) - remainig_running = scheduler.running - assert len(remainig_running) == 0 - assert len(output.decode_seq_groups) == 2 - assert len(output.prefill_seq_groups) == 0 - assert output.decode_seq_groups[0].seq_group.request_id == "0" - assert output.decode_seq_groups[1].seq_group.request_id == "1" - assert len(output.preempted) == 0 - assert len(output.swapped_out) == 1 - # Budget should refledct preempted requests. - assert budget.num_batched_tokens == 2 - # since there are 2 sequences, 2 should be subtracted. - assert budget.num_curr_seqs == 4 - # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == expected_swap_mapping - # Nothing is copied. - assert output.blocks_to_copy == [] - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): +def test_schedule_decode_blocks_to_copy_update(): """ Verify blocks_to_copy is updated. """ block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=4, + scheduler = initialize_scheduler(block_size=4, num_cpu_blocks=16, num_gpu_blocks=16) _, seq_group = create_dummy_prompt("1", @@ -754,117 +633,10 @@ def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): assert output.blocks_to_copy == [(2, 3)] -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_simple(use_v2_block_manager: bool): - block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size) - curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] - _, seq_group = create_dummy_prompt("1", - prompt_length=4, - best_of=2, - block_size=block_size) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(4, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget() - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 0 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - # swap in is the reverse of swap out - blocks_to_swap_in_reverse = [] - for swapin, swapout in output.blocks_to_swap_in: - blocks_to_swap_in_reverse.append((swapout, swapin)) - assert blocks_to_swap_out == blocks_to_swap_in_reverse - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool): - block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, - num_cpu_blocks=32, - num_gpu_blocks=32) - curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] - for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget(token_budget=1) - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 1 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 1 - assert len(output.prefill_seq_groups) == 0 - - # Verify num_batched_tokens are respected. - budget = create_token_budget(token_budget=1) - add_token_budget(budget, 1, 0) - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 1 - assert budget.num_batched_tokens == 1 - assert budget.num_curr_seqs == 0 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_max_seqs(use_v2_block_manager: bool): - block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, - num_cpu_blocks=64, - num_gpu_blocks=64) - curr_loras = None - blocks_to_swap_out: List[Tuple[int, int]] = [] - for i in range(4): - _, seq_group = create_dummy_prompt(str(i), - prompt_length=60, - block_size=4) - scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) - scheduler._swap_out(seq_group, blocks_to_swap_out) - scheduler._add_seq_group_to_swapped(seq_group) - - budget = create_token_budget(max_num_seqs=2) - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 2 - assert budget.num_batched_tokens == 2 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 2 - assert len(output.prefill_seq_groups) == 0 - - # Verify num_curr_seqs are respected. - output = scheduler._schedule_swapped(budget, curr_loras) - remaining_swapped = scheduler.swapped - assert len(remaining_swapped) == 2 - assert budget.num_batched_tokens == 2 - assert budget.num_curr_seqs == 2 - assert len(output.decode_seq_groups) == 0 - assert len(output.prefill_seq_groups) == 0 - - -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_max_loras(use_v2_block_manager: bool): +def test_schedule_swapped_max_loras(): block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) scheduler = initialize_scheduler(lora_config=lora_config, - use_v2_block_manager=use_v2_block_manager, block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) @@ -894,11 +666,9 @@ def test_schedule_swapped_max_loras(use_v2_block_manager: bool): assert len(curr_loras) == 1 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool): +def test_schedule_swapped_cannot_swap_in(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None @@ -927,11 +697,9 @@ def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool): assert len(output.prefill_seq_groups) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_infeasible_swap(use_v2_block_manager: bool): +def test_infeasible_swap(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None @@ -961,11 +729,9 @@ def test_infeasible_swap(use_v2_block_manager: bool): assert len(output.prefill_seq_groups) == 0 -@pytest.mark.parametrize('use_v2_block_manager', [True, False]) -def test_schedule_swapped_blocks_to_copy(use_v2_block_manager: bool): +def test_schedule_swapped_blocks_to_copy(): block_size = 4 - scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, - block_size=block_size, + scheduler = initialize_scheduler(block_size=block_size, num_cpu_blocks=32, num_gpu_blocks=32) curr_loras = None diff --git a/tests/core/test_scheduler_encoder_decoder.py b/tests/core/test_scheduler_encoder_decoder.py index 50c047f30b80..7cd0416d321e 100644 --- a/tests/core/test_scheduler_encoder_decoder.py +++ b/tests/core/test_scheduler_encoder_decoder.py @@ -36,7 +36,12 @@ def test_scheduler_schedule_simple_encoder_decoder(): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + task="generate", + max_num_batched_tokens=64, + max_num_seqs=num_seq_group, + max_model_len=max_model_len, + ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 88d0a4ba7f57..ed6360f9d614 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -11,6 +11,7 @@ import pytest +from vllm.config import TaskOption from vllm.logger import init_logger from ..utils import compare_two_settings, fork_new_process_for_each_test @@ -27,18 +28,26 @@ class ParallelSetup(NamedTuple): chunked_prefill: bool +class PPTestOptions(NamedTuple): + multi_node_only: bool + trust_remote_code: bool + tokenizer_mode: Optional[str] + + @dataclass class PPTestSettings: parallel_setups: List[ParallelSetup] distributed_backends: List[str] - trust_remote_code: bool - tokenizer_mode: Optional[str] + task: TaskOption + test_options: PPTestOptions @staticmethod def detailed( *, tp_base: int = 1, pp_base: int = 2, + multi_node_only: bool = False, + task: TaskOption = "auto", trust_remote_code: bool = False, tokenizer_mode: Optional[str] = None, ): @@ -66,8 +75,10 @@ def detailed( chunked_prefill=False), ], distributed_backends=["mp", "ray"], - trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, + task=task, + test_options=PPTestOptions(multi_node_only=multi_node_only, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode), ) @staticmethod @@ -75,6 +86,8 @@ def fast( *, tp_base: int = 1, pp_base: int = 2, + task: TaskOption = "auto", + multi_node_only: bool = False, trust_remote_code: bool = False, tokenizer_mode: Optional[str] = None, ): @@ -86,25 +99,27 @@ def fast( chunked_prefill=False), ], distributed_backends=["mp"], - trust_remote_code=trust_remote_code, - tokenizer_mode=tokenizer_mode, + task=task, + test_options=PPTestOptions(multi_node_only=multi_node_only, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode), ) def iter_params(self, model_name: str): + opts = self.test_options + for parallel_setup in self.parallel_setups: for distributed_backend in self.distributed_backends: yield (model_name, parallel_setup, distributed_backend, - self.trust_remote_code, self.tokenizer_mode) + self.task, opts) # NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU # The values displayed here are only a rough indicator of the size of the model # yapf: disable -GENERATION_MODEL_SETTINGS = { - # [DETAILED TESTS] - "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), - # [FAST TESTS] +TEXT_GENERATION_MODELS = { + # [Decoder-only] # Uses Llama # "BAAI/AquilaChat-7B": PPTestSettings.fast(), "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501 @@ -130,9 +145,10 @@ def iter_params(self, model_name: str): # Uses Llama # "internlm/internlm-chat-7b": PPTestSettings.fast(), "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), - "core42/jais-13b-chat": PPTestSettings.fast(), + "inceptionai/jais-13b-chat": PPTestSettings.fast(), # TODO: Implement PP # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), + "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), # Uses Llama @@ -145,10 +161,9 @@ def iter_params(self, model_name: str): "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), "microsoft/phi-2": PPTestSettings.fast(), - "microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(), + "microsoft/Phi-3-mini-4k-instruct": PPTestSettings.detailed(trust_remote_code=True, multi_node_only=True), # noqa: E501 "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 - # FIXME: https://github.com/vllm-project/vllm/issues/8553 - # "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "adept/persimmon-8b-chat": PPTestSettings.fast(), "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True), "Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(), @@ -156,41 +171,43 @@ def iter_params(self, model_name: str): "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), "bigcode/starcoder2-3b": PPTestSettings.fast(), "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), - # FIXME: Cannot load tokenizer in latest transformers version + # FIXME: Cannot load tokenizer in latest transformers version. + # Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf` # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), + # [Encoder-only] + # TODO: Implement PP + # "facebook/bart-base": PPTestSettings.fast(), } -EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated] - # [FAST TESTS] +EMBEDDING_MODELS = { # type: ignore[var-annotated] + # [Text-only] "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), "BAAI/bge-multilingual-gemma2": PPTestSettings.fast(), "Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501 } -MULTIMODAL_MODEL_SETTINGS = { - # [FAST TESTS] +MULTIMODAL_MODELS = { + # [Decoder-only] "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(), "facebook/chameleon-7b": PPTestSettings.fast(), "adept/fuyu-8b": PPTestSettings.fast(), + "THUDM/glm-4v-9b": PPTestSettings.fast(trust_remote_code=True), "OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True), "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(), "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(), "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(), "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True), - # TODO: Implement PP - # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), + "allenai/Molmo-7B-D-0924": PPTestSettings.fast(trust_remote_code=True), "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 "mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501 "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), + "Qwen/Qwen2-Audio-7B-Instruct": PPTestSettings.fast(), "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), "fixie-ai/ultravox-v0_3": PPTestSettings.fast(), -} - -CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated] - # [FAST TESTS] + # [Encoder-decoder] # TODO: Implement PP - # "facebook/bart-base": PPTestSettings.fast(), + # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), } # yapf: enable @@ -199,6 +216,7 @@ def iter_params(self, model_name: str): # [LANGUAGE GENERATION] "meta-llama/Meta-Llama-3-8B", "ibm/PowerLM-3b", + "microsoft/Phi-3-mini-4k-instruct", # [LANGUAGE EMBEDDING] "intfloat/e5-mistral-7b-instruct", "BAAI/bge-multilingual-gemma2", @@ -213,19 +231,22 @@ def _compare_tp( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + task: TaskOption, + test_options: PPTestOptions, num_gpus_available: int, *, - method: Literal["generate", "encode"] = "encode", + method: Literal["generate", "encode"], ): tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup + multi_node_only, trust_remote_code, tokenizer_mode = test_options if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") + if multi_node_only and not VLLM_MULTI_NODE: + pytest.skip("Not in multi-node setting") common_args = [ # use half precision for speed and memory savings in CI environment @@ -240,6 +261,8 @@ def _compare_tp( common_args.append("--enable-chunked-prefill") if eager_mode: common_args.append("--enforce-eager") + if task != "auto": + common_args.extend(["--task", task]) if trust_remote_code: common_args.append("--trust-remote-code") if tokenizer_mode: @@ -297,10 +320,10 @@ def _compare_tp( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", - "trust_remote_code", "tokenizer_mode"), + ("model_name", "parallel_setup", "distributed_backend", "task", + "test_options"), [ - params for model_name, settings in GENERATION_MODEL_SETTINGS.items() + params for model_name, settings in TEXT_GENERATION_MODELS.items() for params in settings.iter_params(model_name) if model_name in TEST_MODELS ], @@ -310,24 +333,24 @@ def test_tp_language_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + task: TaskOption, + test_options: PPTestOptions, num_gpus_available, ): _compare_tp(model_name, parallel_setup, distributed_backend, - trust_remote_code, - tokenizer_mode, + task, + test_options, num_gpus_available, method="generate") @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", - "trust_remote_code", "tokenizer_mode"), + ("model_name", "parallel_setup", "distributed_backend", "task", + "test_options"), [ - params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items() + params for model_name, settings in EMBEDDING_MODELS.items() for params in settings.iter_params(model_name) if model_name in TEST_MODELS ], @@ -337,24 +360,24 @@ def test_tp_language_embedding( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + task: TaskOption, + test_options: PPTestOptions, num_gpus_available, ): _compare_tp(model_name, parallel_setup, distributed_backend, - trust_remote_code, - tokenizer_mode, + task, + test_options, num_gpus_available, method="encode") @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", - "trust_remote_code", "tokenizer_mode"), + ("model_name", "parallel_setup", "distributed_backend", "task", + "test_options"), [ - params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items() + params for model_name, settings in MULTIMODAL_MODELS.items() for params in settings.iter_params(model_name) if model_name in TEST_MODELS ], @@ -364,14 +387,14 @@ def test_tp_multimodal_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, - trust_remote_code: bool, - tokenizer_mode: Optional[str], + task: TaskOption, + test_options: PPTestOptions, num_gpus_available, ): _compare_tp(model_name, parallel_setup, distributed_backend, - trust_remote_code, - tokenizer_mode, + task, + test_options, num_gpus_available, method="generate") diff --git a/tests/encoder_decoder/test_e2e_correctness.py b/tests/encoder_decoder/test_e2e_correctness.py index 9324a737a779..bef0c515b907 100644 --- a/tests/encoder_decoder/test_e2e_correctness.py +++ b/tests/encoder_decoder/test_e2e_correctness.py @@ -7,8 +7,8 @@ import pytest from transformers import AutoModelForSeq2SeqLM +from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs -from vllm.utils import is_cpu from ..conftest import DecoderPromptType from ..models.utils import check_logprobs_close @@ -35,7 +35,7 @@ def vllm_to_hf_output( @pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType)) @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.skipif( - is_cpu(), + current_platform.is_cpu(), reason="CPU backend is not currently supported with encoder/decoder models" ) def test_encoder_decoder_e2e( @@ -50,7 +50,7 @@ def test_encoder_decoder_e2e( enforce_eager: bool, ) -> None: ''' - End-to-End (E2E) test for the encoder-decoder framework. + End-to-End (E2E) test for the encoder-decoder framework. This test evaluates the encoder-decoder functionality using the BART model. We compare the outputs of the Hugging Face and vLLM implementations to ensure that both implementations produce consistent diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py new file mode 100644 index 000000000000..fc66386fd2d2 --- /dev/null +++ b/tests/entrypoints/llm/test_chat.py @@ -0,0 +1,92 @@ +from typing import List + +import pytest + +from vllm import LLM + +from ..openai.test_vision import TEST_IMAGE_URLS + + +def test_chat(): + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") + + prompt1 = "Explain the concept of entropy." + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + outputs = llm.chat(messages) + assert len(outputs) == 1 + + +def test_multi_chat(): + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + + messages = [conversation1, conversation2] + + outputs = llm.chat(messages) + assert len(outputs) == 2 + + +@pytest.mark.parametrize("image_urls", + [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +def test_chat_multi_image(image_urls: List[str]): + llm = LLM( + model="microsoft/Phi-3.5-vision-instruct", + dtype="bfloat16", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + trust_remote_code=True, + limit_mm_per_prompt={"image": 2}, + ) + + messages = [{ + "role": + "user", + "content": [ + *({ + "type": "image_url", + "image_url": { + "url": image_url + } + } for image_url in image_urls), + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + outputs = llm.chat(messages) + assert len(outputs) >= 0 diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index 1885f2e168d8..4c9f796e5ed7 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -4,8 +4,7 @@ import pytest from vllm import LLM, EmbeddingRequestOutput, PoolingParams - -from ...conftest import cleanup +from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "intfloat/e5-mistral-7b-instruct" @@ -41,7 +40,7 @@ def llm(): del llm - cleanup() + cleanup_dist_env_and_memory() def assert_outputs_equal(o1: List[EmbeddingRequestOutput], diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index 6543c4bb1b58..7d2b37775272 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -4,9 +4,7 @@ import pytest from vllm import LLM, RequestOutput, SamplingParams - -from ...conftest import cleanup -from ..openai.test_vision import TEST_IMAGE_URLS +from vllm.distributed import cleanup_dist_env_and_memory MODEL_NAME = "facebook/opt-125m" @@ -40,7 +38,7 @@ def llm(): del llm - cleanup() + cleanup_dist_env_and_memory() def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): @@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM): # sampling_params is None, default params should be applied outputs = llm.generate(PROMPTS, sampling_params=None) assert len(PROMPTS) == len(outputs) - - -def test_chat(): - - llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") - - prompt1 = "Explain the concept of entropy." - messages = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, - ] - outputs = llm.chat(messages) - assert len(outputs) == 1 - - -def test_multi_chat(): - - llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") - - prompt1 = "Explain the concept of entropy." - prompt2 = "Explain what among us is." - - conversation1 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt1 - }, - ] - - conversation2 = [ - { - "role": "system", - "content": "You are a helpful assistant" - }, - { - "role": "user", - "content": prompt2 - }, - ] - - messages = [conversation1, conversation2] - - outputs = llm.chat(messages) - assert len(outputs) == 2 - - -@pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) -def test_chat_multi_image(image_urls: List[str]): - llm = LLM( - model="microsoft/Phi-3.5-vision-instruct", - dtype="bfloat16", - max_model_len=4096, - max_num_seqs=5, - enforce_eager=True, - trust_remote_code=True, - limit_mm_per_prompt={"image": 2}, - ) - - messages = [{ - "role": - "user", - "content": [ - *({ - "type": "image_url", - "image_url": { - "url": image_url - } - } for image_url in image_urls), - { - "type": "text", - "text": "What's in this image?" - }, - ], - }] - outputs = llm.chat(messages) - assert len(outputs) >= 0 diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index 9f5727ecd040..eb2113692e7b 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -5,10 +5,9 @@ from huggingface_hub import snapshot_download from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -from ...conftest import cleanup - MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" PROMPTS = [ @@ -39,7 +38,7 @@ def llm(): del llm - cleanup() + cleanup_dist_env_and_memory() @pytest.fixture(scope="module") diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 2841dfc6bd9c..67c79415f322 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -5,12 +5,11 @@ import jsonschema import pytest +from vllm.distributed import cleanup_dist_env_and_memory from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput from vllm.sampling_params import GuidedDecodingParams, SamplingParams -from ...conftest import cleanup - MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -23,7 +22,7 @@ def llm(): with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm - cleanup() + cleanup_dist_env_and_memory() @pytest.mark.skip_global_cleanup diff --git a/tests/entrypoints/llm/test_init.py b/tests/entrypoints/llm/test_init.py new file mode 100644 index 000000000000..c9a4ad44fea3 --- /dev/null +++ b/tests/entrypoints/llm/test_init.py @@ -0,0 +1,22 @@ +import pytest + +from vllm import LLM + +from ...utils import error_on_warning + +MODEL_NAME = "facebook/opt-125m" + + +def test_pos_args_deprecated(): + with error_on_warning(DeprecationWarning): + LLM(model=MODEL_NAME, tokenizer=MODEL_NAME) + + with error_on_warning(DeprecationWarning): + LLM(MODEL_NAME, tokenizer=MODEL_NAME) + + with pytest.warns(DeprecationWarning, match="'tokenizer'"): + LLM(MODEL_NAME, MODEL_NAME) + + with pytest.warns(DeprecationWarning, + match="'tokenizer', 'tokenizer_mode'"): + LLM(MODEL_NAME, MODEL_NAME, "auto") diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py index 39480531f586..cbfb0cc32c1c 100644 --- a/tests/entrypoints/llm/test_lazy_outlines.py +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -1,6 +1,7 @@ import sys from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory def test_lazy_outlines(sample_regex): @@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex): ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + # Create an LLM without guided decoding as a baseline. llm = LLM(model="facebook/opt-125m", enforce_eager=True, gpu_memory_utilization=0.3) @@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex): # make sure outlines is not imported assert 'outlines' not in sys.modules + # Destroy the LLM object and free up the GPU memory. + del llm + cleanup_dist_env_and_memory() + + # Create an LLM with guided decoding enabled. llm = LLM(model="facebook/opt-125m", enforce_eager=True, guided_decoding_backend="lm-format-enforcer", - gpu_memory_utilization=0.3) + gpu_memory_utilization=0.6) sampling_params = SamplingParams(temperature=0.8, top_p=0.95) outputs = llm.generate( prompts=[ diff --git a/tests/entrypoints/offline_mode/test_offline_mode.py b/tests/entrypoints/offline_mode/test_offline_mode.py index 0b6026a89c75..65699e609e4a 100644 --- a/tests/entrypoints/offline_mode/test_offline_mode.py +++ b/tests/entrypoints/offline_mode/test_offline_mode.py @@ -1,51 +1,56 @@ """Tests for HF_HUB_OFFLINE mode""" import importlib import sys -import weakref import pytest from vllm import LLM - -from ...conftest import cleanup - -MODEL_NAME = "facebook/opt-125m" +from vllm.distributed import cleanup_dist_env_and_memory + +MODEL_CONFIGS = [ + { + "model": "facebook/opt-125m", + "enforce_eager": True, + "gpu_memory_utilization": 0.20, + "max_model_len": 64, + "max_num_batched_tokens": 64, + "max_num_seqs": 64, + "tensor_parallel_size": 1, + }, + { + "model": "mistralai/Mistral-7B-Instruct-v0.1", + "enforce_eager": True, + "gpu_memory_utilization": 0.95, + "max_model_len": 64, + "max_num_batched_tokens": 64, + "max_num_seqs": 64, + "tensor_parallel_size": 1, + "tokenizer_mode": "mistral", + }, +] @pytest.fixture(scope="module") -def llm(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) - - with llm.deprecate_legacy_api(): - yield weakref.proxy(llm) +def cache_models(): + # Cache model files first + for model_config in MODEL_CONFIGS: + LLM(**model_config) + cleanup_dist_env_and_memory() - del llm - - cleanup() + yield @pytest.mark.skip_global_cleanup -def test_offline_mode(llm: LLM, monkeypatch): - # we use the llm fixture to ensure the model files are in-cache - del llm - +@pytest.mark.usefixtures("cache_models") +def test_offline_mode(monkeypatch): # Set HF to offline mode and ensure we can still construct an LLM try: monkeypatch.setenv("HF_HUB_OFFLINE", "1") # Need to re-import huggingface_hub and friends to setup offline mode _re_import_modules() # Cached model files should be used in offline mode - LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) + for model_config in MODEL_CONFIGS: + LLM(**model_config) finally: # Reset the environment after the test # NB: Assuming tests are run in online mode diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 0fbc4cca83bd..d1aebbd70d25 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -16,9 +16,6 @@ # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" -# technically this needs Mistral-7B-v0.1 as base, but we're not testing -# generation quality here -LORA_NAME = "typeof/zephyr-7b-beta-lora" @pytest.fixture(scope="module") @@ -433,18 +430,28 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI, model=model_name, messages=messages, max_tokens=10, + extra_body=dict(min_tokens=10), temperature=0.0, stream=True, stream_options={ "include_usage": True, - "continuous_usage_stats": True + "continuous_usage_stats": True, }, ) + last_completion_tokens = 0 async for chunk in stream: assert chunk.usage.prompt_tokens >= 0 - assert chunk.usage.completion_tokens >= 0 + assert last_completion_tokens == 0 or \ + chunk.usage.completion_tokens > last_completion_tokens or \ + ( + not chunk.choices and + chunk.usage.completion_tokens == last_completion_tokens + ) assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + chunk.usage.completion_tokens) + last_completion_tokens = chunk.usage.completion_tokens + + assert last_completion_tokens == 10 # NOTE: Not sure why, but when I place this after `test_guided_regex_chat` @@ -841,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): @pytest.mark.asyncio async def test_response_format_json_schema(client: openai.AsyncOpenAI): + prompt = 'what is 1+1? The format is "result": 2' + # Check that this prompt cannot lead to a valid JSON without json_schema for _ in range(2): resp = await client.chat.completions.create( model=MODEL_NAME, messages=[{ - "role": - "user", - "content": ('what is 1+1? please respond with a JSON object, ' - 'the format is {"result": 2}') + "role": "user", + "content": prompt + }], + ) + content = resp.choices[0].message.content + assert content is not None + with pytest.raises((json.JSONDecodeError, AssertionError)): + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded + + for _ in range(2): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": prompt }], response_format={ "type": "json_schema", diff --git a/tests/entrypoints/openai/test_chunked_prompt.py b/tests/entrypoints/openai/test_chunked_prompt.py new file mode 100644 index 000000000000..61d66365130c --- /dev/null +++ b/tests/entrypoints/openai/test_chunked_prompt.py @@ -0,0 +1,126 @@ +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.fixture(scope="module") +def server(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + # lora config below + "--max-num-seqs", + "128", + "--enable-chunked-prefill", + "--max-num-batched-tokens", + "1000", + # large prompts create a lot of output + "--disable-log-requests", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_completion_stream_options_and_logprobs_with_long_prompts( + client: openai.AsyncOpenAI): + # Test stream with long prompt + prompt = "What is the capital of France?" * 400 + + stream = await client.completions.create( + model=MODEL_NAME, + prompt=prompt, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": True, + }, + logprobs=5, + ) + + tokens_received = 0 + finished = False + async for chunk in stream: + assert chunk.usage.prompt_tokens >= 0 + assert chunk.usage.completion_tokens >= 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + if not finished: + tokens_received += 1 + assert chunk.choices[0].text + + if chunk.choices[0].finish_reason is not None: + finished = True + + if finished: + assert chunk.usage.completion_tokens == tokens_received + + +@pytest.mark.asyncio +async def test_chat_completion_stream_options_and_logprobs_with_long_prompts( + client: openai.AsyncOpenAI): + # Test stream with long prompt + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "What is the capital of France?" * 400 + }] + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=5, + temperature=0.0, + stream=True, + stream_options={ + "include_usage": True, + "continuous_usage_stats": True, + }, + logprobs=True, + top_logprobs=5, + ) + + tokens_received = 0 + empty_chunks_received = 0 + finished = False + async for chunk in stream: + assert chunk.usage.prompt_tokens >= 0 + assert chunk.usage.completion_tokens >= 0 + assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens + + chunk.usage.completion_tokens) + + if not finished: + if chunk.choices[0].delta.content == "": + # when there is no tokens generated + assert chunk.usage.completion_tokens == 0 + assert chunk.choices[0].logprobs is None + empty_chunks_received += 1 + else: + tokens_received += 1 + + if chunk.choices[0].finish_reason is not None: + finished = True + + if finished: + assert chunk.usage.completion_tokens == tokens_received + + assert empty_chunks_received <= 1 diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index cc72a49ebbbd..f03bdb045f64 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-pa"], +) +async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str): + """Streaming for parallel sampling. + The tokens from multiple samples, are flattened into a single stream, + with an index to indicate which sample the token belongs to. + """ + + prompt = "What is an LLM?" + n = 3 + max_tokens = 5 + + stream = await client.completions.create(model=model_name, + prompt=prompt, + max_tokens=max_tokens, + n=n, + stream=True) + chunks: List[List[str]] = [[] for i in range(n)] + finish_reason_count = 0 + async for chunk in stream: + index = chunk.choices[0].index + text = chunk.choices[0].text + chunks[index].append(text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == n + for chunk in chunks: + assert len(chunk) == max_tokens + print("".join(chunk)) + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index ec550fe82c70..e969d33775d8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -22,12 +22,13 @@ class MockHFConfig: @dataclass class MockModelConfig: + task = "generate" tokenizer = MODEL_NAME trust_remote_code = False tokenizer_mode = "auto" + chat_template_text_format = "string" max_model_len = 100 tokenizer_revision = None - embedding_mode = False multimodal_config = MultiModalConfig() hf_config = MockHFConfig() diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 25ab91ef6933..6fcc92022855 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -6,7 +6,7 @@ from ...utils import RemoteOpenAIServer -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +MODEL_NAME = "meta-llama/Llama-3.2-1B" @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 81d79601124a..8311a5cb3c2d 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -23,6 +23,8 @@ @pytest.fixture(scope="module") def server(): args = [ + "--task", + "generate", "--dtype", "bfloat16", "--max-model-len", diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 6ded5102c931..5fa466f8f041 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -8,21 +8,25 @@ from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import (parse_chat_messages, parse_chat_messages_futures) +from vllm.entrypoints.llm import apply_hf_chat_template from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" +MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def phi3v_model_config(): return ModelConfig(PHI3V_MODEL_ID, - PHI3V_MODEL_ID, + task="generate", + tokenizer=PHI3V_MODEL_ID, tokenizer_mode="auto", trust_remote_code=True, dtype="bfloat16", seed=0, + chat_template_text_format="string", limit_mm_per_prompt={ "image": 2, }) @@ -38,6 +42,30 @@ def phi3v_tokenizer(): ) +@pytest.fixture(scope="module") +def mllama_model_config(): + return ModelConfig(MLLAMA_MODEL_ID, + task="generate", + tokenizer=MLLAMA_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }) + + +@pytest.fixture(scope="module") +def mllama_tokenizer(): + return TokenizerGroup( + MLLAMA_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + + @pytest.fixture(scope="module") def image_url(): image = ImageAsset('cherry_blossom') @@ -303,6 +331,51 @@ def test_parse_chat_messages_multiple_images_across_messages( _assert_mm_data_is_image_input(mm_data, 2) +def test_parse_chat_messages_context_text_format( + phi3v_model_config, + phi3v_tokenizer, +): + phi3v_model_config.chat_template_text_format = "openai" + conversation, mm_data = parse_chat_messages( + [{ + "role": "user", + "content": [{ + "type": "text", + "text": "What's in this text?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": "user", + "content": "What about this one?" + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [ + { + "role": "user", + "content": [{ + "type": "text", + "text": "What's in this text?" + }] + }, + { + "role": "assistant", + "content": [{ + "type": "text", + "text": "Some stuff." + }] + }, + { + "role": "user", + "content": [{ + "type": "text", + "text": "What about this one?" + }] + }, + ] + + def test_parse_chat_messages_rejects_too_many_images_in_one_message( phi3v_model_config, phi3v_tokenizer, @@ -387,3 +460,179 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages( "text": "What about these two?" }] }], phi3v_model_config, phi3v_tokenizer) + + +def test_parse_chat_messages_multiple_images_uncommon_input( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [ + "What's in these images?", { + "image_url": image_url + }, { + "image_url": image_url + } + ] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?" + }] + _assert_mm_data_is_image_input(mm_data, 2) + + +### Mllama currently wraps images / texts as interleaved dictionaries +def test_mllama_single_image( + mllama_model_config, + mllama_tokenizer, + image_url, +): + """Ensures that a single image is parsed correctly mllama.""" + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [{ + 'type': 'text', + 'text': 'The content of this image is:' + }, { + "image_url": image_url + }] + }], mllama_model_config, mllama_tokenizer) + _assert_mm_data_is_image_input(mm_data, 1) + assert conversation == [{ + 'role': + 'user', + 'content': [{ + 'type': 'text', + 'text': 'The content of this image is:' + }, { + 'type': 'image' + }] + }] + + +def test_mllama_interleaved_images( + mllama_model_config, + mllama_tokenizer, + image_url, +): + """Ensures that multiple image are parsed as interleaved dicts.""" + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [ + { + 'type': 'text', + 'text': 'The content of the first image is:' + }, + { + "image_url": image_url + }, + { + 'type': 'text', + 'text': 'The content of the second image is:' + }, + { + "image_url": image_url + }, + ] + }], mllama_model_config, mllama_tokenizer) + _assert_mm_data_is_image_input(mm_data, 2) + assert conversation == [{ + 'role': + 'user', + 'content': [{ + 'type': 'text', + 'text': 'The content of the first image is:' + }, { + 'type': 'image' + }, { + 'type': 'text', + 'text': 'The content of the second image is:' + }, { + 'type': 'image' + }] + }] + + +@pytest.mark.parametrize("model", [MLLAMA_MODEL_ID]) +def test_multimodal_image_parsing_matches_hf(model, image_url): + """Checks end to end hf alignment for multimodal [image] parsing.""" + + def get_conversation(is_hf: bool): + img_part = {"type": "image_url", "image_url": {"url": image_url}} + if is_hf: + img_part = {'type': 'image'} + return [{ + 'role': + 'user', + 'content': [ + { + 'type': 'text', + 'text': 'The content of the first image is:' + }, + img_part, + { + 'type': 'text', + 'text': 'The content of the second image is:' + }, + img_part, + { + 'type': 'text', + 'text': 'What animal is in the first image?' + }, + ] + }] + + # Build a config for the model + model_config = ModelConfig(model, + task="generate", + tokenizer=MLLAMA_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }) + + # Build the tokenizer group and grab the underlying tokenizer + tokenizer_group = TokenizerGroup( + MLLAMA_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + tokenizer = tokenizer_group.tokenizer + + # Build and parse a conversation with {"type": "image"} using the tokenizer + hf_conversation = get_conversation(is_hf=True) + hf_result = tokenizer.apply_chat_template( + hf_conversation, + tokenize=False, + add_generation_prompt=True, + ) + + # Now parse with vLLMs chat utils & apply the template + vllm_conversation = get_conversation(is_hf=False) + conversation, _ = parse_chat_messages( + vllm_conversation, + model_config, + tokenizer_group, + ) + + vllm_result = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=None, + add_generation_prompt=True, + ) + + assert hf_result == vllm_result diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 9b476585fa19..0e3d3c3a2e98 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,12 +1,13 @@ +import random from typing import Type import pytest import torch from tests.kernels.utils import opcheck -from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, - NewGELU, QuickGELU, - SiluAndMul) +from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, + GeluAndMul, NewGELU, + QuickGELU, SiluAndMul) from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -20,7 +21,8 @@ ] -@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"]) +@pytest.mark.parametrize("activation", + ["silu", "gelu", "gelu_tanh", "fatrelu"]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -47,16 +49,23 @@ def test_act_and_mul( elif activation == "gelu_tanh": layer = GeluAndMul(approximate="tanh") fn = torch.ops._C.gelu_tanh_and_mul + elif activation == "fatrelu": + threshold = random.uniform(0, 1) + layer = FatreluAndMul(threshold) + fn = torch.ops._C.fatrelu_and_mul out = layer(x) ref_out = layer.forward_native(x) - # The SiLU and GELU implementations are equivalent to the native PyTorch - # implementations, so we can do exact comparison. + # The SiLU, GELU and FatReLU implementations are equivalent to the native + # PyTorch implementations, so we can do exact comparison. torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - opcheck(fn, (out, x)) + if activation == "fatrelu": + opcheck(fn, (out, x, threshold)) + else: + opcheck(fn, (out, x)) @pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast), diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index f471dcee938b..8bcee9840377 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch): override_backend_env_variable(monkeypatch, name) if device == "cpu": - with patch("vllm.attention.selector.is_cpu", return_value=True): - backend = which_attn_to_use(16, None, torch.float16, torch.float16, - 16, False) + with patch("vllm.attention.selector.current_platform.is_cpu", + return_value=True): + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == "TORCH_SDPA" elif device == "hip": with patch("vllm.attention.selector.is_hip", return_value=True): - backend = which_attn_to_use(16, None, torch.float16, torch.float16, - 16, False) + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == "ROCM_FLASH" elif device == "openvino": with patch("vllm.attention.selector.is_openvino", return_value=True): - backend = which_attn_to_use(16, None, torch.float16, torch.float16, - 16, False) + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, + False) assert backend.name == "OPENVINO" else: - backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) assert backend.name == name @@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch): # Unsupported CUDA arch with patch("torch.cuda.get_device_capability", return_value=(7, 5)): - backend = which_attn_to_use(16, None, torch.float16, None, 16, False) + backend = which_attn_to_use(16, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported data type - backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False) + backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported kv cache data type - backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False) + backend = which_attn_to_use(16, torch.float16, "fp8", 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported block size - backend = which_attn_to_use(16, None, torch.float16, None, 8, False) - assert backend.name != STR_FLASH_ATTN_VAL - - # Unsupported sliding window - backend = which_attn_to_use(16, 1, torch.float16, None, 16, False) + backend = which_attn_to_use(16, torch.float16, None, 8, False) assert backend.name != STR_FLASH_ATTN_VAL # flash-attn is not installed with patch.dict('sys.modules', {'vllm_flash_attn': None}): - backend = which_attn_to_use(16, None, torch.float16, None, 16, False) + backend = which_attn_to_use(16, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Unsupported head size - backend = which_attn_to_use(17, None, torch.float16, None, 16, False) + backend = which_attn_to_use(17, torch.float16, None, 16, False) assert backend.name != STR_FLASH_ATTN_VAL # Attention-free models should bypass env and use PlaceholderAttention - backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16, - True) + backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True) assert backend.name != STR_FLASH_ATTN_VAL @@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" override_backend_env_variable(monkeypatch, STR_INVALID_VAL) with pytest.raises(ValueError): - which_attn_to_use(16, None, torch.float16, None, 16, False) + which_attn_to_use(16, torch.float16, None, 16, False) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 069020a536d0..277d7e4977d7 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -6,6 +6,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.utils import seed_everything @@ -114,16 +115,15 @@ def causal_conv1d_update_ref(x, @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) -def causal_conv1d_opcheck_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - cu_seq_len: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): +def causal_conv1d_opcheck_fn(x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID): """ x: (batch, dim, seqlen) weight: (dim, width) @@ -141,16 +141,9 @@ def causal_conv1d_opcheck_fn( x = x.contiguous() bias = bias.contiguous() if bias is not None else None - opcheck(torch.ops._C.causal_conv1d_fwd, ( - x, - weight, - bias, - conv_states, - cu_seq_len, - cache_indices, - has_initial_state, - activation in ["silu", "swish"], - )) + opcheck(torch.ops._C.causal_conv1d_fwd, + (x, weight, bias, conv_states, cu_seq_len, cache_indices, + has_initial_state, activation in ["silu", "swish"], pad_slot_id)) @pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) @@ -233,17 +226,11 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, seed_everything(0) batch = 2 x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) + x_ref = x.clone() conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state.detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -251,7 +238,7 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, weight, bias, activation=activation) - out_ref = causal_conv1d_update_ref(x, + out_ref = causal_conv1d_update_ref(x_ref, conv_state_ref, weight, bias, @@ -260,15 +247,9 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - None, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, None, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", @@ -278,37 +259,48 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, @pytest.mark.parametrize("seqlen", [1, 4, 5]) @pytest.mark.parametrize("width", [2, 3, 4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +def test_causal_conv1d_update_with_batch_gather(with_padding, dim, width, + seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 - # set )seed + # set seed seed_everything(0) - batch = 64 - x = torch.randn(batch, dim, 1, device=device, dtype=itype) + batch_size = 3 + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + total_entries = 10 * batch_size - total_entries = 10 * batch + x = torch.randn(padded_batch_size, dim, 1, device=device, dtype=itype) + x_ref = x.clone() + + conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[conv_state_indices] = False + padded_state_indices = torch.concat([ + conv_state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) conv_state = torch.randn(total_entries, dim, width - 1, device=device, dtype=itype) - conv_state_indices = torch.randperm(total_entries)[:batch].to( - dtype=torch.int32, device=device) + conv_state_for_padding_test = conv_state.clone() - weight = torch.randn(dim, - width, - device=device, - dtype=itype, - requires_grad=True) - if has_bias: - bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) - else: - bias = None + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None conv_state_ref = conv_state[conv_state_indices, :].detach().clone() activation = None if not silu_activation else "silu" out = causal_conv1d_update(x, @@ -316,45 +308,50 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, weight, bias, activation=activation, - conv_state_indices=conv_state_indices) - out_ref = causal_conv1d_update_ref(x, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) + out_ref = causal_conv1d_update_ref(x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation) assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) + assert torch.equal(conv_state[unused_states_bool], + conv_state_for_padding_test[unused_states_bool]) - opcheck(torch.ops._C.causal_conv1d_update, ( - x, - conv_state, - weight, - bias, - activation in ["silu", "swish"], - None, - conv_state_indices, - )) + opcheck(torch.ops._C.causal_conv1d_update, + (x, conv_state, weight, bias, activation + in ["silu", "swish"], None, padded_state_indices, PAD_SLOT_ID)) @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [True]) @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize('seqlen', - [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) +@pytest.mark.parametrize( + 'seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 2049, 4096]) @pytest.mark.parametrize('dim', [64, 4096]) -def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, - itype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize('with_padding', [True, False]) +def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, + silu_activation, itype): device = "cuda" + torch.cuda.empty_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed seed_everything(0) - batch = 1 seqlens = [] - nsplits = 3 + batch_size = 4 + if seqlen < 10: + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + nsplits = padded_batch_size - 1 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( @@ -364,10 +361,11 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) - x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, + x = torch.randn(1, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :] weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None @@ -375,7 +373,7 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, weight_ref = weight.clone() bias_ref = bias.clone() if bias is not None else None activation = None if not silu_activation else "silu" - final_states = torch.randn(nsplits + 1, + final_states = torch.randn(total_entries, dim, width - 1, device=x.device, @@ -385,18 +383,27 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=x.device) - cache_indices = torch.randperm(cumsum.shape[0] - 1, + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=x.device) + device=x.device)[:batch_size] + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) + out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + padded_state_indices, has_initial_states, + final_states, activation, PAD_SLOT_ID) out_ref = [] out_ref_b = [] splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] for i in range(len(seqlens[0])): x_s = [v[i].unsqueeze(0) for v in splits][0] + if padded_state_indices[i] == PAD_SLOT_ID: + continue out_ref_b.append( causal_conv1d_ref( x_s, @@ -404,21 +411,17 @@ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, bias_ref, activation=activation, return_final_states=True, - final_states_out=final_states_ref[cache_indices[i]].unsqueeze( - 0), - initial_states=final_states_ref[cache_indices[i]].unsqueeze(0) - if has_initial_states[i] else None)) + final_states_out=final_states_ref[ + padded_state_indices[i]].unsqueeze(0), + initial_states=final_states_ref[padded_state_indices[i]]. + unsqueeze(0) if has_initial_states[i] else None)) out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) - out_ref = torch.cat(out_ref, dim=0) - - print(f"Output max diff: {(out - out_ref).abs().max().item()}") - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") - print("Output state max diff" - f":{(final_states - final_states_ref).abs().max()}") - print("Output state mean diff" - f":{(final_states - final_states_ref).abs().mean()}") - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + out_ref_tensor = torch.cat(out_ref, dim=0) + + unpadded_out = out[:, :out_ref_tensor.shape[-1]] + assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), - cache_indices, has_initial_states, final_states, - activation) + padded_state_indices, has_initial_states, + final_states, activation) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 3e9b4d9a4f8a..35c29c5bd102 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -78,6 +78,7 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("sliding_window", [None, 256]) @torch.inference_mode() def test_flash_attn_with_paged_kv( kv_lens: List[int], @@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv( block_size: int, soft_cap: Optional[float], num_blocks: int, + sliding_window: Optional[int], ) -> None: torch.set_default_device("cuda") seed_everything(0) @@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_kv_len = max(kv_lens) scale = head_size**-0.5 + window_size = ((sliding_window - 1, 0) if sliding_window is not None else + (-1, -1)) query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) key_cache = torch.randn(num_blocks, @@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv( block_table=block_tables, cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, + window_size=window_size, ).squeeze(1) - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap, - ) + ref_output = ref_paged_attn(query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap, + sliding_window=sliding_window) torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None]) +@pytest.mark.parametrize("sliding_window", [None, 256]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @@ -166,8 +170,7 @@ def test_varlen_with_paged_kv( assert num_query_heads % num_kv_heads == 0 max_query_len = max(query_lens) max_kv_len = max(kv_lens) - window_size = ((sliding_window, - sliding_window) if sliding_window is not None else + window_size = ((sliding_window - 1, 0) if sliding_window is not None else (-1, -1)) scale = head_size**-0.5 diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index 0fc2984a68de..59c0a24753c3 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor, w_q = w_q.t().contiguous().t() # convert to col major w_q_machete = ops.machete_prepack_B(w_q, wtype) - opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype)) + opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id)) return w_ref, w_q_machete, w_s, w_zp @@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype, schedule=schedule, ) - opcheck(torch.ops._C.machete_gemm, - (a, w_q_machete, wtype, w_s, maybe_convert_zeropoints( - w_zp, w_s), group_size, None, None, None, schedule)) + opcheck( + torch.ops._C.machete_gemm, + (a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints( + w_zp, w_s), group_size, None, None, None, schedule)) # Relax atol as our reduction dim becomes larger (more rounding error) # Relax atol when we have zeropoints since the way machete applies diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 8fa55e75f6c1..e92d401368a7 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -5,6 +5,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.utils import seed_everything @@ -174,7 +175,8 @@ def selective_scan_opcheck_fn(u, cu_seq_len=None, cache_indices=None, has_initial_state=None, - ssm_states=None): + ssm_states=None, + pad_slot_id=PAD_SLOT_ID): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -203,7 +205,7 @@ def selective_scan_opcheck_fn(u, # a bogus error. opcheck(torch.ops._C.selective_scan_fwd, (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, - cache_indices, has_initial_state, ssm_states), + cache_indices, has_initial_state, ssm_states, pad_slot_id), test_utils=["test_schema", "test_faketensor"]) @@ -404,9 +406,12 @@ def test_selective_state_update(dim, dstate, has_z, itype): @pytest.mark.parametrize("varBC_groups", [1, 2]) @pytest.mark.parametrize("is_variable_C", [True]) @pytest.mark.parametrize("is_variable_B", [True]) -def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, - has_D, has_z, has_delta_bias, delta_softplus, - return_last_state, seqlen, itype, wtype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [False, True]) +def test_selective_scan_varlen(with_padding, is_variable_B, is_variable_C, + varBC_groups, has_D, has_z, has_delta_bias, + delta_softplus, return_last_state, seqlen, + itype, wtype): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -420,18 +425,27 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, # set seed torch.random.manual_seed(0) seqlens = [] - nsplits = 3 + batch_size = 4 if seqlen < 10: - nsplits = 0 + batch_size = 1 + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + + if with_padding and seqlen < padded_batch_size: + pytest.skip() + + nsplits = padded_batch_size - 1 eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values seqlens.append( torch.diff( torch.cat( [torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist()) + assert sum(seqlens[-1]) == seqlen assert all(s > 0 for s in seqlens[-1]) + total_entries = batch_size * 10 cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0).cuda() @@ -462,22 +476,33 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, delta_ref = delta.clone() out = None out_ref = None - prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1])) + + prev_state_shape = (total_entries, u.shape[0], int(A.shape[1])) prev_state = torch.randn(prev_state_shape, device=u.device, dtype=itype, requires_grad=False) prev_state_ref = prev_state.clone() - cache_indices = torch.randperm(cumsum.shape[0] - 1, + state_indices = torch.randperm(total_entries, dtype=torch.int32, - device=u.device) + device=u.device)[:batch_size] + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1) has_initial_state = torch.randint(0, 2, (cumsum.shape[0] - 1, ), dtype=torch.bool, device=u.device) out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state) outs_ref = [] splits = [ @@ -486,6 +511,8 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, ] for i in range(len(seqlens[0])): u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] + if padded_state_indices[i] == PAD_SLOT_ID: + continue out_ref_s, _ = selective_scan_ref( u_s, delta_s, @@ -497,21 +524,22 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, delta_bias=delta_bias, delta_softplus=delta_softplus, return_last_state=return_last_state, - prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0) + prev_state=prev_state_ref[padded_state_indices[i]].unsqueeze(0) if has_initial_state[i] else None, - final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0)) + final_state_out=prev_state_ref[padded_state_indices[i]].unsqueeze( + 0)) outs_ref.append(out_ref_s) - out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] + out_ref = torch.cat(outs_ref, dim=-1)[0] - print("Output diff max", (out - out_ref[0]).max()) - print("Output diff mean", (out - out_ref[0]).mean()) + unpadded_out = out[:, :out_ref[0].shape[-1]] + print("Output diff max", (unpadded_out - out_ref).max()) + print("Output diff mean", (unpadded_out - out_ref).mean()) print("Output state diff max", (prev_state - prev_state_ref).max()) print("Output state diff mean", (prev_state - prev_state_ref).mean()) assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) - + assert torch.allclose(unpadded_out, out_ref, rtol=rtol, atol=atol) selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, - delta_softplus, cumsum, cache_indices, + delta_softplus, cumsum, padded_state_indices, has_initial_state, prev_state) @@ -520,7 +548,10 @@ def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, @pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, + has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: @@ -530,21 +561,32 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): # set seed torch.random.manual_seed(0) batch_size = 3 - + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) state_indices = torch.randperm(total_entries)[:batch_size].to( dtype=torch.int32, device=device) - - x = torch.randn(batch_size, dim, device=device, dtype=itype) - dt = torch.randn(batch_size, dim, device=device, dtype=itype) + unused_states_bool = torch.ones(total_entries, + dtype=torch.bool, + device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat([ + state_indices, + torch.as_tensor( + [PAD_SLOT_ID] * padding, dtype=torch.int32, device=device) + ], + dim=0) + x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) + dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) dt_bias = torch.rand(dim, device=device) - 4.0 A = -torch.rand(dim, dstate, device=device) - 1.0 - B = torch.randn(batch_size, dstate, device=device) - C = torch.randn(batch_size, dstate, device=device) + B = torch.randn(padded_batch_size, dstate, device=device) + C = torch.randn(padded_batch_size, dstate, device=device) D = torch.randn(dim, device=device) z = torch.randn_like(x) if has_z else None - state_ref = state[state_indices, :].detach().clone() + state_ref = state[state_indices, :].clone() + state_before = state.clone() out = selective_state_update(state, x, dt, @@ -555,15 +597,16 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, - x, - dt, + x[:batch_size], + dt[:batch_size], A, - B, - C, + B[:batch_size], + C[:batch_size], D=D, - z=z, + z=z[:batch_size], dt_bias=dt_bias, dt_softplus=True) @@ -572,11 +615,21 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): print("Output state diff max", (state[state_indices, :] - state_ref).max()) print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) + # test padded entries stay the same + if with_padding: + assert torch.equal(state_before[unused_states_bool], + state[unused_states_bool]) + assert torch.equal(x[batch_size + 1:], x[batch_size + 1:]) + assert torch.equal(dt[batch_size + 1:], dt[batch_size + 1:]) + assert torch.equal(B[batch_size + 1:], B[batch_size + 1:]) + assert torch.equal(C[batch_size + 1:], C[batch_size + 1:]) + + # test "real" entries assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) - assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) @pytest.mark.parametrize("itype", @@ -645,7 +698,8 @@ def test_selective_state_update_with_heads_with_batch_indices( z=z, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices) + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID) out_ref = selective_state_update_ref(state_ref, x, dt, diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index a9bb72156c39..5cfd4d6da7a8 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -225,7 +225,7 @@ def test_gptq_marlin_gemm( opcheck( torch.ops._C.gptq_marlin_gemm, (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1], + workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full, False, use_fp32_reduce), test_utils=DEFAULT_OPCHECK_TEST_UTILS) @@ -254,6 +254,16 @@ def test_gptq_marlin_gemm( assert max_diff < 0.04 +# TODO: find better way to test this? +@torch.compile(fullgraph=True) +def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s, scratch, quant_type, size_m, size_n, + size_k): + return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta, + marlin_24_s, scratch, quant_type, size_m, + size_n, size_k) + + @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS) @@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, opcheck(torch.ops._C.gptq_marlin_24_gemm, (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, - workspace_24.scratch, quant_type, a_input.shape[0], + workspace_24.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1], a_input.shape[1]), test_utils=DEFAULT_OPCHECK_TEST_UTILS) - output = ops.gptq_marlin_24_gemm( + output = marlin_24_gemm_tester( a_input, marlin_24_q_w_comp, marlin_24_meta, diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index b73c45b9cd19..b87fbc3f1937 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -240,8 +240,8 @@ def test_fused_marlin_moe( requires_grad=False) opcheck(torch.ops._moe_C.marlin_gemm_moe, (a, qweight1, sorted_token_ids, topk_weights, topk_ids, - scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m, - 2 * n, k, True, e, topk, block_size_m, True, False)) + scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id, + m, 2 * n, k, True, e, topk, block_size_m, True, False)) @pytest.mark.skip("This test is here for the sake of debugging, " diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index ba9d2d4389b2..94da00915d40 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -105,7 +105,7 @@ def test_batched_rotary_embedding( if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": (1, ) }) rope = rope.to(dtype=dtype) @@ -166,7 +166,7 @@ def test_batched_rotary_embedding_multi_lora( rotary_dim = head_size scaling_factors: List[int] = [1, 2, 4] rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": tuple(scaling_factors) }) rope = rope.to(dtype=dtype) @@ -211,10 +211,10 @@ def test_rope_module_cache(): MAX_POSITIONS = [123, 1234] BASES = [10000, 1000000] ROPE_SCALINGS = (None, { - "type": "linear", + "rope_type": "linear", "factor": (1, ) }, { - "type": "dynamic", + "rope_type": "dynamic", "factor": 1 }) settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE, diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 405c0d0efad6..8d2d85984e7f 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -1,20 +1,16 @@ -import contextlib -import gc import tempfile from collections import OrderedDict from typing import Dict, List, TypedDict from unittest.mock import MagicMock, patch import pytest -import ray import torch import torch.nn as nn from huggingface_hub import snapshot_download import vllm from vllm.config import LoRAConfig -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel, +from vllm.distributed import (cleanup_dist_env_and_memory, init_distributed_environment, initialize_model_parallel) from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -48,16 +44,6 @@ class ContextInfo(TypedDict): }] -def cleanup(): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - gc.collect() - torch.cuda.empty_cache() - ray.shutdown() - - @pytest.fixture() def should_do_global_cleanup_after_test(request) -> bool: """Allow subdirectories to skip global cleanup by overriding this fixture. @@ -72,7 +58,7 @@ def should_do_global_cleanup_after_test(request) -> bool: def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield if should_do_global_cleanup_after_test: - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) @pytest.fixture @@ -87,7 +73,7 @@ def dist_init(): ) initialize_model_parallel(1, 1) yield - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) @pytest.fixture @@ -166,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id): return snapshot_download(repo_id=sql_lora_huggingface_id) +@pytest.fixture(scope="session") +def lora_bias_files(): + return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias") + + @pytest.fixture(scope="session") def mixtral_lora_files(): # Note: this module has incorrect adapter_config.json to test @@ -238,7 +229,7 @@ def long_context_lora_files_32k(): def long_context_infos(long_context_lora_files_16k_1, long_context_lora_files_16k_2, long_context_lora_files_32k): - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) infos: Dict[int, ContextInfo] = {} for lora_checkpoint_info in LONG_LORA_INFOS: lora_id = lora_checkpoint_info["lora_id"] @@ -259,7 +250,7 @@ def long_context_infos(long_context_lora_files_16k_1, @pytest.fixture def llama_2_7b_engine_extra_embeddings(): - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) get_model_old = get_model def get_model_patched(*, model_config, device_config, **kwargs): @@ -272,7 +263,7 @@ def get_model_patched(*, model_config, device_config, **kwargs): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) yield engine.llm_engine del engine - cleanup() + cleanup_dist_env_and_memory(shutdown_ray=True) @pytest.fixture diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index cbc366899781..0ba2ce3617b6 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -3,10 +3,9 @@ import pytest import vllm +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -from .conftest import cleanup - MODEL_PATH = "baichuan-inc/Baichuan-7B" PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 @@ -80,7 +79,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files, output_tp1 = do_sample(llm_tp1, baichuan_lora_files, lora_id=1) del llm_tp1 - cleanup() + cleanup_dist_env_and_memory() llm_tp2 = vllm.LLM(MODEL_PATH, enable_lora=True, @@ -93,7 +92,7 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files, output_tp2 = do_sample(llm_tp2, baichuan_lora_files, lora_id=2) del llm_tp2 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp2 @@ -108,6 +107,6 @@ def test_baichuan_tensor_parallel_equality(baichuan_lora_files, output_tp4 = do_sample(llm_tp4, baichuan_lora_files, lora_id=2) del llm_tp4 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp4 diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index e3233c6b6069..db877219a285 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -951,7 +951,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, lora_rope.create_lora_weights(max_loras, lora_config) linear_rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": scaling_factors }) linear_rope = linear_rope.to(dtype=dtype) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index ad8490353998..e2a4f1ed0496 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -4,10 +4,9 @@ import ray import vllm +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest -from .conftest import cleanup - MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -93,7 +92,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) del llm_tp1 - cleanup() + cleanup_dist_env_and_memory() llm_tp2 = vllm.LLM(MODEL_PATH, enable_lora=True, @@ -103,7 +102,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) del llm_tp2 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp2 @@ -115,7 +114,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) del llm_tp4 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp4 diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 389a3ccbc17e..c8edb02a88d4 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -28,9 +28,15 @@ def _create_lora_request(lora_id, long_context_infos): context_len = long_context_infos[lora_id]["context_length"] scaling_factor = context_len_to_scaling_factor[context_len] - return LoRARequest(context_len, lora_id, - long_context_infos[lora_id]["lora"], None, - 4096 * scaling_factor) + return LoRARequest( + # There are 2 LoRAs for 16K, we need to add lora_id to indicate + # they are different LoRAs. + context_len + str(lora_id), + lora_id, + long_context_infos[lora_id]["lora"], + None, + 4096 * scaling_factor, + ) def evaluate_json_response(model_response, golden_response): @@ -108,14 +114,17 @@ def lora_llm(long_context_infos): for info in long_context_infos.values() ] - llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf", - enable_lora=True, - max_num_seqs=16, - max_loras=2, - long_lora_scaling_factors=tuple(scaling_factors), - max_num_batched_tokens=4096 * 8, - tensor_parallel_size=4, - distributed_executor_backend="mp") + llm = vllm.LLM( + "meta-llama/Llama-2-13b-chat-hf", + enable_lora=True, + max_num_seqs=16, + max_loras=2, + long_lora_scaling_factors=tuple(scaling_factors), + max_num_batched_tokens=4096 * 8, + tensor_parallel_size=4, + # FIXME enable async output processor + disable_async_output_proc=True, + distributed_executor_backend="mp") yield llm del llm diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py new file mode 100644 index 000000000000..63ab9ee9da3e --- /dev/null +++ b/tests/lora/test_lora_bias_e2e.py @@ -0,0 +1,52 @@ +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "ibm-granite/granite-3b-code-base" + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + generated_texts: List[str] = [] + for output in outputs: + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + return generated_texts + + +@pytest.mark.parametrize("lora_bias", [True, False]) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool): + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_lora_rank=8, + max_loras=1, + enable_lora_bias=lora_bias, + tensor_parallel_size=1, + fully_sharded_loras=fully_sharded) + + print("lora adapter created") + output1 = do_sample(llm, lora_bias_files, lora_id=0) + + print("lora") + output2 = do_sample(llm, lora_bias_files, lora_id=1) + + if lora_bias: + assert output1 != output2 + else: + assert output1 == output2 diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 81b8188e638c..be040060d02b 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -61,6 +61,7 @@ def test_minicpmv_lora(minicpmv_lora_files): max_loras=4, max_lora_rank=64, trust_remote_code=True, + gpu_memory_utilization=0.97 # This model is pretty big for CI gpus ) output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 5636c9643502..d004c6592941 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -6,11 +6,10 @@ import pytest import vllm +from vllm.distributed import cleanup_dist_env_and_memory from vllm.lora.request import LoRARequest from vllm.utils import is_hip -from .conftest import cleanup - @dataclass class ModelWithQuantization: @@ -160,7 +159,7 @@ def expect_match(output, expected_output): print("removing lora") del llm - cleanup() + cleanup_dist_env_and_memory() @pytest.mark.parametrize("model", MODELS) @@ -181,7 +180,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1) del llm_tp1 - cleanup() + cleanup_dist_env_and_memory() llm_tp2 = vllm.LLM( model=model.model_path, @@ -194,6 +193,6 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1) del llm_tp2 - cleanup() + cleanup_dist_env_and_memory() assert output_tp1 == output_tp2 diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 732e91a52c0a..2f7ac8550742 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -15,7 +15,8 @@ def test_worker_apply_lora(sql_lora_files): worker = Worker( model_config=ModelConfig( "meta-llama/Llama-2-7b-hf", - "meta-llama/Llama-2-7b-hf", + task="auto", + tokenizer="meta-llama/Llama-2-7b-hf", tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -27,7 +28,7 @@ def test_worker_apply_lora(sql_lora_files): load_format="dummy", ), parallel_config=ParallelConfig(1, 1, False), - scheduler_config=SchedulerConfig(32, 32, 32), + scheduler_config=SchedulerConfig("generate", 32, 32, 32), device_config=DeviceConfig("cuda"), cache_config=CacheConfig(block_size=16, gpu_memory_utilization=1., diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index f1003221ab51..7a361ef32081 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -6,13 +6,12 @@ from prometheus_client import REGISTRY from vllm import EngineArgs, LLMEngine +from vllm.distributed import cleanup_dist_env_and_memory from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.metrics import RayPrometheusStatLogger from vllm.sampling_params import SamplingParams -from ..conftest import cleanup - MODELS = [ "facebook/opt-125m", ] @@ -85,6 +84,45 @@ def test_metric_counter_generation_tokens( f"metric: {metric_count!r}") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [128, 129]) +@pytest.mark.parametrize("disable_async_output_proc", [True, False]) +def test_metric_counter_generation_tokens_multi_step( + vllm_runner, + example_prompts, + model: str, + max_tokens: int, + disable_async_output_proc: bool, +) -> None: + num_scheduler_steps = 8 + with vllm_runner( + model, + disable_log_stats=False, + gpu_memory_utilization=0.4, + num_scheduler_steps=num_scheduler_steps, + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) + tokenizer = vllm_model.model.get_tokenizer() + stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] + metric_count = stat_logger.metrics.counter_generation_tokens.labels( + **stat_logger.labels)._value.get() + vllm_generation_count = 0 + for i in range(len(example_prompts)): + vllm_output_ids, vllm_output_str = vllm_outputs[i] + prompt_ids = tokenizer.encode(example_prompts[i]) + # vllm_output_ids contains both prompt tokens and generation tokens. + # We're interested only in the count of the generation tokens. + vllm_generation_count += len(vllm_output_ids) - len(prompt_ids) + + # The multi-step scheduling will continue to execute forward even when + # encountering EOS, leading to slightly imprecise metrics. + assert abs(vllm_generation_count - metric_count) <\ + len(example_prompts) * num_scheduler_steps, \ + (f"generation token count: {vllm_generation_count!r}\n" + f"metric: {metric_count!r}") + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize( @@ -185,13 +223,14 @@ def test_metric_spec_decode( ) -> None: k = 5 - with vllm_runner(model, - dtype=dtype, - disable_log_stats=False, - gpu_memory_utilization=0.4, - speculative_model=model, - num_speculative_tokens=k, - use_v2_block_manager=True) as vllm_model: + with vllm_runner( + model, + dtype=dtype, + disable_log_stats=False, + gpu_memory_utilization=0.4, + speculative_model=model, + num_speculative_tokens=k, + ) as vllm_model: # Force log interval to be 0 to catch all metrics. stat_logger = vllm_model.model.llm_engine.stat_loggers['prometheus'] @@ -242,7 +281,6 @@ def test_metric_spec_decode_interval( gpu_memory_utilization=0.4, speculative_model=model, num_speculative_tokens=k, - use_v2_block_manager=True, enforce_eager=True) engine = LLMEngine.from_engine_args(engine_args) @@ -307,7 +345,7 @@ def test_metric_spec_decode_interval( finally: del engine - cleanup() + cleanup_dist_env_and_memory() def assert_metrics(engine: LLMEngine, disable_log_stats: bool, diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py new file mode 100644 index 000000000000..af267f804ffa --- /dev/null +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -0,0 +1,92 @@ +import os +from typing import List + +import pytest + +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.activation import (GeluAndMul, + ReLUSquaredActivation, + SiluAndMul) +from vllm.model_executor.layers.layernorm import RMSNorm + + +# Registered subclass for test +@CustomOp.register("relu3") +class Relu3(ReLUSquaredActivation): + pass + + +@pytest.mark.parametrize( + "env, torch_level, ops_enabled, default_on", + [ + # Default values based on compile level + ("", 0, [True] * 4, True), + ("", 1, [True] * 4, True), + ("", 2, [True] * 4, True), # All by default + ("", 3, [False] * 4, False), + ("", 4, [False] * 4, False), # None by default + # Explicitly enabling/disabling + # + # Default: all + # + # All but SiluAndMul + ("+rms_norm,-silu_and_mul", 0, [1, 0, 1, 1], True), + # Only ReLU3 + ("none,-rms_norm,+relu3", 0, [0, 0, 0, 1], False), + # All but SiluAndMul + ("all,-silu_and_mul", 1, [1, 0, 1, 1], True), + # All but ReLU3 (even if ReLU2 is on) + ("-relu3,relu2", 1, [1, 1, 1, 0], True), + # GeluAndMul and SiluAndMul + ("none,-relu3,+gelu_and_mul,+silu_and_mul", 2, [0, 1, 1, 0], False), + # All but RMSNorm + ("-rms_norm", 2, [0, 1, 1, 1], True), + # + # Default: none + # + # Only ReLU3 + ("-silu_and_mul,+relu3", 3, [0, 0, 0, 1], False), + # All but RMSNorm + ("all,-rms_norm", 4, [0, 1, 1, 1], True), + ]) +def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int], + default_on: bool): + os.environ["VLLM_CUSTOM_OPS"] = env + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level) + + # Reset default_on (computed once): + CustomOp.default_on.cache_clear() + + assert CustomOp.default_on() == default_on + + ops_enabled = [bool(x) for x in ops_enabled] + + assert RMSNorm(1024).enabled() == ops_enabled[0] + assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] + + assert SiluAndMul().enabled() == ops_enabled[1] + assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] + + assert GeluAndMul().enabled() == ops_enabled[2] + assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] + + # If registered, subclasses should follow their own name + assert Relu3().enabled() == ops_enabled[3] + assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] + + # Unregistered subclass + class SiluAndMul2(SiluAndMul): + pass + + # Subclasses should not require registration + assert SiluAndMul2().enabled() == SiluAndMul().enabled() + + +@pytest.mark.parametrize( + "env", ["all,none", "all,+rms_norm,all", "+rms_norm,-rms_norm"]) +def test_enabled_ops_invalid(env: str): + os.environ["VLLM_CUSTOM_OPS"] = env + CustomOp.default_on.cache_clear() + + with pytest.raises(AssertionError): + RMSNorm(1024).enabled() diff --git a/tests/models/decoder_only/language/test_big_models.py b/tests/models/decoder_only/language/test_big_models.py index fcc158639748..75625b35209c 100644 --- a/tests/models/decoder_only/language/test_big_models.py +++ b/tests/models/decoder_only/language/test_big_models.py @@ -21,10 +21,14 @@ ] if not current_platform.is_cpu(): - # MiniCPM requires fused_moe which is not supported by CPU - MODELS.append("openbmb/MiniCPM3-4B") + MODELS += [ + # fused_moe which not supported on CPU + "openbmb/MiniCPM3-4B", + # Head size isn't supported on CPU + "h2oai/h2o-danube3-4b-base", + ] -#TODO: remove this after CPU float16 support ready +# TODO: remove this after CPU float16 support ready target_dtype = "float" if current_platform.is_cpu() else "half" diff --git a/tests/models/decoder_only/language/test_danube3_4b.py b/tests/models/decoder_only/language/test_danube3_4b.py deleted file mode 100644 index bdd498edc293..000000000000 --- a/tests/models/decoder_only/language/test_danube3_4b.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Compare the outputs of HF and vLLM when using greedy sampling. - -This tests danube3 separately because its head size isn't supported on CPU yet. - -Run `pytest tests/models/test_danube3_4b.py`. -""" -import pytest - -from ...utils import check_outputs_equal - -MODELS = ["h2oai/h2o-danube3-4b-base"] - -target_dtype = "half" - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [32]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, -) -> None: - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", [target_dtype]) -def test_model_print( - vllm_runner, - model: str, - dtype: str, -) -> None: - with vllm_runner(model, dtype=dtype) as vllm_model: - # This test is for verifying whether the model's extra_repr - # can be printed correctly. - print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 408d12cd5ff5..384ec77e5455 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,5 +1,6 @@ import pytest +from tests.utils import multi_gpu_test from vllm.sampling_params import SamplingParams from vllm.worker.model_runner import _get_graph_batch_size @@ -270,6 +271,30 @@ def test_state_cleanup( "could be related to finished_requests_ids") +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +def test_jamba_distributed_produces_identical_generation( + vllm_runner, model: str, dtype: str, max_tokens: int, + example_prompts) -> None: + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=2) as vllm_model: + vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + with vllm_runner(model, dtype=dtype, tensor_parallel_size=1) as vllm_model: + vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, + max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_outputs_tp_1, + outputs_1_lst=vllm_outputs_tp_2, + name_0="vllm_tp_1", + name_1="vllm_tp_2", + ) + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) def test_model_print( diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index c27bf6a60a4f..2dc231c595ff 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -10,7 +10,7 @@ from ...utils import check_outputs_equal -MODELS = ["state-spaces/mamba-130m-hf"] +MODELS = ["state-spaces/mamba-130m-hf", "tiiuae/falcon-mamba-tiny-dev"] # Use lower-level interfaces to create this greedy generator, as mamba will diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index 89afbcf1c03a..c997359a2781 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -5,7 +5,7 @@ import pytest import torch -from vllm.utils import is_cpu +from vllm.platforms import current_platform from ....utils import large_gpu_test from ...utils import check_logprobs_close @@ -70,7 +70,7 @@ def test_phimoe_routing_function(): assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) -@pytest.mark.skipif(condition=is_cpu(), +@pytest.mark.skipif(condition=current_platform.is_cpu(), reason="This test takes a lot time to run on CPU, " "and vllm CI's disk space is not enough for this model.") @large_gpu_test(min_gb=80) diff --git a/tests/models/decoder_only/vision_language/test_fuyu.py b/tests/models/decoder_only/vision_language/test_fuyu.py index 7827ecb19a74..1affcd10ee72 100644 --- a/tests/models/decoder_only/vision_language/test_fuyu.py +++ b/tests/models/decoder_only/vision_language/test_fuyu.py @@ -3,8 +3,8 @@ import pytest from vllm.multimodal.utils import rescale_image_size +from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs -from vllm.utils import is_cpu from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from ...utils import check_logprobs_close @@ -46,7 +46,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -103,7 +103,7 @@ def run_test( target_dtype = "half" -if is_cpu(): +if current_platform.is_cpu(): target_dtype = "bfloat16" diff --git a/tests/models/decoder_only/vision_language/test_intern_vit.py b/tests/models/decoder_only/vision_language/test_intern_vit.py index 3c3b95b38baa..98f313eb9b9a 100644 --- a/tests/models/decoder_only/vision_language/test_intern_vit.py +++ b/tests/models/decoder_only/vision_language/test_intern_vit.py @@ -6,7 +6,7 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from ....conftest import _ImageAssets, cleanup +from ....conftest import _ImageAssets # we use snapshot_download to prevent conflicts between # dynamic_module and trust_remote_code for hf_runner @@ -45,12 +45,13 @@ def run_intern_vit_test( for pixel_value in pixel_values ] + from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.models.intern_vit import InternVisionModel vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) del hf_model - cleanup() + cleanup_dist_env_and_memory() vllm_model = vllm_model.to("cuda", dtype) vllm_outputs_per_image = [ @@ -58,7 +59,7 @@ def run_intern_vit_test( for pixel_value in pixel_values ] del vllm_model - cleanup() + cleanup_dist_env_and_memory() cos_similar = nn.CosineSimilarity(dim=-1) for vllm_output, hf_output in zip(vllm_outputs_per_image, diff --git a/tests/models/decoder_only/vision_language/test_internvl.py b/tests/models/decoder_only/vision_language/test_internvl.py index 49cab75d8ea5..fc842ec4a617 100644 --- a/tests/models/decoder_only/vision_language/test_internvl.py +++ b/tests/models/decoder_only/vision_language/test_internvl.py @@ -7,7 +7,6 @@ from transformers import AutoConfig from vllm.multimodal.utils import rescale_image_size -from vllm.utils import is_cpu from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, _ImageAssets) @@ -19,15 +18,20 @@ "cherry_blossom": "<|im_start|>User\n\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 }) -HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: \nImage-2: \nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 +HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: \nImage-2: \nDescribe the two images in short.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 models = [ "OpenGVLab/InternVL2-1B", "OpenGVLab/InternVL2-2B", + # NOTE: Mono-InternVL-2B doesn't work with fp16, + # it will result NaN during inference. + # See: https://huggingface.co/OpenGVLab/Mono-InternVL-2B/discussions/9 + "OpenGVLab/Mono-InternVL-2B", # Broken due to outdated implementation of Phi-3 # See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3 # "OpenGVLab/InternVL2-4B", ] +target_dtype = "bfloat16" # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py @@ -52,9 +56,15 @@ def generate( input_embeds = input_embeds.reshape(B, N, C) - outputs = self.language_model.generate( + forward_kwargs = dict( inputs_embeds=input_embeds, attention_mask=attention_mask, + ) + if getattr(self, "use_visual_token_mask", False): + visual_token_mask = selected.reshape(B, N, 1).to(input_embeds.dtype) + forward_kwargs["visual_token_mask"] = visual_token_mask + outputs = self.language_model.generate( + **forward_kwargs, **generate_kwargs, ) @@ -78,7 +88,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -243,11 +253,6 @@ def run_awq_test( ) -target_dtype = "half" -if is_cpu(): - target_dtype = "bfloat16" - - @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index 00c1b9975ef3..dfe10629f1c6 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,17 +1,18 @@ import os import re -from typing import Callable, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import pytest import torch from transformers import AutoImageProcessor, AutoTokenizer -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import InputContext, token_inputs from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size +from vllm.platforms import current_platform from vllm.sequence import SampleLogprobs -from vllm.utils import is_cpu, is_hip +from vllm.utils import is_hip from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, _ImageAssets) @@ -49,7 +50,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, target_dtype = "half" -if is_cpu(): +if current_platform.is_cpu(): target_dtype = "bfloat16" # ROCm Triton FA can run into shared memory issues with these models, @@ -89,6 +90,7 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, + task="generate", max_model_len=4096, max_num_seqs=2, dtype=dtype, @@ -311,7 +313,7 @@ def test_input_mapper_override(model: str, image_assets: _ImageAssets, (4, 781), (16, 2653), ]) -def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, +def test_max_tokens_override(get_max_phi3v_image_tokens, model: str, num_crops: int, expected_max_tokens: int): """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" # NOTE: mm_processor_kwargs on the context in this test is unused, since @@ -343,8 +345,8 @@ def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, (16, 2653, 1), (16, 2653, 2), ]) -def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, - num_crops: int, toks_per_img: int, num_imgs: int): +def test_dummy_data_override(dummy_data_for_phi3v, model: str, num_crops: int, + toks_per_img: int, num_imgs: int): """Ensure dummy_data_for_phi3v handles num_crops properly.""" # Same as the previous test - don't initialize mm_processor_kwargs # in this test and assume that the kwargs will be correctly expanded by @@ -374,7 +376,7 @@ def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, (16, 1921, 1), (16, 1921, 2), ]) -def test_input_processor_override(input_processor_for_phi3v: Callable, +def test_input_processor_override(input_processor_for_phi3v, image_assets: _ImageAssets, model: str, num_crops: int, expected_toks_per_img: int, num_imgs: int): @@ -393,16 +395,14 @@ def test_input_processor_override(input_processor_for_phi3v: Callable, prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" images = [image_assets[0].pil_image] * num_imgs - llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), - prompt=prompt, - multi_modal_data={"image": images}) + inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) - proc_llm_inputs = input_processor_for_phi3v( - ctx=ctx, - llm_inputs=llm_inputs, - num_crops=num_crops, - ) + processed_inputs = input_processor_for_phi3v(ctx, + inputs, + num_crops=num_crops) # Ensure we have the right number of placeholders per num_crops size - img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + img_tok_count = processed_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index d2d0c62f5b2c..db5ab485f872 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -5,7 +5,7 @@ import torch from PIL.Image import Image -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import InputContext, token_inputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size @@ -71,12 +71,12 @@ def test_input_processor_valid_mm_data(input_processor_for_qwen, """Happy cases for image inputs to Qwen's multimodal input processor.""" prompt = "".join( [f"Picture {num}: \n" for num in range(1, num_images + 1)]) - inputs = LLMInputs( + inputs = token_inputs( prompt=prompt, # When processing multimodal data for a multimodal model, the qwen # input processor will overwrite the provided prompt_token_ids with # the image prompts - prompt_token_ids=None, + prompt_token_ids=[], multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)}, ) proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs) @@ -134,9 +134,9 @@ def test_input_processor_invalid_mm_data(input_processor_for_qwen, trust_remote_code=True) prompt = "Picture 1: \n" prompt_token_ids = tokenizer.encode(prompt) - inputs = LLMInputs(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=mm_data) + inputs = token_inputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=mm_data) # Should fail since we have too many or too few dimensions for embeddings with pytest.raises(ValueError): input_processor_for_qwen(qwen_vl_context, inputs) diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py new file mode 100644 index 000000000000..d3de5fb26d4b --- /dev/null +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -0,0 +1,160 @@ +from typing import Any, Dict, Tuple + +import pytest +import torch +from PIL.Image import Image +from transformers import AutoTokenizer + +from vllm.inputs import InputContext, token_inputs +from vllm.multimodal import MultiModalRegistry + +from ....conftest import _ImageAssets +from ...utils import build_model_context + +MODEL = "Qwen/Qwen2-VL-2B-Instruct" +MIN_PIXELS = "min_pixels" +MAX_PIXELS = "max_pixels" + + +# Fixtures lazy import to avoid initializing CUDA during test collection +# NOTE: Qwen2vl supports multiple input modalities, so it registers multiple +# input mappers. +@pytest.fixture() +def image_input_mapper_for_qwen2_vl(): + from vllm.model_executor.models.qwen2_vl import ( + image_input_mapper_for_qwen2_vl) + return image_input_mapper_for_qwen2_vl + + +@pytest.fixture() +def input_processor_for_qwen2_vl(): + from vllm.model_executor.models.qwen2_vl import ( + input_processor_for_qwen2_vl) + return input_processor_for_qwen2_vl + + +@pytest.fixture() +def qwen2_vl_context() -> InputContext: + return build_model_context(model_name=MODEL) + + +@pytest.fixture() +def get_max_qwen2_vl_image_tokens(): + from vllm.model_executor.models.qwen2_vl import ( + get_max_qwen2_vl_image_tokens) + return get_max_qwen2_vl_image_tokens + + +@pytest.fixture() +def dummy_data_for_qwen2_vl(): + from vllm.model_executor.models.qwen2_vl import dummy_data_for_qwen2_vl + return dummy_data_for_qwen2_vl + + +@pytest.mark.parametrize("mm_processor_kwargs,expected_max_tokens", [ + ({}, 1225), + ({ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, 324), +]) +def test_qwen2_vl_max_image_tokens(get_max_qwen2_vl_image_tokens, + qwen2_vl_context: InputContext, + mm_processor_kwargs: Dict[str, Any], + expected_max_tokens: int): + """Ensure that the max token calc handles min/max pixels properly.""" + actual_max_tokens = get_max_qwen2_vl_image_tokens(qwen2_vl_context, + **mm_processor_kwargs) + assert actual_max_tokens == expected_max_tokens + + +@pytest.mark.parametrize("mm_processor_kwargs,token_count,img_size", [ + [{}, 1225, (980, 980)], + [{ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, 324, (504, 504)], +]) +def test_qwen2_vl_dummy_data(dummy_data_for_qwen2_vl, + qwen2_vl_context: InputContext, + mm_processor_kwargs: Dict[str, Any], + token_count: int, img_size: Tuple[int, int]): + """Ensure that the dummy data handles min/max pixels properly.""" + seq_len = 3000 + hf_config = qwen2_vl_context.get_hf_config() + image_token_id = hf_config.image_token_id + + # NOTE: video value is required, but isn't actually used + # when making the dummy data except for error handling currently + seq_data, mm_data = dummy_data_for_qwen2_vl(qwen2_vl_context, seq_len, { + "image": 1, + "video": 0 + }, **mm_processor_kwargs) + + # Ensure we have the right number of placeholders for min/max pixel values + assert seq_data.get_token_ids().count(image_token_id) == token_count + + # Ensure the images were resized correctly + image = mm_data["image"] + assert isinstance(image, Image) + assert image.size == img_size + + +@pytest.mark.parametrize("mm_processor_kwargs,num_placeholders", [ + ({}, 1426), + ({ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, 330), +]) +def test_input_processor(input_processor_for_qwen2_vl, + qwen2_vl_context: InputContext, + image_assets: _ImageAssets, num_placeholders: int, + mm_processor_kwargs: Dict[str, Any]): + """Ensure that the image processor handles min/max pixels properly.""" + tokenizer = AutoTokenizer.from_pretrained(MODEL) + prompt = "<|vision_start|><|image_pad|><|vision_end|>" + + image = image_assets[0].pil_image + hf_config = qwen2_vl_context.get_hf_config() + image_token_id = hf_config.image_token_id + + inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": [image]}) + + processed_inputs = input_processor_for_qwen2_vl(qwen2_vl_context, inputs, + **mm_processor_kwargs) + assert processed_inputs["prompt_token_ids"].count( + image_token_id) == num_placeholders + assert len(processed_inputs["multi_modal_data"]["image"]) == 1 + + +@pytest.mark.parametrize("mm_processor_kwargs,pixels_shape", [ + ({}, [5704, 1176]), + ({ + MIN_PIXELS: 64**2, + MAX_PIXELS: 512**2 + }, [1320, 1176]), +]) +def test_image_mapper_override(qwen2_vl_context: InputContext, + image_assets: _ImageAssets, + mm_processor_kwargs: Dict[str, Any], + pixels_shape: Tuple[int, int]): + """Ensure that the image mapper handles min/max pixels properly.""" + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(qwen2_vl_context.model_config) + + image = image_assets[0].pil_image + + mapped_output = mm_registry.map_input( + qwen2_vl_context.model_config, + {"image": image}, + mm_processor_kwargs=mm_processor_kwargs, + ) + + # Dimension 0 of pixel values should match the product of image_grid_thw + actual_pixels_shape = mapped_output["pixel_values"].shape + assert list(actual_pixels_shape) == pixels_shape + assert actual_pixels_shape[0] == torch.prod( + mapped_output["image_grid_thw"]) diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index be316c6e12da..39b6bbaf4318 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -1,34 +1,36 @@ -"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. +"""Compare the embedding outputs of HF and vLLM models. Run `pytest tests/models/embedding/language/test_embedding.py`. """ import pytest -import torch -import torch.nn.functional as F +from ..utils import check_embeddings_close + +# Model, Guard MODELS = [ "intfloat/e5-mistral-7b-instruct", + "BAAI/bge-base-en-v1.5", "BAAI/bge-multilingual-gemma2", ] - -def compare_embeddings(embeddings1, embeddings2): - similarities = [ - F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) - for e1, e2 in zip(embeddings1, embeddings2) - ] - return similarities +ENCODER_ONLY = [ + "BAAI/bge-base-en-v1.5", +] @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models( + monkeypatch, hf_runner, vllm_runner, example_prompts, - model: str, + model, dtype: str, ) -> None: + if model in ENCODER_ONLY: + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -37,15 +39,17 @@ def test_models( # So we need to strip the input texts to avoid test failing. example_prompts = [str(s).strip() for s in example_prompts] - with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) - similarities = compare_embeddings(hf_outputs, vllm_outputs) - all_similarities = torch.stack(similarities) - tolerance = 1e-2 - assert torch.all((all_similarities <= 1.0 + tolerance) - & (all_similarities >= 1.0 - tolerance) - ), f"Not all values are within {tolerance} of 1.0" + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py new file mode 100644 index 000000000000..fd1c44d9c117 --- /dev/null +++ b/tests/models/embedding/utils.py @@ -0,0 +1,30 @@ +from typing import List, Sequence + +import torch +import torch.nn.functional as F + + +def check_embeddings_close( + *, + embeddings_0_lst: Sequence[List[float]], + embeddings_1_lst: Sequence[List[float]], + name_0: str, + name_1: str, + tol: float = 1e-3, +) -> None: + assert len(embeddings_0_lst) == len(embeddings_1_lst) + + for prompt_idx, (embeddings_0, embeddings_1) in enumerate( + zip(embeddings_0_lst, embeddings_1_lst)): + assert len(embeddings_0) == len(embeddings_1), ( + f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}") + + sim = F.cosine_similarity(torch.tensor(embeddings_0), + torch.tensor(embeddings_1), + dim=0) + + fail_msg = (f"Test{prompt_idx}:" + f"\n{name_0}:\t{embeddings_0!r}" + f"\n{name_1}:\t{embeddings_1!r}") + + assert sim >= 1 - tol, fail_msg diff --git a/tests/models/embedding/vision_language/__init__.py b/tests/models/embedding/vision_language/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py new file mode 100644 index 000000000000..52aef8c34d6f --- /dev/null +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -0,0 +1,135 @@ +from typing import List, Type + +import pytest +import torch.nn.functional as F +from transformers import AutoModelForVision2Seq + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ....utils import large_gpu_test +from ..utils import check_embeddings_close + +llama3_template = '<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n \n' # noqa: E501 + +HF_TEXT_PROMPTS = [ + # T -> X + llama3_template.format( + "The label of the object is stop sign\nSummary above sentence in one word: " # noqa: E501 + ), + # T -> X + llama3_template.format( + "cherry blossom\nSummary above sentence in one word: "), +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + # I -> X + "stop_sign": + llama3_template.format("\nSummary above image in one word: "), + # I -> X + "cherry_blossom": + llama3_template.format("\nSummary above image in one word: "), +}) + +MODELS = ["royokong/e5-v"] + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + input_texts: List[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, + task="embedding", + dtype=dtype, + max_model_len=4096, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.encode(input_texts, images=input_images) + + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: + # Patch the issue where image_token_id + # exceeds the maximum allowed vocab size + hf_model.model.resize_token_embeddings( + hf_model.model.language_model.vocab_size + 1) + + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + # Based on: https://huggingface.co/royokong/e5-v + outputs = hf_model.model( + **hf_model.wrap_device(inputs, + device=hf_model.model.device.type), + return_dict=True, + output_hidden_states=True, + ) + pooled_output = F.normalize(outputs.hidden_states[-1][0, -1, :], + dim=-1) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) + for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) diff --git a/tests/models/embedding/vision_language/test_phi3v.py b/tests/models/embedding/vision_language/test_phi3v.py new file mode 100644 index 000000000000..ee411472ba28 --- /dev/null +++ b/tests/models/embedding/vision_language/test_phi3v.py @@ -0,0 +1,124 @@ +from typing import List, Type + +import pytest +import torch.nn.functional as F + +from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner +from ....utils import large_gpu_test +from ..utils import check_embeddings_close + +HF_TEXT_PROMPTS = [ + # T -> X + "Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501 + # T -> X + "Retrieve an image of this caption: cherry blossom", +] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + # T + I -> X + "stop_sign": + "<|image_1|> Select the portion of the image that isolates the object of the given label: The label of the object is stop sign", # noqa: E501 + # I -> X + "cherry_blossom": + "<|image_1|> Represent the given image for classification", # noqa: E501 +}) + +MODELS = ["TIGER-Lab/VLM2Vec-Full"] + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + input_texts: List[str], + input_images: PromptImageInput, + model: str, + *, + dtype: str, +) -> None: + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + with vllm_runner(model, task="embedding", dtype=dtype, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.encode(input_texts, images=input_images) + + # use eager mode for hf runner, since phi3_v didn't work with flash_attn + hf_model_kwargs = {"_attn_implementation": "eager"} + with hf_runner(model, dtype=dtype, + model_kwargs=hf_model_kwargs) as hf_model: + all_inputs = hf_model.get_inputs(input_texts, images=input_images) + + all_outputs = [] + for inputs in all_inputs: + # Based on: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/db3b951bccabba220c1f53ab46a734e50dd2fc08/src/model.py + outputs = hf_model.model( + **hf_model.wrap_device(inputs, + device=hf_model.model.device.type), + return_dict=True, + output_hidden_states=True, + ) + last_hidden_state = outputs.hidden_states[-1][0] + reps = last_hidden_state[inputs.attention_mask[0].sum() - 1] + pooled_output = F.normalize(reps, p=2, dim=-1) + + all_outputs.append(pooled_output.tolist()) + + hf_outputs = all_outputs + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_text( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [(text, None) for text in HF_TEXT_PROMPTS] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, # type: ignore + model, + dtype=dtype, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_models_image( + hf_runner, + vllm_runner, + image_assets, + model: str, + dtype: str, +) -> None: + input_texts_images = [ + (text, asset.pil_image) + for text, asset in zip(HF_IMAGE_PROMPTS, image_assets) + ] + input_texts = [text for text, _ in input_texts_images] + input_images = [image for _, image in input_texts_images] + + _run_test( + hf_runner, + vllm_runner, + input_texts, + input_images, + model, + dtype=dtype, + ) diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/encoder_decoder/vision_language/test_florence2.py new file mode 100644 index 000000000000..483773f06913 --- /dev/null +++ b/tests/models/encoder_decoder/vision_language/test_florence2.py @@ -0,0 +1,102 @@ +from functools import partial +from typing import List, Optional, Tuple, Type + +import pytest +from PIL import Image + +from vllm.inputs.data import ExplicitEncoderDecoderPrompt +from vllm.sequence import SampleLogprobs + +from ....conftest import HfRunner, VllmRunner +from ...utils import check_logprobs_close + +Florence2Prompt = partial(ExplicitEncoderDecoderPrompt, + decoder_prompt=None, + mm_processor_kwargs=None) + +MODELS = ["microsoft/Florence-2-base"] +# Florence-2 uses BartFastTokenizer which can't be loaded from AutoTokenizer +# Therefore, we borrow the BartTokenizer from the original Bart model +TOKENIZER = "facebook/bart-base" +PROMPTS = [ + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), + Florence2Prompt(encoder_prompt=""), +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], ): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = "" + output_str + "" + + return output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts: List[ExplicitEncoderDecoderPrompt], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +) -> None: + with vllm_runner(model, + tokenizer_name=TOKENIZER, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs( + prompts, max_tokens, num_logprobs) + + # Florence-2 processors require image inputs + dummy_image = Image.new(mode="RGB", size=(2, 2)) + with hf_runner(model, dtype=dtype, skip_tokenizer_init=True) as hf_model: + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.language_model.lm_head + hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( + prompts, + max_tokens, + num_logprobs, + images=[dummy_image] * len(prompts), + )) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output) for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, model, dtype, max_tokens, + num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + PROMPTS, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 78a5c8158e16..52f74ec88594 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -12,7 +12,7 @@ from ....utils import large_gpu_test from ...utils import check_logprobs_close -_LIMIT_IMAGE_PER_PROMPT = 1 +_LIMIT_IMAGE_PER_PROMPT = 3 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -244,8 +244,9 @@ def process(hf_inputs: BatchEncoding): @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, - max_tokens, num_logprobs) -> None: +def test_models_single_leading_image(hf_runner, vllm_runner, image_assets, + model, sizes, dtype, max_tokens, + num_logprobs) -> None: run_test( hf_runner, vllm_runner, @@ -257,3 +258,81 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multi_leading_images(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 + "<|image|><|image|><|begin_of_text|>Describe 2 images.", # noqa: E501 + "<|image|><|image|><|image|><|begin_of_text|>Describe 3 images.", # noqa: E501 + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes. + [ + stop_sign.resize((512, 512)), + stop_sign, + ], + [ + stop_sign, + stop_sign.resize((512, 1536)), + cherry_blossom.resize((512, 1024)), + ], + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@large_gpu_test(min_gb=48) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model, + dtype, max_tokens, num_logprobs) -> None: + + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "<|begin_of_text|>The content of the image <|image|> is", # noqa: E501 + "<|begin_of_text|>Between the first image <|image|> and the second image<|image|>, " # noqa: E501 + "which is a stop sign and which is a cherry blossom?", # noqa: E501 + ], + [ + [stop_sign], + [stop_sign, cherry_blossom], + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/utils.py b/tests/models/utils.py index 86a624483c58..f7802d98ad67 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -3,10 +3,10 @@ import torch -from vllm.config import ModelConfig +from vllm.config import ModelConfig, TaskOption from vllm.inputs import InputContext +from vllm.platforms import current_platform from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs -from vllm.utils import is_cpu TokensText = Tuple[List[int], str] @@ -19,7 +19,7 @@ def check_outputs_equal( name_1: str, ): """ - Compare the two sequences generated by different models, + Compare the two sequences generated by different models, which should be equal. """ assert len(outputs_0_lst) == len(outputs_1_lst) @@ -248,13 +248,14 @@ def check_logprobs_close( def build_model_context(model_name: str, + task: TaskOption = "auto", tokenizer_name: Optional[str] = None, trust_remote_code: bool = False, dtype: Optional[Union[str, torch.dtype]] = None, mm_processor_kwargs: Optional[Dict] = None, limit_mm_per_prompt: Optional[Dict] = None): """Creates an InputContext for a given model. - + Args: model_name: Name of the model being considered. tokenizer_name: Name of the tokenizer being considered. @@ -269,11 +270,12 @@ def build_model_context(model_name: str, if tokenizer_name is None: tokenizer_name = model_name if dtype is None: - dtype = "bfloat16" if is_cpu() else "half" + dtype = "bfloat16" if current_platform.is_cpu() else "half" model_config = ModelConfig( model_name, - tokenizer_name, + task=task, + tokenizer=tokenizer_name, tokenizer_mode="auto", trust_remote_code=trust_remote_code, dtype=dtype, diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 616a15a1328d..205ab00aa6b1 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -59,15 +59,7 @@ async def test_evil_forward(tmp_socket): await asyncio.sleep(2.0) await client.check_health() - # Throws an error in first forward pass. - with pytest.raises(RAISED_ERROR): - async for _ in client.generate(prompt="Hello my name is", - sampling_params=SamplingParams(), - request_id=uuid.uuid4()): - pass - assert client.errored - - # Engine is errored, should get ENGINE_DEAD_ERROR. + # Throws an error that should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): async for _ in client.generate(prompt="Hello my name is", sampling_params=SamplingParams(), @@ -149,7 +141,7 @@ async def test_failed_abort(tmp_socket): client = await engine.make_client() assert client.is_running - # Firsh check health should work. + # First check health should work. await client.check_health() # Trigger an abort on the client side. @@ -174,6 +166,45 @@ async def test_failed_abort(tmp_socket): client.close() +@pytest.mark.asyncio +async def test_batch_error(tmp_socket): + with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, + ipc_path=tmp_socket, + run_fn=run_with_evil_abort) as engine: + + client = await engine.make_client() + assert client.is_running + + # First check health should work. + await client.check_health() + + # Batch of requests + async def do_generate(client): + # min_tokens=2048 to keep busy the engine busy + # to get enough time to get process a request + # that will crash the engine + params = SamplingParams(min_tokens=2048, max_tokens=2048) + async for _ in client.generate(prompt="Hello my name is", + sampling_params=params, + request_id=uuid.uuid4()): + pass + + tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)] + + # This request will force a processing batch to raise + # an exception and next the engine get errored + await client.abort(request_id="foo") + + # The batch of those request failed, then they + # should get the same exception as a MQEngineDeadError. + errors = await asyncio.gather(*tasks, return_exceptions=True) + for e in errors: + assert isinstance(e, MQEngineDeadError) + assert "KeyError" in repr(e) + + client.close() + + @pytest.mark.asyncio async def test_bad_request(tmp_socket): with RemoteMQLLMEngine(engine_args=ENGINE_ARGS, diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 000c923ef3e6..7203d635c2fa 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -17,7 +17,6 @@ DEFAULT_SERVER_ARGS: List[str] = [ "--disable-log-requests", - "--use-v2-block-manager", "--worker-use-ray", "--gpu-memory-utilization", "0.85", diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index f45428675bde..cc1fd1925201 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -76,7 +76,6 @@ def test_multi_step_llm( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, enable_chunked_prefill=enable_chunked_prefill, num_scheduler_steps=num_scheduler_steps, ) as vllm_model: @@ -169,7 +168,6 @@ def test_multi_step_llm_w_prompt_logprobs( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, num_scheduler_steps=num_scheduler_steps, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( @@ -305,7 +303,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, num_scheduler_steps=num_scheduler_steps, max_model_len=48, max_num_batched_tokens=48, @@ -324,7 +321,6 @@ def test_multi_step_llm_chunked_prefill_prefix_cache( enforce_eager=enforce_eager, gpu_memory_utilization=0.7, tensor_parallel_size=tp_size, - use_v2_block_manager=True, enable_chunked_prefill=True, enable_prefix_caching=True, num_scheduler_steps=num_scheduler_steps, diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 7d09b81060ef..13ad4a7966b9 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -24,6 +24,7 @@ def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor): model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, @@ -67,6 +68,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype, model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, @@ -109,6 +111,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid): model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, @@ -139,6 +142,7 @@ def test_image_mapper_multi(image_assets, mm_registry, num_images): model_config = ModelConfig( model=MODEL_NAME, + task="auto", tokenizer=MODEL_NAME, tokenizer_mode="auto", trust_remote_code=False, diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py index efc6903c373b..5044740c3e73 100644 --- a/tests/multimodal/test_processor_kwargs.py +++ b/tests/multimodal/test_processor_kwargs.py @@ -5,7 +5,7 @@ import pytest import torch -from vllm.inputs import InputContext, LLMInputs +from vllm.inputs import DecoderOnlyInputs, InputContext, token_inputs from vllm.inputs.registry import InputRegistry from vllm.multimodal import MultiModalRegistry from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -31,7 +31,7 @@ def use_processor_mock(): """Patches the internal model input processor with an override callable.""" def custom_processor(ctx: InputContext, - llm_inputs: LLMInputs, + inputs: DecoderOnlyInputs, *, num_crops=DEFAULT_NUM_CROPS): # For testing purposes, we don't worry about the llm inputs / return @@ -84,7 +84,7 @@ def test_default_processor_is_a_noop(): dummy_registry = InputRegistry() ctx = build_model_context(DUMMY_MODEL_ID) processor = dummy_registry.create_input_processor(ctx.model_config) - proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") + proc_inputs = token_inputs(prompt_token_ids=[], prompt="") proc_outputs = processor(inputs=proc_inputs) assert proc_inputs is proc_outputs @@ -125,9 +125,9 @@ def test_input_processor_kwargs(use_processor_mock, init_num_crops, ctx = build_model_context(DUMMY_MODEL_ID, mm_processor_kwargs=init_kwargs) processor = dummy_registry.create_input_processor(ctx.model_config) num_crops_val = processor( - LLMInputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=inference_kwargs)) + token_inputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=inference_kwargs)) assert num_crops_val == expected_seq_count @@ -154,9 +154,9 @@ def test_processor_with_sad_kwarg_overrides(use_processor_mock, processor = dummy_registry.create_input_processor(ctx.model_config) # Should filter out the inference time kwargs num_crops_val = processor( - LLMInputs(prompt_token_ids=[], - prompt="", - mm_processor_kwargs=mm_processor_kwargs)) + token_inputs(prompt_token_ids=[], + prompt="", + mm_processor_kwargs=mm_processor_kwargs)) assert num_crops_val == DEFAULT_NUM_CROPS @@ -221,6 +221,7 @@ def test_max_tokens_kwarg_overrides(num_crops): expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) @@ -256,6 +257,7 @@ def test_max_tokens_kwarg_overrides(num_crops): def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): """Ensure that max token calcs filters out invalid mm_processor_kwargs""" ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) @@ -278,12 +280,13 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): ### Test overrides for the mapper @pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) -def test_default_mapper_with_processer_kwargs(image_assets, num_crops): +def test_default_mapper_with_processor_kwargs(image_assets, num_crops): """Ensure that the mapper processor kwargs can fall back to HF models.""" # NOTE - we don't validate bad inputs for the default mapper, because it's # through the automodel interface in transformers, so we can't easily # inspect what kwargs are or are not allowed. ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs={"num_crops": num_crops}, limit_mm_per_prompt={"image": 1}) @@ -311,6 +314,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops, init_num_crops, inference_num_crops) ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=init_kwargs, limit_mm_per_prompt={"image": 1}) @@ -348,6 +352,7 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets, """Ensure that custom mappers filters out invalid mm_processor_kwargs""" # Should filter out the init time kwargs ctx = build_model_context(MULTIMODAL_MODEL_ID, + task="generate", trust_remote_code=True, mm_processor_kwargs=mm_processor_kwargs, limit_mm_per_prompt={"image": 1}) diff --git a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py index 1d61f6b74f52..21958b164020 100644 --- a/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py +++ b/tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_gemma_embedding.py @@ -3,7 +3,7 @@ import torch from vllm.attention import AttentionMetadata -from vllm.model_executor.models.gemma2_embedding import Gemma2EmbeddingModel +from vllm.model_executor.models.gemma2 import Gemma2EmbeddingModel from vllm.sequence import IntermediateTensors diff --git a/tests/prefix_caching/test_disable_sliding_window.py b/tests/prefix_caching/test_disable_sliding_window.py index eeac6ab43c05..5a28943b7ecb 100644 --- a/tests/prefix_caching/test_disable_sliding_window.py +++ b/tests/prefix_caching/test_disable_sliding_window.py @@ -4,8 +4,8 @@ """ import pytest -from tests.conftest import cleanup from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory MODEL_LEN_LEN = [ # Example models with sliding window. @@ -31,7 +31,7 @@ def test_disable_sliding_window(model_len_len, ): model_config.max_model_len) del vllm_disabled_model - cleanup() + cleanup_dist_env_and_memory() vllm_enabled_model = LLM(model, disable_sliding_window=False) vllm_enabled_model.generate("Hi my name is") @@ -41,4 +41,4 @@ def test_disable_sliding_window(model_len_len, ): model_config.max_model_len) del vllm_enabled_model - cleanup() + cleanup_dist_env_and_memory() diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 88437425feb3..366b030eaa39 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,15 +2,9 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ -from typing import List - import pytest from tests.kernels.utils import override_backend_env_variable -from tests.utils import check_deprecated_block_manager_usage -from vllm.block import PhysicalTokenBlock -from vllm.core.block_manager_v1 import CachedBlockAllocator -from vllm.utils import Device from ..models.utils import check_outputs_equal @@ -19,92 +13,11 @@ ] -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/prefix_caching/test_prefix_caching.py') - - -@pytest.mark.parametrize("block_size", [16]) -@pytest.mark.parametrize("num_blocks", [16]) -def test_block_allocator( - block_size: int, - num_blocks: int, -): - block_hash = 1 - block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) - - # Allocate two PysicalTokenBlocks with the same hash and check - # that they are the same PhysicalTokenBlock - first_block = block_allocator.allocate(block_hash, 0) - second_block = block_allocator.allocate(block_hash, 0) - assert (first_block == second_block) - assert (second_block.ref_count == 2) - - # Check metric: 1 hit of 2 queries - assert block_allocator.get_prefix_cache_hit_rate() == 0.5 - - # Free the first_block and confirm that the ref_count is correctly - # decremented on the second block - block_allocator.free(first_block) - assert (second_block.ref_count == 1) - - # Free the second block - block_allocator.free(second_block) - - # Reallocate the first block and confirm that, even after the block - # had its ref_count go to 0, we still get the same block back - first_block = block_allocator.allocate(block_hash, 0) - assert (first_block == second_block) - assert (first_block.block_hash == block_hash) - - # Allocate one more time to get 3/4 hit rate for easy checking - block_allocator.allocate(block_hash, 0) - assert block_allocator.get_prefix_cache_hit_rate() == 0.75 - - -@pytest.mark.parametrize("num_blocks", [16]) -def test_eviction(num_blocks: int, ): - block_size = 16 - block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks) - blocks: List[PhysicalTokenBlock] = [] - - for i in range(num_blocks): - # use i as the block_hash - blocks.append(block_allocator.allocate(i, 0)) - - #Free all blocks - for block in blocks: - block_allocator.free(block) - - # Allocate a new block and confirm that it's the first block freed. - # I.E The Least Recently Used block - new_block_hash = block_size - new_block = block_allocator.allocate(new_block_hash, 0) - assert (new_block == blocks[0]) - assert (new_block.block_hash == new_block_hash) - - # Reallocate the second in blocks to remove it from the free list - realloc_block_hash = 1 - realloc_block = block_allocator.allocate(realloc_block_hash, 0) - assert (realloc_block == blocks[realloc_block_hash]) - assert (realloc_block.block_hash == realloc_block_hash) - - # Allocate a new block and confirm that it's not the realloc_block, - # since the realloc_block shouldn't be in the free list - new_block_hash = block_size + 1 - new_block = block_allocator.allocate(new_block_hash, 0) - assert (realloc_block != new_block) - assert (new_block.block_hash == new_block_hash) - assert (new_block.block_number == 2) - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) -@pytest.mark.parametrize("use_v2_block_manager", [False, True]) def test_mixed_requests( hf_runner, vllm_runner, @@ -114,7 +27,6 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, - use_v2_block_manager: bool, monkeypatch, ) -> None: """ @@ -132,7 +44,6 @@ def test_mixed_requests( model, dtype=dtype, enable_prefix_caching=True, - use_v2_block_manager=use_v2_block_manager, ) as vllm_model: # Run the first prompt so the cache is populated vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index f2acf0d70afe..0f01f5f819ea 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -107,8 +107,7 @@ def validate_generated_texts(hf_runner, quantization='bitsandbytes', load_format='bitsandbytes', tensor_parallel_size=vllm_tp_size, - enforce_eager=False, - gpu_memory_utilization=0.8) as llm: + enforce_eager=False) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 5cdb8a8e8228..03097569b2b3 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -6,13 +6,12 @@ import pytest import torch +from compressed_tensors.quantization import QuantizationType from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) -from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( - QuantizationType) @pytest.mark.parametrize( diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py index d18233fe1aea..cf77ccec7a19 100644 --- a/tests/quantization/test_configs.py +++ b/tests/quantization/test_configs.py @@ -57,7 +57,8 @@ def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None: try: model_config = ModelConfig(model_path, - model_path, + task="auto", + tokenizer=model_path, tokenizer_mode="auto", trust_remote_code=False, seed=0, diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index b450ef97c89d..b9cb3858c006 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -4,10 +4,10 @@ import pytest from vllm import LLM, SamplingParams +from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.utils import set_random_seed from vllm.sequence import PromptLogprobs, SampleLogprobs -from ...conftest import cleanup from ...models.utils import (TokensTextLogprobs, TokensTextLogprobsPromptLogprobs, check_logprobs_close, check_outputs_equal) @@ -44,7 +44,7 @@ def generate(): yield llm del llm - cleanup() + cleanup_dist_env_and_memory() return generate diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 69ea81cfffed..629074188a6c 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -1,27 +1,15 @@ import pytest -from tests.utils import check_deprecated_block_manager_usage from vllm import SamplingParams from .conftest import get_output_from_llm_generator -@pytest.fixture(scope="module", autouse=True) -def check_deprecated_block_manager(): - check_deprecated_block_manager_usage( - 'tests/spec_decode/e2e/test_compatibility.py') - - -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) +@pytest.mark.parametrize("common_llm_kwargs", [{ + "model": "JackFram/llama-68m", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, +}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { "enable_chunked_prefill": True, @@ -51,16 +39,11 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator): sampling_params) -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - "model": "meta-llama/Llama-2-7b-chat-hf", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - - # Required for spec decode. - "use_v2_block_manager": True - }]) +@pytest.mark.parametrize("common_llm_kwargs", [{ + "model": "meta-llama/Llama-2-7b-chat-hf", + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, +}]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", [ @@ -101,34 +84,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): with pytest.raises(ValueError, match="cannot be larger than"): get_output_from_llm_generator(test_llm_generator, prompts, sampling_params) - - -@pytest.mark.parametrize("common_llm_kwargs", [{ - "model": "JackFram/llama-68m", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "use_v2_block_manager": False, -}]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_block_manager_v1(test_llm_generator): - """Verify that speculative decoding with block manager v1 fails. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises(ValueError, - match="Speculative decoding requires usage of the V2"): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index d7ca8815ec25..5bc70de9dac5 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -43,9 +43,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -86,9 +83,6 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -143,9 +137,6 @@ def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, [{ "enforce_eager": False, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -191,9 +182,6 @@ def test_eagle_e2e_greedy_correctness_cuda_graph( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -235,9 +223,6 @@ def test_eagle_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -283,9 +268,6 @@ def test_eagle_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index d04e312689bc..b89e5849727f 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -12,8 +12,6 @@ @pytest.mark.parametrize( "common_llm_kwargs", [{ - # Required for spec decode. - "use_v2_block_manager": True, # Verify equality when cuda graphs allowed. "enforce_eager": False, @@ -57,9 +55,6 @@ def test_spec_decode_cuda_graph(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -111,9 +106,6 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 3, }]) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 679a6ded9ee7..b829d1a5be78 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -17,9 +17,6 @@ [[ # Skip cuda graph recording for fast test. "--enforce-eager", - - # Required for spec decode. - "--use-v2-block-manager", "--tensor-parallel-size", "2" ]]) @@ -74,9 +71,6 @@ def test_target_model_tp_gt_1(common_llm_kwargs, per_test_common_llm_kwargs, [[ # Skip cuda graph recording for fast test. "--enforce-eager", - - # Required for spec decode. - "--use_v2_block_manager", "--tensor_parallel_size", "2", diff --git a/tests/spec_decode/e2e/test_integration_dist_tp4.py b/tests/spec_decode/e2e/test_integration_dist_tp4.py index 3f7c5d749e4f..555aef99218c 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp4.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp4.py @@ -19,9 +19,6 @@ [[ # Skip cuda graph recording for fast test. "--enforce_eager", - - # Required for spec decode. - "--use-v2-block-manager", "--tensor-parallel-size", "4", ]]) @@ -71,9 +68,6 @@ def test_draft_model_tp_lt_target_model_tp4(common_llm_kwargs, # Skip cuda graph recording for fast test. "--enforce-eager", - - # Required for spec decode. - "--use-v2-block-manager", "--tensor-parallel-size", "4", ]]) diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index b7d54991e053..4cfca8b78e79 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -14,9 +14,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -67,9 +64,6 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -119,9 +113,6 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -173,9 +164,6 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -251,8 +239,6 @@ def test_logprobs_temp_1(vllm_runner, common_llm_kwargs, "model_name": "JackFram/llama-160m", # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 0b36e712a11b..b8965606b3d0 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -45,9 +45,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -93,9 +90,6 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -151,9 +145,6 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, [{ "enforce_eager": False, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -204,9 +195,6 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -253,9 +241,6 @@ def test_medusa_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -306,9 +291,6 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -356,9 +338,6 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 52b48a33c309..5ecc0d4e9571 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -47,9 +47,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -94,9 +91,6 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -149,9 +143,6 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -195,9 +186,6 @@ def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, @@ -258,9 +246,6 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -311,9 +296,6 @@ def test_mlp_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -366,9 +348,6 @@ def patched_pad_vocab_size(vocab_size, pad_to=None): # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -419,9 +398,6 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Precision "dtype": PRECISION, @@ -469,9 +445,6 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, "speculative_model": SPEC_MODEL, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index df6f12d57b40..5f240d42d9e0 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -55,9 +55,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True, }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -124,9 +121,6 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -190,9 +184,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -246,9 +237,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs( [{ # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -303,9 +291,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -353,9 +338,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1( # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -404,9 +386,6 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs( # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -454,9 +433,6 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize( "per_test_common_llm_kwargs", @@ -514,9 +490,6 @@ def test_spec_decode_different_block_size(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -570,9 +543,6 @@ def test_skip_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -611,9 +581,6 @@ def test_disable_speculation(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -660,9 +627,6 @@ def test_many_k(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 586245938316..31bedad48028 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -35,9 +35,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -82,9 +79,6 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # Print spec metrics. "disable_log_stats": False, }]) @@ -145,9 +139,6 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [ { @@ -195,9 +186,6 @@ def test_ngram_e2e_greedy_correctness_with_preemption( # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -254,9 +242,6 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, # Skip cuda graph recording for fast test. "enforce_eager": True, - - # Required for spec decode. - "use_v2_block_manager": True }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @@ -303,7 +288,6 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, "enforce_eager": True, # Required for spec decode. - "use_v2_block_manager": True, "speculative_model": "[ngram]", "num_speculative_tokens": 5, "ngram_prompt_lookup_max": 3, diff --git a/tests/spec_decode/e2e/test_seed.py b/tests/spec_decode/e2e/test_seed.py index b17013216ae2..e42cf416b159 100644 --- a/tests/spec_decode/e2e/test_seed.py +++ b/tests/spec_decode/e2e/test_seed.py @@ -17,9 +17,6 @@ # Skip cuda graph recording for fast test. "enforce_eager": True, - # Required for spec decode. - "use_v2_block_manager": True, - # speculative model "speculative_model": "JackFram/llama-160m", diff --git a/tests/tensorizer_loader/conftest.py b/tests/tensorizer_loader/conftest.py index 07b9c6b3c6be..2a4565362244 100644 --- a/tests/tensorizer_loader/conftest.py +++ b/tests/tensorizer_loader/conftest.py @@ -1,27 +1,18 @@ -import contextlib import functools import gc from typing import Callable, TypeVar import pytest -import ray import torch from typing_extensions import ParamSpec -from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel) +from vllm.distributed import cleanup_dist_env_and_memory from vllm.model_executor.model_loader.tensorizer import TensorizerConfig @pytest.fixture(autouse=True) def cleanup(): - destroy_model_parallel() - destroy_distributed_environment() - with contextlib.suppress(AssertionError): - torch.distributed.destroy_process_group() - ray.shutdown() - gc.collect() - torch.cuda.empty_cache() + cleanup_dist_env_and_memory(shutdown_ray=True) _P = ParamSpec("_P") diff --git a/tests/test_config.py b/tests/test_config.py index 225d71c0bc0e..69918b67607d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,6 +2,42 @@ from vllm.config import ModelConfig + +@pytest.mark.parametrize(("model_id", "expected_task"), [ + ("facebook/opt-125m", "generate"), + ("intfloat/e5-mistral-7b-instruct", "embedding"), +]) +def test_auto_task(model_id, expected_task): + config = ModelConfig( + model_id, + task="auto", + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + assert config.task == expected_task + + +@pytest.mark.parametrize(("model_id", "bad_task"), [ + ("facebook/opt-125m", "embedding"), + ("intfloat/e5-mistral-7b-instruct", "generate"), +]) +def test_incorrect_task(model_id, bad_task): + with pytest.raises(ValueError, match=r"does not support the .* task"): + ModelConfig( + model_id, + task=bad_task, + tokenizer=model_id, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + ) + + MODEL_IDS_EXPECTED = [ ("Qwen/Qwen1.5-7B", 32768), ("mistralai/Mistral-7B-v0.1", 4096), @@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected): model_id, expected = model_id_expected model_config = ModelConfig( model_id, - model_id, + task="auto", + tokenizer=model_id, tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -32,7 +69,8 @@ def test_get_sliding_window(): # when use_sliding_window is False. qwen2_model_config = ModelConfig( "Qwen/Qwen1.5-7B", - "Qwen/Qwen1.5-7B", + task="auto", + tokenizer="Qwen/Qwen1.5-7B", tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -49,7 +87,8 @@ def test_get_sliding_window(): mistral_model_config = ModelConfig( "mistralai/Mistral-7B-v0.1", - "mistralai/Mistral-7B-v0.1", + task="auto", + tokenizer="mistralai/Mistral-7B-v0.1", tokenizer_mode="auto", trust_remote_code=False, seed=0, @@ -64,13 +103,14 @@ def test_get_sliding_window(): def test_rope_customization(): - TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0} + TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0} TEST_ROPE_THETA = 16_000_000.0 - LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0} + LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0} llama_model_config = ModelConfig( "meta-llama/Meta-Llama-3-8B-Instruct", - "meta-llama/Meta-Llama-3-8B-Instruct", + task="auto", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", @@ -82,7 +122,8 @@ def test_rope_customization(): llama_model_config = ModelConfig( "meta-llama/Meta-Llama-3-8B-Instruct", - "meta-llama/Meta-Llama-3-8B-Instruct", + task="auto", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", @@ -98,7 +139,8 @@ def test_rope_customization(): longchat_model_config = ModelConfig( "lmsys/longchat-13b-16k", - "lmsys/longchat-13b-16k", + task="auto", + tokenizer="lmsys/longchat-13b-16k", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", @@ -112,7 +154,8 @@ def test_rope_customization(): longchat_model_config = ModelConfig( "lmsys/longchat-13b-16k", - "lmsys/longchat-13b-16k", + task="auto", + tokenizer="lmsys/longchat-13b-16k", tokenizer_mode="auto", trust_remote_code=False, dtype="float16", diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py index 1201aaa92ea8..a9221f08c294 100644 --- a/tests/test_scalartype.py +++ b/tests/test_scalartype.py @@ -32,5 +32,5 @@ def test_scalar_type_min_max(type_tuple): max = torch.iinfo(torch_type).max print(t, min, max, t.min(), t.max()) - assert min == t.min() - assert max == t.max() + assert min == t.min(), f"min: {min} != {t.min()}" + assert max == t.max(), f"max: {max} != {t.max()}" diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index f5d9569046a6..2412da5037ec 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -46,9 +46,10 @@ def test_filter_subtensors(): @pytest.fixture(scope="module") def llama_2_7b_files(): with TemporaryDirectory() as cache_dir: - input_dir = snapshot_download("meta-llama/Llama-2-7b-hf", + input_dir = snapshot_download("meta-llama/Llama-3.2-1B", cache_dir=cache_dir, - ignore_patterns="*.bin*") + ignore_patterns=["*.bin*", "original/*"]) + yield input_dir @@ -58,9 +59,12 @@ def _run_writer(input_dir, output_dir, weights_patterns, **kwargs): # Dump worker states to output directory llm_sharded_writer.llm_engine.model_executor.save_sharded_state( path=output_dir) + # Copy metadata files to output directory for file in os.listdir(input_dir): - if not any(file.endswith(ext) for ext in weights_patterns): + if not any( + file.endswith(ext) and not os.path.isdir(file) + for ext in weights_patterns): shutil.copy(f"{input_dir}/{file}", output_dir) diff --git a/tests/test_utils.py b/tests/test_utils.py index 268e6f8194ab..0fed8e678fc7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -59,7 +59,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with pytest.warns(DeprecationWarning, match="'old_arg'"): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) @@ -69,10 +69,10 @@ def test_deprecate_kwargs_never(): def dummy(*, old_arg: object = None, new_arg: object = None): pass - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) @@ -86,15 +86,15 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with pytest.warns(DeprecationWarning, match="'old_arg'"): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) is_deprecated = False - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(old_arg=1) - with error_on_warning(): + with error_on_warning(DeprecationWarning): dummy(new_arg=1) diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index f4551ed42efb..446a801bdb31 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Generator, List, Optional import pytest from transformers import AutoTokenizer @@ -7,11 +7,14 @@ from vllm.transformers_utils.detokenizer import (Detokenizer, detokenize_incrementally) from vllm.transformers_utils.tokenizer_group import get_tokenizer_group +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer TRUTH = [ "Hello here, this is a simple test", "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving", # noqa - "ęˆ‘å¾ˆę„Ÿč°¢ä½ ēš„ēƒ­ęƒ…" + "ęˆ‘å¾ˆę„Ÿč°¢ä½ ēš„ēƒ­ęƒ…", + # see https://github.com/vllm-project/vllm/pull/9625 + "THIS IS AN URGENCY", ] TOKENIZERS = [ "facebook/opt-125m", @@ -24,6 +27,7 @@ "tiiuae/falcon-7b", "meta-llama/Llama-2-7b-hf", "codellama/CodeLlama-7b-hf", + "mistralai/Pixtral-12B-2409", ] @@ -49,15 +53,38 @@ def _run_incremental_decode(tokenizer, all_input_ids, return decoded_text +@pytest.fixture +def tokenizer(tokenizer_name): + return (MistralTokenizer.from_pretrained(tokenizer_name) + if "mistral" in tokenizer_name else + AutoTokenizer.from_pretrained(tokenizer_name)) + + +# see https://github.com/vllm-project/vllm/pull/9625 +@pytest.mark.parametrize("tokenizer_name", ["mistralai/Pixtral-12B-2409"]) +def test_mistral_edge_case(tokenizer): + assert (_run_incremental_decode(tokenizer, [1492, 1176, 115679], + skip_special_tokens=True, + starting_index=0) == " ư") + + +@pytest.fixture +def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]: + if "mistral" in tokenizer_name: + yield ( + bool(True) if request.param else + pytest.skip("mistral doesn't support skip_special_tokens=False")) + else: + yield bool(True) if request.param else bool(False) + + @pytest.mark.parametrize("truth", TRUTH) @pytest.mark.parametrize("with_prompt", [True, False]) -@pytest.mark.parametrize("tokenizer_id", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", (True, False)) -def test_decode_streaming(tokenizer_id, truth, with_prompt, - skip_special_tokens): - tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) +@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) +@pytest.mark.parametrize("skip_special_tokens", (True, False), indirect=True) +def test_decode_streaming(tokenizer, truth, with_prompt, skip_special_tokens): if with_prompt: - truth_tokens = tokenizer(truth, add_special_tokens=False)["input_ids"] + truth_tokens = tokenizer(truth, add_special_tokens=False).input_ids prompt_input_ids = truth_tokens[:len(truth) // 2] generated_input_ids = truth_tokens[len(truth) // 2:] all_input_ids = prompt_input_ids + generated_input_ids @@ -68,7 +95,7 @@ def test_decode_streaming(tokenizer_id, truth, with_prompt, else: generated = truth starting_index = 0 - all_input_ids = tokenizer(truth, add_special_tokens=False)["input_ids"] + all_input_ids = tokenizer(truth, add_special_tokens=False).input_ids if skip_special_tokens: if tokenizer.bos_token_id is not None: all_input_ids = [tokenizer.bos_token_id] + all_input_ids @@ -98,7 +125,7 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: enable_lora=False, max_num_seqs=100, max_input_length=None, - tokenizer_mode="auto", + tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", trust_remote_code=False, revision=None, ) @@ -113,9 +140,8 @@ def detokenizer(tokenizer_name: str) -> Detokenizer: @pytest.fixture(name="complete_sequence_token_ids") def create_complete_sequence_token_ids(complete_sequence: str, - tokenizer_name: str) -> List[int]: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - complete_sequence_token_ids = tokenizer(complete_sequence)["input_ids"] + tokenizer) -> List[int]: + complete_sequence_token_ids = tokenizer(complete_sequence).input_ids return complete_sequence_token_ids @@ -150,7 +176,7 @@ def create_dummy_prompt_logprobs( @pytest.mark.parametrize("complete_sequence", TRUTH) @pytest.mark.parametrize("tokenizer_name", TOKENIZERS) -@pytest.mark.parametrize("skip_special_tokens", [True, False]) +@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) def test_decode_sequence_logprobs(complete_sequence: str, complete_sequence_token_ids: List[int], detokenizer: Detokenizer, @@ -208,9 +234,9 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: List[int], # decoded_prompt_logprobs doesn't contain the first token. token_ids = complete_sequence_token_ids - tokenzier = detokenizer.get_tokenizer_for_seq(seq) - text_full = tokenzier.decode(token_ids, skip_special_tokens=True) - text_first = tokenzier.decode(token_ids[0], skip_special_tokens=True) + tokenizer = detokenizer.get_tokenizer_for_seq(seq) + text_full = tokenizer.decode(token_ids, skip_special_tokens=True) + text_first = tokenizer.decode(token_ids[0], skip_special_tokens=True) text = text_full[len(text_first):] # Text for logprobs for the chosen token should be the same as the diff --git a/tests/tool_use/test_jamba_tool_parser.py b/tests/tool_use/test_jamba_tool_parser.py new file mode 100644 index 000000000000..3095ef451679 --- /dev/null +++ b/tests/tool_use/test_jamba_tool_parser.py @@ -0,0 +1,275 @@ +import json +from typing import Generator, List, Optional + +import partial_json_parser +import pytest +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import JambaToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +MODEL = "ai21labs/Jamba-tiny-dev" + + +@pytest.fixture(scope="module") +def jamba_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def jamba_tool_parser(jamba_tokenizer): + return JambaToolParser(jamba_tokenizer) + + +def assert_tool_calls(actual_tool_calls: List[ToolCall], + expected_tool_calls: List[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) > 16 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function + + +def stream_delta_message_generator( + jamba_tool_parser: JambaToolParser, jamba_tokenizer: AnyTokenizer, + model_output: str) -> Generator[DeltaMessage, None, None]: + all_token_ids = jamba_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=jamba_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = jamba_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = previous_tokens + new_tokens if previous_tokens\ + else new_tokens + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(jamba_tool_parser): + model_output = "This is a test" + extracted_tool_calls = jamba_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool", + "single_tool_with_content", + "parallel_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], + None), + ( + ''' Sure! let me call the tool for you.[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], + " Sure! let me call the tool for you."), + ( + ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))), + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit" + }))) + ], + None) + ], +) +def test_extract_tool_calls(jamba_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = jamba_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool", + "single_tool_with_content", + "parallel_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ('''This is a test''', [], '''This is a test'''), + ( + ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], + " "), + ( + ''' Sure! let me call the tool for you.[\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}\n]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))) + ], + " Sure! let me call the tool for you."), + ( + ''' [\n {"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}},\n {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}\n]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Dallas", + "state": "TX", + "unit": "fahrenheit" + }))), + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "Orlando", + "state": "FL", + "unit": "fahrenheit" + }))) + ], + " ") + ], +) +def test_extract_tool_calls_streaming(jamba_tool_parser, jamba_tokenizer, + model_output, expected_tool_calls, + expected_content): + other_content: str = '' + function_names: List[str] = [] + function_args_strs: List[str] = [] + tool_call_idx: int = -1 + tool_call_ids: List[Optional[str]] = [] + + for delta_message in stream_delta_message_generator( + jamba_tool_parser, jamba_tokenizer, model_output): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[ + tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall(id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR))) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index ce36515a2381..7970ad0d8893 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -88,6 +88,24 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally." }, + "granite20b": { + "model": + "ibm-granite/granite-20b-functioncalling", + "arguments": [ + "--tool-call-parser", "granite-20b-fc", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_granite_20b_fc.jinja") + ], + "supports_parallel": + False, + }, + "granite8b": { + "model": + "ibm-granite/granite-8b-instruct", + "arguments": [ + "--tool-call-parser", "granite", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_granite.jinja") + ], + }, "internlm": { "model": "internlm/internlm2_5-7b-chat", diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py index 64ed8e26f38e..fe5fc979c66a 100644 --- a/tests/tracing/test_tracing.py +++ b/tests/tracing/test_tracing.py @@ -87,8 +87,19 @@ def test_traces(trace_service): f"The fake trace service didn't receive a trace within " f"the {timeout} seconds timeout") - attributes = decode_attributes(trace_service.request.resource_spans[0]. - scope_spans[0].spans[0].attributes) + request = trace_service.request + assert len(request.resource_spans) == 1, ( + f"Expected 1 resource span, " + f"but got {len(request.resource_spans)}") + assert len(request.resource_spans[0].scope_spans) == 1, ( + f"Expected 1 scope span, " + f"but got {len(request.resource_spans[0].scope_spans)}") + assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( + f"Expected 1 span, " + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") + + attributes = decode_attributes( + request.resource_spans[0].scope_spans[0].spans[0].attributes) assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model assert attributes.get( SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id @@ -142,8 +153,19 @@ def test_traces_with_detailed_steps(trace_service): f"The fake trace service didn't receive a trace within " f"the {timeout} seconds timeout") - attributes = decode_attributes(trace_service.request.resource_spans[0]. - scope_spans[0].spans[0].attributes) + request = trace_service.request + assert len(request.resource_spans) == 1, ( + f"Expected 1 resource span, " + f"but got {len(request.resource_spans)}") + assert len(request.resource_spans[0].scope_spans) == 1, ( + f"Expected 1 scope span, " + f"but got {len(request.resource_spans[0].scope_spans)}") + assert len(request.resource_spans[0].scope_spans[0].spans) == 1, ( + f"Expected 1 span, " + f"but got {len(request.resource_spans[0].scope_spans[0].spans)}") + + attributes = decode_attributes( + request.resource_spans[0].scope_spans[0].spans[0].attributes) assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model assert attributes.get( SpanAttributes.LLM_REQUEST_ID) == outputs[0].request_id diff --git a/tests/utils.py b/tests/utils.py index 924465057468..e983104e3cb0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,7 +8,7 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union import openai import pytest @@ -454,13 +454,13 @@ def multi_process_parallel( @contextmanager -def error_on_warning(): +def error_on_warning(category: Type[Warning] = Warning): """ Within the scope of this context manager, tests will fail if any warning - is emitted. + of the given category is emitted. """ with warnings.catch_warnings(): - warnings.simplefilter("error") + warnings.filterwarnings("error", category=category) yield @@ -587,7 +587,7 @@ def large_gpu_test(*, min_gb: int): ) def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: - return test_skipif(fork_new_process_for_each_test(f)) + return test_skipif(f) return wrapper @@ -678,12 +678,3 @@ def get_client_text_logprob_generations( return [(text_generations, text, (None if x.logprobs is None else x.logprobs.top_logprobs)) for completion in completions for x in completion.choices] - - -def check_deprecated_block_manager_usage(test_name: str): - assert envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1 is True, ( - f"To allow the use of deprecated BlockSpaceManagerV1, set the " - f"environment variable VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. " - f"You can run the tests with: " - f"`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest {test_name}`" #noqa - ) diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 3dccc1b325d9..e75884a7395e 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -5,8 +5,9 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata -from vllm.utils import is_cpu, make_tensor_with_pad +from vllm.utils import make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import _get_graph_batch_size @@ -31,7 +32,7 @@ def _create_model_runner(model: str, *args, return model_runner -@pytest.mark.skipif(condition=is_cpu(), +@pytest.mark.skipif(condition=current_platform.is_cpu(), reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") @@ -74,7 +75,7 @@ def test_empty_seq_group(): assert return_seq_lens is None -@pytest.mark.skipif(condition=is_cpu(), +@pytest.mark.skipif(condition=current_platform.is_cpu(), reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") @@ -264,7 +265,7 @@ def test_prepare_prompt(batch_size): assert torch.equal(actual, expected) -@pytest.mark.skipif(condition=is_cpu(), +@pytest.mark.skipif(condition=current_platform.is_cpu(), reason="CPU backend is currently " "unsupported for encoder/ " "decoder models") @@ -490,7 +491,7 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group): def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): """ Tests that for encoder-decoder models with CUDA Graph capture and replay - enabled, the tensors used during the decode phase are correctly padded + enabled, the tensors used during the decode phase are correctly padded for varying input batch sizes. """ model_runner = _create_model_runner( diff --git a/tests/worker/test_profile.py b/tests/worker/test_profile.py new file mode 100644 index 000000000000..acd2ed683636 --- /dev/null +++ b/tests/worker/test_profile.py @@ -0,0 +1,70 @@ +import torch + +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.worker import Worker + + +def test_gpu_memory_profiling(): + # Tests the gpu profiling that happens in order to determine the number of + # KV cache blocks that we can allocate on the GPU. + # This test mocks the maximum available gpu memory so that it can run on + # any gpu setup. + + # Set up engine args to build a worker. + engine_args = EngineArgs(model="facebook/opt-125m", + dtype="half", + load_format="dummy") + engine_config = engine_args.create_engine_config() + engine_config.cache_config.num_gpu_blocks = 1000 + engine_config.cache_config.num_cpu_blocks = 1000 + + # Create the worker. + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + worker = Worker( + model_config=engine_config.model_config, + parallel_config=engine_config.parallel_config, + scheduler_config=engine_config.scheduler_config, + device_config=engine_config.device_config, + cache_config=engine_config.cache_config, + load_config=engine_config.load_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + is_driver_worker=True, + ) + + # Load the model so we can profile it + worker.init_device() + worker.load_model() + + # Set 10GiB as the total gpu ram to be device-agnostic + def mock_mem_info(): + current_usage = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + mock_total_bytes = 10 * 1024**3 + free = mock_total_bytes - current_usage + + return (free, mock_total_bytes) + + from unittest.mock import patch + with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info): + gpu_blocks, _ = worker.determine_num_available_blocks() + + # Peak vram usage by torch should be 0.7077 GiB + # No memory should be allocated outside of torch + # 9.0 GiB should be the utilization target + # 8.2923 GiB should be available for the KV cache + block_size = CacheEngine.get_cache_block_size( + engine_config.cache_config, engine_config.model_config, + engine_config.parallel_config) + + expected_blocks = (8.2923 * 1024**3) // block_size + + # Check within a small tolerance for portability + # Hardware, kernel, or dependency changes could all affect memory + # utilization. + # A 10 block tolerance here should be about 6MB of wiggle room. + assert abs(gpu_blocks - expected_blocks) < 10 diff --git a/tools/check_repo.sh b/tools/check_repo.sh new file mode 100644 index 000000000000..48eba5bea836 --- /dev/null +++ b/tools/check_repo.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Checks whether the repo is clean and whether tags are available (necessary to correctly produce vllm version at build time) + +if ! git diff --quiet; then + echo "Repo is dirty" >&2 + + exit 1 +fi + +if ! git describe --tags; then + echo "No tags are present. Is this a shallow clone? git fetch --unshallow --tags" >&2 + + exit 1 +fi diff --git a/tools/mypy.sh b/tools/mypy.sh index e6187a08ffd9..14b0976a27da 100755 --- a/tools/mypy.sh +++ b/tools/mypy.sh @@ -2,6 +2,10 @@ CI=${1:-0} +if [ $CI -eq 1 ]; then + set -e +fi + run_mypy() { echo "Running mypy on $1" if [ $CI -eq 1 ] && [ -z "$1" ]; then @@ -13,24 +17,14 @@ run_mypy() { run_mypy # Note that this is less strict than CI run_mypy tests -run_mypy vllm/assets run_mypy vllm/attention -#run_mypy vllm/compilation -#run_mypy vllm/core +run_mypy vllm/compilation run_mypy vllm/distributed run_mypy vllm/engine -run_mypy vllm/entrypoints run_mypy vllm/executor -#run_mypy vllm/inputs -run_mypy vllm/logging run_mypy vllm/lora run_mypy vllm/model_executor -run_mypy vllm/multimodal -run_mypy vllm/platforms run_mypy vllm/plugins run_mypy vllm/prompt_adapter run_mypy vllm/spec_decode -run_mypy vllm/transformers_utils -run_mypy vllm/usage -#run_mypy vllm/vllm_flash_attn run_mypy vllm/worker diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py new file mode 100644 index 000000000000..bbd24b085e3a --- /dev/null +++ b/tools/profiler/print_layerwise_table.py @@ -0,0 +1,77 @@ +import argparse +import json +from typing import Dict + +from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry +from vllm.profiler.utils import TablePrinter, indent_string + + +def flatten_entries(entry_cls, profile_dict: Dict): + entries_and_depth = [] + + def get_entries(node, curr_depth=0): + entries_and_depth.append((entry_cls(**node["entry"]), curr_depth)) + + for child in node["children"]: + get_entries( + child, + curr_depth=curr_depth + 1, + ) + + for root in profile_dict: + get_entries(root) + + return entries_and_depth + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--json-trace", + type=str, + required=True, + help="json trace file output by " + "examples/offline_profile.py") + parser.add_argument("--phase", + type=str, + choices=["prefill", "decode_1"], + required=True, + help="The phase to print the table for.") + parser.add_argument("--table", + type=str, + choices=["summary", "model"], + default="summary", + help="Which table to print, the summary table or the " + "layerwise model table") + + args = parser.parse_args() + + with open(args.json_trace, "r") as f: + profile_data = json.load(f) + + if args.table == "summary": + entries_and_depths = flatten_entries( + SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) + column_widths = dict(name=80, + cuda_time_us=12, + pct_cuda_time=12, + invocations=15) + elif args.table == "model": + entries_and_depths = flatten_entries( + ModelStatsEntry, profile_data[args.phase]["model_stats"]) + column_widths = dict(name=60, + cpu_time_us=12, + cuda_time_us=12, + pct_cuda_time=12, + trace=60) + + # indent entry names based on the depth + entries = [] + for entry, depth in entries_and_depths: + entry.name = indent_string( + entry.name, + indent=depth, + indent_style=lambda indent: "|" + "-" * indent + " ") + entries.append(entry) + + TablePrinter(type(entries[0]), column_widths).print_table(entries) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py new file mode 100644 index 000000000000..65ee3ae108ae --- /dev/null +++ b/tools/profiler/visualize_layerwise_profile.py @@ -0,0 +1,522 @@ +import argparse +import copy +import json +import math +import os +from pathlib import Path +from typing import Any, List, Optional, Tuple + +import matplotlib.pyplot as plt +import pandas as pd + +## JSON parsing utils #### + + +def largest_dist_from_leaf(node: dict, depth: int = 0): + if len(node["children"]) == 0: + return depth + return max([ + largest_dist_from_leaf(child, depth=depth + 1) + for child in node["children"] + ]) + + +def get_entries_at_depth(depth: int, + entries_and_traces: List[Tuple[Any, Any]], + node: dict, + curr_depth: int = 0, + trace=()): + # assert that the query is at kernel or module level + assert depth == -1 or depth == -2 + + if curr_depth == 0 and largest_dist_from_leaf(node) <= (abs(depth) - 1): + # The tree is not tall enough! + entries_and_traces.append((node["entry"], trace)) + return + + if largest_dist_from_leaf(node) == (abs(depth) - 1): + entries_and_traces.append((node["entry"], trace)) + + trace = (node["entry"]["name"], ) + trace + for child in node["children"]: + get_entries_at_depth(depth, + entries_and_traces, + child, + curr_depth=curr_depth + 1, + trace=trace) + + +def fold_nodes(root: dict, nodes_to_fold: List[str]): + + stack: List[dict] = [root] + while len(stack) != 0: + node = stack.pop() + if node['entry']['name'] in nodes_to_fold: + node["children"] = [] + continue + for child in node["children"]: + stack.append(child) + return root + + +## Operation name cleanup utils #### + + +def trim_string_back(string: str, width: int) -> str: + if len(string) > width: + offset = len(string) - width + 3 + string = string[:-offset] + if len(string) > 3: + string = string + "..." + return string + + +def shorten_plot_legend_strings(legend, max_char_len: int): + for t in legend.get_texts(): + t.set_text( + trim_string_back(abbreviate_known_names(t.get_text()), + max_char_len)) + + +def abbreviate_known_names(name: str) -> str: + abbreviations = { + "MergedColumnParallelLinear": "MCPLinear", + "QKVParallelLinear": "QKVPLinear", + "RowParallelLinear": "RPLinear", + "weight=": "w=", + "bfloat16": "bf16", + "float16": "f16", + } + for key, value in abbreviations.items(): + name = name.replace(key, value) + return name + + +def attempt_to_make_names_unique(entries_and_traces): + names, non_unique_names = (set(), set()) + + def all_the_same(items) -> bool: + return all(i == items[0] for i in items) + + for entry, _ in entries_and_traces: + if entry["name"] in names: + non_unique_names.add(entry["name"]) + else: + names.add(entry["name"]) + + for name in non_unique_names: + entries_and_traces_with_name = [(entry, trace) + for entry, trace in entries_and_traces + if entry["name"] == name] + + zipped_traces = list( + zip(*[trace for _, trace in entries_and_traces_with_name])) + first_trace_difference = next( + (i for i, trace_eles in enumerate(zipped_traces) + if not all_the_same(trace_eles)), None) + + if first_trace_difference is None: + # can't create a unique name, leave them names as the + # are they will get aggregated by the pivot_table call + continue + + for entry, trace in entries_and_traces_with_name: + entry["name"] = " <- ".join((entry["name"], ) + + trace[:first_trace_difference + 1]) + + +## Operation grouping utils #### +''' + Group operations in the given dataframe by some high-level ops like, + - gemms + - attention + - rms_norm + etc. +''' + + +def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame: + + def is_rms_norm(op_name: str): + if "rms_norm_kernel" in op_name: + return True + + def is_attention_block(op_name: str): + if "flash_fwd" in op_name or \ + "reshape_and_cache_flash_kernel" in op_name: + return True + + def is_quant(op_name: str): + if "scaled_fp8_quant" in op_name or \ + "scaled_int8_quant" in op_name: + return True + + def is_gemm_op(op_name: str): + if is_quant(op_name): + return False + if "xmma_gemm" in op_name or \ + "gemv2T_kernel" in op_name or \ + "splitKreduce" in op_name or \ + "void cutlass::Kernel" in op_name or \ + "void cutlass::device_kernel" in op_name or \ + "s16816gemm" in op_name: + return True + + def is_elementwise_op(op_name: str): + return "elementwise_kernel" in op_name + + def is_mem_op(op_name: str): + return "memcpy" in op_name.lower() or \ + "memset" in op_name.lower() + + def is_vocab_embedding_op(op_name: str): + return "vocabparallelembed" in op_name.lower() + + # nccl ops + def is_nccl_op(op_name: str): + return "nccl" in op_name.lower() + + def is_nccl_all_reduce(op_name: str): + return is_nccl_op(op_name) and \ + ("all_reduce" in op_name.lower() or \ + "allreduce" in op_name.lower()) + + def is_nccl_gather(op_name: str): + return is_nccl_op(op_name) and \ + "gather" in op_name.lower() + + def is_nccl_broadcast(op_name: str): + return is_nccl_op(op_name) and \ + "broadcast" in op_name.lower() + + # Reduce ops types + def is_cross_device_reduce_1stage(op_name: str): + return "cross_device_reduce_1stage" in op_name + + def is_cross_device_reduce_2stage(op_name: str): + return "cross_device_reduce_2stage" in op_name + + def is_custom_ar_all_reduce_unreg(op_name: str): + return "_C_custom_ar::all_reduce_unreg" in op_name + + def is_reduce_kernel(op_name: str): + return "reduce_kernel" in op_name + + headers = list(trace_df) + ops = copy.deepcopy(headers) + + attention_ops = list(filter(lambda x: is_attention_block(x), ops)) + ops = list(filter(lambda x: x not in attention_ops, ops)) + + quant_ops = list(filter(lambda x: is_quant(x), ops)) + ops = list(filter(lambda x: x not in quant_ops, ops)) + + gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) + ops = list(filter(lambda x: x not in gemm_ops, ops)) + + rms_norm_ops = list(filter(lambda x: is_rms_norm(x), ops)) + ops = list(filter(lambda x: x not in rms_norm_ops, ops)) + + vocab_embed_ops = list(filter(lambda x: is_vocab_embedding_op(x), ops)) + ops = list(filter(lambda x: x not in vocab_embed_ops, ops)) + + mem_ops = list(filter(lambda x: is_mem_op(x), ops)) + ops = list(filter(lambda x: x not in mem_ops, ops)) + + elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops)) + ops = list(filter(lambda x: x not in elementwise_ops, ops)) + + nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops)) + ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops)) + + nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops)) + ops = list(filter(lambda x: x not in nccl_gather_ops, ops)) + + nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops)) + ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops)) + + nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops)) + ops = list(filter(lambda x: x not in nccl_other_ops, ops)) + + cross_device_reduce_1stage_ops = list( + filter(lambda x: is_cross_device_reduce_1stage(x), ops)) + ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops)) + + cross_device_reduce_2stage_ops = list( + filter(lambda x: is_cross_device_reduce_2stage(x), ops)) + ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops)) + + custom_ar_all_reduce_unreg_ops = list( + filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops)) + ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops)) + + reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops)) + ops = list(filter(lambda x: x not in reduce_kernel_ops, ops)) + + if len(attention_ops): + trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) + if len(quant_ops): + trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + if len(gemm_ops): + trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) + if len(rms_norm_ops): + trace_df['rms_norm_ops'] = trace_df[rms_norm_ops].agg("sum", axis=1) + if len(vocab_embed_ops): + trace_df['vocab_embed_ops'] = trace_df[vocab_embed_ops].agg("sum", + axis=1) + if len(mem_ops): + trace_df['mem_ops'] = trace_df[mem_ops].agg("sum", axis=1) + if len(elementwise_ops): + trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum", + axis=1) + + if len(nccl_all_reduce_ops): + trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg( + "sum", axis=1) + if len(nccl_gather_ops): + trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum", + axis=1) + if len(nccl_broadcast_ops): + trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg( + "sum", axis=1) + if len(nccl_other_ops): + trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum", + axis=1) + + if len(cross_device_reduce_1stage_ops): + trace_df['cross_device_reduce_1stage_ops'] = trace_df[ + cross_device_reduce_1stage_ops].agg("sum", axis=1) + if len(cross_device_reduce_2stage_ops): + trace_df['cross_device_reduce_2stage_ops'] = trace_df[ + cross_device_reduce_2stage_ops].agg("sum", axis=1) + if len(custom_ar_all_reduce_unreg_ops): + trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[ + custom_ar_all_reduce_unreg_ops].agg("sum", axis=1) + if len(reduce_kernel_ops): + trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", + axis=1) + + trace_df.drop( + attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + + mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops + + nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops + + cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops + + reduce_kernel_ops, + axis=1, + inplace=True) + return trace_df + + +## Data plotting utils #### + + +def plot_trace_df(traces_df: pd.DataFrame, + plot_metric: str, + plot_title: str, + output: Optional[Path] = None): + + phases = traces_df['phase'].unique() + traces_df = traces_df.pivot_table(index="phase", + columns="name", + values=plot_metric, + aggfunc="sum") + + traces_df = group_trace_by_operations(traces_df) + + # Make the figure + fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) + + # Draw the stacked bars + ops = list(traces_df) + bottom = [0] * len(phases) + for op in ops: + values = [traces_df[op][phase] for phase in phases] + values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) + ax.bar(phases, values, label=op, bottom=bottom) + bottom = [bottom[j] + values[j] for j in range(len(phases))] + + # Write the values as text on the bars + for bar in ax.patches: + if bar.get_height() != 0: + ax.text(bar.get_x() + bar.get_width() / 2, + bar.get_height() / 2 + bar.get_y(), + f"{round(bar.get_height(), 2)}", + ha='center', + color='w', + weight='bold', + size=5) + + # Setup legend + handles, labels = plt.gca().get_legend_handles_labels() + legend = fig.legend(handles, + labels, + loc='center left', + bbox_to_anchor=(1, 1)) + shorten_plot_legend_strings(legend, 50) + + # Setup labels and title + plt.setp(ax.get_xticklabels(), rotation=90) + ax.set_ylabel(plot_metric) + plt.suptitle(plot_title) + + plt.savefig(output, bbox_inches='tight') + print("Created: ", output) + + +def main( + json_trace: Path, + output_directory: Path, + depth: int, # Fetch/Plot operations at this depth of the Json tree + plot_metric: str, + make_names_unique: bool, + top_k: int, + json_nodes_to_fold: List[str]): + + def prepare_data(profile_json: dict, step_keys: List[str]) -> pd.DataFrame: + + def get_entries_and_traces(key: str): + entries_and_traces: List[Tuple[Any, Any]] = [] + for root in profile_json[key]["summary_stats"]: + # Fold nodes in the traces as per user request. i.e. simply + # make the requested nodes leaf-nodes. + root = fold_nodes(root, json_nodes_to_fold) + get_entries_at_depth(depth, entries_and_traces, root) + return entries_and_traces + + def keep_only_top_entries(df: pd.DataFrame, + metric: str, + top_k: int = 9) -> pd.DataFrame: + df.loc[df.nsmallest(len(df) - top_k + 1, metric).index, + ["name"]] = "others" + return df + + # Get data for each key + traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) + + # Attempt some cleanup + if make_names_unique: + for trace in traces: + attempt_to_make_names_unique(trace) + + # To pandas dataframe + trace_dfs = list( + map(lambda t: pd.DataFrame([entry for entry, _ in t]).fillna(0), + traces)) + + # Respect top_k + if top_k: + trace_dfs = list( + map( + lambda trace_df: keep_only_top_entries( + trace_df, "cuda_time_us", top_k), trace_dfs)) + + # Fill in information about the step-keys + for trace_df, step_key in zip(trace_dfs, step_keys): + trace_df['phase'] = step_key + + # Combine all data frames so they can be put in a single plot + traces_df = pd.concat(trace_dfs) + + # Add a derived metric `cuda_time_ms` + traces_df["cuda_time_ms"] = traces_df["cuda_time_us"] / 1000 + traces_df = traces_df.fillna(0) + + return traces_df + + def make_plot_title_suffix(profile_json: dict) -> str: + context = profile_json["context"] + sparsity = context.get('sparsity', None) + return (f"{context['model']}\n" + f"Batch={context['batch_size']}, " + f"PromptLen={context['prompt_len']}, " + f"OutputLen={context['output_len']}," + f"NumGpus={context['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}") + + profile_json = None + with open(json_trace, "r") as f: + profile_json = json.load(f) + assert profile_json is not None + + # Get all `llm.generate.step()` profile + step_traces = list(profile_json.keys()) + assert (step_traces[0] == 'context') + step_traces = step_traces[1:] # have only prefill and decodes + prefills = list(filter(lambda x: "prefill" in x, step_traces)) + all_decodes = list(filter(lambda x: "decode" in x, step_traces)) + assert len(prefills) + len(all_decodes) == len(step_traces) + assert len(prefills) == 1 + + decodes = all_decodes[::args.step_plot_interval] + if decodes[-1] != all_decodes[-1]: + # Always have the last decode + decodes.append(all_decodes[-1]) + + prefill_traces = prepare_data(profile_json, prefills) + decode_traces = prepare_data(profile_json, decodes) + + plot_title_suffix = make_plot_title_suffix(profile_json) + + plot_trace_df(prefill_traces, plot_metric, "prefill " + plot_title_suffix, + output_directory / Path("prefill.png")) + plot_trace_df(decode_traces, plot_metric, "decodes " + plot_title_suffix, + output_directory / Path("decode_steps.png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--json-trace", + type=str, + required=True, + help="json trace file output by examples/offline_profile.py") + parser.add_argument("--output-directory", + type=str, + required=False, + help="Directory to output plots") + parser.add_argument("--level", + type=str, + default="module", + choices=["module", "kernel"]) + parser.add_argument("--top-k", + type=int, + default=12, + help="Only graph the top `top_k` entries by time.") + parser.add_argument("--fold-json-node", + nargs='+', + default=['Sampler', 'LogitsProcessor'], + help='Do not plot the children of these nodes. Let, \ + the node represent the aggregate of all its \ + children') + parser.add_argument("--plot-metric", + type=str, + default="cuda_time_ms", + help='Metric to plot. some options are cuda_time_ms, \ + pct_cuda_time') + parser.add_argument( + "--step-plot-interval", + type=int, + default=4, + help="For every `step_plot_interval` steps, plot 1 step") + + args = parser.parse_args() + + # Prepare/Extract relevant args + make_names_unique = False + if args.level == "module": + depth = -2 + make_names_unique = True + elif args.level == "kernel": + depth = -1 + else: + raise Exception(f"Unexpected level value ({args.level})") + + output_directory = args.output_directory if args.output_directory else Path( + args.json_trace).parent + + if not os.path.exists(output_directory): + os.makedirs(output_directory) + + main(Path(args.json_trace), output_directory, depth, args.plot_metric, + make_names_unique, args.top_k, args.fold_json_node) diff --git a/tools/report_build_time_ninja.py b/tools/report_build_time_ninja.py index 3f9b68c2eccb..33431a33ac83 100644 --- a/tools/report_build_time_ninja.py +++ b/tools/report_build_time_ninja.py @@ -16,7 +16,6 @@ 2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time) 3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time) Longest build steps for .so (linking): - 0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time) 0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time) 0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time) 6.2 weighted s to build _C.abi3.so (6.2 s elapsed time) diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py deleted file mode 100644 index a27b8648bee4..000000000000 --- a/vllm/_core_ext.py +++ /dev/null @@ -1,278 +0,0 @@ -import importlib.util -from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union - -import torch - -from vllm.logger import init_logger - -logger = init_logger(__name__) -core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -if TYPE_CHECKING or not core_C_available: - # On platforms were we cannot use/build the C++ core extension (i.e. namely - # neuron and tpu), we define the mock ScalarType class here that partially - # mimics the C++ ScalarType class. - # - # We also use this provide type signatures to the Python LSP for the methods - # in the C++ ScalarType class. So these type signatures should be kept - # in sync with csrc/core/scalar_type.hpp - - from dataclasses import dataclass - - @dataclass(frozen=True) - class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - _finite_values_only: bool = False - """ - Private: if NANs are supported, used `has_infs()` instead. - """ - - nan_repr: int = NanRepr.IEEE_754.value - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - @property - def size_bits(self): - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - raise NotImplementedError - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - raise NotImplementedError - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - ... - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - raise NotImplementedError - - def __repr__(self) -> str: - raise NotImplementedError - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - return cls(size_bits - 1, size_bits, bias if bias else 0, True) - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - return cls(size_bits, size_bits, bias if bias else 0, False) - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - return cls(exponent, mantissa, 0, True) - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: int) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - return cls(exponent, mantissa, 0, True, finite_values_only, - nan_repr) - -elif core_C_available: - try: - import vllm._core_C # noqa: F401 - except ImportError as e: - logger.warning("Failed to import from vllm._core_C with %r", e) - - ScalarType = torch.classes._core_C.ScalarType - - if (hasattr(torch, "_library") - and hasattr(torch._library, "register_fake_class")): - # Needed for dynamo support of ScalarType. - @torch._library.register_fake_class("_core_C::ScalarType") - class FakeScalarType: - - def __init__(self, scalar_type): - self.ScalarType = scalar_type - - def bias_getter(self) -> int: - return self.ScalarType.bias - - def exponent_getter(self) -> int: - return self.ScalarType.exponent - - def mantissa_getter(self) -> int: - return self.ScalarType.mantissa - - def signed_getter(self) -> bool: - return self.ScalarType.signed - - def size_bits_getter(self) -> int: - return self.ScalarType.size_bits - - @property - def size_bits(self) -> int: - return self.ScalarType.size_bits - - def min(self) -> Union[int, float]: - return self.ScalarType.min() - - def max(self) -> Union[int, float]: - return self.ScalarType.max() - - def is_signed(self) -> bool: - return self.ScalarType.is_signed() - - def is_floating_point(self) -> bool: - return self.ScalarType.is_floating_point() - - def is_integer(self) -> bool: - return self.ScalarType.is_integer() - - def has_bias(self) -> bool: - return self.ScalarType.has_bias() - - def has_infs(self) -> bool: - return self.ScalarType.has_infs() - - def has_nans(self) -> bool: - return self.ScalarType.has_nans() - - def is_ieee_754(self) -> bool: - return self.ScalarType.is_ieee_754() - - def __str__(self) -> str: - return self.ScalarType.__str__() - - def __repr__(self) -> str: - return self.ScalarType.__repr__() - - def __len__(self) -> int: - return self.ScalarType.__len__() - - def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: - return torch.classes._core_C.ScalarType.__obj_flatten__( - self.ScalarType) - - @classmethod - def __obj_unflatten__( - cls, flat_type: Tuple[Tuple[str, Any], - ...]) -> 'ScalarType': - return cls( - torch.classes._core_C.ScalarType.__obj_unflatten__( - flat_type)) - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.int_(size_bits, bias) - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.uint(size_bits, bias) - - @classmethod - def float_IEEE754(cls, exponent: int, - mantissa: int) -> 'ScalarType': - return ScalarType.float_IEEE754(exponent, mantissa) - - @classmethod - def float_(cls, exponent: int, mantissa: int, - finite_values_only: bool, - nan_repr: int) -> 'ScalarType': - return ScalarType.float_(exponent, mantissa, - finite_values_only, nan_repr) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 3a23692285ef..60f458096c70 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -6,9 +6,9 @@ import torch.library import vllm.envs as envs -from vllm._core_ext import ScalarType from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType logger = init_logger(__name__) @@ -26,7 +26,8 @@ import vllm._moe_C # noqa: F401 supports_moe_ops = True -if TYPE_CHECKING: +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING or current_platform.is_neuron(): def register_fake(fn): return lambda name: fn @@ -78,6 +79,12 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_tanh_and_mul(out, x) +def fatrelu_and_mul(out: torch.Tensor, + x: torch.Tensor, + threshold: float = 0.0) -> None: + torch.ops._C.fatrelu_and_mul(out, x, threshold) + + def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_fast(out, x) @@ -306,7 +313,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, size_m: int, size_n: int, size_k: int) -> torch.Tensor: return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales, - workspace, b_q_type, size_m, + workspace, b_q_type.id, size_m, size_n, size_k) @@ -316,8 +323,9 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: int, - size_n: int, size_k: int) -> torch.Tensor: + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::gptq_marlin_gemm") @@ -329,17 +337,18 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, is_k_full: bool, has_zp: bool = False, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, - n: int) -> torch.Tensor: + def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, + m: torch.SymInt, + n: torch.SymInt) -> torch.Tensor: return torch.empty((m, n), dtype=torch.float16, device=W.device) @register_fake("_C::ggml_mul_mat_vec_a8") @@ -347,7 +356,7 @@ def _ggml_mul_mat_vec_a8_fake( W: torch.Tensor, X: torch.Tensor, quant_type: int, - row: int, + row: torch.SymInt, ) -> torch.Tensor: return torch.empty((1, row), dtype=torch.float16, device=W.device) @@ -356,7 +365,7 @@ def _ggml_mul_mat_a8_fake( W: torch.Tensor, X: torch.Tensor, quant_type: int, - row: int, + row: torch.SymInt, ) -> torch.Tensor: batch = X.size(0) return torch.empty((batch, row), dtype=torch.float16, device=W.device) @@ -365,8 +374,8 @@ def _ggml_mul_mat_a8_fake( def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, s_group: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, - size_k: int) -> torch.Tensor: + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) @@ -374,16 +383,16 @@ def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @register_fake("_C::marlin_gemm") def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: int, size_n: int, - size_k: int) -> torch.Tensor: + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) @register_fake("_C::awq_dequantize") def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, - zeros: torch.Tensor, split_k_iters: int, thx: int, - thy: int) -> torch.Tensor: + zeros: torch.Tensor, split_k_iters: torch.SymInt, + thx: int, thy: int) -> torch.Tensor: in_c = qweight.size(0) qout_c = qweight.size(1) out_c = qout_c * 8 @@ -394,7 +403,7 @@ def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, @register_fake("_C::awq_gemm") def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, - split_k_iters: int) -> torch.Tensor: + split_k_iters: torch.SymInt) -> torch.Tensor: num_in_feats = input.size(0) return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), dtype=input.dtype, @@ -429,8 +438,9 @@ def _aqlm_dequant_fake( @register_fake("_C::fp8_marlin_gemm") def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) @register_fake("_C::machete_gemm") @@ -457,37 +467,6 @@ def machete_prepack_B_fake(b_q_weight: torch.Tensor, return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - @register_fake("_C::causal_conv1d_fwd") - def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], - conv_states: Optional[torch.Tensor], - cu_seq_len: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.empty_like(x) - - @register_fake("_C::causal_conv1d_update") - def causal_conv1d_update_fake( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: - return torch.empty_like(x) - - @register_fake("_C::selective_scan_fwd") - def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, - A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], - delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, - cu_seq_len: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], - ssm_states: Optional[torch.Tensor]) -> None: - return None - # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: @@ -608,7 +587,7 @@ def gptq_marlin_gemm(a: torch.Tensor, has_zp: bool = False, use_fp32_reduce: bool = False) -> torch.Tensor: return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros, - g_idx, perm, workspace, b_q_type, + g_idx, perm, workspace, b_q_type.id, size_m, size_n, size_k, is_k_full, has_zp, use_fp32_reduce) @@ -624,7 +603,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # machete def machete_supported_schedules(b_type: ScalarType) -> List[str]: - return torch.ops._C.machete_supported_schedules(b_type) + return torch.ops._C.machete_supported_schedules(b_type.id) def machete_gemm( @@ -639,13 +618,13 @@ def machete_gemm( beta: Optional[float] = None, schedule: Optional[str] = None, ) -> torch.Tensor: - return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros, + return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, b_group_size, c, alpha, beta, schedule) def machete_prepack_B(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, b_type) + return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) if hasattr(torch.ops._C, "permute_cols"): @@ -800,33 +779,37 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, query_start_loc: Optional[torch.Tensor], cache_indices: Optional[torch.Tensor], has_initial_state: Optional[torch.Tensor], - silu_activation: bool) -> torch.Tensor: - return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, - query_start_loc, cache_indices, - has_initial_state, silu_activation) - - -def causal_conv1d_update( - x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, - bias_: Optional[torch.Tensor], silu_activation: bool, - cache_seqlens: Optional[torch.Tensor], - conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: - return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, - silu_activation, cache_seqlens, - conv_state_indices) - - -def selective_scan_fwd( - u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, - C: torch.Tensor, D_: Optional[torch.Tensor], - z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], - delta_softplus: bool, query_start_loc: Optional[torch.Tensor], - cache_indices: Optional[torch.Tensor], - has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): + silu_activation: bool, pad_slot_id: int): + torch.ops._C.causal_conv1d_fwd(x, weight, bias_, conv_states, + query_start_loc, cache_indices, + has_initial_state, silu_activation, + pad_slot_id) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor], + pad_slot_id: int): + torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation, cache_seqlens, + conv_state_indices, pad_slot_id) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: torch.Tensor, pad_slot_id: int): torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, delta_bias_, delta_softplus, query_start_loc, cache_indices, has_initial_state, - ssm_states) + ssm_states, pad_slot_id) # moe @@ -855,10 +838,10 @@ def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, topk_ids: torch.Tensor, b_scales: torch.Tensor, b_zero_points: torch.Tensor, g_idx: torch.Tensor, perm: torch.Tensor, workspace: torch.Tensor, - b_q_type: ScalarType, size_m: int, size_n: int, - size_k: int, is_k_full: bool, num_experts: int, - topk: int, moe_block_size: int, - replicate_input: bool, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, size_k: torch.SymInt, + is_k_full: bool, num_experts: int, topk: int, + moe_block_size: int, replicate_input: bool, apply_weights: bool) -> torch.Tensor: return torch.empty((size_m, topk, size_n), dtype=a.dtype, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2bc36ff18a96..9ea89eca01f5 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -15,8 +15,11 @@ class AttentionType(Enum): DECODER = auto() # Decoder attention between previous layer Q/K/V - ENCODER = auto() # Encoder attention between previous layer Q/K/V - ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + ENCODER = auto( + ) # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V + ENCODER_DECODER = auto( + ) # Attention between dec. Q and enc. K/V for encoder-decoder class AttentionBackend(ABC): diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8457bde066eb..ffa05e80623a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -32,7 +32,7 @@ def get_supported_head_sizes() -> List[int]: @staticmethod def get_name() -> str: - return "flash-attn" + return "FLASH_ATTN" @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: @@ -305,8 +305,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.runner = input_builder.runner self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", @@ -355,9 +353,9 @@ def _add_seq_group( # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, self.sliding_window, - self.use_v2_block_manager) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) @@ -526,8 +524,8 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - self.sliding_window = ((sliding_window, sliding_window) - if sliding_window is not None else (-1, -1)) + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype if logits_soft_cap is None: # In flash-attn, setting logits_soft_cap as 0 means no soft cap. @@ -537,12 +535,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if sliding_window is not None: - # NOTE(woosuk): flash-attn's sliding window does not work with - # paged KV cache. - raise ValueError( - "Sliding window is not supported in FlashAttention.") - support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() if head_size not in support_head_sizes: raise ValueError( @@ -706,6 +698,7 @@ def unified_flash_attention( max_seqlen_k=max_seq_len, softmax_scale=softmax_scale, causal=True, + window_size=window_size, alibi_slopes=alibi_slopes, block_table=prefill_meta.block_tables, softcap=logits_soft_cap, @@ -727,6 +720,7 @@ def unified_flash_attention( max_seqlen_k=decode_meta.max_decode_seq_len, softmax_scale=softmax_scale, causal=True, + window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, block_table=decode_meta.block_tables, @@ -741,6 +735,7 @@ def unified_flash_attention( cache_seqlens=decode_meta.seq_lens_tensor, softmax_scale=softmax_scale, causal=True, + window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, ).squeeze(1) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ba9b2d043c64..e43fb134a6a5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -17,6 +17,7 @@ import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, @@ -39,7 +40,7 @@ class FlashInferBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "flashinfer" + return "FLASHINFER" @staticmethod def get_impl_cls() -> Type["FlashInferImpl"]: @@ -124,7 +125,8 @@ def _get_decode_wrapper(self): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = num_qo_heads // num_kv_heads > 4 + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD", @@ -183,7 +185,8 @@ def graph_capture_get_metadata_for_batch( self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = num_qo_heads // num_kv_heads > 4 + use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or ( + num_qo_heads // num_kv_heads > 4) self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer, @@ -475,8 +478,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout # for the precise definition of the following fields. @@ -542,9 +543,9 @@ def _add_seq_group( is_profile_run = is_block_tables_empty(block_tables) # Compute slot mapping. - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, self.sliding_window, - self.use_v2_block_manager) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 7398732ddfc9..1eb5fe10d76d 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "ipex-attn" + return "IPEX" @staticmethod def get_impl_cls() -> Type["IpexAttnBackendImpl"]: diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py index 8b3623073038..6fddfc200212 100644 --- a/vllm/attention/backends/openvino.py +++ b/vllm/attention/backends/openvino.py @@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "openvino" + return "OPENVINO" @staticmethod def get_impl_cls(): diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 86716602985a..6fee81de1442 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -11,6 +11,10 @@ class PallasAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "PALLAS" + @staticmethod def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: return PallasAttentionBackendImpl @@ -208,35 +212,54 @@ def forward( else: # Decoding run. assert kv_cache[0].numel() > 0 - + query = query.squeeze(dim=1) pages_per_compute_block = 16 # TODO(woosuk): Tune this value. - if self.megacore_mode == "batch" and batch_size % 2 != 0: - megacore_mode = None - else: - megacore_mode = self.megacore_mode - - # NOTE(woosuk): A temporary workaround to avoid the error: - # "xla::paged_attention() Expected a value of type 'str' for - # argument 'megacore_mode' but instead found type 'NoneType'." - if megacore_mode is not None: - output = torch.ops.xla.paged_attention( - query.squeeze(dim=1), + + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, key_cache, value_cache, attn_metadata.context_lens, attn_metadata.block_tables, pages_per_compute_block, - megacore_mode=megacore_mode, + self.megacore_mode, ) else: - output = torch.ops.xla.paged_attention( - query.squeeze(dim=1), - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - pages_per_compute_block, - ) + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output # Reshape the output tensor. return output.reshape(batch_size, seq_len, hidden_size) @@ -258,3 +281,43 @@ def write_to_kv_cache( value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 3987986f1786..4116fbf00020 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "placeholder-attn" + return "NO_ATTENTION" @staticmethod def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 682eac50126a..c2aec4aaa74e 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "rocm-flash-attn" + return "ROCM_FLASH" @staticmethod def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index ef8d57661683..f985f70728a6 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -10,9 +10,9 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.utils import is_cpu +from vllm.platforms import current_platform -if is_cpu(): +if current_platform.is_cpu(): try: from vllm.attention.ops.ipex_attn import PagedAttention except ImportError: @@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend): @staticmethod def get_name() -> str: - return "torch-sdpa" + return "TORCH_SDPA" @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: @@ -234,10 +234,10 @@ def get_seq_len_block_table_args( on the type of attention operation. Decoder attn -> select entirely decoder self-attention-related fields - Encoder/decoder cross-attn -> select encoder sequence lengths & + Encoder/decoder cross-attn -> select encoder sequence lengths & cross-attn block-tables fields Encoder attn -> select encoder sequence lengths fields & no block tables - + Arguments: * attn_metadata: Attention metadata structure associated with attention diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 53e3a53badea..d1a44f3e8bfa 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -38,18 +38,12 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, - context_len: int, sliding_window: int, - use_v2_block_manager: bool): + context_len: int, sliding_window: int): """ Compute the start index of slot mapping. """ start_idx = 0 if is_prompt and sliding_window is not None: - assert use_v2_block_manager or context_len == 0, ( - "Prefix caching is currently not supported with " - "sliding window attention in V1 block manager") - # When prefill, we use it to not write slots to kv cache - # to save memory. start_idx = max(0, query_len - sliding_window) return start_idx @@ -138,8 +132,6 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.sliding_window = input_builder.sliding_window self.block_size = input_builder.block_size - self.use_v2_block_manager = ( - input_builder.scheduler_config.use_v2_block_manager) def _add_seq_group( self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", @@ -180,9 +172,9 @@ def _add_seq_group( # Compute slot mapping. is_profile_run = is_block_tables_empty(block_tables) - start_idx = compute_slot_mapping_start_idx( - is_prompt, query_len, context_len, self.sliding_window, - self.use_v2_block_manager) + start_idx = compute_slot_mapping_start_idx(is_prompt, query_len, + context_len, + self.sliding_window) compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) @@ -325,8 +317,8 @@ def graph_capture_get_metadata_for_batch( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ + assert self.runner.attn_backend.get_name() == "XFORMERS", \ + f"Expected attn_backend name to be 'XFORMERS', but "\ f" got '{self.runner.attn_backend.get_name()}'" self._update_captured_metadata_for_enc_dec_model( batch_size=batch_size, attn_metadata=attn_metadata) @@ -345,8 +337,8 @@ def get_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ + assert self.runner.attn_backend.get_name() == "XFORMERS", \ + f"Expected attn_backend name to be 'XFORMERS', but "\ f" got '{self.runner.attn_backend.get_name()}'" self._add_additonal_input_buffers_for_enc_dec_model( attn_metadata=attn_metadata, input_buffers=input_buffers) @@ -364,8 +356,8 @@ def prepare_graph_input_buffers( if is_encoder_decoder_model: # The encoder decoder model works only with XFormers backend. # Assert the same. - assert self.runner.attn_backend.get_name() == "xformers", \ - f"Expected attn_backend name to be 'xformers', but "\ + assert self.runner.attn_backend.get_name() == "XFORMERS", \ + f"Expected attn_backend name to be 'XFORMERS', but "\ f" got '{self.runner.attn_backend.get_name()}'" self._prepare_input_buffers_for_enc_dec_model( attn_metadata, input_buffers) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 25b86176f630..5aaf13d8ea74 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend): @staticmethod def get_name() -> str: - return "xformers" + return "XFORMERS" @staticmethod def get_impl_cls() -> Type["XFormersImpl"]: @@ -287,13 +287,15 @@ def _get_attn_bias( * Appropriate attention bias value given the attention type ''' - if attn_type == AttentionType.DECODER: + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): return attn_metadata.attn_bias elif attn_type == AttentionType.ENCODER: return attn_metadata.encoder_attn_bias - else: - # attn_type == AttentionType.ENCODER_DECODER + elif attn_type == AttentionType.ENCODER_DECODER: return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") def _set_attn_bias( @@ -313,7 +315,8 @@ def _set_attn_bias( encoder/decoder cross-attention ''' - if attn_type == AttentionType.DECODER: + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): attn_metadata.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: attn_metadata.encoder_attn_bias = attn_bias @@ -371,6 +374,12 @@ def _get_seq_len_block_table_args( # No block tables associated with encoder attention return (attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, None) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + + # No block tables associated with encoder attention + return (attn_metadata.seq_lens_tensor, + attn_metadata.max_prefill_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") @@ -479,7 +488,10 @@ def forward( * ENCODER: no KV caching; pass encoder sequence attributes (encoder_seq_lens/encoder_seq_lens_tensor/ max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). + Used for encoder branch of encoder-decoder models. + * ENCODER_ONLY: no kv_caching, uses the normal attention + attributes (seq_lens/seq_lens_tensor/max_seq_len). * ENCODER_DECODER: cross-attention behavior; use cross-attention block table for caching KVs derived from encoder hidden states; since KV sequence lengths @@ -509,6 +521,7 @@ def forward( and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER and (not attn_metadata.is_all_cross_attn_metadata_set)): raise AttributeError("Encoder/decoder cross-attention " @@ -609,6 +622,8 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have prefix attention.") assert prefill_meta.query_start_loc is not None assert prefill_meta.max_query_len is not None @@ -638,6 +653,8 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") ( seq_lens_arg, @@ -703,36 +720,60 @@ def _run_memory_efficient_xformers_forward( None, :].expand(value.shape[0], self.num_kv_heads, self.num_queries_per_kv, value.shape[-1]) + # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: + + # Cross attention block of decoder branch of encoder-decoder + # model uses seq_lens for dec / encoder_seq_lens for enc if (attn_type == AttentionType.ENCODER_DECODER): assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens is not None - # Default enc/dec cross-attention mask is non-causal + # Cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + + # Encoder branch of encoder-decoder model uses + # attn_metadata.encoder_seq_lens elif attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None - # Default encoder self-attention mask is non-causal + # Encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.encoder_seq_lens) - else: + + # Self-attention block of encoder-only model just + # uses the seq_lens directly. + elif attn_type == AttentionType.ENCODER_ONLY: assert attn_metadata.seq_lens is not None - # Default decoder self-attention mask is causal + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens) + + # Self-attention block of decoder branch just + # uses the seq_lens directly + elif attn_type == AttentionType.DECODER: + assert attn_metadata.seq_lens is not None + + # Decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) + else: + raise ValueError("Unknown AttentionType: %s", attn_type) + if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) attn_bias = [attn_bias] else: + assert attn_type == AttentionType.DECODER assert attn_metadata.seq_lens is not None attn_bias = _make_alibi_bias(self.alibi_slopes, self.num_kv_heads, query.dtype, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0112f4987699..33d05cbd3fe0 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -78,10 +78,9 @@ def __init__( # During model initialization, the default dtype is set as the model # weight and activation dtype. dtype = torch.get_default_dtype() - attn_backend = get_attn_backend(head_size, sliding_window, dtype, - kv_cache_dtype, block_size, - is_attention_free, blocksparse_params - is not None) + attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype, + block_size, is_attention_free, + blocksparse_params is not None) impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, @@ -92,7 +91,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: Optional[torch.Tensor], + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, attn_type: AttentionType = AttentionType.DECODER, ) -> torch.Tensor: diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py index 1ead541f391b..e4dc576d2793 100644 --- a/vllm/attention/ops/blocksparse_attention/interface.py +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -3,7 +3,7 @@ import torch from vllm.platforms import current_platform -from vllm.utils import is_cpu, is_hip +from vllm.utils import is_hip from .utils import (dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask) @@ -32,7 +32,7 @@ def __init__( ): super().__init__() if use_spda is None: - use_spda = is_hip() or is_cpu() or not \ + use_spda = is_hip() or current_platform.is_cpu() or not \ IS_COMPUTE_8_OR_ABOVE device = device or (torch.cuda.current_device() if current_platform.is_cuda_alike() else "cpu") @@ -109,13 +109,13 @@ def varlen_attn(self, q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). Support grouped attention, with `q[:, i*r:(i*r + r)]` is correspondent to `k[:, i]`, where `r` is the q/k ratio. - cu_seqlens_k: shape=(batch_size + 1,), - indicating segment of samples, + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i cu_seqlens_q: shape=(batch_size + 1, ). Default None: same as cu_seqlens_k for prefilling or [0, 1, .., batch_size] for decoding. - The only case you need to specify is when q is a mix of + The only case you need to specify is when q is a mix of prefilling and decoding. sm_scale: softmax scale, default to 1/sqrt(head_size). @@ -171,7 +171,7 @@ def transpose_and_unpad(x_padded, cu_seqlens): def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): """For CPU, V100 or other older GPUs. - NOTE: torch SPDA supports nested tensor, + NOTE: torch SPDA supports nested tensor, but seems extremely slow. Choose to pad instead. """ assert (cu_seqlens_q is None or @@ -201,8 +201,8 @@ def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): return self.transpose_and_unpad(spda_output, cu_seqlens) def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): - """Dispatch to `varlen_attn` (Ampere or newer) or - `self.spda`(cpu, Volta, Turing or older)based on + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on the type of device used and cuda compute capability. q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). @@ -213,8 +213,8 @@ def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): cu_seqlens_q: shape=(batch_size + 1, ). Default None: same as cu_seqlens_k for prefilling or [0, 1, .., batch_size] for decoding. - The only case you need to specify - is when q is a mix of prefilling + The only case you need to specify + is when q is a mix of prefilling and decoding. sm_scale: softmax scale, default to 1/sqrt(head_size). diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 7edb7676ea2c..cd3c642b8c8a 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -10,13 +10,14 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu +from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino logger = init_logger(__name__) class _Backend(enum.Enum): FLASH_ATTN = enum.auto() + FLASH_ATTN_VLLM_V1 = enum.auto() XFORMERS = enum.auto() ROCM_FLASH = enum.auto() TORCH_SDPA = enum.auto() @@ -90,7 +91,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: @lru_cache(maxsize=None) def get_attn_backend( head_size: int, - sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, @@ -105,12 +105,16 @@ def get_attn_backend( BlocksparseFlashAttentionBackend) return BlocksparseFlashAttentionBackend - backend = which_attn_to_use(head_size, sliding_window, dtype, - kv_cache_dtype, block_size, is_attention_free) + backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, + is_attention_free) if backend == _Backend.FLASH_ATTN: from vllm.attention.backends.flash_attn import ( # noqa: F401 FlashAttentionBackend) return FlashAttentionBackend + if backend == _Backend.FLASH_ATTN_VLLM_V1: + from vllm.v1.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend as FlashAttentionBackendV1) + return FlashAttentionBackendV1 if backend == _Backend.XFORMERS: logger.info("Using XFormers backend.") from vllm.attention.backends.xformers import ( # noqa: F401 @@ -122,7 +126,7 @@ def get_attn_backend( ROCmFlashAttentionBackend) return ROCmFlashAttentionBackend elif backend == _Backend.TORCH_SDPA: - assert is_cpu(), RuntimeError( + assert current_platform.is_cpu(), RuntimeError( "Torch SDPA backend is only used for the CPU device.") logger.info("Using Torch SDPA backend.") from vllm.attention.backends.torch_sdpa import TorchSDPABackend @@ -132,7 +136,7 @@ def get_attn_backend( from vllm.attention.backends.openvino import OpenVINOAttentionBackend return OpenVINOAttentionBackend elif backend == _Backend.IPEX: - assert is_xpu(), RuntimeError( + assert current_platform.is_xpu(), RuntimeError( "IPEX attention backend is only used for the XPU device.") logger.info("Using IPEX attention backend.") from vllm.attention.backends.ipex_attn import IpexAttnBackend @@ -155,7 +159,6 @@ def get_attn_backend( def which_attn_to_use( head_size: int, - sliding_window: Optional[int], dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, @@ -185,7 +188,7 @@ def which_attn_to_use( if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) - if is_cpu(): + if current_platform.is_cpu(): if selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) return _Backend.TORCH_SDPA @@ -195,7 +198,7 @@ def which_attn_to_use( logger.info("Cannot use %s backend on OpenVINO.", selected_backend) return _Backend.OPENVINO - if is_xpu(): + if current_platform.is_xpu(): if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) return _Backend.IPEX @@ -217,6 +220,9 @@ def which_attn_to_use( logger.info("%s is not supported in AMD GPUs.", selected_backend) return _Backend.ROCM_FLASH + if envs.VLLM_USE_V1: + return _Backend.FLASH_ATTN_VLLM_V1 + # FlashAttn in NVIDIA GPUs. if selected_backend == _Backend.FLASH_ATTN: if not current_platform.has_device_capability(80): @@ -243,10 +249,6 @@ def which_attn_to_use( "Cannot use FlashAttention-2 backend for block size not " "divisible by 16.") selected_backend = _Backend.XFORMERS - elif sliding_window is not None: - logger.info( - "Cannot use FlashAttention-2 backend due to sliding window.") - selected_backend = _Backend.XFORMERS # FlashAttn is valid for the model, checking if the package is installed. if selected_backend == _Backend.FLASH_ATTN: diff --git a/vllm/beam_search.py b/vllm/beam_search.py index 04624b8b9443..1b48538734da 100644 --- a/vllm/beam_search.py +++ b/vllm/beam_search.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional + +from vllm.sequence import Logprob @dataclass @@ -11,6 +13,7 @@ class BeamSearchSequence: """ # The tokens includes the prompt. tokens: List[int] + logprobs: List[Dict[int, Logprob]] cum_logprob: float = 0.0 text: Optional[str] = None @@ -28,7 +31,7 @@ class BeamSearchInstance: def __init__(self, prompt_tokens: List[int]): self.beams: List[BeamSearchSequence] = [ - BeamSearchSequence(tokens=prompt_tokens) + BeamSearchSequence(tokens=prompt_tokens, logprobs=[]) ] self.completed: List[BeamSearchSequence] = [] diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4780358cea51..6d9832e2c39c 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -244,8 +244,8 @@ def compiled_graph_wrapper(*args): def select_default_backend(level: int) -> Union[str, Callable]: if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: - backend = "eager" - return backend + backend_str = "eager" + return backend_str assert level in [ CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE ], f"Invalid level {level}" diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 655c4c443017..0449f9354d0a 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -1,24 +1,58 @@ import inspect -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import torch import vllm.envs as envs from vllm.compilation.levels import CompilationLevel from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from vllm.utils import supports_dynamo +logger = init_logger(__name__) -def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]): + +def support_torch_compile( + cls: Optional[type] = None, + dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None): """ A decorator to add support for compiling the forward method of a class. + Usage 1: use directly as a decorator without arguments: + + ```python + @support_torch_compile + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + + Usage 2: use as a decorator with arguments: + + ```python + @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0}) + class MyModel(nn.Module): + def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): + ... + ``` + `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic dimensions of the argument. The dynamic dimensions can be either a single integer or a list of integers. - Depending on the value of arguments: + if `dynamic_arg_dims` is `None`, it is inferred from the type annotation + of the `forward` method, based on the following default rules: + + - if the argument is annotated as `torch.Tensor` or + `Optional[torch.Tensor]`, the first dimension will be + marked as dynamic. + - if the argument is annotated as `IntermediateTensors`, the first + dimension of all the tensors in the intermediate tensors + will be marked as dynamic. + + During runtime, when we actually mark dimensions of tensors, + it depends on the value of arguments: - if it is a single integer, the corresponding dimension of the argument will be marked as dynamic. @@ -35,12 +69,38 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]): def cls_decorator_helper(cls: type): # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` # to avoid too much indentation for `_support_torch_compile`` + if not hasattr(cls, 'forward'): + raise TypeError("decorated class should have a forward method.") sig = inspect.signature(cls.forward) - for k in dynamic_arg_dims: + inferred_dynamic_arg_dims = dynamic_arg_dims + if inferred_dynamic_arg_dims is None: + inferred_dynamic_arg_dims = {} + for k, v in sig.parameters.items(): + if v.annotation in [ + torch.Tensor, Optional[torch.Tensor], + IntermediateTensors, Optional[IntermediateTensors] + ]: + inferred_dynamic_arg_dims[k] = 0 + + logger.debug(("Inferred dynamic dimensions for " + "forward method of %s: %s"), cls, + list(inferred_dynamic_arg_dims.keys())) + + if len(inferred_dynamic_arg_dims) == 0: + raise ValueError( + "No dynamic dimensions found in the forward method of " + f"{cls}. Please provide dynamic_arg_dims explicitly.") + + for k in inferred_dynamic_arg_dims: if k not in sig.parameters: raise ValueError( f"Argument {k} not found in the forward method of {cls}") - return _support_torch_compile(cls, dynamic_arg_dims) + return _support_torch_compile(cls, inferred_dynamic_arg_dims) + + if cls is not None: + # use `support_torch_compile` as a decorator without arguments + assert isinstance(cls, type) + return cls_decorator_helper(cls) return cls_decorator_helper @@ -63,13 +123,13 @@ def _support_torch_compile(cls: type, # other than TorchCompileWrapperWithCustomDispatcher cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) - old_init = cls.__init__ + old_init = cls.__init__ # type: ignore def __init__(self, *args, **kwargs): old_init(self, *args, **kwargs) TorchCompileWrapperWithCustomDispatcher.__init__(self) - cls.__init__ = __init__ + cls.__init__ = __init__ # type: ignore def __call__(self, *args, **kwargs): # torch.compiler.is_compiling() means we are inside the compilation @@ -109,5 +169,5 @@ def __call__(self, *args, **kwargs): model_output = self.forward(*args, **kwargs) return model_output - cls.__call__ = __call__ + cls.__call__ = __call__ # type: ignore return cls diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 1594b64a61b9..7366ed4d16b0 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -73,7 +73,7 @@ def bytecode_hook(self, old_code: CodeType, new_code: CodeType): return # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 frame = sys._getframe() - while True: + while frame and frame.f_back: frame = frame.f_back code_name = frame.f_code.co_name file_name = frame.f_code.co_filename.split(os.path.sep)[-1] diff --git a/vllm/config.py b/vllm/config.py index 7a3248f4087a..ea6e45d09e20 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,8 +1,8 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, - Optional, Tuple, Type, Union) +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal, + Mapping, Optional, Set, Tuple, Type, Union) import torch from transformers import PretrainedConfig @@ -17,8 +17,7 @@ get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - is_hip, is_neuron, is_openvino, is_xpu, - print_warning_once) + is_hip, is_openvino, print_warning_once) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -33,6 +32,11 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 +TaskOption = Literal["auto", "generate", "embedding"] + +# "draft" is only used internally for speculative decoding +_Task = Literal["generate", "embedding", "draft"] + class ModelConfig: """Configuration for the model. @@ -40,7 +44,11 @@ class ModelConfig: Args: model: Name or path of the huggingface model to use. It is also used as the content for `model_name` tag in metrics - output when `served_model_name` is not specified. + output when `served_model_name` is not specified. + task: The task to use the model for. Each vLLM instance only supports + one task, even if the same model can be used for multiple tasks. + When the model only supports one task, "auto" can be used to select + it; otherwise, you must specify explicitly which task to use. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if available, "slow" will always use the slow tokenizer, and @@ -108,6 +116,7 @@ class ModelConfig: def __init__(self, model: str, + task: Union[TaskOption, _Task], tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, @@ -133,6 +142,7 @@ def __init__(self, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, config_format: ConfigFormat = ConfigFormat.AUTO, + chat_template_text_format: str = "string", mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer @@ -167,20 +177,27 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc + self.chat_template_text_format = chat_template_text_format self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: self.enforce_eager = False - if (not self.disable_sliding_window - and self.hf_text_config.model_type == "gemma2" - and self.hf_text_config.sliding_window is not None): + sliding_window = getattr(self.hf_text_config, "sliding_window", None) + has_interleaved_attention = (sliding_window is not None) and ( + isinstance(sliding_window, list) or + (self.hf_text_config.model_type in ["gemma2"])) + + if (not self.disable_sliding_window and has_interleaved_attention): + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + print_warning_once( - "Gemma 2 uses sliding window attention for every odd layer, " + f"{self.hf_text_config.model_type} has interleaved attention, " "which is currently not supported by vLLM. Disabling sliding " "window and capping the max length to the sliding window size " - f"({self.hf_text_config.sliding_window}).") + f"({sliding_window_len_min}).") self.disable_sliding_window = True self.max_model_len = _get_and_verify_max_len( @@ -199,9 +216,15 @@ def __init__(self, self.is_attention_free = self._init_attention_free() self.has_inner_state = self._init_has_inner_state() - self.override_neuron_config = override_neuron_config if is_neuron( - ) else None - self._verify_embedding_mode() + if current_platform.is_neuron(): + self.override_neuron_config = override_neuron_config + else: + self.override_neuron_config = None + + supported_tasks, task = self._resolve_task(task, self.hf_config) + self.supported_tasks = supported_tasks + self.task: Final = task + self._verify_quantization() self._verify_cuda_graph() self._verify_bnb_config() @@ -235,9 +258,44 @@ def _verify_tokenizer_mode(self) -> None: "either 'auto', 'slow' or 'mistral'.") self.tokenizer_mode = tokenizer_mode - def _verify_embedding_mode(self) -> None: - architectures = getattr(self.hf_config, "architectures", []) - self.embedding_mode = ModelRegistry.is_embedding_model(architectures) + def _resolve_task( + self, + task_option: Union[TaskOption, _Task], + hf_config: PretrainedConfig, + ) -> Tuple[Set[_Task], _Task]: + if task_option == "draft": + return {"draft"}, "draft" + + architectures = getattr(hf_config, "architectures", []) + + task_support: Dict[_Task, bool] = { + # NOTE: Listed from highest to lowest priority, + # in case the model supports multiple of them + "generate": ModelRegistry.is_text_generation_model(architectures), + "embedding": ModelRegistry.is_embedding_model(architectures), + } + supported_tasks_lst: List[_Task] = [ + task for task, is_supported in task_support.items() if is_supported + ] + supported_tasks = set(supported_tasks_lst) + + if task_option == "auto": + selected_task = next(iter(supported_tasks_lst)) + + if len(supported_tasks) > 1: + logger.info( + "This model supports multiple tasks: %s. " + "Defaulting to '%s'.", supported_tasks, selected_task) + else: + if task_option not in supported_tasks: + msg = ( + f"This model does not support the '{task_option}' task. " + f"Supported tasks: {supported_tasks}") + raise ValueError(msg) + + selected_task = task_option + + return supported_tasks, selected_task def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -313,7 +371,7 @@ def _verify_quantization(self) -> None: "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" " is not set, enabling VLLM_USE_TRITON_AWQ.") envs.VLLM_USE_TRITON_AWQ = True - if is_neuron( + if current_platform.is_neuron( ) and self.quantization not in neuron_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not " @@ -386,7 +444,7 @@ def verify_async_output_proc(self, parallel_config, speculative_config, # Async postprocessor is not necessary with embedding mode # since there is no token generation - if self.embedding_mode: + if self.task == "embedding": self.use_async_output_proc = False # Reminder: Please update docs/source/serving/compatibility_matrix.rst @@ -422,7 +480,8 @@ def verify_with_parallel_config( "pipeline parallelism currently. Disabling it.") self.use_async_output_proc = False - def get_hf_config_sliding_window(self) -> Optional[int]: + def get_hf_config_sliding_window( + self) -> Union[Optional[int], List[Optional[int]]]: """Get the sliding window size, or None if disabled.""" # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in @@ -433,7 +492,7 @@ def get_hf_config_sliding_window(self) -> Optional[int]: return None return getattr(self.hf_text_config, "sliding_window", None) - def get_sliding_window(self) -> Optional[int]: + def get_sliding_window(self) -> Optional[Union[int, List[Optional[int]]]]: """Get the sliding window size, or None if disabled. """ # If user disables sliding window, return None. @@ -566,11 +625,6 @@ def is_encoder_decoder_model(self) -> bool: (hasattr(self.hf_config, "text_config") and getattr( self.hf_config.text_config, "is_encoder_decoder", False))) - @property - def is_embedding_model(self) -> bool: - """Extract the embedding model flag.""" - return self.embedding_mode - @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None @@ -610,13 +664,14 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb + self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() # Will be set after profiling. - self.num_gpu_blocks = None - self.num_cpu_blocks = None + self.num_gpu_blocks: Optional[int] = None + self.num_cpu_blocks: Optional[int] = None def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus @@ -693,7 +748,8 @@ def __post_init__(self): @classmethod def create_config( - cls, tokenizer_pool_size: int, tokenizer_pool_type: str, + cls, tokenizer_pool_size: int, + tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]], tokenizer_pool_extra_config: Optional[Union[str, dict]] ) -> Optional["TokenizerPoolConfig"]: """Create a TokenizerPoolConfig from the given parameters. @@ -925,13 +981,13 @@ class SchedulerConfig: """Scheduler configuration. Args: + task: The task to use the model for. max_num_batched_tokens: Maximum number of tokens to be processed in a single iteration. max_num_seqs: Maximum number of sequences to be processed in a single iteration. max_model_len: Maximum length of a sequence (including prompt and generated text). - use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. num_lookahead_slots: The number of slots to allocate per sequence per step, beyond the known token ids. This is used in speculative decoding to store KV activations of tokens which may or may not be @@ -940,7 +996,6 @@ class SchedulerConfig: prompt latency) before scheduling next prompt. enable_chunked_prefill: If True, prefill requests can be chunked based on the remaining max_num_batched_tokens. - embedding_mode: Whether the running model is for embedding. preemption_mode: Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than @@ -955,14 +1010,13 @@ class SchedulerConfig: """ def __init__(self, + task: _Task, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - use_v2_block_manager: bool = True, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, - embedding_mode: bool = False, is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, @@ -986,7 +1040,7 @@ def __init__(self, # for higher throughput. max_num_batched_tokens = max(max_model_len, 2048) - if embedding_mode: + if task == "embedding": # For embedding, choose specific value for higher throughput max_num_batched_tokens = max( max_num_batched_tokens, @@ -1006,13 +1060,12 @@ def __init__(self, "Chunked prefill is enabled with max_num_batched_tokens=%d.", self.max_num_batched_tokens) + self.task: Final = task self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len - self.use_v2_block_manager = use_v2_block_manager self.num_lookahead_slots = num_lookahead_slots self.delay_factor = delay_factor self.chunked_prefill_enabled = enable_chunked_prefill - self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps self.multi_step_stream_outputs = multi_step_stream_outputs @@ -1049,18 +1102,6 @@ def _verify_args(self) -> None: f"({self.num_scheduler_steps}) must be greater than or " "equal to 1.") - if (not self.use_v2_block_manager \ - and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1): - raise ValueError( - "The use of BlockSpaceManagerV1 is deprecated and will " - "be removed in a future release. Please switch to " - "BlockSpaceManagerV2 by setting --use-v2-block-manager to " - "True. If you wish to suppress this error temporarily, " - "you can set the environment variable " - "`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. If your use " - "case is not supported in BlockSpaceManagerV2, please " - "file an issue with detailed information.") - @property def is_multi_step(self) -> bool: return self.num_scheduler_steps > 1 @@ -1074,7 +1115,7 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if current_platform.is_cuda_alike(): self.device_type = "cuda" - elif is_neuron(): + elif current_platform.is_neuron(): self.device_type = "neuron" elif is_openvino(): self.device_type = "openvino" @@ -1082,7 +1123,7 @@ def __init__(self, device: str = "auto") -> None: self.device_type = "tpu" elif current_platform.is_cpu(): self.device_type = "cpu" - elif is_xpu(): + elif current_platform.is_xpu(): self.device_type = "xpu" else: raise RuntimeError("Failed to infer device type") @@ -1119,7 +1160,6 @@ def maybe_create_spec_config( speculative_disable_mqa_scorer: Optional[bool], speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, - use_v2_block_manager: bool, disable_log_stats: bool, speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], @@ -1160,9 +1200,6 @@ def maybe_create_spec_config( enable_chunked_prefill (bool): Whether vLLM is configured to use chunked prefill or not. Used for raising an error since its not yet compatible with spec decode. - use_v2_block_manager (bool): Whether vLLM is configured to use the - v2 block manager or not. Used for raising an error since the v2 - block manager is required with spec decode. speculative_disable_by_batch_size (Optional[int]): Disable speculative decoding for new incoming requests when the number of enqueue requests is larger than this value, if provided. @@ -1213,11 +1250,6 @@ def maybe_create_spec_config( "Speculative decoding and chunked prefill are " f"currently mutually exclusive ({enable_chunked_prefill=}).") - if not use_v2_block_manager: - raise ValueError( - "Speculative decoding requires usage of the V2 " - "block manager. Enable it with --use-v2-block-manager.") - # TODO: The user should be able to specify revision/max model len # for the draft model. It is not currently supported. draft_revision = None @@ -1245,6 +1277,7 @@ def maybe_create_spec_config( ngram_prompt_lookup_min = 0 draft_model_config = ModelConfig( model=speculative_model, + task="draft", tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code, @@ -1378,11 +1411,11 @@ def create_draft_parallel_config( else: speculative_draft_tensor_parallel_size = \ target_parallel_config.tensor_parallel_size - elif speculative_draft_tensor_parallel_size != 1: - # TODO(wooyeon): allow tp values larger than 1 + elif speculative_draft_tensor_parallel_size not in ( + 1, target_parallel_config.tensor_parallel_size): raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " - f"other value than 1") + f"other value than 1 or target model tensor_parallel_size") draft_parallel_config = ParallelConfig( pipeline_parallel_size=target_parallel_config. @@ -1528,11 +1561,12 @@ class LoRAConfig: max_loras: int fully_sharded_loras: bool = False max_cpu_loras: Optional[int] = None - lora_dtype: Optional[torch.dtype] = None + lora_dtype: Optional[Union[torch.dtype, str]] = None lora_extra_vocab_size: int = 256 # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None + bias_enabled: bool = False def __post_init__(self): # Setting the maximum rank to 256 should be able to satisfy the vast @@ -1680,7 +1714,7 @@ def _get_and_verify_max_len( hf_config: PretrainedConfig, max_model_len: Optional[int], disable_sliding_window: bool, - sliding_window_len: Optional[int], + sliding_window_len: Optional[Union[int, List[Optional[int]]]], spec_target_max_model_len: Optional[int] = None, ) -> int: """Get and verify the model's maximum length.""" @@ -1713,9 +1747,12 @@ def _get_and_verify_max_len( # If sliding window is manually disabled, max_length should be less # than the sliding window length in the model config. if disable_sliding_window and sliding_window_len is not None: + + sliding_window_len_min = get_min_sliding_window(sliding_window_len) max_len_key = "sliding_window" \ - if sliding_window_len < derived_max_model_len else max_len_key - derived_max_model_len = min(derived_max_model_len, sliding_window_len) + if sliding_window_len_min < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, + sliding_window_len_min) # If none of the keys were found in the config, use a default and # log a warning. @@ -1739,16 +1776,10 @@ def _get_and_verify_max_len( rope_scaling = getattr(hf_config, "rope_scaling", None) if rope_scaling is not None: - if "type" in rope_scaling: - rope_type = rope_scaling["type"] - elif "rope_type" in rope_scaling: - rope_type = rope_scaling["rope_type"] - else: - raise ValueError( - "rope_scaling must have a 'type' or 'rope_type' key.") + # No need to consider "type" key because of patch_rope_scaling when + # loading HF config + rope_type = rope_scaling["rope_type"] - # The correct one should be "longrope", kept "su" here - # to be backward compatible if rope_type not in ("su", "longrope", "llama3"): if disable_sliding_window: # TODO(robertgshaw): Find a model that supports rope_scaling @@ -1758,11 +1789,10 @@ def _get_and_verify_max_len( "with rope_scaling. Please raise an issue so we can " "investigate.") - if rope_type == "mrope": - scaling_factor = 1 - else: - assert "factor" in rope_scaling - scaling_factor = rope_scaling["factor"] + # NOTE: rope_type == "default" does not define factor + # https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/modeling_rope_utils.py + scaling_factor = rope_scaling.get("factor", 1.0) + if rope_type == "yarn": derived_max_model_len = rope_scaling[ "original_max_position_embeddings"] @@ -1803,6 +1833,14 @@ def _get_and_verify_max_len( return int(max_model_len) +def get_min_sliding_window( + sliding_window: Union[int, List[Optional[int]]]) -> int: + if isinstance(sliding_window, list): + return min(s for s in sliding_window if s is not None) + + return sliding_window + + def get_served_model_name(model: str, served_model_name: Optional[Union[str, List[str]]]): """ diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 7c8a2bc49351..57527e39b9bd 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -7,7 +7,7 @@ from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) -from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor +from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor PrefixHash = int diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py index 28839437c33c..1c6578e4cc6a 100644 --- a/vllm/core/block/utils.py +++ b/vllm/core/block/utils.py @@ -4,28 +4,6 @@ STR_NOT_IMPL_ENC_DEC_SWA) -def _get_block_mgr_sliding_window_attr(block_mgr): - ''' - BlockManagerV1 and BlockManagerV2 have slightly different - members related to sliding window attention (SWA). This - function extracts the appropriate member to use for determining - whether SWA is enabled. - - Arguments: - - * block_mgr: BlockManagerV1 or BlockManagerV2 instance - ''' - - if hasattr(block_mgr, 'block_sliding_window'): - return block_mgr.block_sliding_window - if hasattr(block_mgr, 'max_block_sliding_window'): - return block_mgr.max_block_sliding_window - - raise AttributeError("Block manager instance has neither " + \ - "block_sliding_window nor " + \ - "max_block_sliding_window attributes.") - - def check_no_caching_or_swa_for_blockmgr_encdec( block_mgr, seq_group: SequenceGroup) -> None: ''' @@ -41,7 +19,7 @@ def check_no_caching_or_swa_for_blockmgr_encdec( ''' if seq_group.is_encoder_decoder(): - if _get_block_mgr_sliding_window_attr(block_mgr) is not None: + if block_mgr.max_block_sliding_window is not None: raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) if block_mgr.enable_caching: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager.py similarity index 99% rename from vllm/core/block_manager_v2.py rename to vllm/core/block_manager.py index cb047c832e6c..61ed7afba12e 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager.py @@ -17,7 +17,7 @@ EncoderSeqId = str -class BlockSpaceManagerV2(BlockSpaceManager): +class SelfAttnBlockSpaceManager(BlockSpaceManager): """BlockSpaceManager which manages the allocation of KV cache. It owns responsibility for allocation, swapping, allocating memory for diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py deleted file mode 100644 index 8bc0ce2bc662..000000000000 --- a/vllm/core/block_manager_v1.py +++ /dev/null @@ -1,743 +0,0 @@ -"""A block manager that manages token blocks.""" -import math -from abc import ABC, abstractmethod -from itertools import count, takewhile -from os.path import commonprefix -from typing import Dict, List, Optional -from typing import Sequence as GenericSequence -from typing import Set, Tuple - -from vllm.block import BlockTable, PhysicalTokenBlock -from vllm.core.block.common import CacheMetricData -from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor -from vllm.core.interfaces import AllocStatus, BlockSpaceManager -from vllm.logger import init_logger -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus -from vllm.utils import Device - -logger = init_logger(__name__) - - -class BlockAllocatorBase(ABC): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - @abstractmethod - def __init__(self, - device: Device, - block_size: int, - num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU): - pass - - @abstractmethod - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - pass - - @abstractmethod - def free(self, block: PhysicalTokenBlock) -> None: - pass - - @abstractmethod - def get_num_free_blocks(self) -> int: - pass - - @abstractmethod - def get_num_total_blocks(self) -> int: - pass - - @abstractmethod - def contains_block(self, block_hash: int) -> bool: - pass - - @abstractmethod - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - pass - - @abstractmethod - def get_prefix_cache_hit_rate(self) -> float: - """Prefix cache hit rate. -1 means not supported or disabled.""" - pass - - -class CachedBlockAllocator(BlockAllocatorBase): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - def __init__(self, - device: Device, - block_size: int, - num_blocks: int, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None: - self.device = device - self.block_size = block_size - self.num_blocks = num_blocks - - self.current_num_blocks = 0 - self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} - - self.evictor: Evictor = make_evictor(eviction_policy) - - self.default_hash_ctr = count() - - self.cache_metric_data = CacheMetricData() - - def allocate_block(self, block_hash: int, - num_hashed_tokens: int) -> PhysicalTokenBlock: - if self.current_num_blocks == self.num_blocks: - block = self.evictor.evict() - block.block_hash = block_hash - block.num_hashed_tokens = num_hashed_tokens - return block - block = PhysicalTokenBlock(device=self.device, - block_number=self.current_num_blocks, - block_size=self.block_size, - block_hash=block_hash, - num_hashed_tokens=num_hashed_tokens) - self.current_num_blocks += 1 - return block - - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - if block_hash is None: - block_hash = next(self.default_hash_ctr) - - if block_hash in self.evictor: - assert block_hash not in self.cached_blocks - block = self.evictor.remove(block_hash) - assert block.ref_count == 0 - self.cached_blocks[block_hash] = block - - if block_hash in self.cached_blocks: - self.cache_metric_data.query(hit=True) - else: - self.cache_metric_data.query(hit=False) - self.cached_blocks[block_hash] = self.allocate_block( - block_hash, num_hashed_tokens) - block = self.cached_blocks[block_hash] - assert block.block_hash == block_hash - block.ref_count += 1 - return block - - def free(self, block: PhysicalTokenBlock) -> None: - if block.ref_count == 0: - raise ValueError(f"Double free! {block} is already freed.") - block.ref_count -= 1 - if block.ref_count == 0: - assert block.block_hash not in self.evictor - self.evictor.add(block) - - # Remove the block from the cached_blocks - del self.cached_blocks[block.block_hash] - - def get_num_free_blocks(self) -> int: - return (self.num_blocks - self.current_num_blocks + - self.evictor.num_blocks) - - def get_num_total_blocks(self) -> int: - return self.num_blocks - - def contains_block(self, block_hash: int) -> bool: - return block_hash in self.cached_blocks or block_hash in self.evictor - - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - # Update the hash of block and the cached_blocks dictionary. - assert not self.contains_block(block_hash) - old_hash = block.block_hash - block.block_hash = block_hash - del self.cached_blocks[old_hash] - self.cached_blocks[block_hash] = block - - def get_prefix_cache_hit_rate(self) -> float: - return self.cache_metric_data.get_hit_rate() - - -class UncachedBlockAllocator(BlockAllocatorBase): - """Manages free physical token blocks for a device. - - The allocator maintains a list of free blocks and allocates a block when - requested. When a block is freed, its reference count is decremented. If - the reference count becomes zero, the block is added back to the free list. - """ - - def __init__( - self, - device: Device, - block_size: int, - num_blocks: int, - ) -> None: - self.device = device - self.block_size = block_size - self.num_blocks = num_blocks - - # Initialize the free blocks. - self.free_blocks: List[PhysicalTokenBlock] = [] - for i in range(num_blocks): - block = PhysicalTokenBlock(device=device, - block_number=i, - block_size=block_size, - block_hash=-1, - num_hashed_tokens=0) - self.free_blocks.append(block) - - def allocate(self, - block_hash: Optional[int] = None, - num_hashed_tokens: int = 0) -> PhysicalTokenBlock: - if not self.free_blocks: - raise ValueError("Out of memory! No free blocks are available.") - block = self.free_blocks.pop() - block.ref_count = 1 - return block - - def free(self, block: PhysicalTokenBlock) -> None: - if block.ref_count == 0: - raise ValueError(f"Double free! {block} is already freed.") - block.ref_count -= 1 - if block.ref_count == 0: - self.free_blocks.append(block) - - def get_num_free_blocks(self) -> int: - return len(self.free_blocks) - - def get_num_total_blocks(self) -> int: - return self.num_blocks - - def contains_block(self, block_hash: int) -> bool: - raise NotImplementedError( - "Invalid codepath for uncached block allocator.") - - def update_hash(self, block_hash: int, block: PhysicalTokenBlock): - raise NotImplementedError( - "Invalid codepath for uncached block allocator.") - - def get_prefix_cache_hit_rate(self) -> float: - return -1 - - -class BlockSpaceManagerV1(BlockSpaceManager): - """Manages the mapping between logical and physical token blocks.""" - - def __init__( - self, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, - watermark: float = 0.01, - sliding_window: Optional[int] = None, - enable_caching: bool = False, - ) -> None: - self.block_size = block_size - self.num_total_gpu_blocks = num_gpu_blocks - self.num_total_cpu_blocks = num_cpu_blocks - - if enable_caching and sliding_window is not None: - raise NotImplementedError( - "Sliding window is not allowed with prefix caching enabled!") - - self.block_sliding_window = None - if sliding_window is not None: - # Round up to nearest block size to regularize sliding window - # allocation sizes. - self.block_sliding_window = math.ceil(sliding_window / block_size) - - self.watermark = watermark - assert watermark >= 0.0 - - self.enable_caching = enable_caching - - self.watermark_blocks = int(watermark * num_gpu_blocks) - - if self.enable_caching: - logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( - Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( - Device.CPU, block_size, num_cpu_blocks) - else: - self.gpu_allocator = UncachedBlockAllocator( - Device.GPU, block_size, num_gpu_blocks) - self.cpu_allocator = UncachedBlockAllocator( - Device.CPU, block_size, num_cpu_blocks) - # Mapping: seq_id -> BlockTable. - self.block_tables: Dict[int, BlockTable] = {} - - # Mapping: req_id -> BlockTable - # Note that each SequenceGroup has a unique - # request ID - self.cross_block_tables: Dict[str, BlockTable] = {} - - def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: - return 0 if seq is None else seq.n_blocks - - def can_allocate(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - - assert (num_lookahead_slots == 0 - ), "lookahead allocation not supported in BlockSpaceManagerV1" - - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - self_num_required_blocks = self._get_seq_num_required_blocks( - seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) - cross_num_required_blocks = self._get_seq_num_required_blocks( - seq_group.get_encoder_seq()) - num_required_blocks = self_num_required_blocks + \ - cross_num_required_blocks - - if self.block_sliding_window is not None: - - num_required_blocks = min(num_required_blocks, - self.block_sliding_window) - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - - # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): - return AllocStatus.NEVER - if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _allocate_sequence(self, \ - seq: Optional[Sequence], \ - ref_count: int, \ - is_encoder_decoder: bool = True) -> BlockTable: - # Allocate new physical token blocks that will store the prompt tokens. - num_prompt_blocks = self._get_seq_num_required_blocks(seq) - - block_table: BlockTable = BlockTable() - assert seq is not None - for logical_idx in range(num_prompt_blocks): - if (self.block_sliding_window is not None - and logical_idx >= self.block_sliding_window): - block = block_table[logical_idx % self.block_sliding_window] - # Set the reference counts of the token blocks. - block.ref_count = ref_count - elif not is_encoder_decoder and self.enable_caching: - block = self.gpu_allocator.allocate( - seq.hash_of_block(logical_idx), - seq.num_hashed_tokens_of_block(logical_idx)) - else: - block = self.gpu_allocator.allocate() - # Set the reference counts of the token blocks. - block.ref_count = ref_count - block_table.append(block) - - return block_table - - def allocate(self, seq_group: SequenceGroup) -> None: - is_encoder_decoder = seq_group.is_encoder_decoder() - check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) - - # Allocate decoder sequences - # - # NOTE: Here we assume that all sequences in the group have the same - # decoder prompt. - wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) - seq = wait_seqs[0] - block_table: BlockTable = \ - self._allocate_sequence(seq, - seq_group.num_seqs(), - is_encoder_decoder) - - # Assign the self-attention block tables for each sequence. - if len(wait_seqs) == 1: - self.block_tables[seq.seq_id] = block_table - else: - for seq in wait_seqs: - self.block_tables[seq.seq_id] = block_table.copy() - - # Allocate encoder sequence - if is_encoder_decoder: - # A SequenceGroup has only a single encoder sequence (at most), - # thus allocate with a ref count of 1 - block_table = self._allocate_sequence(seq_group.get_encoder_seq(), - 1, is_encoder_decoder) - # Assign the cross-attention block table for the SequenceGroup. - self.cross_block_tables[seq_group.request_id] = block_table - - def can_append_slots(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> bool: - assert (num_lookahead_slots == 0 - ), "lookahead allocation not supported in BlockSpaceManagerV1" - - # Simple heuristic: If there is at least one free block - # for each sequence, we can append. - num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() - num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) - return num_seqs <= num_free_gpu_blocks - - def _promote_last_block( - self, - seq: Sequence, - last_block: PhysicalTokenBlock, - ) -> PhysicalTokenBlock: - assert self.enable_caching - - # Compute a new hash for the block so that it can be shared by other - # Sequences - new_hash = seq.hash_of_block(seq.n_blocks - 1) - - # if new_hash is already in the cached table, then free last_block - # and return the cached version - if self.gpu_allocator.contains_block(new_hash): - self.gpu_allocator.free(last_block) - return self.gpu_allocator.allocate(new_hash) - else: - self.gpu_allocator.update_hash(new_hash, last_block) - return last_block - - def _is_last_block_full( - self, - seq: Sequence, - ) -> bool: - token_ids_len = seq.data.get_len() - return token_ids_len > 0 and token_ids_len % seq.block_size == 0 - - def _maybe_promote_last_block( - self, - seq: Sequence, - last_block: PhysicalTokenBlock, - ) -> PhysicalTokenBlock: - if self._is_last_block_full(seq): - return self._promote_last_block(seq, last_block) - else: - return last_block - - def _allocate_last_physical_block( - self, - seq: Sequence, - ) -> PhysicalTokenBlock: - # Called before a new block is appended. - # This is in charge of allocating a new physical block (to be appended). - - # None if the last block is not full. Otherwise, we set it to the - # content hash. - if not self.enable_caching: - return self.gpu_allocator.allocate() - block_hash: Optional[int] = None - n_blocks = seq.n_blocks - if (self._is_last_block_full(seq)): - block_hash = seq.hash_of_block(n_blocks - 1) - num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1) - - # num_hashed_tokens is used to compute future hashes - # (e.g. in the hashing function, it is used to ask the sequence for - # prefix tokens) - new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens) - - # If the block_hash is None, then the block is not full. - # If the block is not full, then we expect it to have a refcount of 1. - if block_hash is None: - assert new_block.ref_count == 1 - return new_block - - def append_slots( - self, - seq: Sequence, - num_lookahead_slots: int = 0, - ) -> List[Tuple[int, int]]: - """Allocate a physical slot for a new token.""" - n_blocks = seq.n_blocks - block_table = self.block_tables[seq.seq_id] - # If we need to allocate a new physical block - if len(block_table) < n_blocks: - # Currently this code only supports adding one physical block - assert len(block_table) == n_blocks - 1 - - if (self.block_sliding_window - and len(block_table) >= self.block_sliding_window): - # reuse a block - block_table.append(block_table[len(block_table) % - self.block_sliding_window]) - else: - # The sequence hash a new logical block. - # Allocate a new physical block. - new_block = self._allocate_last_physical_block(seq) - block_table.append(new_block) - return [] - - # We want to append the token to the last physical block. - last_block = block_table[-1] - assert last_block.device == Device.GPU - if last_block.ref_count == 1: - # Not shared with other sequences. Appendable. - if self.enable_caching: - # If the last block is now complete, we may reuse an old block - # to save memory. - maybe_new_block = self._maybe_promote_last_block( - seq, last_block) - block_table[-1] = maybe_new_block - return [] - else: - # The last block is shared with other sequences. - # Copy on Write: Allocate a new block and copy the tokens. - new_block = self._allocate_last_physical_block(seq) - - block_table[-1] = new_block - self.gpu_allocator.free(last_block) - return [(last_block.block_number, new_block.block_number)] - - def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: - # NOTE: fork does not allocate a new physical block. - # Thus, it is always safe from OOM. - if parent_seq.seq_id not in self.block_tables: - # Parent sequence has either been freed or never existed. - return - src_block_table = self.block_tables[parent_seq.seq_id] - self.block_tables[child_seq.seq_id] = src_block_table.copy() - - # When using a sliding window, blocks will be eventually reused. - # In this case the block tables will contain repeated blocks. - # When forking, we must make sure that each block's `ref_count` - # is only incremented by one, so we deduplicate them by wrapping - # them in a set. - for block in set(src_block_table): - block.ref_count += 1 - - def _get_physical_blocks( - self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: - - # NOTE: Here, we assume that the physical blocks are only shared by - # the sequences in the same group. - request_id = seq_group.request_id - blocks: Set[PhysicalTokenBlock] = set() - for seq in seq_group.get_seqs(): - if seq.is_finished(): - continue - blocks.update(self.block_tables[seq.seq_id]) - # Cross-attention blocks - if seq_group.is_encoder_decoder(): - blocks.update(self.cross_block_tables[request_id]) - return list(blocks) - - def can_swap_in(self, - seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> AllocStatus: - assert (num_lookahead_slots == 0 - ), "BlockSpaceManagerV1 does not support lookahead allocation" - - blocks = self._get_physical_blocks(seq_group) - num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) - if seq_group.is_encoder_decoder(): - num_swapped_seqs += 1 - num_free_blocks = self.gpu_allocator.get_num_free_blocks() - # NOTE: Conservatively, we assume that every sequence will allocate - # at least one free block right after the swap-in. - # NOTE: This should match the logic in can_append_slot(). - num_required_blocks = len(blocks) + num_swapped_seqs - if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: - return AllocStatus.NEVER - elif num_free_blocks - num_required_blocks >= self.watermark_blocks: - return AllocStatus.OK - else: - return AllocStatus.LATER - - def _swap_block_table( - self, block_table: BlockTable, src_allocator: BlockAllocatorBase, - dest_allocator: BlockAllocatorBase, - mapping: Dict[PhysicalTokenBlock, - PhysicalTokenBlock]) -> BlockTable: - new_block_table: BlockTable = BlockTable() - - for from_block in block_table: - if from_block in mapping: - to_block = mapping[from_block] - to_block.ref_count += 1 - else: - to_block = dest_allocator.allocate( - from_block.block_hash, from_block.num_hashed_tokens) - mapping[from_block] = to_block - new_block_table.append(to_block) - # Free the source block swapped in to destination. - src_allocator.free(from_block) - - return new_block_table - - def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - - request_id = seq_group.request_id - - # CPU block -> GPU block. - # dict is efficient in lookup `if cpu_block in mapping` - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.cpu_allocator, self.gpu_allocator, - mapping) - - if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.cpu_allocator, - self.gpu_allocator, - mapping) - - return [(cpu_block.block_number, gpu_block.block_number) - for cpu_block, gpu_block in mapping.items()] - - def can_swap_out(self, seq_group: SequenceGroup) -> bool: - blocks = self._get_physical_blocks(seq_group) - return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - - def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: - request_id = seq_group.request_id - - # GPU block -> CPU block. - # dict is efficient in lookup `if gpu_block in mapping` - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - self.block_tables[seq.seq_id] = \ - self._swap_block_table(self.block_tables[seq.seq_id], - self.gpu_allocator, self.cpu_allocator, - mapping) - - if seq_group.is_encoder_decoder(): - self.cross_block_tables[request_id] = \ - self._swap_block_table(self.cross_block_tables[request_id], - self.gpu_allocator, - self.cpu_allocator, - mapping) - - return [(cpu_block.block_number, gpu_block.block_number) - for cpu_block, gpu_block in mapping.items()] - - def _free_block_table(self, block_table: BlockTable) -> None: - # when using a sliding window, each seq will only use up - # to `self.block_sliding_window` blocks. When freeing - # the block table, we must make sure to not free blocks more - # than once. If no sliding window is used, there is no block - # reuse in the block table, so we must free all blocks. - blocks_to_free = (block_table[-self.block_sliding_window:] - if self.block_sliding_window is not None else - block_table) - for block in set(blocks_to_free): - if block.device == Device.GPU: - self.gpu_allocator.free(block) - else: - self.cpu_allocator.free(block) - - def free(self, seq: Sequence) -> None: - if seq.seq_id not in self.block_tables: - # Already freed or haven't been scheduled yet. - return - block_table = self.block_tables[seq.seq_id] - self._free_block_table(block_table) - del self.block_tables[seq.seq_id] - - def free_cross(self, seq_group: SequenceGroup) -> None: - if seq_group.request_id not in self.cross_block_tables: - # Already freed or hasn't ben scheduled yet. - return - block_table = self.cross_block_tables[seq_group.request_id] - self._free_block_table(block_table) - del self.cross_block_tables[seq_group.request_id] - - def reset(self) -> None: - # Free decoder block tables - for block_table in self.block_tables.values(): - self._free_block_table(block_table) - self.block_tables.clear() - # Free cross-attention block tables - for block_table in self.cross_block_tables.values(): - self._free_block_table(block_table) - self.cross_block_tables.clear() - - def get_block_table(self, seq: Sequence) -> List[int]: - return self.block_tables[seq.seq_id].ids() - - def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: - block_table = self.cross_block_tables[seq_group.request_id] - return [block.block_number for block in block_table] - - def get_num_free_gpu_blocks(self) -> int: - return self.gpu_allocator.get_num_free_blocks() - - def get_num_free_cpu_blocks(self) -> int: - return self.cpu_allocator.get_num_free_blocks() - - def access_all_blocks_in_seq( - self, - seq: Sequence, - access_time: float, - ) -> None: - if self.enable_caching: - # Update the last accessed time of all the blocks accessed - # in this step. - block_table = self.block_tables[seq.seq_id] - for block in block_table: - block.last_accessed = access_time - - def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int): - if seq.seq_id not in self.block_tables: - return - - # When chunked prefill is enabled, the computed full blocks - # should be calculated based on the number of computed tokens. - max_computed_tokens = (seq.data.get_num_computed_tokens() + - token_chunk_size) - computed_full_blocks = max_computed_tokens // self.block_size - - block_table = self.block_tables[seq.seq_id] - if computed_full_blocks == 0: - return - for i in reversed(range(computed_full_blocks)): - if block_table[i].computed: - break - block_table[i].computed = True - - def get_all_computed_blocks(self, seq: Sequence) -> List[int]: - if seq.seq_id not in self.block_tables: - return [] - block_table = self.block_tables[seq.seq_id] - # NOTE We exclude the last block to avoid the case where the entire - # prompt is cached. This would cause erroneous behavior in model - # runner. - return [ - b.block_number - for b in takewhile(lambda b: b.computed, block_table[:-1]) - ] - - def get_common_computed_block_ids( - self, seqs: List[Sequence]) -> GenericSequence[int]: - """Return the block ids that are common for a given sequence group. - - Used in prefill (can skip prefill of some blocks). - """ - # Can return non-empty result only with prefix caching enabled. - if not self.enable_caching: - return [] - - ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] - return commonprefix([ids for ids in ids_list if ids != []]) - - def mark_blocks_as_computed(self, seq_group: SequenceGroup, - token_chunk_size: int): - if self.enable_caching: - for seq in seq_group.get_seqs(): - self.compute_full_blocks_in_seq(seq, token_chunk_size) - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - if device == Device.GPU: - return self.gpu_allocator.get_prefix_cache_hit_rate() - if device == Device.CPU: - return self.cpu_allocator.get_prefix_cache_hit_rate() - raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor.py similarity index 100% rename from vllm/core/evictor_v2.py rename to vllm/core/evictor.py diff --git a/vllm/core/evictor_v1.py b/vllm/core/evictor_v1.py deleted file mode 100644 index 5db5a08a5bb6..000000000000 --- a/vllm/core/evictor_v1.py +++ /dev/null @@ -1,106 +0,0 @@ -import enum -from abc import ABC, abstractmethod -from typing import OrderedDict - -from vllm.block import PhysicalTokenBlock - - -class EvictionPolicy(enum.Enum): - """Enum for eviction policy used by make_evictor to instantiate the correct - Evictor subclass. - """ - LRU = enum.auto() - - -class Evictor(ABC): - """The Evictor subclasses should be used by the BlockAllocator class to - handle eviction of freed PhysicalTokenBlocks. - """ - - @abstractmethod - def __init__(self): - pass - - @abstractmethod - def __contains__(self, block_hash: int) -> bool: - pass - - @abstractmethod - def evict(self) -> PhysicalTokenBlock: - """Runs the eviction algorithm and returns the evicted block""" - pass - - @abstractmethod - def add(self, block: PhysicalTokenBlock): - """Adds block to the evictor, making it a candidate for eviction""" - pass - - @abstractmethod - def remove(self, block_hash: int) -> PhysicalTokenBlock: - """Simply removes the block with the hash value block_hash from the - evictor. Caller is responsible for making sure that block_hash is - contained in the evictor before calling remove. Should be used to - "bring back" blocks that have been freed but not evicted yet. - """ - pass - - @property - @abstractmethod - def num_blocks(self) -> int: - pass - - -class LRUEvictor(Evictor): - """Evicts in a least-recently-used order using the last_accessed timestamp - that's recorded in the PhysicalTokenBlock. If there are multiple blocks with - the same last_accessed time, then the one with the largest num_hashed_tokens - will be evicted. If two blocks each have the lowest last_accessed time and - highest num_hashed_tokens value, then one will be chose arbitrarily - """ - - def __init__(self): - self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict() - - def __contains__(self, block_hash: int) -> bool: - return block_hash in self.free_table - - def evict(self) -> PhysicalTokenBlock: - if len(self.free_table) == 0: - raise ValueError("No usable cache memory left") - - evicted_block = next(iter(self.free_table.values())) - # The blocks with the lowest timestamps should be placed consecutively - # at the start of OrderedDict. Loop through all these blocks to - # find the one with maximum number of hashed tokens. - for _, block in self.free_table.items(): - if evicted_block.last_accessed < block.last_accessed: - break - if evicted_block.num_hashed_tokens < block.num_hashed_tokens: - evicted_block = block - - self.free_table.pop(evicted_block.block_hash) - - evicted_block.computed = False - return evicted_block - - def add(self, block: PhysicalTokenBlock): - self.free_table[block.block_hash] = block - - def remove(self, block_hash: int) -> PhysicalTokenBlock: - if block_hash not in self.free_table: - raise ValueError( - "Attempting to remove block that's not in the evictor") - block: PhysicalTokenBlock = self.free_table[block_hash] - self.free_table.pop(block_hash) - return block - - @property - def num_blocks(self) -> int: - return len(self.free_table) - - -def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: - if eviction_policy == EvictionPolicy.LRU: - return LRUEvictor() - else: - raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 9e1d1b02f680..9501a516bf02 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -28,13 +28,9 @@ class BlockSpaceManager(ABC): def get_block_space_manager_class(version: str): version = version.lower() - if version == "v1": - from vllm.core.block_manager_v1 import BlockSpaceManagerV1 - return BlockSpaceManagerV1 - - if version == "v2": - from vllm.core.block_manager_v2 import BlockSpaceManagerV2 - return BlockSpaceManagerV2 + if version == "selfattn": + from vllm.core.block_manager import SelfAttnBlockSpaceManager + return SelfAttnBlockSpaceManager if version == "placeholder": from vllm.core.placeholder_block_space_manager import ( diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 1f0a121711db..88733b8f53b8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,8 +4,9 @@ import time from collections import deque from dataclasses import dataclass, field -from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) +from typing import Callable, Deque, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -115,7 +116,7 @@ class ScheduledSequenceGroup: class SchedulerOutputs: """The scheduling decision made from a scheduler.""" # Scheduled sequence groups. - scheduled_seq_groups: Iterable[ScheduledSequenceGroup] + scheduled_seq_groups: GenericSequence[ScheduledSequenceGroup] # Number of prefill groups scheduled. num_prefill_groups: int # Total number of batched tokens. @@ -289,7 +290,7 @@ def scheduler_running_outputs_builder(): def scheduled_seq_group_builder(): - return ScheduledSequenceGroup(SequenceGroup("", [], -1), + return ScheduledSequenceGroup(SequenceGroup.__new__(SequenceGroup), token_chunk_size=0) # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) @@ -311,10 +312,8 @@ def __init__( # LoRAs. This should be improved in the future. self.lora_config = lora_config - version = "v1" - if self.scheduler_config.use_v2_block_manager: - version = "v2" - if (self.scheduler_config.embedding_mode + version = "selfattn" + if (self.scheduler_config.task == "embedding" or self.cache_config.is_attention_free): version = "placeholder" diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6e1970bfed98..ec39856b6f67 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -7,7 +7,7 @@ The typical workflow is: - call `init_distributed_environment` to initialize the distributed environment. -- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to initialize the model parallel groups. - any code dealing with the distributed stuff @@ -20,6 +20,7 @@ steps. """ import contextlib +import gc import pickle import weakref from collections import namedtuple @@ -391,8 +392,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: # Convert negative dim to positive. dim += input_.dim() input_size = input_.size() + # NOTE: we have to use concat-style all-gather here, + # stack-style all-gather has compatibility issues with + # torch.compile . see https://github.com/pytorch/pytorch/issues/138795 + output_size = (input_size[0] * world_size, ) + input_size[1:] # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, + output_tensor = torch.empty(output_size, dtype=input_.dtype, device=input_.device) # All-gather. @@ -400,6 +405,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: input_, group=self.device_group) # Reshape + output_tensor = output_tensor.reshape((world_size, ) + input_size) output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + (world_size * @@ -1129,6 +1135,19 @@ def destroy_distributed_environment(): torch.distributed.destroy_process_group() +def cleanup_dist_env_and_memory(shutdown_ray: bool = False): + destroy_model_parallel() + destroy_distributed_environment() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + if shutdown_ray: + import ray # Lazy import Ray + ray.shutdown() + gc.collect() + if not current_platform.is_cpu(): + torch.cuda.empty_cache() + + def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: """ This is a collective operation that returns if each rank is in the same node diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1b132cf76a10..c268d0af9ca0 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, - Tuple, Type, Union) + Tuple, Type, Union, cast, get_args) import torch @@ -12,10 +12,12 @@ DeviceConfig, EngineConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TokenizerPoolConfig) + SpeculativeConfig, TaskOption, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import FlexibleArgumentParser @@ -84,12 +86,14 @@ class EngineArgs: model: str = 'facebook/opt-125m' served_model_name: Optional[Union[str, List[str]]] = None tokenizer: Optional[str] = None + task: TaskOption = "auto" skip_tokenizer_init: bool = False tokenizer_mode: str = 'auto' + chat_template_text_format: str = 'string' trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' - config_format: str = 'auto' + config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -133,6 +137,7 @@ class EngineArgs: tokenizer_pool_extra_config: Optional[dict] = None limit_mm_per_prompt: Optional[Mapping[str, int]] = None enable_lora: bool = False + enable_lora_bias: bool = False max_loras: int = 1 max_lora_rank: int = 16 enable_prompt_adapter: bool = False @@ -181,7 +186,7 @@ class EngineArgs: scheduling_policy: Literal["fcfs", "priority"] = "fcfs" def __post_init__(self): - if self.tokenizer is None: + if not self.tokenizer: self.tokenizer = self.model # Setup plugins @@ -198,6 +203,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default=EngineArgs.model, help='Name or path of the huggingface model to use.') + parser.add_argument( + '--task', + default=EngineArgs.task, + choices=get_args(TaskOption), + help='The task to use the model for. Each vLLM instance only ' + 'supports one task, even if the same model can be used for ' + 'multiple tasks. When the model only supports one task, "auto" ' + 'can be used to select it; otherwise, you must specify explicitly ' + 'which task to use.') parser.add_argument( '--tokenizer', type=nullable_str, @@ -238,6 +252,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'fast tokenizer if available.\n* "slow" will ' 'always use the slow tokenizer. \n* ' '"mistral" will always use the `mistral_common` tokenizer.') + parser.add_argument( + '--chat-template-text-format', + type=str, + default=EngineArgs.chat_template_text_format, + choices=['string', 'openai'], + help='The format to render text content within a chat template. ' + '"string" will keep the content field as a string whereas ' + '"openai" will parse content in the current OpenAI format.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') @@ -373,12 +395,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='Disables sliding window, ' 'capping to sliding window size') - parser.add_argument( - '--use-v2-block-manager', - default=EngineArgs.use_v2_block_manager, - action='store_true', - help='Use BlockSpaceMangerV2. By default this is set to True. ' - 'Set to False to use BlockSpaceManagerV1') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='[DEPRECATED] block manager v1 has been ' + 'removed and SelfAttnBlockSpaceManager (i.e. ' + 'block manager v2) is now the default. ' + 'Setting this flag to True or False' + ' has no effect on vLLM behavior.') parser.add_argument( '--num-lookahead-slots', type=int, @@ -417,7 +440,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='The fraction of GPU memory to be used for the model ' 'executor, which can range from 0 to 1. For example, a value of ' '0.5 would imply 50%% GPU memory utilization. If unspecified, ' - 'will use the default value of 0.9.') + 'will use the default value of 0.9. This is a global gpu memory ' + 'utilization limit, for example if 50%% of the gpu memory is ' + 'already used before vLLM starts and --gpu-memory-utilization is ' + 'set to 0.9, then only 40%% of the gpu memory will be allocated ' + 'to the model executor.') parser.add_argument( '--num-gpu-blocks-override', type=int, @@ -454,11 +481,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'None, we assume the model weights are not ' 'quantized and use `dtype` to determine the data ' 'type of the weights.') - parser.add_argument('--rope-scaling', - default=None, - type=json.loads, - help='RoPE scaling configuration in JSON format. ' - 'For example, {"type":"dynamic","factor":2.0}') + parser.add_argument( + '--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"rope_type":"dynamic","factor":2.0}') parser.add_argument('--rope-theta', default=None, type=float, @@ -535,6 +563,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--enable-lora', action='store_true', help='If True, enable handling of LoRA adapters.') + parser.add_argument('--enable-lora-bias', + action='store_true', + help='If True, enable bias for LoRA adapters.') parser.add_argument('--max-loras', type=int, default=EngineArgs.max_loras, @@ -836,8 +867,11 @@ def from_cli_args(cls, args: argparse.Namespace): def create_model_config(self) -> ModelConfig: return ModelConfig( model=self.model, - tokenizer=self.tokenizer, + task=self.task, + # We know this is not None because we set it in __post_init__ + tokenizer=cast(str, self.tokenizer), tokenizer_mode=self.tokenizer_mode, + chat_template_text_format=self.chat_template_text_format, trust_remote_code=self.trust_remote_code, dtype=self.dtype, seed=self.seed, @@ -906,9 +940,12 @@ def create_engine_config(self) -> EngineConfig: "supported for multimodal models and has been disabled.") self.enable_prefix_caching = False + maybe_register_config_serialize_by_value(self.trust_remote_code) + cache_config = CacheConfig( + # neuron needs block_size = max_model_len block_size=self.block_size if self.device != "neuron" else - self.max_model_len, # neuron needs block_size = max_model_len + (self.max_model_len if self.max_model_len is not None else 0), gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, @@ -966,12 +1003,6 @@ def create_engine_config(self) -> EngineConfig: "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) - if self.num_scheduler_steps > 1 and not self.use_v2_block_manager: - self.use_v2_block_manager = True - logger.warning( - "Enabled BlockSpaceManagerV2 because it is " - "required for multi-step (--num-scheduler-steps > 1)") - speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, target_parallel_config=parallel_config, @@ -987,7 +1018,6 @@ def create_engine_config(self) -> EngineConfig: speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, enable_chunked_prefill=self.enable_chunked_prefill, - use_v2_block_manager=self.use_v2_block_manager, disable_log_stats=self.disable_log_stats, ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, @@ -1018,15 +1048,24 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is None \ else speculative_config.num_lookahead_slots + if not self.use_v2_block_manager: + logger.warning( + "[DEPRECATED] Block manager v1 has been removed, " + "and setting --use-v2-block-manager to True or False has " + "no effect on vLLM behavior. Please remove " + "--use-v2-block-manager in your engine argument. " + "If your use case is not supported by " + "SelfAttnBlockSpaceManager (i.e. block manager v2)," + " please file an issue with detailed information.") + scheduler_config = SchedulerConfig( + task=model_config.task, max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, - use_v2_block_manager=self.use_v2_block_manager, num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, - embedding_mode=model_config.embedding_mode, is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, @@ -1036,6 +1075,7 @@ def create_engine_config(self) -> EngineConfig: policy=self.scheduling_policy, ) lora_config = LoRAConfig( + bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, fully_sharded_loras=self.fully_sharded_loras, @@ -1078,13 +1118,6 @@ def create_engine_config(self) -> EngineConfig: or "all" in detailed_trace_modules, ) - if (model_config.get_sliding_window() is not None - and scheduler_config.chunked_prefill_enabled - and not scheduler_config.use_v2_block_manager): - raise ValueError( - "Chunked prefill is not supported with sliding window. " - "Set --disable-sliding-window to disable sliding window.") - return EngineConfig( model_config=model_config, cache_config=cache_config, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 30e1a09981c5..1f57aecb6481 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,6 @@ from weakref import ReferenceType import vllm.envs as envs -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs @@ -15,25 +14,24 @@ from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase +from vllm.engine.protocol import EngineClient from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType, TokensPrompt +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, - RequestOutput) +from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext -from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - random_uuid, weak_bind) +from vllm.utils import deprecate_kwargs, weak_bind logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -541,7 +539,7 @@ async def build_guided_decoding_logits_processor_async( return sampling_params -class AsyncLLMEngine: +class AsyncLLMEngine(EngineClient): """An asynchronous wrapper for :class:`LLMEngine`. This class is used to wrap the :class:`LLMEngine` class to make it @@ -1039,102 +1037,6 @@ async def generate( ): yield LLMEngine.validate_output(output, RequestOutput) - async def beam_search( - self, - prompt: Union[PromptType, List[int]], - request_id: str, - params: BeamSearchParams, - ) -> AsyncGenerator[RequestOutput, None]: - - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - - tokenizer = await self.get_tokenizer() - tokenizedPrompt = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) - tokenizedLength = len(tokenizedPrompt) - - sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) - - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) - all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] - completed = [] - - for _ in range(max_tokens): - prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) - for beam in all_beams - ] - - tasks = [] - - request_id = f"beam_search-{random_uuid()}" - for i, individual_prompt in enumerate(prompts_batch): - request_id_item = f"{request_id}-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.generate(individual_prompt, beam_search_params, - request_id_item))) - tasks.append(task) - - output = await asyncio.gather(*tasks) - - output = [x[0] for x in output] - - logger.info(output) - - new_beams = [] - for i, current_beam in enumerate(all_beams): - result = output[i] - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) - - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - completed.append(new_beam) - else: - new_beams.append(new_beam) - - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) - all_beams = sorted_beams[:beam_width] - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) - - beam_search_output = RequestOutput( - request_id=request_id, - prompt=prompt, - outputs=[ - CompletionOutput( - text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens, - index=i, - logprobs=beam.cum_logprob, - ) for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=tokenizedPrompt, - prompt_logprobs=None) - - yield LLMEngine.validate_output(beam_search_output, RequestOutput) - async def encode( self, prompt: PromptType, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 563e52a37d93..1dd0f097c74f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,4 +1,5 @@ import time +from collections import Counter as collectionsCounter from collections import deque from contextlib import contextmanager from dataclasses import dataclass @@ -6,7 +7,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, Iterable, List, Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence -from typing import Set, Type, Union, overload +from typing import Set, Type, Union, cast, overload import torch from typing_extensions import TypeVar @@ -29,8 +30,8 @@ from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) +from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, + EncoderDecoderInputs, InputRegistry, PromptType) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -43,7 +44,9 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - Sequence, SequenceGroup, SequenceGroupMetadata, + ParallelSampleSequenceGroup, Sequence, + SequenceGroup, SequenceGroupBase, + SequenceGroupMetadata, SequenceGroupOutput, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) @@ -188,7 +191,7 @@ def validate_output( raise TypeError(f"Expected output of type {output_type}, " f"but found type {type(output)}") - return output + return cast(_O, output) @classmethod def validate_outputs( @@ -247,11 +250,11 @@ def __init__( "enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "seed=%d, served_model_name=%s, " "num_scheduler_steps=%d, chunked_prefill_enabled=%s " "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s)", + "chat_template_text_format=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -280,13 +283,13 @@ def __init__( observability_config, model_config.seed, model_config.served_model_name, - scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, scheduler_config.chunked_prefill_enabled, scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, use_cached_outputs, + model_config.chat_template_text_format, model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. @@ -345,7 +348,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: observability_config=self.observability_config, ) - if not self.model_config.embedding_mode: + if self.model_config.task != "embedding": self._initialize_kv_caches() # If usage stat is enabled, collect relevant info. @@ -474,6 +477,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) + self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -635,14 +640,31 @@ def _verify_args(self) -> None: def _add_processed_request( self, request_id: str, - processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], + processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs], params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, priority: int = 0, - ) -> None: + ) -> Optional[SequenceGroup]: + """Add a processed request to the engine's request pool. + return the created sequence group. + """ + if isinstance(params, SamplingParams) and params.n > 1: + ParallelSampleSequenceGroup.add_request( + request_id, + self, + params, + processed_inputs=processed_inputs, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + return None + self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size @@ -696,6 +718,8 @@ def _add_processed_request( min_cost_scheduler = self.scheduler[costs.index(min(costs))] min_cost_scheduler.add_seq_group(seq_group) + return seq_group + def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() @@ -1039,6 +1063,7 @@ def _process_model_outputs(self, scheduler_outputs.scheduled_seq_groups) has_multiple_outputs: bool = len(outputs) > 1 + outputs_by_sequence_group: List[List[SequenceGroupOutput]] if has_multiple_outputs: assert self.scheduler_config.is_multi_step or \ self.speculative_config @@ -1084,6 +1109,7 @@ def _process_model_outputs(self, finished_before.append(i) continue + output: List[SequenceGroupOutput] if has_multiple_outputs: output = outputs_by_sequence_group[i] else: @@ -1096,7 +1122,7 @@ def _process_model_outputs(self, seq_group, seq_group_meta, is_first_step_output) else: seq_group.update_num_computed_tokens( - seq_group_meta.token_chunk_size) + seq_group_meta.token_chunk_size or 0) if outputs: for o in outputs: @@ -1104,18 +1130,18 @@ def _process_model_outputs(self, and seq_group.metrics is not None): if seq_group.metrics.model_forward_time is not None: seq_group.metrics.model_forward_time += ( - o.model_forward_time) + o.model_forward_time or 0) else: seq_group.metrics.model_forward_time = ( o.model_forward_time) if seq_group.metrics.model_execute_time is not None: seq_group.metrics.model_execute_time += ( - o.model_execute_time) + o.model_execute_time or 0) else: seq_group.metrics.model_execute_time = ( o.model_execute_time) - if self.model_config.embedding_mode: + if self.model_config.task == "embedding": self._process_sequence_group_outputs(seq_group, output) else: self.output_processor.process_prompt_logprob(seq_group, output) @@ -1133,7 +1159,9 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1173,7 +1201,9 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1192,7 +1222,10 @@ def _process_model_outputs(self, continue request_output = RequestOutputFactory.create( - seq_group, use_cache=self.use_cached_outputs) + seq_group, + self.seq_id_to_seq_group, + use_cache=self.use_cached_outputs, + ) if request_output: ctx.request_outputs.append(request_output) @@ -1211,7 +1244,7 @@ def _process_model_outputs(self, skip) # Tracing - self.do_tracing(scheduler_outputs) + self.do_tracing(scheduler_outputs, finished_before) return None @@ -1236,8 +1269,10 @@ def _advance_to_next_step( seq_group, seq_group_metadata, seq_group.state.num_steps == 1) else: - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) + token_chunk_size = (seq_group_metadata.token_chunk_size + if seq_group_metadata.token_chunk_size + is not None else 0) + seq_group.update_num_computed_tokens(token_chunk_size) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( @@ -1576,7 +1611,7 @@ def _get_stats(self, # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks gpu_cache_usage_sys = 0. - if num_total_gpu is not None: + if num_total_gpu: # Guard against both None and 0 num_free_gpu = sum( scheduler.block_manager.get_num_free_gpu_blocks() for scheduler in self.scheduler) @@ -1584,7 +1619,7 @@ def _get_stats(self, num_total_cpu = self.cache_config.num_cpu_blocks cpu_cache_usage_sys = 0. - if num_total_cpu is not None and num_total_cpu > 0: + if num_total_cpu: # Guard against both None and 0 num_free_cpu = sum( scheduler.block_manager.get_num_free_cpu_blocks() for scheduler in self.scheduler) @@ -1614,6 +1649,25 @@ def _get_stats(self, n_requests: List[int] = [] finished_reason_requests: List[str] = [] + # Lora requests + running_lora_adapters = dict( + collectionsCounter([ + running_request.lora_request.lora_name + for scheduler in self.scheduler + for running_request in scheduler.running + if running_request.lora_request + ])) + waiting_lora_adapters = dict( + collectionsCounter([ + waiting_request.lora_request.lora_name + for scheduler in self.scheduler + for waiting_request in scheduler.waiting + if waiting_request.lora_request + ])) + max_lora_stat = "0" + if self.lora_config: + max_lora_stat = str(self.lora_config.max_loras) + # NOTE: This loop assumes prefill seq_groups are before # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: @@ -1663,6 +1717,15 @@ def _get_stats(self, # TPOTs. latency = seq_group.get_last_latency(now) time_per_output_tokens_iter.append(latency) + if seq_group.state.current_step == 0: + # For async_output_proc, the do_log_stats() + # is called following init_multi_step(), which + # sets the current_step to zero. + actual_num_batched_tokens +=\ + seq_group.state.num_steps - 1 + else: + actual_num_batched_tokens +=\ + seq_group.state.current_step - 1 # Because of chunked prefill, we can have a single sequence # group that does multiple prompt_runs. To prevent logging @@ -1735,7 +1798,9 @@ def _get_stats(self, num_generation_tokens_requests=num_generation_tokens_requests, n_requests=n_requests, finished_reason_requests=finished_reason_requests, - ) + max_lora=str(max_lora_stat), + waiting_lora_adapters=list(waiting_lora_adapters.keys()), + running_lora_adapters=list(running_lora_adapters.keys())) def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) @@ -1783,11 +1848,18 @@ def stop_profile(self) -> None: def is_tracing_enabled(self) -> bool: return self.tracer is not None - def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None: + def do_tracing(self, + scheduler_outputs: SchedulerOutputs, + finished_before: Optional[List[int]] = None) -> None: if self.tracer is None: return - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double tracing when using async output proc + if finished_before and idx in finished_before: + continue + seq_group = scheduled_seq_group.seq_group if seq_group.is_finished(): self.create_trace_span(seq_group) @@ -1852,11 +1924,8 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: def is_encoder_decoder_model(self): return self.input_preprocessor.is_encoder_decoder_model() - def is_embedding_model(self): - return self.model_config.is_embedding_model - - def _validate_model_inputs(self, inputs: Union[LLMInputs, - EncoderDecoderLLMInputs]): + def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, + EncoderDecoderInputs]): if self.model_config.is_multimodal_model: # For encoder-decoder multimodal models, the max_prompt_len # restricts the decoder prompt length diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 42acd3ea4c94..a46625eff1e4 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING from typing import Counter as CollectionsCounter -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Type, Union, cast import numpy as np import prometheus_client @@ -34,7 +34,11 @@ class Metrics: See https://prometheus.github.io/client_python/multiprocess/ for more details on limitations. """ + labelname_finish_reason = "finished_reason" + labelname_waiting_lora_adapters = "waiting_lora_adapters" + labelname_running_lora_adapters = "running_lora_adapters" + labelname_max_lora = "max_lora" _gauge_cls = prometheus_client.Gauge _counter_cls = prometheus_client.Counter _histogram_cls = prometheus_client.Histogram @@ -55,6 +59,16 @@ def __init__(self, labelnames: List[str], max_model_len: int): documentation="Number of requests waiting to be processed.", labelnames=labelnames, multiprocess_mode="sum") + self.gauge_lora_info = self._gauge_cls( + name="vllm:lora_requests_info", + documentation="Running stats on lora requests.", + labelnames=[ + self.labelname_running_lora_adapters, + self.labelname_max_lora, + self.labelname_waiting_lora_adapters, + ], + multiprocess_mode="livemostrecent", + ) self.gauge_scheduler_swapped = self._gauge_cls( name="vllm:num_requests_swapped", documentation="Number of requests swapped to CPU.", @@ -249,10 +263,11 @@ def __init__(self, labelnames: Optional[List[str]] = None, buckets: Optional[List[float]] = None): labelnames_tuple = tuple(labelnames) if labelnames else None + boundaries = buckets if buckets else [] self._histogram = ray_metrics.Histogram(name=name, description=documentation, tag_keys=labelnames_tuple, - boundaries=buckets) + boundaries=boundaries) def labels(self, **labels): self._histogram.set_default_tags(labels) @@ -267,9 +282,12 @@ class RayMetrics(Metrics): RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. Provides the same metrics as Metrics but uses Ray's util.metrics library. """ - _gauge_cls = _RayGaugeWrapper - _counter_cls = _RayCounterWrapper - _histogram_cls = _RayHistogramWrapper + _gauge_cls: Type[prometheus_client.Gauge] = cast( + Type[prometheus_client.Gauge], _RayGaugeWrapper) + _counter_cls: Type[prometheus_client.Counter] = cast( + Type[prometheus_client.Counter], _RayCounterWrapper) + _histogram_cls: Type[prometheus_client.Histogram] = cast( + Type[prometheus_client.Histogram], _RayHistogramWrapper) def __init__(self, labelnames: List[str], max_model_len: int): if ray_metrics is None: @@ -422,6 +440,9 @@ def _log_histogram(self, histogram, data: Union[List[int], for datum in data: histogram.labels(**self.labels).observe(datum) + def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None: + gauge.labels(**data).set(1) + def _log_prometheus(self, stats: Stats) -> None: # System state data self._log_gauge(self.metrics.gauge_scheduler_running, @@ -438,7 +459,17 @@ def _log_prometheus(self, stats: Stats) -> None: stats.cpu_prefix_cache_hit_rate) self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, stats.gpu_prefix_cache_hit_rate) - + # Including max-lora in metric, in future this property of lora + # config maybe extended to be dynamic. + lora_info = { + self.metrics.labelname_running_lora_adapters: + ",".join(stats.running_lora_adapters), + self.metrics.labelname_waiting_lora_adapters: + ",".join(stats.waiting_lora_adapters), + self.metrics.labelname_max_lora: + stats.max_lora, + } + self._log_gauge_string(self.metrics.gauge_lora_info, lora_info) # Iteration level data self._log_counter(self.metrics.counter_num_preemption, stats.num_preemption_iter) diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index bafd5fa1a8a8..e9a5bd3b586b 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -51,6 +51,9 @@ class Stats: num_generation_tokens_requests: List[int] n_requests: List[int] finished_reason_requests: List[str] + waiting_lora_adapters: List[str] + running_lora_adapters: List[str] + max_lora: str spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 166906f24673..9e5a6b21f4c1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -3,7 +3,7 @@ import pickle from contextlib import contextmanager, suppress from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union, overload) + Optional, Union, cast, overload) import cloudpickle import zmq @@ -12,8 +12,8 @@ from zmq.asyncio import Socket from vllm import PoolingParams -from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs from vllm.engine.arg_utils import AsyncEngineArgs # yapf conflicts with isort for this block # yapf: disable @@ -26,18 +26,18 @@ RPCError, RPCProcessRequest, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) +from vllm.engine.protocol import EngineClient # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType, TokensPrompt +from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, - RequestOutput) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - random_uuid) +from vllm.utils import deprecate_kwargs logger = init_logger(__name__) @@ -53,7 +53,7 @@ class MQClientClosedError(Exception): """ -class MQLLMEngineClient: +class MQLLMEngineClient(EngineClient): """A client wrapper for MQLLMEngine that conforms to the EngineClient protocol. @@ -204,8 +204,20 @@ async def run_output_handler_loop(self): # (and record only the first one) if is_engine_errored and not self._errored_with: self._errored_with = exception + # If engine is errored, no matter the type of exception + # it will no longer be able to receive new requests, + # therefore we have to inform that the current + # processed requests failed as well. Send back a dead + # engine error give this feedback and also give a + # 'hint' to the server to shutdown next. + exception = self.dead_error if request_id is None: + # If request_id is None, then the engine raised an + # exception for a batch, and we may not know the + # request that caused it, neither if it was actually + # caused by any of them (e.g. CUDA OOM). Therefore we + # broadcast the same exception for all requests. for queue_i in tuple(self.output_queues.values()): queue_i.put_nowait(exception) else: @@ -316,7 +328,7 @@ async def _check_success(error_message: str, socket: Socket): or response != VLLM_RPC_SUCCESS_STR): raise ValueError(error_message) - async def get_tokenizer(self, lora_request: LoRARequest): + async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None): return await self.tokenizer.get_lora_tokenizer_async(lora_request) async def get_decoding_config(self) -> DecodingConfig: @@ -344,8 +356,14 @@ async def abort(self, request_id: str): await self._send_one_way_rpc_request( request=RPCAbortRequest(request_id), socket=self.input_socket) - async def do_log_stats(self): - """Ignore do_log_stats (handled on MQLLMEngine polling)""" + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + """ + Ignore do_log_stats (handled on MQLLMEngine polling) + """ pass async def check_health(self): @@ -444,104 +462,6 @@ def generate( lora_request, trace_headers, prompt_adapter_request, priority) - async def beam_search( - self, - prompt: Union[PromptType, List[int]], - request_id: str, - params: BeamSearchParams, - ) -> AsyncGenerator[RequestOutput, None]: - - beam_width = params.beam_width - max_tokens = params.max_tokens - ignore_eos = params.ignore_eos - temperature = params.temperature - length_penalty = params.length_penalty - - tokenizer = await self.get_tokenizer(lora_request=None) - tokenizedPrompt = prompt if isinstance( - prompt, list) else tokenizer.encode(prompt) - tokenizedLength = len(tokenizedPrompt) - - sort_beams_key = create_sort_beams_key_function( - tokenizer.eos_token_id, length_penalty) - - beam_search_params = SamplingParams(logprobs=2 * beam_width, - max_tokens=1, - temperature=temperature) - all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] - completed = [] - - for _ in range(max_tokens): - prompts_batch = [ - TokensPrompt(prompt_token_ids=beam.tokens) - for beam in all_beams - ] - - tasks = [] - - request_id = f"beam_search-{random_uuid()}" - for i, individual_prompt in enumerate(prompts_batch): - request_id_item = f"{request_id}-{i}" - task = asyncio.create_task( - collect_from_async_generator( - self.generate(individual_prompt, beam_search_params, - request_id_item))) - tasks.append(task) - - output = await asyncio.gather(*tasks) - - output = [x[0] for x in output] - - logger.info(output) - - new_beams = [] - for i, current_beam in enumerate(all_beams): - result = output[i] - - if result.outputs[0].logprobs is not None: - logprobs = result.outputs[0].logprobs[0] - for token_id, logprob_obj in logprobs.items(): - new_beam = BeamSearchSequence( - tokens=current_beam.tokens + [token_id], - cum_logprob=current_beam.cum_logprob + - logprob_obj.logprob) - - if token_id == tokenizer.eos_token_id and \ - not ignore_eos: - completed.append(new_beam) - else: - new_beams.append(new_beam) - - sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) - all_beams = sorted_beams[:beam_width] - - completed.extend(all_beams) - sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) - best_beams = sorted_completed[:beam_width] - - for beam in best_beams: - beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) - - beam_search_output = RequestOutput( - request_id=request_id, - prompt=prompt, - outputs=[ - CompletionOutput( - text=beam.text, - cumulative_logprob=beam.cum_logprob, - token_ids=beam.tokens, - index=i, - logprobs=beam.cum_logprob, - ) for (i, beam) in enumerate(best_beams) - ], - finished=True, - prompt_token_ids=tokenizedPrompt, - prompt_logprobs=None) - - logger.info(beam_search_output) - - yield beam_search_output - @overload # DEPRECATED def encode( self, @@ -605,9 +525,14 @@ def encode( assert (prompt is not None and pooling_params is not None and request_id is not None) - return self._process_request(prompt, pooling_params, request_id, - lora_request, trace_headers, None, - priority) + return cast( + AsyncGenerator[EmbeddingRequestOutput, None], + self._process_request(prompt, + pooling_params, + request_id, + lora_request, + trace_headers, + priority=priority)) async def _process_request( self, @@ -635,7 +560,9 @@ async def _process_request( build_guided_decoding_logits_processor_async( sampling_params=params, tokenizer=await self.get_tokenizer(lora_request), - default_guided_backend=self.decoding_config.guided_decoding_backend + default_guided_backend=(self.decoding_config.guided_decoding_backend + if self.decoding_config + else DecodingConfig.guided_decoding_backend), ) # 1) Create output queue for this requests. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 2bf0ce83c760..f67acdf66075 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -8,7 +8,7 @@ import cloudpickle import zmq -from vllm import AsyncEngineArgs, LLMEngine, SamplingParams +from vllm import AsyncEngineArgs, SamplingParams from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) # yapf conflicts with isort for this block @@ -21,12 +21,17 @@ RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable -from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.envs import VLLM_RPC_TIMEOUT, VLLM_USE_V1 from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext +if VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine +else: + from vllm.engine.llm_engine import LLMEngine + CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, SchedulerConfig, LoRAConfig] @@ -73,11 +78,9 @@ def __init__(self, # For MQLLMEngine, we can use cached outputs, since each new request # output is immediately pickled and send over the socket, which frees # the python object to be reused again. - use_cached_outputs = True + kwargs['use_cached_outputs'] = True - self.engine = LLMEngine(*args, - **kwargs, - use_cached_outputs=use_cached_outputs) + self.engine = LLMEngine(*args, **kwargs) self.log_requests = log_requests self.use_async_sockets = use_async_sockets @@ -138,14 +141,16 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, executor_class = LLMEngine._get_executor_cls(engine_config) - return cls( - ipc_path=ipc_path, - use_async_sockets=engine_config.model_config.use_async_output_proc, - **engine_config.to_dict(), - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - usage_context=usage_context) + use_async_sockets = (engine_config.model_config.use_async_output_proc + and not VLLM_USE_V1) + + return cls(ipc_path=ipc_path, + use_async_sockets=use_async_sockets, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) def start(self): try: diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 74ddb250ccd9..3ed37a269c4b 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List +from typing import Callable, List, cast from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( @@ -9,8 +9,10 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Sequence, + SequenceGroup, SequenceGroupOutput, SequenceOutput, + SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -57,6 +59,7 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, """ for output in outputs: # Concatenate single-step prompt logprob processing results. + assert isinstance(output, CompletionSequenceGroupOutput) single_step_process_prompt_logprob(self, seq_group, output) @staticmethod @@ -100,8 +103,18 @@ def process_outputs(self, "Beam search not supported in multi-step decoding.") seq = seqs[0] seq_id = seq.seq_id - assert all( - [seq_id == output.samples[0].parent_seq_id for output in outputs]) + # This method is defined in the more generic + # SequenceGroupOutputProcessor, but here we assume that the outputs are + # of a more specific type. + assert all([ + isinstance(output, CompletionSequenceGroupOutput) + for output in outputs + ]) + compl_outputs = cast(List[CompletionSequenceGroupOutput], outputs) + assert all([ + seq_id == output.samples[0].parent_seq_id + for output in compl_outputs + ]) if is_async: # Async case: We process tokens one by one. Here, we know the token @@ -113,7 +126,7 @@ def process_outputs(self, # Since there's only one sequence per sequence group, # we can take the first sample. - samples = [output.samples[0] for output in outputs] + samples = [output.samples[0] for output in compl_outputs] # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index cfa84077685a..da3185f33dbe 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -6,8 +6,8 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, - SequenceOutput, SequenceStatus) +from vllm.sequence import (CompletionSequenceGroupOutput, SequenceGroup, + SequenceGroupOutput) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.utils import Counter @@ -16,7 +16,7 @@ def single_step_process_prompt_logprob( sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, - output: SequenceGroupOutput) -> None: + output: CompletionSequenceGroupOutput) -> None: """Process prompt logprobs associated with the :class:`SequenceGroupOutput` for a given step. @@ -106,110 +106,29 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, """ assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] + assert isinstance(output, CompletionSequenceGroupOutput) single_step_process_prompt_logprob(self, seq_group, output) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput, is_async: bool) -> None: sampling_params = seq_group.sampling_params - if sampling_params.n == 1: - # only have one output sample - sample = outputs.samples[0] - # only have one sequence - seq = seq_group.seqs[0] - if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs) - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) - else: - new_char_count = 0 - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count, - sampling_params, - lora_req=seq_group.lora_request, - ) - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) - return - - # TODO: Add support for async for beam search - assert not is_async - - # Process samples - samples = outputs.samples - parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - parent_child_dict: Dict[int, List[SequenceOutput]] = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } - for sample in samples: - # Guard against a KeyError which can occur if the request was - # aborted while the output was generated - if (child_list := - parent_child_dict.get(sample.parent_seq_id)) is not None: - child_list.append(sample) - # List of (child, parent) - child_seqs: List[Tuple[Sequence, Sequence]] = [] - - # Process the child samples for each parent sequence - for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] - if len(child_samples) == 0: - # This parent sequence has no children samples. Remove - # the parent sequence from the sequence group since it will - # not be used in the future iterations. - parent.status = SequenceStatus.FINISHED_ABORTED - seq_group.remove(parent.seq_id) - for scheduler in self.scheduler: - scheduler.free_seq(parent) - continue - # Fork the parent sequence if there are multiple child samples. - for child_sample in child_samples[:-1]: - new_child_seq_id: int = next(self.seq_counter) - child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs) - child_seqs.append((child, parent)) - # Continue the parent sequence for the last child sample. - # We reuse the parent sequence here to reduce redundant memory - # copies, especially when using non-beam search sampling methods. - last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs) - child_seqs.append((parent, parent)) - - for seq, _ in child_seqs: - if sampling_params.detokenize and self.detokenizer: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) - else: - new_char_count = 0 - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count, - sampling_params, - lora_req=seq_group.lora_request, - ) - - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - for scheduler in self.scheduler: - scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - # NOTE: we need to fork the new sequences before freeing the - # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) - return + + sample = outputs.samples[0] + seq = seq_group.first_seq + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 0c5f8fb7f5be..a71ad493d992 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -57,7 +57,7 @@ def maybe_stop_sequence( # Check if a stop token was encountered. # This assumes a single token produced per step. last_token_id = seq.get_last_token_id() - if last_token_id in sampling_params.stop_token_ids: + if last_token_id in (sampling_params.stop_token_ids or ()): if new_char_count and ( not sampling_params.include_stop_str_in_output): # Remove last token @@ -92,7 +92,7 @@ def _check_stop_strings(seq: Sequence, new_char_count: int, Returns the stop string if matched or else None. """ - if not new_char_count: + if not new_char_count or not sampling_params.stop: return None for stop_str in sampling_params.stop: diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 76782888031e..770982a207e6 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -1,22 +1,25 @@ from typing import List from typing import Sequence as GenericSequence -from typing import Union +from typing import cast from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import PoolerOutput, SequenceGroupOutput +from vllm.sequence import CompletionSequenceGroupOutput, SequenceGroupOutput def create_output_by_sequence_group( - outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], + outputs: GenericSequence[SamplerOutput], num_seq_groups: int) -> List[List[SequenceGroupOutput]]: """Helper method which transforms a 2d list organized by [step][sequence group] into [sequence group][step]. """ - output_by_sequence_group: List[List[SequenceGroupOutput]] = [ + output_by_sequence_group: List[List[CompletionSequenceGroupOutput]] = [ [] for _ in range(num_seq_groups) ] for step in outputs: + sequence_group_output: CompletionSequenceGroupOutput for i, sequence_group_output in enumerate(step): output_by_sequence_group[i].append(sequence_group_output) - return output_by_sequence_group + # Cast to the more generic type that CompletionSequenceGroupOutput + # inherits from. + return cast(List[List[SequenceGroupOutput]], output_by_sequence_group) diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d7ff743e0ada..5c504e0f0217 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,38 +1,49 @@ -from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, - runtime_checkable) +import asyncio +from abc import ABC, abstractmethod +from typing import AsyncGenerator, List, Mapping, Optional, Union +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TokensPrompt +from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, + RequestOutput) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import collect_from_async_generator, random_uuid +logger = init_logger(__name__) -@runtime_checkable -class EngineClient(Protocol): + +class EngineClient(ABC): """Protocol class for Clients to Engine""" @property + @abstractmethod def is_running(self) -> bool: ... @property + @abstractmethod def is_stopped(self) -> bool: ... @property + @abstractmethod def errored(self) -> bool: ... @property + @abstractmethod def dead_error(self) -> BaseException: ... + @abstractmethod def generate( self, prompt: PromptType, @@ -46,6 +57,110 @@ def generate( """Generate outputs for a request.""" ... + async def beam_search( + self, + prompt: Union[str, List[int]], + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + + tokenizer = await self.get_tokenizer(lora_request=None) + if isinstance(prompt, str): + tokenized_prompt = tokenizer.encode(prompt) + prompt_text = prompt + else: + tokenized_prompt = prompt + prompt_text = None + tokenized_length = len(tokenized_prompt) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + all_beams = [ + BeamSearchSequence(tokens=tokenized_prompt, + logprobs=[], + cum_logprob=0) + ] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens[tokenized_length:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt_text, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens[tokenized_length:], + index=i, + logprobs=beam.logprobs, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=tokenized_prompt, + prompt_logprobs=None) + + yield beam_search_output + + @abstractmethod def encode( self, prompt: PromptType, @@ -58,6 +173,7 @@ def encode( """Generate outputs for a request from an embedding model.""" ... + @abstractmethod async def abort(self, request_id: str) -> None: """Abort a request. @@ -65,14 +181,17 @@ async def abort(self, request_id: str) -> None: request_id: The unique id of the request. """ + @abstractmethod async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" ... + @abstractmethod async def get_decoding_config(self) -> DecodingConfig: ... """Get the decoding configuration of the vLLM engine.""" + @abstractmethod async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, @@ -80,9 +199,11 @@ async def get_tokenizer( """Get the appropriate tokenizer for the request""" ... + @abstractmethod async def is_tracing_enabled(self) -> bool: ... + @abstractmethod async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, @@ -90,14 +211,17 @@ async def do_log_stats( ) -> None: ... + @abstractmethod async def check_health(self) -> None: """Raise if unhealthy""" ... + @abstractmethod async def start_profile(self) -> None: """Start profiling the engine""" ... + @abstractmethod async def stop_profile(self) -> None: """Start profiling the engine""" ... diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 41354dc602c6..fef6a91414db 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -5,8 +5,8 @@ from collections import defaultdict from functools import lru_cache, partial from pathlib import Path -from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, - Mapping, Optional, Tuple, TypeVar, Union, cast) +from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List, + Literal, Mapping, Optional, Tuple, TypeVar, Union, cast) # yapf conflicts with isort for this block # yapf: disable @@ -33,6 +33,7 @@ async_get_and_parse_image, get_and_parse_audio, get_and_parse_image) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -58,10 +59,35 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): """The type of the content part.""" +class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain image_url. + This is supported by OpenAI API, although it is not documented. + + Example: + { + "image_url": "https://example.com/image.jpg" + } + """ + image_url: Required[str] + + +class CustomChatCompletionContentSimpleAudioParam(TypedDict, total=False): + """A simpler version of the param that only accepts a plain audio_url. + + Example: + { + "audio_url": "https://example.com/audio.mp3" + } + """ + audio_url: Required[str] + + ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, ChatCompletionContentPartRefusalParam, - CustomChatCompletionContentPartParam] + CustomChatCompletionContentPartParam, + CustomChatCompletionContentSimpleImageParam, + CustomChatCompletionContentSimpleAudioParam, str] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -95,7 +121,7 @@ class ConversationMessage(TypedDict, total=False): role: Required[str] """The role of the message's author.""" - content: Optional[str] + content: Union[Optional[str], List[Dict[str, str]]] """The contents of the message""" tool_call_id: Optional[str] @@ -166,15 +192,18 @@ def _placeholder_str(self, modality: ModalityStr, if model_type == "molmo": return "" - raise TypeError(f"Unknown model type: {model_type}") + raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": if model_type == "ultravox": return "<|reserved_special_token_0|>" + if model_type == "qwen2_audio": + return (f"Audio {current_count}: " + f"<|audio_bos|><|AUDIO|><|audio_eos|>") raise TypeError(f"Unknown model type: {model_type}") elif modality == "video": if model_type == "qwen2_vl": return "<|vision_start|><|video_pad|><|vision_end|>" - raise TypeError(f"Unknown model type: {model_type}") + raise TypeError(f"Unknown {modality} model type: {model_type}") else: raise TypeError(f"Unknown modality: {modality}") @@ -386,60 +415,146 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} +# Define a mapping from part types to their corresponding parsing functions. +MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = { + "text": + lambda part: _TextParser(part).get("text", ""), + "image_url": + lambda part: _ImageParser(part).get("image_url", {}).get("url", ""), + "audio_url": + lambda part: _AudioParser(part).get("audio_url", {}).get("url", ""), + "refusal": + lambda part: _RefusalParser(part).get("refusal", ""), +} + + +def _parse_chat_message_content_mm_part( + part: ChatCompletionContentPartParam) -> Tuple[str, str]: + """ + Parses a given multi-modal content part based on its type. + + Args: + part: A dict containing the content part, with a potential 'type' field. + + Returns: + A tuple (part_type, content) where: + - part_type: Type of the part (e.g., 'text', 'image_url'). + - content: Parsed content (e.g., text, image URL). + + Raises: + ValueError: If the 'type' field is missing and no direct URL is found. + """ + assert isinstance( + part, dict) # This is needed to avoid mypy errors: part.get() from str + part_type = part.get("type", None) + + if isinstance(part_type, str) and part_type in MM_PARSER_MAP: + content = MM_PARSER_MAP[part_type](part) + + # Special case for 'image_url.detail' + if part_type == "image_url" and part.get("detail") != "auto": + logger.warning("'image_url.detail' is currently not supported " + "and will be ignored.") + + return part_type, content + + # Handle missing 'type' but provided direct URL fields. + if part_type is None: + if part.get("image_url") is not None: + image_params = cast(CustomChatCompletionContentSimpleImageParam, + part) + return "image_url", image_params.get("image_url", "") + if part.get("audio_url") is not None: + audio_params = cast(CustomChatCompletionContentSimpleAudioParam, + part) + return "audio_url", audio_params.get("audio_url", "") + + # Raise an error if no 'type' or direct URL is found. + raise ValueError("Missing 'type' field in multimodal part.") + + if not isinstance(part_type, str): + raise ValueError("Invalid 'type' field in multimodal part.") + return part_type, "unknown part_type content" + + +VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url", + "audio_url") + def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], mm_tracker: BaseMultiModalItemTracker, + chat_template_text_format: str, ) -> List[ConversationMessage]: - texts: List[str] = [] + content: List[Union[str, Dict[str, str]]] = [] mm_parser = mm_tracker.create_parser() - keep_multimodal_content = \ + wrap_dicts = \ mm_tracker._model_config.hf_config.model_type in \ - MODEL_KEEP_MULTI_MODAL_CONTENT + MODEL_KEEP_MULTI_MODAL_CONTENT or \ + (chat_template_text_format == "openai") - has_image = False for part in parts: - part_type = part["type"] - if part_type == "text": - text = _TextParser(part)["text"] - texts.append(text) - elif part_type == "image_url": - image_url = _ImageParser(part)["image_url"] - - if image_url.get("detail", "auto") != "auto": - logger.warning( - "'image_url.detail' is currently not supported and " - "will be ignored.") - - mm_parser.parse_image(image_url["url"]) - has_image = True - elif part_type == "audio_url": - audio_url = _AudioParser(part)["audio_url"] - - mm_parser.parse_audio(audio_url["url"]) - elif part_type == "refusal": - text = _RefusalParser(part)["refusal"] - texts.append(text) - else: - raise NotImplementedError(f"Unknown part type: {part_type}") - + parse_res = _parse_chat_message_content_part( + part, + mm_parser, + wrap_dicts=wrap_dicts, + ) + if parse_res: + content.append(parse_res) + + if wrap_dicts: + # Parsing wraps images and texts as interleaved dictionaries + return [ConversationMessage(role=role, + content=content)] # type: ignore + texts = cast(List[str], content) text_prompt = "\n".join(texts) - if keep_multimodal_content: - text_prompt = "\n".join(texts) - role_content = [{'type': 'text', 'text': text_prompt}] + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] + + +def _parse_chat_message_content_part( + part: ChatCompletionContentPartParam, + mm_parser: BaseMultiModalContentParser, + wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]: + """Parses a single part of a conversation. If wrap_dicts is True, + structured dictionary pieces for texts and images will be + wrapped in dictionaries, i.e., {"type": "text", "text", ...} and + {"type": "image"}, respectively. Otherwise multimodal data will be + handled by mm_parser, and texts will be returned as strings to be joined + with multimodal placeholders. + """ + if isinstance(part, str): # Handle plain text parts + text = _TextParser(part) + return text - if has_image: - role_content = [{'type': 'image'}] + role_content - return [ConversationMessage(role=role, - content=role_content)] # type: ignore - else: - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt( - mm_placeholder_counts, text_prompt) - return [ConversationMessage(role=role, content=text_prompt)] + # Handle structured dictionary parts + part_type, content = _parse_chat_message_content_mm_part(part) + + # if part_type is text/refusal/image_url/audio_url but + # content is empty, log a warning and skip + if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and not content: + logger.warning( + "Skipping multimodal part (type: '%s')" + "with empty / unparsable content.", part_type) + return None + + if part_type in ("text", "refusal"): + return {'type': 'text', 'text': content} if wrap_dicts else content + + if part_type == "image_url": + mm_parser.parse_image(content) + return {'type': 'image'} if wrap_dicts else None + + if part_type == "audio_url": + mm_parser.parse_audio(content) + return {'type': 'audio'} if wrap_dicts else None + + raise NotImplementedError(f"Unknown part type: {part_type}") # No need to validate using Pydantic again @@ -450,6 +565,7 @@ def _parse_chat_message_content_parts( def _parse_chat_message_content( message: ChatCompletionMessageParam, mm_tracker: BaseMultiModalItemTracker, + chat_template_text_format: str, ) -> List[ConversationMessage]: role = message["role"] content = message.get("content") @@ -465,6 +581,7 @@ def _parse_chat_message_content( role, content, # type: ignore mm_tracker, + chat_template_text_format, ) for result_msg in result: @@ -508,7 +625,11 @@ def parse_chat_messages( mm_tracker = MultiModalItemTracker(model_config, tokenizer) for msg in messages: - sub_messages = _parse_chat_message_content(msg, mm_tracker) + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + model_config.chat_template_text_format, + ) conversation.extend(sub_messages) @@ -526,7 +647,11 @@ def parse_chat_messages_futures( mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) for msg in messages: - sub_messages = _parse_chat_message_content(msg, mm_tracker) + sub_messages = _parse_chat_message_content( + msg, + mm_tracker, + model_config.chat_template_text_format, + ) conversation.extend(sub_messages) @@ -564,14 +689,14 @@ def apply_mistral_chat_template( **kwargs: Any, ) -> List[int]: if chat_template is not None: - logger.warning( + print_warning_once( "'chat_template' cannot be overridden for mistral tokenizer.") if "add_generation_prompt" in kwargs: - logger.warning( + print_warning_once( "'add_generation_prompt' is not supported for mistral tokenizer, " "so it will be ignored.") if "continue_final_message" in kwargs: - logger.warning( + print_warning_once( "'continue_final_message' is not supported for mistral tokenizer, " "so it will be ignored.") diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2010381076c7..db97fe0a0285 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,10 +6,10 @@ from tqdm import tqdm +from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) -from vllm.engine.arg_utils import EngineArgs -from vllm.engine.llm_engine import LLMEngine +from vllm.engine.arg_utils import EngineArgs, TaskOption from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, apply_mistral_chat_template, @@ -29,7 +29,12 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, deprecate_kwargs, is_list_of +from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of + +if envs.VLLM_USE_V1: + from vllm.v1.engine.llm_engine import LLMEngine # type: ignore +else: + from vllm.engine.llm_engine import LLMEngine # type: ignore logger = init_logger(__name__) @@ -108,6 +113,12 @@ class LLM: DEPRECATE_LEGACY: ClassVar[bool] = False """A flag to toggle whether to deprecate the legacy generate/encode API.""" + DEPRECATE_INIT_POSARGS: ClassVar[bool] = True + """ + A flag to toggle whether to deprecate positional arguments in + :meth:`LLM.__init__`. + """ + @classmethod @contextmanager def deprecate_legacy_api(cls): @@ -117,6 +128,13 @@ def deprecate_legacy_api(cls): cls.DEPRECATE_LEGACY = False + @deprecate_args( + start_index=2, # Ignore self and model + is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS, + additional_message=( + "All positional arguments other than `model` will be " + "replaced with keyword arguments in an upcoming version."), + ) def __init__( self, model: str, @@ -139,6 +157,8 @@ def __init__( disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, mm_processor_kwargs: Optional[Dict[str, Any]] = None, + # After positional args are removed, move this right below `model` + task: TaskOption = "auto", **kwargs, ) -> None: ''' @@ -153,6 +173,7 @@ def __init__( engine_args = EngineArgs( model=model, + task=task, tokenizer=tokenizer, tokenizer_mode=tokenizer_mode, skip_tokenizer_init=skip_tokenizer_init, @@ -316,10 +337,21 @@ def generate( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - if self.llm_engine.model_config.embedding_mode: - raise ValueError( + task = self.llm_engine.model_config.task + if task != "generate": + messages = [ "LLM.generate() is only supported for (conditional) generation " - "models (XForCausalLM, XForConditionalGeneration).") + "models (XForCausalLM, XForConditionalGeneration).", + ] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "generate" in supported_tasks: + messages.append( + "Your model supports the 'generate' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task generate`.") + + raise ValueError(" ".join(messages)) if prompt_token_ids is not None: parsed_prompts = self._convert_v1_inputs( @@ -433,6 +465,7 @@ def sort_beams_key(x: BeamSearchSequence) -> float: for token_id, logprob_obj in logprobs.items(): new_beam = BeamSearchSequence( tokens=current_beam.tokens + [token_id], + logprobs=current_beam.logprobs + [logprobs], cum_logprob=current_beam.cum_logprob + logprob_obj.logprob) @@ -691,10 +724,18 @@ def encode( considered legacy and may be deprecated in the future. You should instead pass them via the ``inputs`` parameter. """ - if not self.llm_engine.model_config.embedding_mode: - raise ValueError( - "LLM.encode() is only supported for embedding models (XModel)." - ) + task = self.llm_engine.model_config.task + if task != "embedding": + messages = ["LLM.encode() is only supported for embedding models."] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "embedding" in supported_tasks: + messages.append( + "Your model supports the 'embedding' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task embedding`.") + + raise ValueError(" ".join(messages)) if prompt_token_ids is not None: parsed_prompts = self._convert_v1_inputs( @@ -904,6 +945,3 @@ def _run_engine( def _is_encoder_decoder_model(self): return self.llm_engine.is_encoder_decoder_model() - - def _is_embedding_model(self): - return self.llm_engine.is_embedding_model() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index ae44b26a6c55..4a0959ea149d 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -33,6 +33,7 @@ from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) +from vllm.entrypoints.openai.fim import get_supported_fim_encoders # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -507,6 +508,7 @@ def init_app_state( prompt_adapters=args.prompt_adapters, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, + fim_encoder=args.fim, ) state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, @@ -531,11 +533,19 @@ async def run_server(args, **uvicorn_kwargs) -> None: if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: ToolParserManager.import_tool_parser(args.tool_parser_plugin) - valide_tool_parses = ToolParserManager.tool_parsers.keys() - if args.enable_auto_tool_choice \ - and args.tool_call_parser not in valide_tool_parses: - raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " - f"(chose from {{ {','.join(valide_tool_parses)} }})") + if args.enable_auto_tool_choice: + valid_tool_parsers = ToolParserManager.tool_parsers + if args.tool_call_parser not in valid_tool_parsers: + raise KeyError( + f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valid_tool_parsers.keys())} }})") + + if args.fim is not None: + valid_fim_encoders = get_supported_fim_encoders() + if args.fim not in valid_fim_encoders: + raise KeyError( + f"invalid FIM encoder: {args.fim} " + f"(chose from {{ {','.join(valid_fim_encoders)} }})") # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index a089985ac975..9559304d6ecd 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -11,6 +11,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.openai.fim import get_supported_fim_encoders from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -213,6 +214,16 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: " into OpenAI API format, the name register in this plugin can be used " "in --tool-call-parser.") + valid_fim_encoders = get_supported_fim_encoders() + parser.add_argument( + "--fim", + type=str, + metavar="{" + ",".join(valid_fim_encoders) + "}", + default=None, + help="Select the fill-in-the-middle (FIM) encoder depending on the" + " model that you're using. Required to use the suffix parameter of the" + " OpenAI Completions API.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/fim/__init__.py b/vllm/entrypoints/openai/fim/__init__.py new file mode 100644 index 000000000000..d5f343113016 --- /dev/null +++ b/vllm/entrypoints/openai/fim/__init__.py @@ -0,0 +1,69 @@ +from functools import partial +from inspect import isclass +from typing import Callable, Dict, Iterable, Optional, Tuple, Type, Union + +from vllm.entrypoints.openai.fim.codellama_fim import CodeLlamaFIMEncoder +from vllm.entrypoints.openai.fim.fim_encoder import (FIMEncoder, + StringTemplateFIMEncoder) +from vllm.entrypoints.openai.fim.mistral_fim import MistralFIMEncoder +from vllm.transformers_utils.tokenizer import AnyTokenizer + +__all__ = [ + "FIMEncoder", "get_supported_fim_encoders", "get_fim_encoder_lookup" +] + +# Entries are either an FIMEncoder implementation class or +# tuple of (template, special_tokens_list). +_FIM_ENCODERS: Dict[str, Union[Type, Tuple[str, Iterable[str]]]] = { + "mistral": + MistralFIMEncoder, + "codellama": + CodeLlamaFIMEncoder, + "deepseek": ( + "<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>", + ("<|fim▁begin|>", "<|fim▁hole|>", "<|fim▁end|>"), + ), + "starcoder": ( + "{prefix}{suffix}", + ("", "", ""), + ) +} + + +def get_supported_fim_encoders() -> Iterable[str]: + """Return set of supported FIM encoder types.""" + return _FIM_ENCODERS.keys() + + +def get_fim_encoder_lookup( + name: Optional[str]) -> Optional[Callable[[AnyTokenizer], FIMEncoder]]: + """ + Get a function that returns a FIMEncoder instance for a given tokenizer. + Raise a KeyError exception if the name is not recognized. + """ + if name is None: + return None + + if (encoder := _FIM_ENCODERS.get(name)) is None: + raise ValueError(f"fim encoder '{name}' not recognized") + + factory: Callable[[AnyTokenizer], FIMEncoder] + if isclass(encoder): + assert issubclass(encoder, FIMEncoder) + factory = encoder + else: + assert isinstance(encoder, tuple) + template, special_tokens = encoder + factory = partial(StringTemplateFIMEncoder, + name=name, + template=template, + special_tokens=special_tokens) + + def for_tokenizer(tokenizer: AnyTokenizer) -> FIMEncoder: + fim_encoder = getattr(tokenizer, "fim_encoder", None) + if fim_encoder is None: + fim_encoder = factory(tokenizer) + tokenizer.fim_encoder = fim_encoder # type: ignore[union-attr] + return fim_encoder + + return for_tokenizer diff --git a/vllm/entrypoints/openai/fim/codellama_fim.py b/vllm/entrypoints/openai/fim/codellama_fim.py new file mode 100644 index 000000000000..224d34d2288d --- /dev/null +++ b/vllm/entrypoints/openai/fim/codellama_fim.py @@ -0,0 +1,41 @@ +from typing import List + +from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class CodeLlamaFIMEncoder(FIMEncoder): + """ + FIM Encoder for Meta CodeLlama models + + Adapted from https://github.com/meta-llama/codellama/blob/e81b597e44dbecc2a0dedb9949fdf84adfc22395/llama/generation.py#L474 + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not hasattr(tokenizer, "convert_tokens_to_ids"): + raise ValueError( + "tokenizer incompatible with 'codellama' FIM encoder") + + self.bos_id = tokenizer.convert_tokens_to_ids("") + self.prefix_id = tokenizer.convert_tokens_to_ids("▁
")
+        self.suffix_id = tokenizer.convert_tokens_to_ids("▁")
+        self.middle_id = tokenizer.convert_tokens_to_ids("▁")
+
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        if any(tid in
+               {self.bos_id, self.prefix_id, self.suffix_id, self.middle_id}
+               for tid in (None, unk_token_id)):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        prefix_tokens = self.tokenizer(prefix,
+                                       add_special_tokens=False).input_ids
+        # Encode a string without an implicit leading space.
+        suffix_tokens = self.tokenizer("☺" + suffix,
+                                       add_special_tokens=False).input_ids[2:]
+
+        return ([self.bos_id, self.prefix_id] + prefix_tokens[self.suffix_id] +
+                suffix_tokens + [self.middle_id])
diff --git a/vllm/entrypoints/openai/fim/fim_encoder.py b/vllm/entrypoints/openai/fim/fim_encoder.py
new file mode 100644
index 000000000000..9b6f27a7a1e3
--- /dev/null
+++ b/vllm/entrypoints/openai/fim/fim_encoder.py
@@ -0,0 +1,52 @@
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Optional
+
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+
+
+class FIMEncoder(ABC):
+    """
+    An encoder of fill-in-the-middle (FIM) prompts comprising prefix
+    and suffix strings.
+    """
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        self.tokenizer = tokenizer
+
+    @abstractmethod
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        """
+        Encode the provided prompt prefix and suffix
+        to a list of token ids
+        """
+        pass
+
+
+class StringTemplateFIMEncoder(FIMEncoder):
+    """FIMEncoder implementation using a simple string template
+    with prefix and suffix variables."""
+
+    def __init__(
+        self,
+        tokenizer: AnyTokenizer,
+        name: str,
+        template: str,
+        special_tokens: Optional[Iterable[str]] = None,
+    ):
+        super().__init__(tokenizer)
+
+        if not hasattr(tokenizer, "convert_tokens_to_ids"):
+            raise ValueError(
+                "tokenizer incompatible with 'codellama' FIM encoder")
+
+        unk_token_id = getattr(tokenizer, "unk_token_id", None)
+        for special_token in special_tokens or ():
+            token_id = tokenizer.convert_tokens_to_ids(special_token)
+            if token_id is None or token_id == unk_token_id:
+                raise ValueError(
+                    f"tokenizer incompatible with '{name}' FIM encoder")
+        self.template = template
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        prompt = self.template.format(prefix=prefix, suffix=suffix)
+        return self.tokenizer(prompt, add_special_tokens=False).input_ids
diff --git a/vllm/entrypoints/openai/fim/mistral_fim.py b/vllm/entrypoints/openai/fim/mistral_fim.py
new file mode 100644
index 000000000000..21fd1cca9e21
--- /dev/null
+++ b/vllm/entrypoints/openai/fim/mistral_fim.py
@@ -0,0 +1,22 @@
+from typing import List
+
+from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerV2
+
+from vllm.entrypoints.openai.fim.fim_encoder import FIMEncoder
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.transformers_utils.tokenizers import MistralTokenizer
+
+
+class MistralFIMEncoder(FIMEncoder):
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        super().__init__(tokenizer)
+
+        # InstructTokenizerV3 is a subclass of InstructTokenizerV2
+        if not isinstance(tokenizer, MistralTokenizer) \
+            or not isinstance(tokenizer.instruct, InstructTokenizerV2):
+            raise ValueError(
+                "tokenizer incompatible with 'mistral' FIM encoder")
+
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        return self.tokenizer.encode_with_suffix(prefix=prefix, suffix=suffix)
diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py
index 6f1135f8093b..733decf80a71 100644
--- a/vllm/entrypoints/openai/protocol.py
+++ b/vllm/entrypoints/openai/protocol.py
@@ -284,6 +284,12 @@ class ChatCompletionRequest(OpenAIBaseModel):
             "The priority of the request (lower means earlier handling; "
             "default: 0). Any priority other than 0 will raise an error "
             "if the served model does not use priority scheduling."))
+    request_id: str = Field(
+        default_factory=lambda: f"{random_uuid()}",
+        description=(
+            "The request_id related to this request. If the caller does "
+            "not set it, a random_uuid will be generated. This id is used "
+            "through out the inference process and return in response."))
 
     # doc: end-chat-completion-extra-params
 
@@ -314,9 +320,15 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
             prompt_logprobs = self.top_logprobs
 
         guided_json_object = None
-        if (self.response_format is not None
-                and self.response_format.type == "json_object"):
-            guided_json_object = True
+        if self.response_format is not None:
+            if self.response_format.type == "json_object":
+                guided_json_object = True
+            elif self.response_format.type == "json_schema":
+                json_schema = self.response_format.json_schema
+                assert json_schema is not None
+                self.guided_json = json_schema.json_schema
+                if self.guided_decoding_backend is None:
+                    self.guided_decoding_backend = "lm-format-enforcer"
 
         guided_decoding = GuidedDecodingParams.from_optional(
             json=self._get_guided_json_from_tool() or self.guided_json,
@@ -537,8 +549,8 @@ class CompletionRequest(OpenAIBaseModel):
         default=None,
         description=
         ("Similar to chat completion, this parameter specifies the format of "
-         "output. Only {'type': 'json_object'} or {'type': 'text' } is "
-         "supported."),
+         "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
+         "{'type': 'text' } is supported."),
     )
     guided_json: Optional[Union[str, dict, BaseModel]] = Field(
         default=None,
diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py
index 4931195ae0e0..cd2883a3b323 100644
--- a/vllm/entrypoints/openai/serving_chat.py
+++ b/vllm/entrypoints/openai/serving_chat.py
@@ -9,8 +9,6 @@
 from fastapi import Request
 
 from vllm.config import ModelConfig
-from vllm.engine.async_llm_engine import AsyncLLMEngine
-from vllm.engine.multiprocessing.client import MQLLMEngineClient
 from vllm.engine.protocol import EngineClient
 from vllm.entrypoints.chat_utils import (ConversationMessage,
                                          apply_hf_chat_template,
@@ -40,7 +38,7 @@
 from vllm.tracing import (contains_trace_headers, extract_trace_headers,
                           log_tracing_disabled_warning)
 from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
-from vllm.utils import iterate_with_cancellation, random_uuid
+from vllm.utils import iterate_with_cancellation
 
 logger = init_logger(__name__)
 
@@ -178,7 +176,7 @@ async def create_chat_completion(
                 "\"auto\" tool choice requires "
                 "--enable-auto-tool-choice and --tool-call-parser to be set")
 
-        request_id = f"chat-{random_uuid()}"
+        request_id = f"chat-{request.request_id}"
 
         request_metadata = RequestResponseMetadata(request_id=request_id)
         if raw_request:
@@ -237,11 +235,6 @@ async def create_chat_completion(
                 log_tracing_disabled_warning()
 
             if isinstance(sampling_params, BeamSearchParams):
-                assert isinstance(self.engine_client,
-                                    (AsyncLLMEngine,
-                                    MQLLMEngineClient)), \
-                    "Beam search is only supported with" \
-                    "AsyncLLMEngine and MQLLMEngineClient."
                 result_generator = self.engine_client.beam_search(
                     engine_inputs['prompt_token_ids'],
                     request_id,
@@ -331,12 +324,20 @@ async def chat_completion_stream_generator(
             else:
                 tool_parsers = [None] * num_choices
         except RuntimeError as e:
-            logger.error("Error in tool parser creation: %s", e)
+            logger.exception("Error in tool parser creation.")
             data = self.create_streaming_error_response(str(e))
             yield f"data: {data}\n\n"
             yield "data: [DONE]\n\n"
             return
 
+        stream_options = request.stream_options
+        if stream_options:
+            include_usage = stream_options.include_usage
+            include_continuous_usage = include_usage and \
+                                       stream_options.continuous_usage_stats
+        else:
+            include_usage, include_continuous_usage = False, False
+
         try:
             async for res in result_generator:
                 if res.prompt_token_ids is not None:
@@ -355,7 +356,6 @@ async def chat_completion_stream_generator(
                     # NOTE num_choices defaults to 1 so this usually executes
                     # once per request
                     for i in range(num_choices):
-                        tool_parser = tool_parsers[i]
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
                             delta=DeltaMessage(
@@ -371,19 +371,12 @@ async def chat_completion_stream_generator(
                             choices=[choice_data],
                             model=model_name)
 
-                        # if usage should be included
-                        if (request.stream_options
-                                and request.stream_options.include_usage):
-                            # if continuous usage stats are requested, add it
-                            if request.stream_options.continuous_usage_stats:
-                                usage = UsageInfo(
-                                    prompt_tokens=num_prompt_tokens,
-                                    completion_tokens=0,
-                                    total_tokens=num_prompt_tokens)
-                                chunk.usage = usage
-                            # otherwise don't
-                            else:
-                                chunk.usage = None
+                        # if continuous usage stats are requested, add it
+                        if include_continuous_usage:
+                            chunk.usage = UsageInfo(
+                                prompt_tokens=num_prompt_tokens,
+                                completion_tokens=0,
+                                total_tokens=num_prompt_tokens)
 
                         data = chunk.model_dump_json(exclude_unset=True)
                         yield f"data: {data}\n\n"
@@ -391,7 +384,7 @@ async def chat_completion_stream_generator(
                     # Send response to echo the input portion of the
                     # last message
                     if request.echo or request.continue_final_message:
-                        last_msg_content: str = ""
+                        last_msg_content: Union[str, List[Dict[str, str]]] = ""
                         if conversation and "content" in conversation[
                                 -1] and conversation[-1].get("role") == role:
                             last_msg_content = conversation[-1]["content"] or ""
@@ -411,17 +404,11 @@ async def chat_completion_stream_generator(
                                     created=created_time,
                                     choices=[choice_data],
                                     model=model_name)
-                                if (request.stream_options and
-                                        request.stream_options.include_usage):
-                                    if (request.stream_options.
-                                            continuous_usage_stats):
-                                        usage = UsageInfo(
-                                            prompt_tokens=num_prompt_tokens,
-                                            completion_tokens=0,
-                                            total_tokens=num_prompt_tokens)
-                                        chunk.usage = usage
-                                    else:
-                                        chunk.usage = None
+                                if include_continuous_usage:
+                                    chunk.usage = UsageInfo(
+                                        prompt_tokens=num_prompt_tokens,
+                                        completion_tokens=0,
+                                        total_tokens=num_prompt_tokens)
 
                                 data = chunk.model_dump_json(
                                     exclude_unset=True)
@@ -448,6 +435,12 @@ async def chat_completion_stream_generator(
                         logprobs = None
 
                     delta_text = output.text
+
+                    if not delta_text and not output.token_ids and \
+                        not previous_num_tokens[i]:
+                        # Chunked prefill case, don't return empty chunks
+                        continue
+
                     delta_message: Optional[DeltaMessage]
 
                     # handle streaming deltas for tools with named tool_choice
@@ -501,36 +494,11 @@ async def chat_completion_stream_generator(
 
                     if output.finish_reason is None:
                         # Send token-by-token response for each request.n
-
                         choice_data = ChatCompletionResponseStreamChoice(
                             index=i,
                             delta=delta_message,
                             logprobs=logprobs,
                             finish_reason=None)
-                        chunk = ChatCompletionStreamResponse(
-                            id=request_id,
-                            object=chunk_object_type,
-                            created=created_time,
-                            choices=[choice_data],
-                            model=model_name)
-
-                        # handle usage stats if requested & if continuous
-                        if (request.stream_options
-                                and request.stream_options.include_usage):
-                            if request.stream_options.continuous_usage_stats:
-                                completion_tokens = len(output.token_ids)
-                                usage = UsageInfo(
-                                    prompt_tokens=num_prompt_tokens,
-                                    completion_tokens=completion_tokens,
-                                    total_tokens=num_prompt_tokens +
-                                    completion_tokens,
-                                )
-                                chunk.usage = usage
-                            else:
-                                chunk.usage = None
-
-                        data = chunk.model_dump_json(exclude_unset=True)
-                        yield f"data: {data}\n\n"
 
                     # if the model is finished generating
                     else:
@@ -580,34 +548,32 @@ async def chat_completion_stream_generator(
                             finish_reason=output.finish_reason
                             if not auto_tools_called else "tool_calls",
                             stop_reason=output.stop_reason)
-                        chunk = ChatCompletionStreamResponse(
-                            id=request_id,
-                            object=chunk_object_type,
-                            created=created_time,
-                            choices=[choice_data],
-                            model=model_name)
-                        if (request.stream_options
-                                and request.stream_options.include_usage):
-                            if request.stream_options.continuous_usage_stats:
-                                completion_tokens = len(output.token_ids)
-                                usage = UsageInfo(
-                                    prompt_tokens=num_prompt_tokens,
-                                    completion_tokens=completion_tokens,
-                                    total_tokens=num_prompt_tokens +
-                                    completion_tokens,
-                                )
-                                chunk.usage = usage
-                            else:
-                                chunk.usage = None
-                        data = chunk.model_dump_json(exclude_unset=True)
-                        yield f"data: {data}\n\n"
+
                         finish_reason_sent[i] = True
 
+                    chunk = ChatCompletionStreamResponse(
+                        id=request_id,
+                        object=chunk_object_type,
+                        created=created_time,
+                        choices=[choice_data],
+                        model=model_name)
+
+                    # handle usage stats if requested & if continuous
+                    if include_continuous_usage:
+                        completion_tokens = previous_num_tokens[i]
+                        chunk.usage = UsageInfo(
+                            prompt_tokens=num_prompt_tokens,
+                            completion_tokens=completion_tokens,
+                            total_tokens=num_prompt_tokens + completion_tokens,
+                        )
+
+                    data = chunk.model_dump_json(exclude_unset=True)
+                    yield f"data: {data}\n\n"
+
             # once the final token is handled, if stream_options.include_usage
             # is sent, send the usage
-            if (request.stream_options
-                    and request.stream_options.include_usage):
-                completion_tokens = previous_num_tokens[i]
+            if include_usage:
+                completion_tokens = sum(previous_num_tokens)
                 final_usage = UsageInfo(
                     prompt_tokens=num_prompt_tokens,
                     completion_tokens=completion_tokens,
@@ -634,7 +600,7 @@ async def chat_completion_stream_generator(
 
         except ValueError as e:
             # TODO: Use a vllm-specific Validation Error
-            logger.error("error in chat completion stream generator: %s", e)
+            logger.exception("Error in chat completion stream generator.")
             data = self.create_streaming_error_response(str(e))
             yield f"data: {data}\n\n"
         # Send the final done message after all response.n are finished
@@ -721,7 +687,7 @@ async def chat_completion_full_generator(
                 try:
                     tool_parser = self.tool_parser(tokenizer)
                 except RuntimeError as e:
-                    logger.error("Error in tool parser creation: %s", e)
+                    logger.exception("Error in tool parser creation.")
                     return self.create_error_response(str(e))
 
                 tool_call_info = tool_parser.extract_tool_calls(
@@ -758,10 +724,13 @@ async def chat_completion_full_generator(
             choices.append(choice_data)
 
         if request.echo or request.continue_final_message:
-            last_msg_content = ""
+            last_msg_content: Union[str, List[Dict[str, str]]] = ""
             if conversation and "content" in conversation[-1] and conversation[
                     -1].get("role") == role:
                 last_msg_content = conversation[-1]["content"] or ""
+            if isinstance(last_msg_content, list):
+                last_msg_content = "\n".join(msg['text']
+                                             for msg in last_msg_content)
 
             for choice in choices:
                 full_message = last_msg_content + (choice.message.content
diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py
index 077312dd1414..7fcf0ca3e3a6 100644
--- a/vllm/entrypoints/openai/serving_completion.py
+++ b/vllm/entrypoints/openai/serving_completion.py
@@ -8,8 +8,6 @@
 from fastapi import Request
 
 from vllm.config import ModelConfig
-from vllm.engine.async_llm_engine import AsyncLLMEngine
-from vllm.engine.multiprocessing.client import MQLLMEngineClient
 from vllm.engine.protocol import EngineClient
 from vllm.entrypoints.logger import RequestLogger
 # yapf conflicts with isort for this block
@@ -57,6 +55,7 @@ def __init__(
         prompt_adapters: Optional[List[PromptAdapterPath]],
         request_logger: Optional[RequestLogger],
         return_tokens_as_token_ids: bool = False,
+        fim_encoder: Optional[str] = None,
     ):
         super().__init__(engine_client=engine_client,
                          model_config=model_config,
@@ -64,7 +63,8 @@ def __init__(
                          lora_modules=lora_modules,
                          prompt_adapters=prompt_adapters,
                          request_logger=request_logger,
-                         return_tokens_as_token_ids=return_tokens_as_token_ids)
+                         return_tokens_as_token_ids=return_tokens_as_token_ids,
+                         fim_encoder=fim_encoder)
 
     async def create_completion(
         self,
@@ -76,9 +76,6 @@ async def create_completion(
         See https://platform.openai.com/docs/api-reference/completions/create
         for the API specification. This API mimics the OpenAI Completion API.
 
-        NOTE: Currently we do not support the following feature:
-            - suffix (the language models we currently support do not support
-            suffix)
         """
         error_check_ret = await self._check_model(request)
         if error_check_ret is not None:
@@ -90,11 +87,6 @@ async def create_completion(
         if self.engine_client.errored:
             raise self.engine_client.dead_error
 
-        # Return error for unsupported features.
-        if request.suffix is not None:
-            return self.create_error_response(
-                "suffix is not currently supported")
-
         model_name = self.base_model_paths[0].name
         request_id = f"cmpl-{random_uuid()}"
         created_time = int(time.time())
@@ -118,6 +110,7 @@ async def create_completion(
                     request,
                     tokenizer,
                     request.prompt,
+                    suffix=request.suffix,
                     truncate_prompt_tokens=request.truncate_prompt_tokens,
                     add_special_tokens=request.add_special_tokens,
                 ))
@@ -151,11 +144,6 @@ async def create_completion(
                     log_tracing_disabled_warning()
 
                 if isinstance(sampling_params, BeamSearchParams):
-                    assert isinstance(self.engine_client,
-                                    (AsyncLLMEngine,
-                                    MQLLMEngineClient)), \
-                    "Beam search is only supported with" \
-                    "AsyncLLMEngine and MQLLMEngineClient."
                     generator = self.engine_client.beam_search(
                         prompt_inputs["prompt_token_ids"],
                         request_id_item,
@@ -265,6 +253,14 @@ async def completion_stream_generator(
         has_echoed = [False] * num_choices * num_prompts
         num_prompt_tokens = [0] * num_prompts
 
+        stream_options = request.stream_options
+        if stream_options:
+            include_usage = stream_options.include_usage
+            include_continuous_usage = include_usage and \
+                                       stream_options.continuous_usage_stats
+        else:
+            include_usage, include_continuous_usage = False, False
+
         try:
             async for prompt_idx, res in result_generator:
                 prompt_token_ids = res.prompt_token_ids
@@ -281,32 +277,27 @@ async def completion_stream_generator(
 
                 for output in res.outputs:
                     i = output.index + prompt_idx * num_choices
-                    # TODO(simon): optimize the performance by avoiding full
-                    # text O(n^2) sending.
 
                     assert request.max_tokens is not None
-                    if request.echo and request.max_tokens == 0:
-                        assert prompt_token_ids is not None
-                        assert prompt_text is not None
-                        # only return the prompt
-                        delta_text = prompt_text
-                        delta_token_ids = prompt_token_ids
-                        out_logprobs = prompt_logprobs
-                        has_echoed[i] = True
-                    elif (request.echo and request.max_tokens > 0
-                          and not has_echoed[i]):
+                    if request.echo and not has_echoed[i]:
                         assert prompt_token_ids is not None
                         assert prompt_text is not None
-                        assert prompt_logprobs is not None
-                        # echo the prompt and first token
-                        delta_text = prompt_text + output.text
-                        delta_token_ids = [
-                            *prompt_token_ids, *output.token_ids
-                        ]
-                        out_logprobs = [
-                            *prompt_logprobs,
-                            *(output.logprobs or []),
-                        ]
+                        if request.max_tokens == 0:
+                            # only return the prompt
+                            delta_text = prompt_text
+                            delta_token_ids = prompt_token_ids
+                            out_logprobs = prompt_logprobs
+                        else:
+                            assert prompt_logprobs is not None
+                            # echo the prompt and first token
+                            delta_text = prompt_text + output.text
+                            delta_token_ids = [
+                                *prompt_token_ids, *output.token_ids
+                            ]
+                            out_logprobs = [
+                                *prompt_logprobs,
+                                *(output.logprobs or []),
+                            ]
                         has_echoed[i] = True
                     else:
                         # return just the delta
@@ -314,6 +305,19 @@ async def completion_stream_generator(
                         delta_token_ids = output.token_ids
                         out_logprobs = output.logprobs
 
+                        if not delta_text and not delta_token_ids \
+                            and not previous_num_tokens[i]:
+                            # Chunked prefill case, don't return empty chunks
+                            continue
+
+                    previous_text_lens[i] += len(output.text)
+                    previous_num_tokens[i] += len(output.token_ids)
+                    finish_reason = output.finish_reason
+                    stop_reason = output.stop_reason
+
+                    if finish_reason and request.echo and request.suffix:
+                        delta_text += request.suffix
+
                     if request.logprobs is not None:
                         assert out_logprobs is not None, (
                             "Did not output logprobs")
@@ -327,11 +331,6 @@ async def completion_stream_generator(
                     else:
                         logprobs = None
 
-                    previous_text_lens[i] += len(output.text)
-                    previous_num_tokens[i] += len(output.token_ids)
-                    finish_reason = output.finish_reason
-                    stop_reason = output.stop_reason
-
                     chunk = CompletionStreamResponse(
                         id=request_id,
                         created=created_time,
@@ -345,45 +344,39 @@ async def completion_stream_generator(
                                 stop_reason=stop_reason,
                             )
                         ])
-                    if (request.stream_options
-                            and request.stream_options.include_usage):
-                        if (request.stream_options.continuous_usage_stats
-                                or output.finish_reason is not None):
-                            prompt_tokens = num_prompt_tokens[prompt_idx]
-                            completion_tokens = previous_num_tokens[i]
-                            usage = UsageInfo(
-                                prompt_tokens=prompt_tokens,
-                                completion_tokens=completion_tokens,
-                                total_tokens=prompt_tokens + completion_tokens,
-                            )
-                        if request.stream_options.continuous_usage_stats:
-                            chunk.usage = usage
-                        else:
-                            chunk.usage = None
+                    if include_continuous_usage:
+                        prompt_tokens = num_prompt_tokens[prompt_idx]
+                        completion_tokens = previous_num_tokens[i]
+                        chunk.usage = UsageInfo(
+                            prompt_tokens=prompt_tokens,
+                            completion_tokens=completion_tokens,
+                            total_tokens=prompt_tokens + completion_tokens,
+                        )
 
                     response_json = chunk.model_dump_json(exclude_unset=False)
                     yield f"data: {response_json}\n\n"
 
-            if (request.stream_options
-                    and request.stream_options.include_usage):
+            total_prompt_tokens = sum(num_prompt_tokens)
+            total_completion_tokens = sum(previous_num_tokens)
+            final_usage_info = UsageInfo(
+                prompt_tokens=total_prompt_tokens,
+                completion_tokens=total_completion_tokens,
+                total_tokens=total_prompt_tokens + total_completion_tokens)
+
+            if include_usage:
                 final_usage_chunk = CompletionStreamResponse(
                     id=request_id,
                     created=created_time,
                     model=model_name,
                     choices=[],
-                    usage=usage,
+                    usage=final_usage_info,
                 )
                 final_usage_data = (final_usage_chunk.model_dump_json(
                     exclude_unset=False, exclude_none=True))
                 yield f"data: {final_usage_data}\n\n"
 
             # report to FastAPI middleware aggregate usage across all choices
-            total_prompt_tokens = sum(num_prompt_tokens)
-            total_completion_tokens = sum(previous_num_tokens)
-            request_metadata.final_usage_info = UsageInfo(
-                prompt_tokens=total_prompt_tokens,
-                completion_tokens=total_completion_tokens,
-                total_tokens=total_prompt_tokens + total_completion_tokens)
+            request_metadata.final_usage_info = final_usage_info
 
         except ValueError as e:
             # TODO: Use a vllm-specific Validation Error
@@ -405,6 +398,8 @@ def request_output_to_completion_response(
         num_prompt_tokens = 0
         num_generated_tokens = 0
 
+        suffix = "" if request.suffix is None else request.suffix
+
         for final_res in final_res_batch:
             prompt_token_ids = final_res.prompt_token_ids
             assert prompt_token_ids is not None
@@ -414,29 +409,28 @@ def request_output_to_completion_response(
             token_ids: GenericSequence[int]
             out_logprobs: Optional[GenericSequence[Optional[Dict[int,
                                                                  Logprob]]]]
-
             for output in final_res.outputs:
                 assert request.max_tokens is not None
-                if request.echo and request.max_tokens == 0:
+                if request.echo:
                     assert prompt_text is not None
-                    token_ids = prompt_token_ids
-                    out_logprobs = prompt_logprobs
-                    output_text = prompt_text
-                elif request.echo and request.max_tokens > 0:
-                    assert prompt_text is not None
-                    token_ids = [*prompt_token_ids, *output.token_ids]
-
-                    if request.logprobs is None:
-                        out_logprobs = None
+                    if request.max_tokens == 0:
+                        token_ids = prompt_token_ids
+                        out_logprobs = prompt_logprobs
+                        output_text = prompt_text + suffix
                     else:
-                        assert prompt_logprobs is not None
-                        assert output.logprobs is not None
-                        out_logprobs = [
-                            *prompt_logprobs,
-                            *output.logprobs,
-                        ]
-
-                    output_text = prompt_text + output.text
+                        token_ids = [*prompt_token_ids, *output.token_ids]
+
+                        if request.logprobs is None:
+                            out_logprobs = None
+                        else:
+                            assert prompt_logprobs is not None
+                            assert output.logprobs is not None
+                            out_logprobs = [
+                                *prompt_logprobs,
+                                *output.logprobs,
+                            ]
+
+                        output_text = prompt_text + output.text + suffix
                 else:
                     token_ids = output.token_ids
                     out_logprobs = output.logprobs
diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py
index e9504cfa64b6..eb498df1e52d 100644
--- a/vllm/entrypoints/openai/serving_embedding.py
+++ b/vllm/entrypoints/openai/serving_embedding.py
@@ -83,7 +83,8 @@ def __init__(
                          lora_modules=None,
                          prompt_adapters=None,
                          request_logger=request_logger)
-        self._enabled = self._check_embedding_mode(model_config.embedding_mode)
+        self._enabled = self._check_embedding_mode(
+            model_config.task == "embedding")
 
     async def create_embedding(
         self,
@@ -134,9 +135,11 @@ async def create_embedding(
             pooling_params = request.to_pooling_params()
 
             prompts = list(
-                self._tokenize_prompt_input_or_inputs(request, tokenizer,
-                                                      request.input,
-                                                      truncate_prompt_tokens))
+                self._tokenize_prompt_input_or_inputs(
+                    request,
+                    tokenizer,
+                    request.input,
+                    truncate_prompt_tokens=truncate_prompt_tokens))
 
             for i, prompt_inputs in enumerate(prompts):
                 request_id_item = f"{request_id}-{i}"
diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py
index e6d2ab93d336..1a313a5c63fd 100644
--- a/vllm/entrypoints/openai/serving_engine.py
+++ b/vllm/entrypoints/openai/serving_engine.py
@@ -2,7 +2,8 @@
 import pathlib
 from dataclasses import dataclass
 from http import HTTPStatus
-from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
+from typing import (Callable, Iterable, Iterator, List, Optional, Tuple,
+                    TypedDict, Union)
 
 from pydantic import Field
 from typing_extensions import Annotated
@@ -10,6 +11,7 @@
 from vllm.config import ModelConfig
 from vllm.engine.protocol import EngineClient
 from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.fim import FIMEncoder, get_fim_encoder_lookup
 # yapf conflicts with isort for this block
 # yapf: disable
 from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
@@ -77,6 +79,7 @@ def __init__(
         prompt_adapters: Optional[List[PromptAdapterPath]],
         request_logger: Optional[RequestLogger],
         return_tokens_as_token_ids: bool = False,
+        fim_encoder: Optional[str] = None,
     ):
         super().__init__()
 
@@ -117,6 +120,9 @@ def __init__(
         self.request_logger = request_logger
         self.return_tokens_as_token_ids = return_tokens_as_token_ids
 
+        self.get_fim_encoder: Optional[Callable[[AnyTokenizer], FIMEncoder]] = \
+            get_fim_encoder_lookup(fim_encoder)
+
     async def show_available_models(self) -> ModelList:
         """Show available models. Right now we only have one model."""
         model_cards = [
@@ -204,22 +210,32 @@ def _normalize_prompt_text_to_input(
         request: AnyRequest,
         tokenizer: AnyTokenizer,
         prompt: str,
+        suffix: Optional[str],
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
         add_special_tokens: bool,
     ) -> TextTokensPrompt:
-        if truncate_prompt_tokens is None:
-            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
+        if suffix:
+            if not (get_fim_encoder := self.get_fim_encoder):
+                raise ValueError("fim support must be enabled to use suffix")
+            if truncate_prompt_tokens is not None:
+                raise ValueError(
+                    "truncate_prompt_tokens is not supported with suffix")
+            fim_encoder = get_fim_encoder(tokenizer)
+            input_ids = fim_encoder.encode_with_suffix(prefix=prompt,
+                                                       suffix=suffix)
         else:
-            encoded = tokenizer(prompt,
-                                add_special_tokens=add_special_tokens,
-                                truncation=True,
-                                max_length=truncate_prompt_tokens)
-
-        input_ids = encoded.input_ids
+            if truncate_prompt_tokens is None:
+                encoded = tokenizer(prompt,
+                                    add_special_tokens=add_special_tokens)
+            else:
+                encoded = tokenizer(prompt,
+                                    add_special_tokens=add_special_tokens,
+                                    truncation=True,
+                                    max_length=truncate_prompt_tokens)
 
-        input_text = prompt
+            input_ids = encoded.input_ids
 
-        return self._validate_input(request, input_ids, input_text)
+        return self._validate_input(request, input_ids, input_text=prompt)
 
     def _normalize_prompt_tokens_to_input(
         self,
@@ -307,6 +323,7 @@ def _tokenize_prompt_inputs(
         request: AnyRequest,
         tokenizer: AnyTokenizer,
         prompt_inputs: Iterable[Union[str, List[int]]],
+        suffix: Optional[str] = None,
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
         add_special_tokens: bool = True,
     ) -> Iterator[TextTokensPrompt]:
@@ -320,10 +337,14 @@ def _tokenize_prompt_inputs(
                     request,
                     tokenizer,
                     prompt=text,
+                    suffix=suffix,
                     truncate_prompt_tokens=truncate_prompt_tokens,
                     add_special_tokens=add_special_tokens,
                 )
             else:
+                if suffix:
+                    raise ValueError(
+                        "suffix is only supported with string prompt input")
                 yield self._normalize_prompt_tokens_to_input(
                     request,
                     tokenizer,
@@ -336,6 +357,7 @@ def _tokenize_prompt_input_or_inputs(
         request: AnyRequest,
         tokenizer: AnyTokenizer,
         input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
+        suffix: Optional[str] = None,
         truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
         add_special_tokens: bool = True,
     ) -> Iterator[TextTokensPrompt]:
@@ -356,10 +378,14 @@ def _tokenize_prompt_input_or_inputs(
                     request,
                     tokenizer,
                     prompt=prompt_input["content"],
+                    suffix=suffix,
                     truncate_prompt_tokens=truncate_prompt_tokens,
                     add_special_tokens=add_special_tokens,
                 )
             else:
+                if suffix:
+                    raise ValueError(
+                        "suffix is only supported with string prompt input")
                 yield self._normalize_prompt_tokens_to_input(
                     request,
                     tokenizer,
diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py
index 309d9bede489..a5cd176cd7b7 100644
--- a/vllm/entrypoints/openai/tool_parsers/__init__.py
+++ b/vllm/entrypoints/openai/tool_parsers/__init__.py
@@ -1,10 +1,14 @@
 from .abstract_tool_parser import ToolParser, ToolParserManager
+from .granite_20b_fc_tool_parser import Granite20bFCToolParser
+from .granite_tool_parser import GraniteToolParser
 from .hermes_tool_parser import Hermes2ProToolParser
 from .internlm2_tool_parser import Internlm2ToolParser
+from .jamba_tool_parser import JambaToolParser
 from .llama_tool_parser import Llama3JsonToolParser
 from .mistral_tool_parser import MistralToolParser
 
 __all__ = [
-    "ToolParser", "ToolParserManager", "Hermes2ProToolParser",
-    "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
+    "ToolParser", "ToolParserManager", "Granite20bFCToolParser",
+    "GraniteToolParser", "Hermes2ProToolParser", "Internlm2ToolParser",
+    "JambaToolParser",  "Llama3JsonToolParser", "MistralToolParser"
 ]
diff --git a/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
new file mode 100644
index 000000000000..9ef796cd4321
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py
@@ -0,0 +1,253 @@
+import json
+import re
+from json import JSONDecoder
+from typing import Dict, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+                                              DeltaFunctionCall, DeltaMessage,
+                                              DeltaToolCall,
+                                              ExtractedToolCallInformation,
+                                              FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+    ToolParser, ToolParserManager)
+from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
+                                                        find_common_prefix,
+                                                        is_complete_json,
+                                                        partial_json_loads)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("granite-20b-fc")
+class Granite20bFCToolParser(ToolParser):
+    """
+    Tool call parser for the granite-20b-functioncalling model intended
+    for use with the examples/tool_chat_template_granite20b_fc.jinja
+    template.
+
+    Used when --enable-auto-tool-choice --tool-call-parser granite-20-fc
+    are all set
+    """
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        super().__init__(tokenizer)
+
+        self.bot_token = ""
+        self.tool_start_token = self.bot_token
+        self.tool_call_regex = re.compile(r"\s*")
+
+    def extract_tool_calls(
+            self, model_output: str,
+            request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+        if self.tool_start_token not in model_output:
+            return ExtractedToolCallInformation(tools_called=False,
+                                                tool_calls=[],
+                                                content=model_output)
+
+        else:
+            dec = JSONDecoder()
+            try:
+                matches = list(self.tool_call_regex.finditer(model_output))
+                logger.debug("Found %d tool call matches", len(matches))
+
+                raw_function_calls = []
+
+                for i, match in enumerate(matches):
+                    # position after the  tag
+                    start_of_json = match.end()
+                    # end_index == the start of the next function call
+                    # (if exists)
+                    next_function_call_start = (matches[i + 1].start() if
+                                                i + 1 < len(matches) else None)
+
+                    raw_function_calls.append(
+                        dec.raw_decode(model_output[
+                            start_of_json:next_function_call_start])[0])
+
+                logger.debug("Extracted %d tool calls",
+                             len(raw_function_calls))
+                tool_calls = [
+                    ToolCall(
+                        type="function",
+                        function=FunctionCall(
+                            name=function_call["name"],
+                            # function call args are JSON but as a string
+                            arguments=json.dumps(function_call["arguments"]),
+                        ),
+                    ) for function_call in raw_function_calls
+                ]
+
+                content = model_output[:model_output.find(self.bot_token)]
+                return ExtractedToolCallInformation(
+                    tools_called=True,
+                    tool_calls=tool_calls,
+                    content=content if content else None,
+                )
+
+            except Exception as e:
+                logger.error("Error in extracting tool call from response %s",
+                             e)
+                return ExtractedToolCallInformation(tools_called=False,
+                                                    tool_calls=[],
+                                                    content=model_output)
+
+    def extract_tool_calls_streaming(
+        self,
+        previous_text: str,
+        current_text: str,
+        delta_text: str,
+        previous_token_ids: Sequence[int],
+        current_token_ids: Sequence[int],
+        delta_token_ids: Sequence[int],
+        request: ChatCompletionRequest,
+    ) -> Union[DeltaMessage, None]:
+
+        if len(current_text) < len(
+                self.bot_token) and self.bot_token.startswith(current_text):
+            return None
+
+        if not current_text.startswith(self.bot_token):
+            return DeltaMessage(content=delta_text)
+
+        # bit mask flags for partial JSON parsing. If the name hasn't been
+        # sent yet, don't allow sending
+        # an incomplete string since OpenAI only ever (as far as I have
+        # seen) allows sending the entire tool/ function name at once.
+        flags = Allow.ALL if self.current_tool_name_sent \
+            else Allow.ALL & ~Allow.STR
+        try:
+            tool_call_arr = []
+            is_complete = []
+            try:
+                start_idx = len(self.bot_token)
+                start_idx = consume_space(start_idx, current_text)
+
+                while start_idx < len(current_text):
+                    (obj,
+                     end_idx) = partial_json_loads(current_text[start_idx:],
+                                                   flags)
+                    is_complete.append(
+                        is_complete_json(current_text[start_idx:start_idx +
+                                                      end_idx]))
+                    start_idx += end_idx
+                    start_idx = consume_space(start_idx, current_text)
+                    start_idx += len(self.bot_token)
+                    start_idx = consume_space(start_idx, current_text)
+                    tool_call_arr.append(obj)
+            except partial_json_parser.core.exceptions.MalformedJSON:
+                logger.debug('not enough tokens to parse into JSON yet')
+                return None
+
+            # select as the current tool call the one we're on the state at
+            current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
+                if len(tool_call_arr) > 0 else {}
+
+            # case -- if no tokens have been streamed for the tool, e.g.
+            #   only the array brackets, stream nothing
+            if len(tool_call_arr) == 0:
+                return None
+
+            # case: we are starting a new tool in the array
+            #   -> array has > 0 length AND length has moved past cursor
+            elif (len(tool_call_arr) > 0
+                  and len(tool_call_arr) > self.current_tool_id + 1):
+
+                # if we're moving on to a new call, first make sure we
+                # haven't missed anything in the previous one that was
+                # auto-generated due to JSON completions, but wasn't
+                # streamed to the client yet.
+                if self.current_tool_id >= 0:
+                    cur_arguments = current_tool_call.get("arguments")
+                    if cur_arguments:
+                        cur_args_json = json.dumps(cur_arguments)
+                        sent = len(
+                            self.streamed_args_for_tool[self.current_tool_id])
+                        argument_diff = cur_args_json[sent:]
+
+                        logger.debug("got arguments diff: %s", argument_diff)
+                        delta = DeltaMessage(tool_calls=[
+                            DeltaToolCall(index=self.current_tool_id,
+                                          function=DeltaFunctionCall(
+                                              arguments=argument_diff).
+                                          model_dump(exclude_none=True))
+                        ])
+                        self.streamed_args_for_tool[
+                            self.current_tool_id] += argument_diff
+                    else:
+                        delta = None
+                else:
+                    delta = None
+                # re-set stuff pertaining to progress in the current tool
+                self.current_tool_id = len(tool_call_arr) - 1
+                self.current_tool_name_sent = False
+                self.streamed_args_for_tool.append("")
+                logger.debug("starting on new tool %d", self.current_tool_id)
+                return delta
+
+            # if the current tool name hasn't been sent, send if available
+            # - otherwise send nothing
+            elif not self.current_tool_name_sent:
+                function_name = current_tool_call.get("name")
+                if function_name:
+
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      type="function",
+                                      id=f"chatcmpl-tool-{random_uuid()}",
+                                      function=DeltaFunctionCall(
+                                          name=function_name).model_dump(
+                                              exclude_none=True))
+                    ])
+                    self.current_tool_name_sent = True
+                else:
+                    delta = None
+
+            # now we know we're on the same tool call and we're streaming
+            # arguments
+            else:
+                cur_arguments = current_tool_call.get("arguments")
+                delta = None
+
+                if cur_arguments:
+                    sent = len(
+                        self.streamed_args_for_tool[self.current_tool_id])
+                    cur_args_json = json.dumps(cur_arguments)
+                    prev_arguments = self.prev_tool_call_arr[
+                        self.current_tool_id].get("arguments")
+
+                    argument_diff = None
+                    if is_complete[self.current_tool_id]:
+                        argument_diff = cur_args_json[sent:]
+                    elif prev_arguments:
+                        prev_args_json = json.dumps(prev_arguments)
+                        if cur_args_json != prev_args_json:
+
+                            prefix = find_common_prefix(
+                                prev_args_json, cur_args_json)
+                            argument_diff = prefix[sent:]
+
+                    if argument_diff is not None:
+                        delta = DeltaMessage(tool_calls=[
+                            DeltaToolCall(index=self.current_tool_id,
+                                          function=DeltaFunctionCall(
+                                              arguments=argument_diff).
+                                          model_dump(exclude_none=True))
+                        ])
+                        self.streamed_args_for_tool[
+                            self.current_tool_id] += argument_diff
+
+            self.prev_tool_call_arr = tool_call_arr
+            return delta
+
+        except Exception as e:
+            logger.error("Error trying to handle streaming tool call: %s", e)
+            logger.debug(
+                "Skipping chunk as a result of tool streaming extraction "
+                "error")
+            return None
diff --git a/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
new file mode 100644
index 000000000000..763cb5645541
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py
@@ -0,0 +1,227 @@
+import json
+from typing import Dict, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+                                              DeltaFunctionCall, DeltaMessage,
+                                              DeltaToolCall,
+                                              ExtractedToolCallInformation,
+                                              FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
+    ToolParser, ToolParserManager)
+from vllm.entrypoints.openai.tool_parsers.utils import (consume_space,
+                                                        find_common_prefix,
+                                                        is_complete_json,
+                                                        partial_json_loads)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("granite")
+class GraniteToolParser(ToolParser):
+    """
+    Tool call parser for the granite models. Intended
+    for use with the examples/tool_chat_template_granite.jinja
+    template.
+
+    Used when --enable-auto-tool-choice --tool-call-parser granite
+    are all set
+    """
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        super().__init__(tokenizer)
+
+    def extract_tool_calls(
+            self, model_output: str,
+            request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+        stripped = model_output.strip()
+        if not stripped or stripped[0] != '[':
+            return ExtractedToolCallInformation(tools_called=False,
+                                                tool_calls=[],
+                                                content=model_output)
+        else:
+            try:
+                raw_function_calls = json.loads(stripped)
+                if type(raw_function_calls) is not list:
+                    raise Exception(
+                        f"Expected dict or list, got {type(raw_function_calls)}"
+                    )
+
+                logger.debug("Extracted %d tool calls",
+                             len(raw_function_calls))
+                tool_calls = [
+                    ToolCall(
+                        type="function",
+                        function=FunctionCall(
+                            name=function_call["name"],
+                            # function call args are JSON but as a string
+                            arguments=json.dumps(function_call["arguments"]),
+                        ),
+                    ) for function_call in raw_function_calls
+                ]
+
+                return ExtractedToolCallInformation(
+                    tools_called=True,
+                    tool_calls=tool_calls,
+                    content=None,
+                )
+
+            except Exception as e:
+                logger.error("Error in extracting tool call from response %s",
+                             e)
+                return ExtractedToolCallInformation(tools_called=False,
+                                                    tool_calls=[],
+                                                    content=model_output)
+
+    def extract_tool_calls_streaming(
+        self,
+        previous_text: str,
+        current_text: str,
+        delta_text: str,
+        previous_token_ids: Sequence[int],
+        current_token_ids: Sequence[int],
+        delta_token_ids: Sequence[int],
+        request: ChatCompletionRequest,
+    ) -> Union[DeltaMessage, None]:
+
+        start_idx = consume_space(0, current_text)
+        if not current_text or current_text[start_idx] != '[':
+            return DeltaMessage(content=delta_text)
+
+        # bit mask flags for partial JSON parsing. If the name hasn't been
+        # sent yet, don't allow sending
+        # an incomplete string since OpenAI only ever (as far as I have
+        # seen) allows sending the entire tool/ function name at once.
+        flags = Allow.ALL if self.current_tool_name_sent \
+            else Allow.ALL & ~Allow.STR
+        try:
+            tool_call_arr = None
+            is_complete = None
+            try:
+                tool_calls, end_idx = partial_json_loads(
+                    current_text[start_idx:], flags)
+                if type(tool_calls) is list:
+                    tool_call_arr = tool_calls
+                else:
+                    return DeltaMessage(content=delta_text)
+
+                is_complete = [True] * len(tool_calls)
+                if not is_complete_json(
+                        current_text[start_idx:start_idx + end_idx]):
+                    is_complete[-1] = False
+            except partial_json_parser.core.exceptions.MalformedJSON:
+                logger.debug('not enough tokens to parse into JSON yet')
+                return None
+
+            # select as the current tool call the one we're on the state at
+            current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
+                if len(tool_call_arr) > 0 else {}
+
+            # case -- if no tokens have been streamed for the tool, e.g.
+            #   only the array brackets, stream nothing
+            if len(tool_call_arr) == 0:
+                return None
+
+            # case: we are starting a new tool in the array
+            #   -> array has > 0 length AND length has moved past cursor
+            elif (len(tool_call_arr) > 0
+                  and len(tool_call_arr) > self.current_tool_id + 1):
+
+                # if we're moving on to a new call, first make sure we
+                # haven't missed anything in the previous one that was
+                # auto-generated due to JSON completions, but wasn't
+                # streamed to the client yet.
+                if self.current_tool_id >= 0:
+                    cur_arguments = current_tool_call.get("arguments")
+                    if cur_arguments:
+                        cur_args_json = json.dumps(cur_arguments)
+                        sent = len(
+                            self.streamed_args_for_tool[self.current_tool_id])
+                        argument_diff = cur_args_json[sent:]
+
+                        logger.debug("got arguments diff: %s", argument_diff)
+                        delta = DeltaMessage(tool_calls=[
+                            DeltaToolCall(index=self.current_tool_id,
+                                          function=DeltaFunctionCall(
+                                              arguments=argument_diff).
+                                          model_dump(exclude_none=True))
+                        ])
+                        self.streamed_args_for_tool[
+                            self.current_tool_id] += argument_diff
+                    else:
+                        delta = None
+                else:
+                    delta = None
+                # re-set stuff pertaining to progress in the current tool
+                self.current_tool_id = len(tool_call_arr) - 1
+                self.current_tool_name_sent = False
+                self.streamed_args_for_tool.append("")
+                logger.debug("starting on new tool %d", self.current_tool_id)
+                return delta
+
+            # if the current tool name hasn't been sent, send if available
+            # - otherwise send nothing
+            elif not self.current_tool_name_sent:
+                function_name = current_tool_call.get("name")
+                if function_name:
+
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      type="function",
+                                      id=f"chatcmpl-tool-{random_uuid()}",
+                                      function=DeltaFunctionCall(
+                                          name=function_name).model_dump(
+                                              exclude_none=True))
+                    ])
+                    self.current_tool_name_sent = True
+                else:
+                    delta = None
+
+            # now we know we're on the same tool call and we're streaming
+            # arguments
+            else:
+                cur_arguments = current_tool_call.get("arguments")
+                delta = None
+
+                if cur_arguments:
+                    sent = len(
+                        self.streamed_args_for_tool[self.current_tool_id])
+                    cur_args_json = json.dumps(cur_arguments)
+                    prev_arguments = self.prev_tool_call_arr[
+                        self.current_tool_id].get("arguments")
+
+                    argument_diff = None
+                    if is_complete[self.current_tool_id]:
+                        argument_diff = cur_args_json[sent:]
+                    elif prev_arguments:
+                        prev_args_json = json.dumps(prev_arguments)
+                        if cur_args_json != prev_args_json:
+
+                            prefix = find_common_prefix(
+                                prev_args_json, cur_args_json)
+                            argument_diff = prefix[sent:]
+
+                    if argument_diff is not None:
+                        delta = DeltaMessage(tool_calls=[
+                            DeltaToolCall(index=self.current_tool_id,
+                                          function=DeltaFunctionCall(
+                                              arguments=argument_diff).
+                                          model_dump(exclude_none=True))
+                        ])
+                        self.streamed_args_for_tool[
+                            self.current_tool_id] += argument_diff
+
+            self.prev_tool_call_arr = tool_call_arr
+            return delta
+
+        except Exception as e:
+            logger.error("Error trying to handle streaming tool call: %s", e)
+            logger.debug(
+                "Skipping chunk as a result of tool streaming extraction "
+                "error")
+            return None
diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
index bcbcda3fa528..faa6f653b835 100644
--- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py
@@ -53,7 +53,8 @@ def __init__(self, tokenizer: AnyTokenizer):
         self.tool_call_start_token_id = self.vocab.get(
             self.tool_call_start_token)
         self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
-        if not self.tool_call_start_token_id or not self.tool_call_end_token_id:
+        if (self.tool_call_start_token_id is None
+                or self.tool_call_end_token_id is None):
             raise RuntimeError(
                 "Hermes 2 Pro Tool parser could not locate tool call start/end "
                 "tokens in the tokenizer!")
@@ -103,9 +104,9 @@ def extract_tool_calls(
                     tool_calls=tool_calls,
                     content=content if content else None)
 
-            except Exception as e:
-                logger.error("Error in extracting tool call from response %s",
-                             e)
+            except Exception:
+                logger.exception(
+                    "Error in extracting tool call from response.")
                 return ExtractedToolCallInformation(tools_called=False,
                                                     tool_calls=[],
                                                     content=model_output)
@@ -333,6 +334,6 @@ def extract_tool_calls_streaming(
 
             return delta
 
-        except Exception as e:
-            logger.error("Error trying to handle streaming tool call: %s", e)
+        except Exception:
+            logger.exception("Error trying to handle streaming tool call.")
             return None  # do not stream a delta. skip this token ID.
diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
index 905ab7db3d04..cb391e11bbde 100644
--- a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
@@ -166,8 +166,8 @@ def extract_tool_calls_streaming(
             tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
             self.prev_tool_call_arr = [tool_call_arr]
             return delta
-        except Exception as e:
-            logger.error("Error trying to handle streaming tool call: %s", e)
+        except Exception:
+            logger.exception("Error trying to handle streaming tool call.")
             logger.debug(
                 "Skipping chunk as a result of tool streaming extraction "
                 "error")
diff --git a/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
new file mode 100644
index 000000000000..cfd024853f88
--- /dev/null
+++ b/vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py
@@ -0,0 +1,300 @@
+import json
+import re
+from typing import Dict, List, Sequence, Union
+
+import partial_json_parser
+from partial_json_parser.core.options import Allow
+
+from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
+                                              DeltaFunctionCall, DeltaMessage,
+                                              DeltaToolCall,
+                                              ExtractedToolCallInformation,
+                                              FunctionCall, ToolCall)
+from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
+from vllm.entrypoints.openai.tool_parsers.utils import (
+    extract_intermediate_diff)
+from vllm.logger import init_logger
+from vllm.transformers_utils.tokenizer import AnyTokenizer
+from vllm.transformers_utils.tokenizers import MistralTokenizer
+from vllm.utils import random_uuid
+
+logger = init_logger(__name__)
+
+
+@ToolParserManager.register_module("jamba")
+class JambaToolParser(ToolParser):
+
+    def __init__(self, tokenizer: AnyTokenizer):
+        super().__init__(tokenizer)
+
+        if isinstance(self.model_tokenizer, MistralTokenizer):
+            raise ValueError(
+                "Detected a MistralTokenizer tokenizer when using a Jamba model"
+            )
+
+        self.current_tool_name_sent: bool = False
+        self.prev_tool_call_arr: List[Dict] = []
+        self.current_tool_id: int = -1
+        self.streamed_args_for_tool: List[str] = [
+        ]  # map what has been streamed for each tool so far to a list
+
+        self.tool_calls_start_token: str = ""
+        self.tool_calls_end_token: str = ""
+
+        self.tool_calls_regex = re.compile(
+            rf"{self.tool_calls_start_token}(.*?){self.tool_calls_end_token}",
+            re.DOTALL)
+
+        if not self.model_tokenizer:
+            raise ValueError(
+                "The model tokenizer must be passed to the ToolParser "
+                "constructor during construction.")
+        self.tool_calls_start_token_id = self.vocab.get(
+            self.tool_calls_start_token)
+        self.tool_calls_end_token_id = self.vocab.get(
+            self.tool_calls_end_token)
+        if (self.tool_calls_start_token_id is None
+                or self.tool_calls_end_token_id is None):
+            raise RuntimeError(
+                "Jamba Tool parser could not locate tool calls start/end "
+                "tokens in the tokenizer!")
+
+    def adjust_request(
+            self, request: ChatCompletionRequest) -> ChatCompletionRequest:
+        if request.tools and request.tool_choice != 'none':
+            # do not skip special tokens because jamba use the special
+            # tokens to indicate the start and end of the tool calls
+            # information.
+            request.skip_special_tokens = False
+        return request
+
+    def extract_tool_calls(
+            self, model_output: str,
+            request: ChatCompletionRequest) -> ExtractedToolCallInformation:
+
+        # sanity check; avoid unnecessary processing
+        if self.tool_calls_start_token not in model_output:
+            return ExtractedToolCallInformation(tools_called=False,
+                                                tool_calls=[],
+                                                content=model_output)
+
+        else:
+
+            try:
+                # use a regex to find the tool call between the tags
+                function_calls = self.tool_calls_regex.findall(model_output)[0]
+
+                # load the JSON, and then use it to build the Function and
+                # Tool Call
+                raw_function_calls = json.loads(function_calls)
+                tool_calls = [
+                    ToolCall(
+                        type="function",
+                        function=FunctionCall(
+                            name=function_call["name"],
+                            # function call args are JSON but as a string
+                            arguments=json.dumps(function_call["arguments"])))
+                    for function_call in raw_function_calls
+                ]
+
+                content = model_output[:model_output.
+                                       find(self.tool_calls_start_token)]
+                return ExtractedToolCallInformation(
+                    tools_called=True,
+                    tool_calls=tool_calls,
+                    content=content if
+                    (len(content) > 0 and content != " ") else None)
+
+            except Exception:
+                logger.exception(
+                    "Error in extracting tool call from response.")
+                return ExtractedToolCallInformation(tools_called=False,
+                                                    tool_calls=[],
+                                                    content=model_output)
+
+    def extract_tool_calls_streaming(
+        self,
+        previous_text: str,
+        current_text: str,
+        delta_text: str,
+        previous_token_ids: Sequence[int],
+        current_token_ids: Sequence[int],
+        delta_token_ids: Sequence[int],
+        request: ChatCompletionRequest,
+    ) -> Union[DeltaMessage, None]:
+
+        # if the tool call token is not in the tokens generated so far, append
+        # output to contents since it's not a tool
+        if self.tool_calls_start_token not in current_text:
+            return DeltaMessage(content=delta_text)
+
+        # if the tool call token ID IS in the tokens generated so far, that
+        # means we're parsing as tool calls now
+
+        # handle if we detected the start of tool calls token which means
+        # the start of tool calling
+        if (self.tool_calls_start_token_id in delta_token_ids
+                and len(delta_token_ids) == 1):
+            # if it's the only token, return None, so we don't send a chat
+            # completion and don't send a control token
+            return None
+
+        # bit mask flags for partial JSON parsing. If the name hasn't been
+        # sent yet, don't allow sending
+        # an incomplete string since OpenAI only ever (as far as I have
+        # seen) allows sending the entire tool/ function name at once.
+        flags = Allow.ALL if self.current_tool_name_sent \
+            else Allow.ALL & ~Allow.STR
+        try:
+
+            # Extract the tool calls between the special tool call tokens
+            parsable_arr = current_text.split(
+                self.tool_calls_start_token)[-1].split(
+                    self.tool_calls_end_token)[0]
+
+            # tool calls are generated in an array, so do partial JSON
+            # parsing on the entire array
+            try:
+                tool_call_arr: List[Dict] = partial_json_parser.loads(
+                    parsable_arr, flags)
+            except partial_json_parser.core.exceptions.MalformedJSON:
+                logger.debug('not enough tokens to parse into JSON yet')
+                return None
+
+            # select as the current tool call the one we're on the state at
+
+            current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
+                if len(tool_call_arr) > 0 else {}
+
+            # case -- if no tokens have been streamed for the tool, e.g.
+            #   only the array brackets, stream nothing
+            if len(tool_call_arr) == 0:
+                return None
+
+            # case: we are starting a new tool in the array
+            #   -> array has > 0 length AND length has moved past cursor
+            elif (len(tool_call_arr) > 0
+                  and len(tool_call_arr) > self.current_tool_id + 1):
+
+                # if we're moving on to a new call, first make sure we
+                # haven't missed anything in the previous one that was
+                # auto-generated due to JSON completions, but wasn't
+                # streamed to the client yet.
+                if self.current_tool_id >= 0:
+                    diff: Union[str, None] = current_tool_call.get("arguments")
+
+                    if diff:
+                        diff = json.dumps(diff).replace(
+                            self.streamed_args_for_tool[self.current_tool_id],
+                            "")
+                        delta = DeltaMessage(tool_calls=[
+                            DeltaToolCall(index=self.current_tool_id,
+                                          function=DeltaFunctionCall(
+                                              arguments=diff).model_dump(
+                                                  exclude_none=True))
+                        ])
+                        self.streamed_args_for_tool[
+                            self.current_tool_id] += diff
+                    else:
+                        delta = None
+                else:
+                    delta = None
+                # re-set stuff pertaining to progress in the current tool
+                self.current_tool_id = len(tool_call_arr) - 1
+                self.current_tool_name_sent = False
+                self.streamed_args_for_tool.append("")
+                logger.debug("starting on new tool %d", self.current_tool_id)
+                return delta
+
+            # case: update an existing tool - this is handled below
+
+            # if the current tool name hasn't been sent, send if available
+            # - otherwise send nothing
+            if not self.current_tool_name_sent:
+                function_name = current_tool_call.get("name")
+                if function_name:
+
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      type="function",
+                                      id=f"chatcmpl-tool-{random_uuid()}",
+                                      function=DeltaFunctionCall(
+                                          name=function_name).model_dump(
+                                              exclude_none=True))
+                    ])
+                    self.current_tool_name_sent = True
+                else:
+                    delta = None
+
+            # now we know we're on the same tool call and we're streaming
+            # arguments
+            else:
+
+                prev_arguments = self.prev_tool_call_arr[
+                    self.current_tool_id].get("arguments")
+                cur_arguments = current_tool_call.get("arguments")
+
+                new_text = delta_text.replace("\'", "\"")
+
+                if not cur_arguments and not prev_arguments:
+
+                    delta = None
+                elif not cur_arguments and prev_arguments:
+                    logger.error(
+                        "INVARIANT - impossible to have arguments reset "
+                        "mid-arguments")
+                    delta = None
+                elif cur_arguments and not prev_arguments:
+                    cur_arguments_json = json.dumps(cur_arguments)
+                    logger.debug("finding %s in %s", new_text,
+                                 cur_arguments_json)
+
+                    arguments_delta = cur_arguments_json[:cur_arguments_json.
+                                                         index(new_text) +
+                                                         len(new_text)]
+                    logger.debug("First tokens in arguments received: %s",
+                                 arguments_delta)
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      function=DeltaFunctionCall(
+                                          arguments=arguments_delta).
+                                      model_dump(exclude_none=True))
+                    ])
+                    self.streamed_args_for_tool[
+                        self.current_tool_id] += arguments_delta
+
+                elif cur_arguments and prev_arguments:
+                    cur_args_json = json.dumps(cur_arguments)
+                    prev_args_json = json.dumps(prev_arguments)
+                    logger.debug("Searching for diff between \n%s\n%s",
+                                 cur_args_json, prev_args_json)
+
+                    argument_diff = extract_intermediate_diff(
+                        cur_args_json, prev_args_json)
+                    logger.debug("got arguments diff: %s", argument_diff)
+                    delta = DeltaMessage(tool_calls=[
+                        DeltaToolCall(index=self.current_tool_id,
+                                      function=DeltaFunctionCall(
+                                          arguments=argument_diff).model_dump(
+                                              exclude_none=True))
+                    ])
+                    self.streamed_args_for_tool[
+                        self.current_tool_id] += argument_diff
+                else:
+                    # try parsing it with regular JSON - if it works we're
+                    # at the end, and we need to send the difference between
+                    # tokens streamed so far and the valid JSON
+                    delta = None
+
+            # check to see if the name is defined and has been sent. if so,
+            # stream the name - otherwise keep waiting
+            # finish by setting old and returning None as base case
+            self.prev_tool_call_arr = tool_call_arr
+            return delta
+
+        except Exception:
+            logger.exception("Error trying to handle streaming tool call.")
+            logger.debug(
+                "Skipping chunk as a result of tool streaming extraction "
+                "error")
+            return None
diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
index 3cf34bc4928a..a5f44d69e5fd 100644
--- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
@@ -1,6 +1,6 @@
 import json
 import re
-from json import JSONDecodeError, JSONDecoder
+from json import JSONDecoder
 from typing import Dict, List, Sequence, Union
 
 import partial_json_parser
@@ -14,34 +14,15 @@
                                               FunctionCall, ToolCall)
 from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
     ToolParser, ToolParserManager)
-from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
+from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
+                                                        is_complete_json,
+                                                        partial_json_loads)
 from vllm.logger import init_logger
 from vllm.utils import random_uuid
 
 logger = init_logger(__name__)
 
 
-# partial_json_parser doesn't support extra data and
-# JSONDecorder.raw_decode doesn't support partial JSON
-def partial_json_loads(input_str, flags):
-    try:
-        return (partial_json_parser.loads(input_str, flags), len(input_str))
-    except JSONDecodeError as e:
-        if "Extra data" in e.msg:
-            dec = JSONDecoder()
-            return dec.raw_decode(input_str)
-        else:
-            raise
-
-
-def is_complete_json(input_str):
-    try:
-        json.loads(input_str)
-        return True
-    except JSONDecodeError:
-        return False
-
-
 @ToolParserManager.register_module("llama3_json")
 class Llama3JsonToolParser(ToolParser):
     """
@@ -112,9 +93,8 @@ def extract_tool_calls(
                                                content=None)
             return ret
 
-        except Exception as e:
-            logger.error("Error in extracting tool call from response: %s", e)
-            print("ERROR", e)
+        except Exception:
+            logger.exception("Error in extracting tool call from response.")
             # return information to just treat the tool call as regular JSON
             return ExtractedToolCallInformation(tools_called=False,
                                                 tool_calls=[],
@@ -269,8 +249,8 @@ def extract_tool_calls_streaming(
             self.prev_tool_call_arr = tool_call_arr
             return delta
 
-        except Exception as e:
-            logger.error("Error trying to handle streaming tool call: %s", e)
+        except Exception:
+            logger.exception("Error trying to handle streaming tool call.")
             logger.debug(
                 "Skipping chunk as a result of tool streaming extraction "
                 "error")
diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
index c6dc0688e38f..f5c0d92f3f9b 100644
--- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
+++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
@@ -63,7 +63,7 @@ def __init__(self, tokenizer: AnyTokenizer):
         self.bot_token = "[TOOL_CALLS]"
         self.bot_token_id = self.vocab.get(self.bot_token)
         self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
-        if not self.bot_token_id:
+        if self.bot_token_id is None:
             raise RuntimeError(
                 "Mistral Tool Parser could not locate the tool call token in "
                 "the tokenizer!")
@@ -111,8 +111,8 @@ def extract_tool_calls(
                 tool_calls=tool_calls,
                 content=content if len(content) > 0 else None)
 
-        except Exception as e:
-            logger.error("Error in extracting tool call from response: %s", e)
+        except Exception:
+            logger.exception("Error in extracting tool call from response.")
             # return information to just treat the tool call as regular JSON
             return ExtractedToolCallInformation(tools_called=False,
                                                 tool_calls=[],
@@ -298,8 +298,8 @@ def extract_tool_calls_streaming(
             self.prev_tool_call_arr = tool_call_arr
             return delta
 
-        except Exception as e:
-            logger.error("Error trying to handle streaming tool call: %s", e)
+        except Exception:
+            logger.exception("Error trying to handle streaming tool call.")
             logger.debug(
                 "Skipping chunk as a result of tool streaming extraction "
                 "error")
diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py
index db7fc5259fc4..4c79c8b35d32 100644
--- a/vllm/entrypoints/openai/tool_parsers/utils.py
+++ b/vllm/entrypoints/openai/tool_parsers/utils.py
@@ -1,3 +1,9 @@
+import json
+from json import JSONDecodeError, JSONDecoder
+
+import partial_json_parser
+
+
 def find_common_prefix(s1: str, s2: str) -> str:
     """
     Finds a common prefix that is shared between two strings, if there is one.
@@ -85,3 +91,30 @@ def find_all_indices(string, substring):
             break
         indices.append(index)
     return indices
+
+
+# partial_json_parser doesn't support extra data and
+# JSONDecorder.raw_decode doesn't support partial JSON
+def partial_json_loads(input_str, flags):
+    try:
+        return (partial_json_parser.loads(input_str, flags), len(input_str))
+    except JSONDecodeError as e:
+        if "Extra data" in e.msg:
+            dec = JSONDecoder()
+            return dec.raw_decode(input_str)
+        else:
+            raise
+
+
+def is_complete_json(input_str):
+    try:
+        json.loads(input_str)
+        return True
+    except JSONDecodeError:
+        return False
+
+
+def consume_space(i, s):
+    while i < len(s) and s[i].isspace():
+        i += 1
+    return i
diff --git a/vllm/envs.py b/vllm/envs.py
index 8b541e5b78c0..ae6825f28007 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -27,11 +27,13 @@
     VLLM_USAGE_SOURCE: str = ""
     VLLM_CONFIGURE_LOGGING: int = 1
     VLLM_LOGGING_LEVEL: str = "INFO"
+    VLLM_LOGGING_PREFIX: str = ""
     VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
     VLLM_TRACE_FUNCTION: int = 0
     VLLM_ATTENTION_BACKEND: Optional[str] = None
     VLLM_USE_FLASHINFER_SAMPLER: bool = False
     VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
+    VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
     VLLM_PP_LAYER_PARTITION: Optional[str] = None
     VLLM_CPU_KVCACHE_SPACE: int = 0
     VLLM_CPU_OMP_THREADS_BIND: str = ""
@@ -64,8 +66,10 @@
     VLLM_USE_TRITON_AWQ: bool = False
     VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
     VLLM_SKIP_P2P_CHECK: bool = False
-    VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
     VLLM_TORCH_COMPILE_LEVEL: int = 0
+    VLLM_CUSTOM_OPS: List[str] = []
+    VLLM_DISABLED_KERNELS: List[str] = []
+    VLLM_USE_V1: bool = False
 
 
 def get_default_cache_root():
@@ -205,7 +209,17 @@ def get_default_config_root():
         os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
     "VLLM_TORCH_COMPILE_LEVEL":
     lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
-
+    # Fine-grained control over which custom ops to enable/disable.
+    # Use 'all' to enable all, 'none' to disable all.
+    # Also specify a list of custom op names to enable (prefixed with a '+'),
+    # or disable (prefixed with a '-').
+    # Examples:
+    # - 'all,-op1' to enable all except op1
+    # - 'none,+op1,+op2' to enable only op1 and op2
+    # By default, all custom ops are enabled when running without Inductor
+    # and disabled when running with Inductor (compile_level >= Inductor).
+    "VLLM_CUSTOM_OPS":
+    lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
     # local rank of the process in the distributed setting, used to determine
     # the GPU device id
     "LOCAL_RANK":
@@ -255,6 +269,10 @@ def get_default_config_root():
     "VLLM_LOGGING_LEVEL":
     lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"),
 
+    # if set, VLLM_LOGGING_PREFIX will be prepended to all log messages
+    "VLLM_LOGGING_PREFIX":
+    lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),
+
     # Trace function calls
     # If set to 1, vllm will trace function calls
     # Useful for debugging
@@ -275,6 +293,11 @@ def get_default_config_root():
     "VLLM_USE_FLASHINFER_SAMPLER":
     lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),
 
+    # If set, vllm will force flashinfer to use tensor cores;
+    # otherwise will use heuristic based on model architecture.
+    "VLLM_FLASHINFER_FORCE_TENSOR_CORES":
+    lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),
+
     # Pipeline stage partition strategy
     "VLLM_PP_LAYER_PARTITION":
     lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
@@ -426,10 +449,17 @@ def get_default_config_root():
     "VLLM_SKIP_P2P_CHECK":
     lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1",
 
-    # If set, allowing the use of deprecated block manager V1
-    "VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
-    lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
-                           ) == "1",
+    # List of quantization kernels that should be disabled, used for testing
+    # and performance comparisons. Currently only affects MPLinearKernel
+    # selection
+    # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
+    "VLLM_DISABLED_KERNELS":
+    lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
+        "VLLM_DISABLED_KERNELS"].split(","),
+
+    # If set, use the V1 code path.
+    "VLLM_USE_V1":
+    lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))),
 }
 
 # end-env-vars-definition
diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py
index e14ecc13a9dc..884267d23dfc 100644
--- a/vllm/executor/multiproc_worker_utils.py
+++ b/vllm/executor/multiproc_worker_utils.py
@@ -3,7 +3,6 @@
 import os
 import sys
 import threading
-import traceback
 import uuid
 from dataclasses import dataclass
 from multiprocessing import Queue
@@ -227,10 +226,9 @@ def _run_worker_process(
             except KeyboardInterrupt:
                 break
             except BaseException as e:
-                tb = traceback.format_exc()
-                logger.error(
-                    "Exception in worker %s while processing method %s: %s, %s",
-                    process_name, method, e, tb)
+                logger.exception(
+                    "Exception in worker %s while processing method %s.",
+                    process_name, method)
                 exception = e
             result_queue.put(
                 Result(task_id=task_id, value=output, exception=exception))
diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py
index 7e46acefc5b0..0af7b3386d89 100644
--- a/vllm/executor/ray_utils.py
+++ b/vllm/executor/ray_utils.py
@@ -10,7 +10,7 @@
 from vllm.logger import init_logger
 from vllm.platforms import current_platform
 from vllm.sequence import ExecuteModelRequest, IntermediateTensors
-from vllm.utils import get_ip, is_hip, is_xpu
+from vllm.utils import get_ip, is_hip
 from vllm.worker.worker_base import WorkerWrapperBase
 
 logger = init_logger(__name__)
@@ -231,7 +231,7 @@ def initialize_ray_cluster(
     assert_ray_available()
 
     # Connect to a ray cluster.
-    if is_hip() or is_xpu():
+    if is_hip() or current_platform.is_xpu():
         ray.init(address=ray_address,
                  ignore_reinit_error=True,
                  num_gpus=parallel_config.world_size)
diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py
index a8c8672cb5fe..7b73922ddd2c 100644
--- a/vllm/inputs/__init__.py
+++ b/vllm/inputs/__init__.py
@@ -1,7 +1,8 @@
-from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
-                   LLMInputs, PromptType, SingletonPrompt, TextPrompt,
-                   TokensPrompt, build_explicit_enc_dec_prompt,
-                   to_enc_dec_tuple_list, zip_enc_dec_prompts)
+from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
+                   ExplicitEncoderDecoderPrompt, PromptType, SingletonInputs,
+                   SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
+                   build_explicit_enc_dec_prompt, to_enc_dec_tuple_list,
+                   token_inputs, zip_enc_dec_prompts)
 from .registry import InputContext, InputRegistry
 
 INPUT_REGISTRY = InputRegistry()
@@ -19,8 +20,11 @@
     "PromptType",
     "SingletonPrompt",
     "ExplicitEncoderDecoderPrompt",
-    "LLMInputs",
-    "EncoderDecoderLLMInputs",
+    "TokenInputs",
+    "token_inputs",
+    "SingletonInputs",
+    "DecoderOnlyInputs",
+    "EncoderDecoderInputs",
     "build_explicit_enc_dec_prompt",
     "to_enc_dec_tuple_list",
     "zip_enc_dec_prompts",
@@ -31,9 +35,9 @@
 
 
 def __getattr__(name: str):
-    if name == "PromptInput":
-        import warnings
+    import warnings
 
+    if name == "PromptInput":
         msg = ("PromptInput has been renamed to PromptType. "
                "The original name will be removed in an upcoming version.")
 
@@ -41,4 +45,21 @@ def __getattr__(name: str):
 
         return PromptType
 
+    if name == "LLMInputs":
+        msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
+               "The original name will be removed in an upcoming version.")
+
+        warnings.warn(DeprecationWarning(msg), stacklevel=2)
+
+        return DecoderOnlyInputs
+
+    if name == "EncoderDecoderLLMInputs":
+        msg = (
+            "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
+            "The original name will be removed in an upcoming version.")
+
+        warnings.warn(DeprecationWarning(msg), stacklevel=2)
+
+        return EncoderDecoderInputs
+
     raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py
index 724cdd2e6e80..9a094191eda3 100644
--- a/vllm/inputs/data.py
+++ b/vllm/inputs/data.py
@@ -1,5 +1,5 @@
 from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
-                    Optional, Tuple, Union)
+                    Optional, Tuple, Union, cast)
 
 from typing_extensions import NotRequired, TypedDict, TypeVar
 
@@ -51,7 +51,7 @@ class TokensPrompt(TypedDict):
 
 SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
 """
-Set of possible schemas for a single LLM input:
+Set of possible schemas for a single prompt:
 
 - A text prompt (:class:`str` or :class:`TextPrompt`)
 - A tokenized prompt (:class:`TokensPrompt`)
@@ -120,13 +120,8 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
 """
 
 
-class LLMInputs(TypedDict):
-    """
-    The inputs in :class:`~vllm.LLMEngine` before they are
-    passed to the model executor.
-
-    This specifies the data required for decoder-only models.
-    """
+class TokenInputs(TypedDict):
+    """Represents token-based inputs."""
     prompt_token_ids: List[int]
     """The token IDs of the prompt."""
 
@@ -150,7 +145,40 @@ class LLMInputs(TypedDict):
     """
 
 
-class EncoderDecoderLLMInputs(LLMInputs):
+def token_inputs(
+    prompt_token_ids: List[int],
+    prompt: Optional[str] = None,
+    multi_modal_data: Optional["MultiModalDataDict"] = None,
+    mm_processor_kwargs: Optional[Dict[str, Any]] = None,
+) -> TokenInputs:
+    """Construct :class:`TokenInputs` from optional values."""
+    inputs = TokenInputs(prompt_token_ids=prompt_token_ids)
+
+    if prompt is not None:
+        inputs["prompt"] = prompt
+    if multi_modal_data is not None:
+        inputs["multi_modal_data"] = multi_modal_data
+    if mm_processor_kwargs is not None:
+        inputs["mm_processor_kwargs"] = mm_processor_kwargs
+
+    return inputs
+
+
+SingletonInputs = TokenInputs
+"""
+A processed :class:`SingletonPrompt` which can be passed to
+:class:`vllm.sequence.Sequence`.
+"""
+
+DecoderOnlyInputs = TokenInputs
+"""
+The inputs in :class:`~vllm.LLMEngine` before they are
+passed to the model executor.
+This specifies the data required for decoder-only models.
+"""
+
+
+class EncoderDecoderInputs(TokenInputs):
     """
     The inputs in :class:`~vllm.LLMEngine` before they are
     passed to the model executor.
@@ -204,11 +232,12 @@ def zip_enc_dec_prompts(
     be zipped with the encoder/decoder prompts.
     """
     if mm_processor_kwargs is None:
-        mm_processor_kwargs = {}
-    if isinstance(mm_processor_kwargs, Dict):
+        mm_processor_kwargs = cast(Dict[str, Any], {})
+    if isinstance(mm_processor_kwargs, dict):
         return [
-            build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt,
-                                          mm_processor_kwargs)
+            build_explicit_enc_dec_prompt(
+                encoder_prompt, decoder_prompt,
+                cast(Dict[str, Any], mm_processor_kwargs))
             for (encoder_prompt,
                  decoder_prompt) in zip(enc_prompts, dec_prompts)
         ]
@@ -229,9 +258,9 @@ def to_enc_dec_tuple_list(
 
 
 def __getattr__(name: str):
-    if name == "PromptInput":
-        import warnings
+    import warnings
 
+    if name == "PromptInput":
         msg = ("PromptInput has been renamed to PromptType. "
                "The original name will be removed in an upcoming version.")
 
@@ -239,4 +268,21 @@ def __getattr__(name: str):
 
         return PromptType
 
+    if name == "LLMInputs":
+        msg = ("LLMInputs has been renamed to DecoderOnlyInputs. "
+               "The original name will be removed in an upcoming version.")
+
+        warnings.warn(DeprecationWarning(msg), stacklevel=2)
+
+        return DecoderOnlyInputs
+
+    if name == "EncoderDecoderLLMInputs":
+        msg = (
+            "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. "
+            "The original name will be removed in an upcoming version.")
+
+        warnings.warn(DeprecationWarning(msg), stacklevel=2)
+
+        return EncoderDecoderInputs
+
     raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py
index e5fa1e418427..e79d2c813bb4 100644
--- a/vllm/inputs/parse.py
+++ b/vllm/inputs/parse.py
@@ -1,12 +1,12 @@
-from typing import List, Literal, Sequence, TypedDict, Union, overload
+from typing import List, Literal, Sequence, TypedDict, Union, cast, overload
 
 from typing_extensions import TypeIs
 
 from vllm.utils import is_list_of
 
-from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
-                   LLMInputs, PromptType, SingletonPrompt, TextPrompt,
-                   TokensPrompt)
+from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
+                   ExplicitEncoderDecoderPrompt, PromptType, SingletonPrompt,
+                   TextPrompt, TokensPrompt)
 
 
 class ParsedText(TypedDict):
@@ -44,13 +44,16 @@ def parse_and_batch_prompt(
 
         if is_list_of(prompt, str):
             # case 2: array of strings
+            prompt = cast(List[str], prompt)
             return [
                 ParsedText(content=elem, is_tokens=False) for elem in prompt
             ]
         if is_list_of(prompt, int):
             # case 3: array of tokens
+            prompt = cast(List[int], prompt)
             return [ParsedTokens(content=prompt, is_tokens=True)]
         if is_list_of(prompt, list):
+            prompt = cast(List[List[int]], prompt)
             if len(prompt[0]) == 0:
                 raise ValueError("please provide at least one prompt")
 
@@ -100,7 +103,7 @@ def is_explicit_encoder_decoder_prompt(
     return isinstance(prompt, dict) and "encoder_prompt" in prompt
 
 
-def is_valid_encoder_decoder_llm_inputs(
-    inputs: Union[LLMInputs, EncoderDecoderLLMInputs],
-) -> TypeIs[EncoderDecoderLLMInputs]:
+def is_encoder_decoder_inputs(
+    inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
+) -> TypeIs[EncoderDecoderInputs]:
     return "encoder_prompt_token_ids" in inputs
diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py
index 64387fd2fa47..82ce7d392b71 100644
--- a/vllm/inputs/preprocess.py
+++ b/vllm/inputs/preprocess.py
@@ -10,7 +10,7 @@
 from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
 from vllm.utils import print_warning_once
 
-from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
+from .data import (DecoderOnlyInputs, EncoderDecoderInputs, PromptType,
                    SingletonPrompt)
 from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
 
@@ -306,7 +306,7 @@ def _build_enc_dec_llm_inputs(
         encoder_comps: PromptComponents,
         decoder_comps: DecoderPromptComponents,
         mm_processor_kwargs: Dict[str, Any],
-    ) -> EncoderDecoderLLMInputs:
+    ) -> EncoderDecoderInputs:
         encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps
         decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps
 
@@ -324,7 +324,7 @@ def _build_enc_dec_llm_inputs(
             decoder_prompt_ids,
             force_bos=(encoder_mm_data is None and decoder_mm_data is None)))
 
-        return EncoderDecoderLLMInputs(
+        return EncoderDecoderInputs(
             prompt_token_ids=decoder_prompt_ids,
             prompt=decoder_prompt,
             multi_modal_data=decoder_mm_data,
@@ -338,11 +338,11 @@ def _process_encoder_decoder_prompt(
         self,
         prompt: PromptType,
         request_id: str,
-    ) -> EncoderDecoderLLMInputs:
+    ) -> EncoderDecoderInputs:
         '''
         For encoder/decoder models only:
         Process an input prompt into an
-        :class:`EncoderDecoderLLMInputs` instance.
+        :class:`EncoderDecoderInputs` instance.
 
         There are two types of input prompts:
         singleton prompts which carry only the
@@ -369,7 +369,7 @@ def _process_encoder_decoder_prompt(
 
         Returns:
 
-        * :class:`EncoderDecoderLLMInputs` instance
+        * :class:`EncoderDecoderInputs` instance
         '''
 
         encoder_comps: PromptComponents
@@ -411,7 +411,7 @@ async def _process_encoder_decoder_prompt_async(
         self,
         prompt: PromptType,
         request_id: str,
-    ) -> EncoderDecoderLLMInputs:
+    ) -> EncoderDecoderInputs:
         """Async version of :meth:`_process_encoder_decoder_prompt`."""
         encoder_comps: PromptComponents
         decoder_comps: DecoderPromptComponents
@@ -455,17 +455,17 @@ def _build_decoder_only_llm_inputs(
         self,
         prompt_comps: PromptComponents,
         prompt_adapter_request: Optional[PromptAdapterRequest],
-    ) -> LLMInputs:
+    ) -> DecoderOnlyInputs:
         (prompt, prompt_token_ids, multi_modal_data,
          mm_processor_kwargs) = prompt_comps
 
         prompt_token_ids = self._apply_prompt_adapter(
             prompt_token_ids, prompt_adapter_request=prompt_adapter_request)
 
-        return LLMInputs(prompt_token_ids=prompt_token_ids,
-                         prompt=prompt,
-                         multi_modal_data=multi_modal_data,
-                         mm_processor_kwargs=mm_processor_kwargs)
+        return DecoderOnlyInputs(prompt_token_ids=prompt_token_ids,
+                                 prompt=prompt,
+                                 multi_modal_data=multi_modal_data,
+                                 mm_processor_kwargs=mm_processor_kwargs)
 
     def _process_decoder_only_prompt(
         self,
@@ -473,10 +473,10 @@ def _process_decoder_only_prompt(
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> LLMInputs:
+    ) -> DecoderOnlyInputs:
         '''
         For decoder-only models:
-        Process an input prompt into an :class:`LLMInputs` instance.
+        Process an input prompt into an :class:`DecoderOnlyInputs` instance.
 
         Arguments:
 
@@ -487,7 +487,7 @@ def _process_decoder_only_prompt(
 
         Returns:
 
-        * :class:`LLMInputs` instance
+        * :class:`DecoderOnlyInputs` instance
         '''
 
         prompt_comps = self._extract_prompt_components(
@@ -507,7 +507,7 @@ async def _process_decoder_only_prompt_async(
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> LLMInputs:
+    ) -> DecoderOnlyInputs:
         """Async version of :meth:`_process_decoder_only_prompt`."""
         prompt_comps = await self._extract_prompt_components_async(
             prompt,
@@ -526,7 +526,7 @@ def preprocess(
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
+    ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
         """Preprocess the input prompt."""
         if self.is_encoder_decoder_model():
             # Encoder-decoder model requires special mapping of
@@ -554,7 +554,7 @@ async def preprocess_async(
         request_id: str,
         lora_request: Optional[LoRARequest] = None,
         prompt_adapter_request: Optional[PromptAdapterRequest] = None,
-    ) -> Union[LLMInputs, EncoderDecoderLLMInputs]:
+    ) -> Union[DecoderOnlyInputs, EncoderDecoderInputs]:
         """Async version of :meth:`preprocess`."""
         if self.is_encoder_decoder_model():
             # Encoder-decoder model requires special mapping of
diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py
index 5bd3e1c86f66..4cebc91ce715 100644
--- a/vllm/inputs/registry.py
+++ b/vllm/inputs/registry.py
@@ -12,7 +12,7 @@
 from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once,
                         resolve_mm_processor_kwargs)
 
-from .data import LLMInputs
+from .data import DecoderOnlyInputs
 
 if TYPE_CHECKING:
     from vllm.config import ModelConfig
@@ -100,7 +100,7 @@ def __getitem__(self, key: str) -> int:
             raise KeyError(msg) from exc
 
 
-InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
+InputProcessor = Callable[[InputContext, DecoderOnlyInputs], DecoderOnlyInputs]
 """Preprocess the inputs to the model."""
 
 
@@ -134,7 +134,7 @@ def _default_dummy_data_factory(
         # Avoid circular import
         from vllm.sequence import SequenceData
 
-        dummy_seq_data = SequenceData.from_token_counts((0, seq_len))
+        dummy_seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
         dummy_multi_modal_data = None
 
         return dummy_seq_data, dummy_multi_modal_data
@@ -245,8 +245,11 @@ def dummy_data_for_profiling(
 
         return seq_data, mm_data
 
-    def _default_input_processor(self, ctx: InputContext,
-                                 inputs: LLMInputs) -> LLMInputs:
+    def _default_input_processor(
+        self,
+        ctx: InputContext,
+        inputs: DecoderOnlyInputs,
+    ) -> DecoderOnlyInputs:
         """The default input processor is a no-op."""
         return inputs
 
@@ -279,7 +282,7 @@ def _get_model_input_processor(self, model_cls: Type[nn.Module]):
             .get(model_cls, self._default_input_processor)
 
     def process_input(self, model_config: "ModelConfig",
-                      inputs: LLMInputs) -> LLMInputs:
+                      inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
         """
         Apply an input processor to an instance of model inputs.
 
diff --git a/vllm/logger.py b/vllm/logger.py
index 77dddbfb6096..ccf09691a052 100644
--- a/vllm/logger.py
+++ b/vllm/logger.py
@@ -15,8 +15,10 @@
 VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING
 VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH
 VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL
+VLLM_LOGGING_PREFIX = envs.VLLM_LOGGING_PREFIX
 
-_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
+_FORMAT = (f"{VLLM_LOGGING_PREFIX}%(levelname)s %(asctime)s "
+           "%(filename)s:%(lineno)d] %(message)s")
 _DATE_FORMAT = "%m-%d %H:%M:%S"
 
 DEFAULT_LOGGING_CONFIG = {
diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py
index a7887a048746..04fc635828d4 100644
--- a/vllm/lora/fully_sharded_layers.py
+++ b/vllm/lora/fully_sharded_layers.py
@@ -70,6 +70,14 @@ def apply(self, x: torch.Tensor,
                                        self.lora_b_stacked,
                                        add_input=True)
         # now have column partitioned output
+
+        if self.bias_stacked is not None:
+            self.bias_stacked = self.bias_stacked.view(
+                -1, self.bias_stacked.shape[-1])
+            self.bias_stacked = self.bias_stacked[
+                self.punica_wrapper.token_lora_indices]
+            output += self.bias_stacked
+
         output = output.view(*out_orig_shape)
         return output
 
@@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
     left_offset = 0
     for idx in range(n):
         shard_size = layer.lora_b_stacked[idx].shape[2]
+
+        if layer.bias_stacked is not None:
+            bias = layer.bias_stacked[idx]
+            if bias is not None:
+                bias = bias.view(-1, bias.shape[-1])
+                bias = bias[layer.punica_wrapper.token_lora_indices]
+                bias[layer.punica_wrapper.token_lora_indices == -1] = 0
+                output[:, left_offset:left_offset + shard_size] += bias
+
         layer.punica_wrapper.add_expand_slice(
             output,
             buffers[idx],
@@ -295,6 +312,15 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
         lora_b = lora_b[:, start_idx:end_idx]
         return lora_b
 
+    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
+        if bias is None:
+            return bias
+        shard_size = self.bias_stacked.shape[2]
+        start_idx = self.tp_rank * shard_size
+        end_idx = (self.tp_rank + 1) * shard_size
+        bias = bias[start_idx:end_idx]
+        return bias
+
     def apply(self, x: torch.Tensor) -> torch.Tensor:
         output = self.base_layer.quant_method.apply(self.base_layer, x)
 
@@ -318,6 +344,13 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
         # reduced before being used
         shard_size = self.lora_b_stacked.shape[2]
         start_idx = self.tp_rank * shard_size
+
+        if self.bias_stacked is not None:
+            bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
+            bias = bias[self.punica_wrapper.token_lora_indices]
+            bias[self.punica_wrapper.token_lora_indices == -1] = 0
+            output += bias
+
         self.punica_wrapper.add_expand_slice(output, buffer,
                                              self.lora_b_stacked, start_idx,
                                              shard_size)
diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py
index 6254c67596e6..abc8dde24112 100644
--- a/vllm/lora/layers.py
+++ b/vllm/lora/layers.py
@@ -67,6 +67,64 @@ def dec(*args, **kwargs):
     return dec
 
 
+def apply_bias(
+    indices: torch.Tensor,
+    output: torch.Tensor,
+    bias_stacked: torch.Tensor,
+):
+    """Applies bias to output
+
+    Input shapes:
+        bias_stacked:    (num_loras, output_dim)
+        indices:         (batch_size)
+        output:          (batch_size, output_dim)
+    """
+    org_output = output
+    output = output.view(-1, output.shape[-1])
+    indices = indices.view(-1)
+
+    bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
+    bias_stacked = bias_stacked[indices]
+    bias_stacked[indices == -1] = 0
+    output += bias_stacked
+
+    return output.view_as(org_output)
+
+
+def apply_bias_packed_nslice(
+    indices: torch.Tensor,
+    output: torch.Tensor,
+    output_slices: Tuple[int, ...],
+    bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+):
+    """Applies bias to output
+
+    Input shapes:
+        bias_stacked:      3 element tuple of (num_loras, output_dim)
+        indices:           (batch_size)
+        output:            (batch_size, q_slice_size + 2*kv_slice_size)
+        output_slices:     n-1 element tuple of (slice_size...),
+                           where n is number of slices
+    """
+    org_output = output
+    output = output.view(-1, output.shape[-1])
+    indices = indices.view(-1)
+
+    offset_left = 0
+    for slice_idx in range(len(output_slices)):
+        bias = bias_stacked[slice_idx]
+        if bias is not None:
+            bias = bias.view(-1, bias.shape[-1])
+            bias = bias[indices]
+            bias[indices == -1] = 0
+            output[:,
+                   offset_left:offset_left + output_slices[slice_idx]] += bias
+
+        offset_left += output_slices[slice_idx]
+
+    return output.view_as(org_output)
+
+
 @dataclass
 class LoRAMapping(AdapterMapping):
     is_prefill: bool = False
@@ -105,6 +163,7 @@ def set_lora(
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         """Overwrites lora tensors at index."""
         ...
@@ -203,6 +262,7 @@ def set_lora(
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
         self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
@@ -299,10 +359,22 @@ def create_lora_weights(
             dtype=lora_config.lora_dtype,
             device=self.device,
         )
+        if lora_config.bias_enabled:
+            self.bias_stacked = torch.zeros(
+                max_loras,
+                1,
+                self.output_size,
+                dtype=lora_config.lora_dtype,
+                device=self.device,
+            )
+        else:
+            self.bias_stacked = None
 
     def reset_lora(self, index: int):
         self.lora_a_stacked[index] = 0
         self.lora_b_stacked[index] = 0
+        if self.lora_config.bias_enabled:
+            self.bias_stacked[index] = 0
 
     def set_lora(
         self,
@@ -310,6 +382,7 @@ def set_lora(
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
 
@@ -319,10 +392,21 @@ def set_lora(
         self.lora_b_stacked[index,
                             0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                 lora_b.T, non_blocking=True)
+        if bias is not None:
+            self.bias_stacked[index,
+                              0, :bias.shape[0]].copy_(bias.T,
+                                                       non_blocking=True)
 
     def apply(self, x: torch.Tensor,
               bias: Optional[torch.Tensor]) -> torch.Tensor:
         output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
+        if self.bias_stacked is not None:
+            self.indices = self.punica_wrapper.token_lora_indices
+            output = apply_bias(
+                self.indices,
+                output,
+                self.bias_stacked,
+            )
         self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                      self.lora_b_stacked, 1.0)
         return output
@@ -401,11 +485,25 @@ def create_lora_weights(
             dtype=lora_config.lora_dtype,
             device=self.device,
         )
+
+        if lora_config.bias_enabled:
+            self.bias_stacked = torch.zeros(
+                max_loras,
+                1,
+                self.output_size,
+                dtype=lora_config.lora_dtype,
+                device=self.device,
+            )
+        else:
+            self.bias_stacked = None
+
         self.output_dim = self.lora_b_stacked.shape[2]
 
     def reset_lora(self, index: int):
         self.lora_a_stacked[index] = 0
         self.lora_b_stacked[index] = 0
+        if self.lora_config.bias_enabled:
+            self.bias_stacked[index] = 0
 
     def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
         return lora_a
@@ -418,18 +516,30 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
         lora_b = lora_b[:, start_idx:end_idx]
         return lora_b
 
+    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
+        if bias is None:
+            return bias
+        tensor_model_parallel_rank = get_tensor_model_parallel_rank()
+        shard_size = self.output_dim
+        start_idx = tensor_model_parallel_rank * shard_size
+        end_idx = (tensor_model_parallel_rank + 1) * shard_size
+        bias = bias[start_idx:end_idx]
+        return bias
+
     def set_lora(
         self,
         index: int,
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
 
         if self.tp_size > 1:
             lora_a = self.slice_lora_a(lora_a)
             lora_b = self.slice_lora_b(lora_b)
+            bias = self.slice_bias(bias)
 
         self.lora_a_stacked[index,
                             0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@@ -437,10 +547,21 @@ def set_lora(
         self.lora_b_stacked[index,
                             0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                 lora_b.T, non_blocking=True)
+        if bias is not None:
+            self.bias_stacked[index,
+                              0, :bias.shape[0]].copy_(bias.T,
+                                                       non_blocking=True)
 
     def apply(self, x: torch.Tensor,
               bias: Optional[torch.Tensor]) -> torch.Tensor:
         output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
+        if self.bias_stacked is not None:
+            self.indices = self.punica_wrapper.token_lora_indices
+            output = apply_bias(
+                self.indices,
+                output,
+                self.bias_stacked,
+            )
         self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                      self.lora_b_stacked, 1.0)
         return output
@@ -534,6 +655,17 @@ def create_lora_weights(
                 dtype=lora_config.lora_dtype,
                 device=self.device,
             ) for _ in range(n_slices))
+        if lora_config.bias_enabled:
+            self.bias_stacked = tuple(
+                torch.zeros(
+                    max_loras,
+                    1,
+                    self.output_size // 2,
+                    dtype=lora_config.lora_dtype,
+                    device=self.device,
+                ) for _ in range(n_slices))
+        else:
+            self.bias_stacked = None
 
         self.output_dim = self.lora_b_stacked[0].shape[2]
 
@@ -542,6 +674,9 @@ def reset_lora(self, index: int):
         self.lora_a_stacked[1][index] = 0
         self.lora_b_stacked[0][index] = 0
         self.lora_b_stacked[1][index] = 0
+        if self.lora_config.bias_enabled:
+            self.bias_stacked[0][index] = 0
+            self.bias_stacked[1][index] = 0
 
     def slice_lora_a(
         self, lora_a: List[Union[torch.Tensor, None]]
@@ -562,18 +697,32 @@ def slice_lora_b(
         ]
         return lora_b
 
+    def slice_bias(
+        self, bias: List[Union[torch.Tensor,
+                               None]]) -> List[Union[torch.Tensor, None]]:
+        if bias[0] is None or bias[1] is None:
+            return bias
+        shard_size = self.output_dim
+        start_idx = self.tp_rank * shard_size
+        end_idx = (self.tp_rank + 1) * shard_size
+        bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
+        return bias
+
     def set_lora(
         self,
         index: int,
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
 
         if self.tp_size > 1:
             lora_a = self.slice_lora_a(lora_a)
             lora_b = self.slice_lora_b(lora_b)
+            if bias is not None:
+                bias = self.slice_bias(bias)
 
         if lora_a[0] is not None:
             self.lora_a_stacked[0][
@@ -582,6 +731,10 @@ def set_lora(
             self.lora_b_stacked[0][
                 index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
                     lora_b[0].T, non_blocking=True)
+        if bias is not None and bias[0] is not None:
+            self.bias_stacked[0][index,
+                                 0, :bias[0].shape[0]].copy_(bias[0].T,
+                                                             non_blocking=True)
         if lora_a[1] is not None:
             self.lora_a_stacked[1][
                 index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
@@ -589,10 +742,22 @@ def set_lora(
             self.lora_b_stacked[1][
                 index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                     lora_b[1].T, non_blocking=True)
+        if bias is not None and bias[1] is not None:
+            self.bias_stacked[1][index,
+                                 0, :bias[1].shape[0]].copy_(bias[1].T,
+                                                             non_blocking=True)
 
     def apply(self, x: torch.Tensor,
               bias: Optional[torch.Tensor]) -> torch.Tensor:
         output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
+        if self.bias_stacked is not None:
+            self.indices = self.punica_wrapper.token_lora_indices
+            output = apply_bias_packed_nslice(
+                self.indices,
+                output,
+                (self.output_dim, self.output_dim),
+                self.bias_stacked,
+            )
         self.punica_wrapper.add_lora_packed_nslice(
             output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
             (self.output_dim, self.output_dim))
@@ -654,17 +819,36 @@ def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
         lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
         return lora_b
 
+    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
+        bias_q, bias_k, bias_v = None, None, None
+        bias_q = bias[self.q_proj_shard_size *
+                      self.q_shard_id:self.q_proj_shard_size *
+                      (self.q_shard_id + 1)]
+        k_offset = self.q_proj_total_size
+        bias_k = bias[k_offset +
+                      self.kv_proj_shard_size * self.kv_shard_id:k_offset +
+                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+        v_offset = k_offset + self.kv_proj_total_size
+        bias_v = bias[v_offset +
+                      self.kv_proj_shard_size * self.kv_shard_id:v_offset +
+                      self.kv_proj_shard_size * (self.kv_shard_id + 1)]
+        bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
+        return bias
+
     def set_lora(
         self,
         index: int,
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
         if self.tp_size > 1:
             lora_a = self.slice_lora_a(lora_a)
             lora_b = self.slice_lora_b(lora_b)
+            if bias is not None:
+                bias = self.slice_bias(bias)
 
         self.lora_a_stacked[index,
                             0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@@ -672,6 +856,10 @@ def set_lora(
         self.lora_b_stacked[index,
                             0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                 lora_b.T, non_blocking=True)
+        if bias is not None:
+            self.bias_stacked[index,
+                              0, :bias.shape[0]].copy_(bias.T,
+                                                       non_blocking=True)
 
     @classmethod
     @_not_fully_sharded_can_replace
@@ -768,6 +956,32 @@ def create_lora_weights(
                 device=self.device,
             ),
         )
+        if lora_config.bias_enabled:
+            self.bias_stacked = (
+                torch.zeros(
+                    max_loras,
+                    1,
+                    self.q_proj_shard_size,
+                    dtype=lora_config.lora_dtype,
+                    device=self.device,
+                ),
+                torch.zeros(
+                    max_loras,
+                    1,
+                    self.kv_proj_shard_size,
+                    dtype=lora_config.lora_dtype,
+                    device=self.device,
+                ),
+                torch.zeros(
+                    max_loras,
+                    1,
+                    self.kv_proj_shard_size,
+                    dtype=lora_config.lora_dtype,
+                    device=self.device,
+                ),
+            )
+        else:
+            self.bias_stacked = None
 
         self.output_slices = (
             self.q_proj_shard_size,
@@ -787,6 +1001,10 @@ def reset_lora(self, index: int):
         self.lora_b_stacked[1][index] = 0
         self.lora_a_stacked[2][index] = 0
         self.lora_b_stacked[2][index] = 0
+        if self.lora_config.bias_enabled:
+            self.bias_stacked[0][index] = 0
+            self.bias_stacked[1][index] = 0
+            self.bias_stacked[2][index] = 0
 
     def slice_lora_a(
         self, lora_a: List[Union[torch.Tensor, None]]
@@ -812,18 +1030,40 @@ def slice_lora_b(
         lora_b = [lora_b_q, lora_b_k, lora_b_v]
         return lora_b
 
+    def slice_bias(
+        self, bias: List[Union[torch.Tensor,
+                               None]]) -> List[Union[torch.Tensor, None]]:
+        bias_q, bias_k, bias_v = None, None, None
+        if bias[0] is not None:
+            bias_q = bias[0][self.q_proj_shard_size *
+                             self.q_shard_id:self.q_proj_shard_size *
+                             (self.q_shard_id + 1)]
+        if bias[1] is not None:
+            bias_k = bias[1][self.kv_proj_shard_size *
+                             self.kv_shard_id:self.kv_proj_shard_size *
+                             (self.kv_shard_id + 1)]
+        if bias[2] is not None:
+            bias_v = bias[2][self.kv_proj_shard_size *
+                             self.kv_shard_id:self.kv_proj_shard_size *
+                             (self.kv_shard_id + 1)]
+        bias = [bias_q, bias_k, bias_v]
+        return bias
+
     def set_lora(
         self,
         index: int,
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
 
         if self.tp_size > 1:
             lora_a = self.slice_lora_a(lora_a)
             lora_b = self.slice_lora_b(lora_b)
+            if bias is not None:
+                bias = self.slice_bias(bias)
 
         if lora_b[0] is not None:
             lora_b_q = lora_b[0]
@@ -854,9 +1094,28 @@ def set_lora(
                 index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                     lora_a[2].T, non_blocking=True)
 
+        if bias is not None:
+            if bias[0] is not None:
+                self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
+                    bias[0].T, non_blocking=True)
+            if bias[1] is not None:
+                self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
+                    bias[1].T, non_blocking=True)
+            if bias[2] is not None:
+                self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
+                    bias[2].T, non_blocking=True)
+
     def apply(self, x: torch.Tensor,
               bias: Optional[torch.Tensor]) -> torch.Tensor:
         output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
+        if self.bias_stacked is not None:
+            self.indices = self.punica_wrapper.token_lora_indices
+            output = apply_bias_packed_nslice(
+                self.indices,
+                output,
+                self.output_slices,
+                self.bias_stacked,
+            )
         self.punica_wrapper.add_lora_packed_nslice(output, x,
                                                    self.lora_a_stacked,
                                                    self.lora_b_stacked, 1.0,
@@ -919,9 +1178,27 @@ def create_lora_weights(
             device=self.device,
         )
 
+        if lora_config.bias_enabled:
+            self.bias_stacked = torch.zeros(
+                (
+                    max_loras,
+                    1,
+                    self.output_size,
+                ),
+                dtype=lora_config.lora_dtype,
+                device=self.device,
+            )
+        else:
+            self.bias_stacked = None
+        # Lazily initialized
+        self.indices: torch.Tensor
+        self.indices_len: List[int]
+
     def reset_lora(self, index: int):
         self.lora_a_stacked[index] = 0
         self.lora_b_stacked[index] = 0
+        if self.lora_config.bias_enabled:
+            self.bias_stacked[index] = 0
 
     def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
         tensor_model_parallel_rank = get_tensor_model_parallel_rank()
@@ -934,18 +1211,24 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
     def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
         return lora_b
 
+    def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
+        return bias
+
     def set_lora(
         self,
         index: int,
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
 
         if self.base_layer.tp_size > 1:
             lora_a = self.slice_lora_a(lora_a)
             lora_b = self.slice_lora_b(lora_b)
+            if bias is not None:
+                bias = self.slice_bias(bias)
 
         self.lora_a_stacked[index,
                             0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@@ -953,9 +1236,20 @@ def set_lora(
         self.lora_b_stacked[index,
                             0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
                                 lora_b.T, non_blocking=True)
+        if bias is not None:
+            self.bias_stacked[index,
+                              0, :bias.shape[0]].copy_(bias.T,
+                                                       non_blocking=True)
 
     def apply(self, x: torch.Tensor) -> torch.Tensor:
         output = self.base_layer.quant_method.apply(self.base_layer, x)
+        if self.bias_stacked is not None:
+            self.indices = self.punica_wrapper.token_lora_indices
+            output = apply_bias(
+                self.indices,
+                output,
+                self.bias_stacked,
+            )
         self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
                                      self.lora_b_stacked, 1.0)
         return output
@@ -1132,6 +1426,7 @@ def set_lora(
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         self.reset_lora(index)
         self.lora_a_stacked[index,
@@ -1199,7 +1494,7 @@ def _get_logits(
                                                       neginf=float("-inf")))
         logits[:,
                self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
-               lora_logits.shape[1], ] = lora_logits
+               lora_logits.shape[1]] = lora_logits
 
         # LogitsProcessorWithLoRA always using bgmv
         self.punica_wrapper.add_lora_logits(logits, hidden_states,
@@ -1276,6 +1571,7 @@ def set_lora(
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
         embeddings_tensor: Optional[torch.Tensor],
+        bias: Optional[torch.Tensor] = None,
     ):
         ...
 
diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py
index 14081b5ba441..b648312ba76e 100644
--- a/vllm/lora/lora.py
+++ b/vllm/lora/lora.py
@@ -17,6 +17,7 @@ def __init__(
         lora_alpha: int,
         lora_a: torch.Tensor,
         lora_b: torch.Tensor,
+        bias: Optional[torch.Tensor] = None,
         embeddings_tensor: Optional[torch.Tensor] = None,
         scaling: Optional[float] = None,
     ) -> None:
@@ -25,6 +26,7 @@ def __init__(
         self.lora_alpha = lora_alpha
         self.lora_a = lora_a
         self.lora_b = lora_b
+        self.bias = bias
         self.embeddings_tensor = embeddings_tensor
 
         if scaling is None:
@@ -66,7 +68,8 @@ def create_dummy_lora_weights(
             rank: int,
             dtype: torch.dtype,
             device: torch.types.Device,
-            embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
+            embeddings_tensor_dim: Optional[int] = None,
+            bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
         pin_memory = str(device) == "cpu" and is_pin_memory_available()
         lora_a = torch.zeros([input_dim, rank],
                              dtype=dtype,
@@ -76,6 +79,14 @@ def create_dummy_lora_weights(
                              dtype=dtype,
                              device=device,
                              pin_memory=pin_memory)
+        if bias_enabled:
+            bias = torch.zeros([output_dim],
+                               dtype=dtype,
+                               device=device,
+                               pin_memory=pin_memory)
+        else:
+            bias = None
+
         embeddings_tensor = torch.rand(
             10,
             embeddings_tensor_dim,
@@ -88,6 +99,7 @@ def create_dummy_lora_weights(
             lora_alpha=1,
             lora_a=lora_a,
             lora_b=lora_b,
+            bias=bias,
             embeddings_tensor=embeddings_tensor,
         )
 
@@ -102,6 +114,7 @@ def __init__(
         lora_alphas: List[Optional[int]],
         lora_a: List[Optional[torch.Tensor]],
         lora_b: List[Optional[torch.Tensor]],
+        bias: Optional[List[Optional[torch.Tensor]]] = None,
         scaling: Optional[List[float]] = None,
     ) -> None:
         super().__init__(
@@ -110,6 +123,7 @@ def __init__(
             lora_alpha=0,
             lora_a=lora_a,
             lora_b=lora_b,
+            bias=bias,
             scaling=scaling,  # type: ignore
             embeddings_tensor=None,
         )
@@ -141,6 +155,7 @@ def pack(
             [lora.lora_alpha if lora is not None else None for lora in loras],
             [lora.lora_a if lora is not None else None for lora in loras],
             [lora.lora_b if lora is not None else None for lora in loras],
+            [lora.bias if lora is not None else None for lora in loras],
             scaling=[
                 1 if lora is not None else None  # type: ignore
                 for lora in loras
diff --git a/vllm/lora/models.py b/vllm/lora/models.py
index aaadca9a4d16..1f1f55b986c3 100644
--- a/vllm/lora/models.py
+++ b/vllm/lora/models.py
@@ -119,7 +119,8 @@ def from_lora_tensors(
         pin_memory = str(device) == "cpu" and is_pin_memory_available()
         loras: Dict[str, LoRALayerWeights] = {}
         for tensor_name, tensor in tensors.items():
-            module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
+            module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
+                tensor_name)
             if module_name not in loras:
                 lora_embeddings_tensor = None
                 if embeddings:
@@ -137,7 +138,14 @@ def from_lora_tensors(
                 loras[module_name] = LoRALayerWeights(module_name, rank,
                                                       lora_alpha, None, None,
                                                       lora_embeddings_tensor)
-            if is_lora_a:
+            if is_bias:
+                loras[module_name].bias = tensor.to(device=device,
+                                                    dtype=dtype).t()
+                if pin_memory:
+                    bias = loras[module_name].bias
+                    if bias is not None:
+                        loras[module_name].bias = bias.pin_memory()
+            elif is_lora_a:
                 loras[module_name].lora_a = tensor.to(device=device,
                                                       dtype=dtype).t()
                 if pin_memory:
@@ -215,7 +223,7 @@ def from_local_checkpoint(
             with safetensors.safe_open(lora_tensor_path,
                                        framework="pt") as f:  # type: ignore
                 for lora_module in f.keys():  # noqa
-                    module_name, _ = parse_fine_tuned_lora_name(lora_module)
+                    module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
                     part_name = module_name.split(".")[-1]
                     if part_name not in expected_lora_modules:
                         unexpected_modules.append(module_name)
@@ -384,8 +392,12 @@ def activate_adapter(
             module_lora = lora_model.get_lora(module_name)
             if module_lora:
                 module_lora.optimize()
+                # Bias is not explicitly enabled with the flag enable_lora_bias.
+                if not self.lora_config.bias_enabled:
+                    module_lora.bias = None
                 module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
-                                module_lora.embeddings_tensor)
+                                module_lora.embeddings_tensor,
+                                module_lora.bias)
             else:
                 module.reset_lora(index)
         return True
@@ -507,6 +519,7 @@ def create_dummy_lora(
         """Create zero-initialized LoRAModel for warmup."""
         model = LoRAModel(lora_id, rank, {}, scaling_factor)
         for module_name, module in self.model.named_modules():
+            bias_enabled = self.lora_config.bias_enabled
             if (not self._match_target_modules(module_name)
                     or not isinstance(module, BaseLayerWithLoRA)
                     or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
@@ -534,7 +547,8 @@ def create_dummy_lora(
                         rank,
                         module.lora_a_stacked.dtype,
                         "cpu",
-                        embeddings_tensor_dim=embeddings_tensor_dim)
+                        embeddings_tensor_dim=embeddings_tensor_dim,
+                        bias_enabled=bias_enabled)
                 else:
                     lora = LoRALayerWeights.create_dummy_lora_weights(
                         module_name,
@@ -543,6 +557,7 @@ def create_dummy_lora(
                         rank,
                         module.lora_a_stacked.dtype,
                         "cpu",
+                        bias_enabled=bias_enabled,
                     )
                 lora.optimize()
             else:
@@ -557,6 +572,7 @@ def create_dummy_lora(
                         rank,
                         module.lora_a_stacked[i].dtype,
                         "cpu",
+                        bias_enabled=bias_enabled,
                     )
                     lora.optimize()
                     subloras.append(lora)
diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py
index a780429f413d..707ad2006e85 100644
--- a/vllm/lora/utils.py
+++ b/vllm/lora/utils.py
@@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str,
     return new_module
 
 
-def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
+def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
     """Parse the name of lora weights.
 
     args:
@@ -101,15 +101,20 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
         Tuple(module_name, is_lora_a):
             module_name: the name of the module, e.g. model.dense1,
             is_lora_a whether the tensor is lora_a or lora_b.
+            is_bias whether the tensor is lora bias.
     """
     parts = name.split(".")
+    assert parts[0] == "base_model"
+    assert parts[1] == "model"
+    if parts[-1] == "weight" and (parts[-2] == "lora_A"
+                                  or parts[-2] == "lora_B"):
+        return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
 
-    if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
-        if parts[-1] == "weight":
-            if parts[-2] == "lora_A" or parts[-2] == "lora_B":
-                return ".".join(parts[2:-2]), parts[-2] == "lora_A"
-        elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
-            return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
+    if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
+        return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
+
+    if parts[-1] == "bias":
+        return ".".join(parts[2:-2]), False, True
 
     raise ValueError(f"{name} is unsupported LoRA weight")
 
diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py
index d0e90245ad01..71eed6eb68d7 100644
--- a/vllm/model_executor/custom_op.py
+++ b/vllm/model_executor/custom_op.py
@@ -1,14 +1,24 @@
+from functools import lru_cache
+from typing import Dict, Type
+
 import torch.nn as nn
 
 import vllm.envs as envs
 from vllm.compilation.levels import CompilationLevel
+from vllm.logger import init_logger
 from vllm.platforms import current_platform
-from vllm.utils import is_cpu, is_hip, is_xpu
+from vllm.utils import is_hip, print_warning_once
+
+logger = init_logger(__name__)
 
 
 class CustomOp(nn.Module):
+    """
+    Base class for custom ops.
+    Dispatches the forward method to the appropriate backend.
+    """
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self):
         super().__init__()
         self._forward_method = self.dispatch_forward()
 
@@ -17,7 +27,6 @@ def forward(self, *args, **kwargs):
 
     def forward_native(self, *args, **kwargs):
         """PyTorch-native implementation of the forward method.
-
         This method is optional. If implemented, it can be used with compilers
         such as torch.compile or PyTorch XLA. Also, it can be used for testing
         purposes.
@@ -56,16 +65,67 @@ def dispatch_forward(self):
         # NOTE(woosuk): Here we assume that vLLM was built for only one
         # specific backend. Currently, we do not support dynamic dispatching.
 
-        if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
+        enabled = self.enabled()
+        logger.debug("custom op %s %s", self.__class__.name,
+                     "enabled" if enabled else "disabled")
+
+        if not enabled:
             return self.forward_native
 
         if is_hip():
             return self.forward_hip
-        elif is_cpu():
+        elif current_platform.is_cpu():
             return self.forward_cpu
         elif current_platform.is_tpu():
             return self.forward_tpu
-        elif is_xpu():
+        elif current_platform.is_xpu():
             return self.forward_xpu
         else:
             return self.forward_cuda
+
+    @classmethod
+    def enabled(cls) -> bool:
+        # if no name, then it was not registered
+        if not hasattr(cls, "name"):
+            print_warning_once(
+                f"Custom op {cls.__name__} was not registered, "
+                f"which means it won't appear in the op registry. "
+                f"It will be enabled/disabled based on the global settings.")
+            return CustomOp.default_on()
+
+        enabled = f"+{cls.name}" in envs.VLLM_CUSTOM_OPS
+        disabled = f"-{cls.name}" in envs.VLLM_CUSTOM_OPS
+        assert not (enabled
+                    and disabled), f"Cannot enable and disable {cls.name}"
+
+        return (CustomOp.default_on() or enabled) and not disabled
+
+    # On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
+    # Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
+    @staticmethod
+    @lru_cache()
+    def default_on() -> bool:
+        count_none = envs.VLLM_CUSTOM_OPS.count("none")
+        count_all = envs.VLLM_CUSTOM_OPS.count("all")
+        assert count_none + count_all <= 1, "Can only specify 'none' or 'all'"
+        return envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR and \
+            not count_none > 0 or count_all > 0
+
+    # Dictionary of all custom ops (classes, indexed by registered name).
+    # To check if an op with a name is enabled, call .enabled() on the class.
+    # Examples:
+    # - MyOp.enabled()
+    # - op_registry["my_op"].enabled()
+    op_registry: Dict[str, Type['CustomOp']] = {}
+
+    # Decorator to register custom ops.
+    @classmethod
+    def register(cls, name: str):
+
+        def decorator(op_cls):
+            assert name not in cls.op_registry, f"Duplicate op name: {name}"
+            op_cls.name = name
+            cls.op_registry[name] = op_cls
+            return op_cls
+
+        return decorator
diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
index c28bd71c9f68..cb8ae37ba48f 100644
--- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py
+++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py
@@ -15,11 +15,11 @@
 # limitations under the License.
 import copy
 import json
-import math
 from collections import defaultdict
 from functools import lru_cache
 from typing import Callable, DefaultDict, Dict, List, Union
 
+import numpy as np
 import torch
 from lark import Lark
 from outlines import grammars
@@ -77,8 +77,14 @@ def __call__(self, input_ids: List[int],
                 f"Unsupported instruction type {type(instruction)}")
 
         mask = torch.full((scores.shape[-1], ),
-                          -math.inf,
+                          -torch.inf,
                           device=scores.device)
+        # The tokenizer may support more token ids than the model can generate,
+        # eg. Llama 3.2 Vision models have an `<|image|>` token with id 128256
+        # but scores.shape == torch.Size([128256])
+        # Using NumPy is faster for filtering token ids
+        allowed_tokens = np.array(allowed_tokens)
+        allowed_tokens = allowed_tokens[allowed_tokens < scores.shape[-1]]
         mask[allowed_tokens] = 0
         scores.add_(mask)
         return scores
diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py
index 43056786d35c..658a3700f33d 100644
--- a/vllm/model_executor/layers/activation.py
+++ b/vllm/model_executor/layers/activation.py
@@ -11,8 +11,44 @@
 from vllm.model_executor.custom_op import CustomOp
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.utils import set_weight_attrs
+from vllm.utils import LazyDict
 
 
+@CustomOp.register("fatrelu_and_mul")
+class FatreluAndMul(CustomOp):
+    """An activation function for FATReLU.
+
+    The function computes x -> FATReLU(x[:d]) * x[d:] where
+    d = x.shape[-1] // 2.
+    This is used in openbmb/MiniCPM-S-1B-sft.
+
+    Shapes:
+        x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
+        return: (num_tokens, d) or (batch_size, seq_len, d)
+    """
+
+    def __init__(self, threshold: float = 0.):
+        super().__init__()
+        self.threshold = threshold
+
+    def forward_native(self, x: torch.Tensor) -> torch.Tensor:
+        d = x.shape[-1] // 2
+        x1 = x[..., :d]
+        x2 = x[..., d:]
+        x1 = F.threshold(x1, self.threshold, 0.0)
+        return x1 * x2
+
+    def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
+        from vllm import _custom_ops as ops
+
+        d = x.shape[-1] // 2
+        output_shape = (x.shape[:-1] + (d, ))
+        out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
+        ops.fatrelu_and_mul(out, x, self.threshold)
+        return out
+
+
+@CustomOp.register("silu_and_mul")
 class SiluAndMul(CustomOp):
     """An activation function for SwiGLU.
 
@@ -47,6 +83,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
         return out
 
 
+@CustomOp.register("gelu_and_mul")
 class GeluAndMul(CustomOp):
     """An activation function for GeGLU.
 
@@ -96,6 +133,7 @@ def extra_repr(self) -> str:
         return f'approximate={repr(self.approximate)}'
 
 
+@CustomOp.register("gelu_new")
 class NewGELU(CustomOp):
 
     def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -117,6 +155,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
         return ops.gelu_new(x)
 
 
+@CustomOp.register("gelu_fast")
 class FastGELU(CustomOp):
 
     def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -137,8 +176,8 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
         return ops.gelu_fast(x)
 
 
+@CustomOp.register("quick_gelu")
 class QuickGELU(CustomOp):
-
     # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
     def forward_native(self, x: torch.Tensor) -> torch.Tensor:
         """PyTorch-native implementation equivalent to forward()."""
@@ -162,6 +201,7 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
     # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
 
 
+@CustomOp.register("relu2")
 class ReLUSquaredActivation(CustomOp):
     """
     Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
@@ -217,15 +257,24 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
         param_data.copy_(loaded_weight)
 
 
-_ACTIVATION_REGISTRY = {
-    "gelu": nn.GELU(),
-    "gelu_fast": FastGELU(),
-    "gelu_new": NewGELU(),
-    "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
-    "relu": nn.ReLU(),
-    "relu2": ReLUSquaredActivation(),
-    "quick_gelu": QuickGELU(),
-}
+_ACTIVATION_REGISTRY = LazyDict({
+    "gelu":
+    lambda: nn.GELU(),
+    "gelu_fast":
+    lambda: FastGELU(),
+    "gelu_new":
+    lambda: NewGELU(),
+    "gelu_pytorch_tanh":
+    lambda: nn.GELU(approximate="tanh"),
+    "relu":
+    lambda: nn.ReLU(),
+    "relu2":
+    lambda: ReLUSquaredActivation(),
+    "silu":
+    lambda: nn.SiLU(),
+    "quick_gelu":
+    lambda: QuickGELU(),
+})
 
 
 def get_act_fn(
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index 5964d5a5465f..5ae40a2af5a2 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -116,7 +116,7 @@ def single_marlin_moe(
 
     intermediate_cache = torch.ops._moe_C.marlin_gemm_moe(
         hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales,
-        w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K,
+        w_zeros, g_idx, sort_indices, workspace, scalar_type.id, M, N, K,
         is_k_full, E, topk, block_size_m, True, False)
 
     return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)
@@ -272,7 +272,7 @@ def fused_marlin_moe(
         g_idx1,
         sort_indices1,
         workspace,
-        scalar_type1,
+        scalar_type1.id,
         M,
         2 * N,
         K,
@@ -297,7 +297,7 @@ def fused_marlin_moe(
         g_idx2,
         sort_indices2,
         workspace,
-        scalar_type2,
+        scalar_type2.id,
         M,
         K,
         N,
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index bce740d0db75..8dd36620e3fa 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -37,13 +37,13 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor,
         raise NotImplementedError
 
 
+@CustomOp.register("unquantized_fused_moe")
 class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
     """MoE method without quantization."""
 
     def create_weights(self, layer: torch.nn.Module, num_experts: int,
                        hidden_size: int, intermediate_size: int,
                        params_dtype: torch.dtype, **extra_weight_attrs):
-
         # Fused gate_up_proj (column parallel)
         w13_weight = torch.nn.Parameter(torch.empty(num_experts,
                                                     2 * intermediate_size,
@@ -74,7 +74,6 @@ def apply(
             num_expert_group: Optional[int] = None,
             custom_routing_function: Optional[Callable] = None
     ) -> torch.Tensor:
-
         return self.forward(x=x,
                             layer=layer,
                             router_logits=router_logits,
@@ -97,7 +96,6 @@ def forward_cuda(
             num_expert_group: Optional[int] = None,
             custom_routing_function: Optional[Callable] = None
     ) -> torch.Tensor:
-
         from vllm.model_executor.layers.fused_moe.fused_moe import (
             fused_experts)
 
@@ -134,7 +132,6 @@ def forward_tpu(
             num_expert_group: Optional[int] = None,
             custom_routing_function: Optional[Callable] = None
     ) -> torch.Tensor:
-
         from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
         assert not use_grouped_topk
         assert num_expert_group is None
diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py
index d55f86056d17..30b43f375dd5 100644
--- a/vllm/model_executor/layers/layernorm.py
+++ b/vllm/model_executor/layers/layernorm.py
@@ -7,6 +7,7 @@
 from vllm.model_executor.custom_op import CustomOp
 
 
+@CustomOp.register("rms_norm")
 class RMSNorm(CustomOp):
     """Root mean square normalization.
 
@@ -26,7 +27,6 @@ def __init__(
         self.variance_epsilon = eps
         self.variance_size_override = (None if var_hidden_size == hidden_size
                                        else var_hidden_size)
-
         self.weight = nn.Parameter(torch.ones(hidden_size))
 
     def forward_native(
@@ -122,6 +122,7 @@ def extra_repr(self) -> str:
         return s
 
 
+@CustomOp.register("gemma_rms_norm")
 class GemmaRMSNorm(CustomOp):
     """RMS normalization for Gemma.
 
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 94f30412e43b..8d380cbedf54 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -1,5 +1,5 @@
 from abc import abstractmethod
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional, Tuple, Type
 
 import torch
 import torch.nn.functional as F
@@ -135,6 +135,36 @@ def apply(self,
         return F.linear(x, layer.weight, bias)
 
 
+class TiedWeightLinearMethod(UnquantizedLinearMethod):
+    """Linear method base with noop create_weights
+
+    Can be used to prevent the initialization of weights
+    during the initialization of modules with weight tying.
+    """
+
+    def create_weights(self, layer: torch.nn.Module,
+                       input_size_per_partition: int,
+                       output_partition_sizes: List[int], input_size: int,
+                       output_size: int, params_dtype: torch.dtype,
+                       **extra_weight_attrs):
+        ...
+
+
+class QuantizationConfigOverride(QuantizationConfig):
+    """Config class to inject a specific LinearMethod.
+    """
+
+    def __init__(self, cls: Type[LinearMethodBase]):
+        self.cls = cls
+
+    def get_quant_method(self, layer: torch.nn.Module,
+                         prefix: str) -> Optional[LinearMethodBase]:
+        return self.cls()
+
+
+QuantizationConfigOverride.__abstractmethods__ = frozenset()
+
+
 class LinearBase(torch.nn.Module):
     """Base linear layer.
 
diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py
index 1d5b6fad2e16..288f5a1134b6 100644
--- a/vllm/model_executor/layers/logits_processor.py
+++ b/vllm/model_executor/layers/logits_processor.py
@@ -48,14 +48,15 @@ def forward(
         self,
         lm_head: VocabParallelEmbedding,
         hidden_states: torch.Tensor,
-        sampling_metadata: SamplingMetadata,
+        sampling_metadata: Optional[SamplingMetadata] = None,
         embedding_bias: Optional[torch.Tensor] = None,
     ) -> Optional[torch.Tensor]:
         if self.logits_as_input:
             logits = hidden_states
         else:
-            hidden_states = _prune_hidden_states(hidden_states,
-                                                 sampling_metadata)
+            if sampling_metadata is not None:
+                hidden_states = _prune_hidden_states(hidden_states,
+                                                     sampling_metadata)
 
             # Get the logits for the next tokens.
             logits = self._get_logits(hidden_states, lm_head, embedding_bias)
@@ -69,7 +70,8 @@ def forward(
                 logits *= self.scale
 
             # Apply logits processors (if any).
-            logits = _apply_logits_processors(logits, sampling_metadata)
+            if sampling_metadata is not None:
+                logits = _apply_logits_processors(logits, sampling_metadata)
 
         return logits
 
diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
index ed7241af6cd1..be5639df985f 100644
--- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
@@ -6,18 +6,18 @@
 import torch
 
 from vllm import _custom_ops as ops
+from vllm.attention.backends.utils import PAD_SLOT_ID
 
 
-def causal_conv1d_fn(
-    x: torch.Tensor,
-    weight: torch.Tensor,
-    bias: Optional[torch.Tensor] = None,
-    query_start_loc: Optional[torch.Tensor] = None,
-    cache_indices: Optional[torch.Tensor] = None,
-    has_initial_state: Optional[torch.Tensor] = None,
-    conv_states: Optional[torch.Tensor] = None,
-    activation: Optional[str] = "silu",
-):
+def causal_conv1d_fn(x: torch.Tensor,
+                     weight: torch.Tensor,
+                     bias: Optional[torch.Tensor] = None,
+                     query_start_loc: Optional[torch.Tensor] = None,
+                     cache_indices: Optional[torch.Tensor] = None,
+                     has_initial_state: Optional[torch.Tensor] = None,
+                     conv_states: Optional[torch.Tensor] = None,
+                     activation: Optional[str] = "silu",
+                     pad_slot_id: int = PAD_SLOT_ID):
     """
     x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
         sequences are concatenated from left to right for varlen
@@ -37,6 +37,13 @@ def causal_conv1d_fn(
     conv_states: (...,dim,width - 1) itype
         updated inplace if provided
     activation: either None or "silu" or "swish"
+    pad_slot_id: int
+            if cache_indices is passed, lets the kernel identify padded 
+            entries that will not be processed, 
+            for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] 
+            in this case, the kernel will not process entries at 
+            indices 0 and 3
+
 
     out: (batch, dim, seqlen)
     """
@@ -46,10 +53,10 @@ def causal_conv1d_fn(
         x = x.contiguous()
     bias = bias.contiguous() if bias is not None else None
 
-    out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
-                                cache_indices, has_initial_state, activation
-                                in ["silu", "swish"])
-    return out
+    ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc,
+                          cache_indices, has_initial_state, activation
+                          in ["silu", "swish"], pad_slot_id)
+    return x
 
 
 def causal_conv1d_update(x: torch.Tensor,
@@ -58,7 +65,8 @@ def causal_conv1d_update(x: torch.Tensor,
                          bias: Optional[torch.Tensor] = None,
                          activation: Optional[str] = None,
                          cache_seqlens: Optional[torch.Tensor] = None,
-                         conv_state_indices: Optional[torch.Tensor] = None):
+                         conv_state_indices: Optional[torch.Tensor] = None,
+                         pad_slot_id: int = PAD_SLOT_ID):
     """
     x: (batch, dim) or (batch, dim, seqlen)
     conv_state: (batch, dim, state_len), where state_len >= width - 1
@@ -73,7 +81,12 @@ def causal_conv1d_update(x: torch.Tensor,
         If not None, the conv_state is a larger tensor along the batch dim, 
         and we are selecting the batch coords specified by conv_state_indices.
         Useful for a continuous batching scenario.
-
+    pad_slot_id: int
+            if cache_indices is passed, lets the kernel identify padded 
+            entries that will not be processed, 
+            for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] 
+            in this case, the kernel will not process entries at 
+            indices 0 and 3
     out: (batch, dim) or (batch, dim, seqlen)
     """
     if activation not in [None, "silu", "swish"]:
@@ -82,8 +95,8 @@ def causal_conv1d_update(x: torch.Tensor,
     unsqueeze = x.dim() == 2
     if unsqueeze:
         x = x.unsqueeze(-1)
-    out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
-                                   cache_seqlens, conv_state_indices)
+    ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val,
+                             cache_seqlens, conv_state_indices, pad_slot_id)
     if unsqueeze:
-        out = out.squeeze(-1)
-    return out
+        x = x.squeeze(-1)
+    return x
diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
index 08b016c20c42..1484b79815ab 100644
--- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
@@ -1,14 +1,13 @@
 # Copyright (c) 2024, Tri Dao, Albert Gu.
 # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
 
-from typing import Tuple
-
 import torch
 import triton
 import triton.language as tl
 from packaging import version
 
 from vllm import _custom_ops as ops
+from vllm.attention.backends.utils import PAD_SLOT_ID
 
 TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
 
@@ -50,6 +49,7 @@ def _selective_scan_update_kernel(
     z_ptr,
     out_ptr,
     state_batch_indices_ptr,
+    pad_slot_id,
     # Matrix dimensions
     batch,
     nheads,
@@ -143,10 +143,11 @@ def _selective_scan_update_kernel(
     if HAS_Z:
         z_ptrs = z_ptr + offs_m * stride_z_dim
     out_ptrs = out_ptr + offs_m * stride_out_dim
+    mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
+    if HAS_STATE_BATCH_INDICES:
+        mask &= (state_batch_idx != pad_slot_id)
+    state = tl.load(state_ptrs, mask=mask, other=0.0)
 
-    state = tl.load(state_ptrs,
-                    mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
-                    other=0.0)
     x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
     if not TIE_HDIM:
         dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
@@ -177,9 +178,11 @@ def _selective_scan_update_kernel(
 
     dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
     state = state * dA + dB * x[:, None]
-    tl.store(state_ptrs,
-             state,
-             mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
+
+    mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
+    if HAS_STATE_BATCH_INDICES:
+        mask &= (state_batch_idx != pad_slot_id)
+    tl.store(state_ptrs, state, mask=mask)
     out = tl.sum(state * C[None, :], axis=1)
     if HAS_D:
         out += x * D
@@ -198,7 +201,8 @@ def selective_state_update(state,
                            z=None,
                            dt_bias=None,
                            dt_softplus=False,
-                           state_batch_indices=None):
+                           state_batch_indices=None,
+                           pad_slot_id=PAD_SLOT_ID):
     """
     Argument:
         state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@@ -210,6 +214,12 @@ def selective_state_update(state,
         D: (dim,) or (nheads, dim)
         z: (batch, dim) or (batch, nheads, dim)
         dt_bias: (dim,) or (nheads, dim)
+        pad_slot_id: int
+            if cache_indices is passed, lets the kernel identify padded 
+            entries that will not be processed, 
+            for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] 
+            in this case, the kernel will not process entries at 
+            indices 0 and 3
     Return:
         out: (batch, dim) or (batch, nheads, dim)
     """
@@ -276,6 +286,7 @@ def selective_state_update(state,
             z,
             out,
             state_batch_indices,
+            pad_slot_id,
             batch,
             nheads,
             dim,
@@ -319,22 +330,25 @@ def selective_state_update(state,
     return out
 
 
-def selective_scan_fn(
-        u,
-        ssm_states,
-        delta,
-        A,
-        B,
-        C,
-        D=None,
-        z=None,
-        delta_bias=None,
-        delta_softplus=False,
-        query_start_loc=None,
-        cache_indices=None,
-        has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]:
+def selective_scan_fn(u,
+                      ssm_states,
+                      delta,
+                      A,
+                      B,
+                      C,
+                      D=None,
+                      z=None,
+                      delta_bias=None,
+                      delta_softplus=False,
+                      query_start_loc=None,
+                      cache_indices=None,
+                      has_initial_state=None,
+                      pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
     """
     u: (dim, total_length) for varlen or (batch, dim, seqlen) 
+        applies changes in place.
+    ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
+        applies changes in place.
     delta: (dim, total_length) for varlen or (batch, dim, seqlen)
     A: (dim, dstate) 
     B: (ngroups, dstate, total_length) for varlen or 
@@ -357,12 +371,14 @@ def selective_scan_fn(
         indicate if the ssm_state at the corresponding index should be 
         used as initial state. Not providing argument assumes 
         there's no initial state
-
+    pad_slot_id: int
+        if cache_indices is passed, lets the kernel identify padding entries 
+        that will not be processed, 
+        for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] 
+        in this case, the kernel will not process entries at indices 0 and 3
     returns
         output: (dim, total_length) for varlen or (batch, dim, seqlen) 
                 supports inplace replacement
-        last_state has shape (batch, dim, dstate). 
-                supports inplace replacement if ssm_state was provided
     """
     if u.stride(-1) != 1:
         u = u.contiguous()
@@ -387,7 +403,7 @@ def selective_scan_fn(
 
     ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
                            query_start_loc, cache_indices, has_initial_state,
-                           ssm_states)
+                           ssm_states, pad_slot_id)
 
     if z is None:
         return delta  # output written inplace to delta
diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py
index 76ccb3dfe0a6..3455a4ccf282 100644
--- a/vllm/model_executor/layers/pooler.py
+++ b/vllm/model_executor/layers/pooler.py
@@ -12,6 +12,7 @@ class PoolingType(IntEnum):
     """Enumeration for different types of pooling methods."""
     LAST = 0
     ALL = 1
+    CLS = 2
 
 
 class Pooler(nn.Module):
@@ -23,12 +24,13 @@ class Pooler(nn.Module):
     3. Returns structured results as `PoolerOutput`.
 
     Attributes:
-        pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
+        pooling_type: The type of pooling to use (LAST, ALL, CLS).
         normalize: Whether to normalize the pooled data.
     """
 
     def __init__(self, pooling_type: PoolingType, normalize: bool):
         super().__init__()
+
         self.pooling_type = pooling_type
         self.normalize = normalize
 
@@ -38,10 +40,16 @@ def forward(
         pooling_metadata: PoolingMetadata,
     ) -> PoolerOutput:
         """Pools specific information from hidden states based on metadata."""
+
         prompt_lens = PoolingTensors.from_pooling_metadata(
             pooling_metadata, hidden_states.device).prompt_lens
 
-        if self.pooling_type == PoolingType.LAST:
+        if self.pooling_type is PoolingType.CLS:
+            first_token_flat_indices = torch.zeros_like(prompt_lens)
+            first_token_flat_indices[1:] += torch.cumsum(prompt_lens,
+                                                         dim=0)[:-1]
+            pooled_data = hidden_states[first_token_flat_indices]
+        elif self.pooling_type == PoolingType.LAST:
             last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
             pooled_data = hidden_states[last_token_flat_indices]
         elif self.pooling_type == PoolingType.ALL:
diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py
index 410b3cb5321c..38dd1f2e10fc 100644
--- a/vllm/model_executor/layers/quantization/awq.py
+++ b/vllm/model_executor/layers/quantization/awq.py
@@ -3,7 +3,8 @@
 import torch
 
 from vllm import _custom_ops as ops
-from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
+from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
+                                               UnquantizedLinearMethod)
 from vllm.model_executor.layers.quantization.base_config import (
     QuantizationConfig)
 from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@@ -21,10 +22,12 @@ def __init__(
         weight_bits: int,
         group_size: int,
         zero_point: bool,
+        modules_to_not_convert: Optional[List[str]] = None,
     ) -> None:
         self.weight_bits = weight_bits
         self.group_size = group_size
         self.zero_point = zero_point
+        self.modules_to_not_convert = modules_to_not_convert or []
 
         if self.weight_bits != 4:
             raise ValueError(
@@ -35,7 +38,8 @@ def __init__(
     def __repr__(self) -> str:
         return (f"AWQConfig(weight_bits={self.weight_bits}, "
                 f"group_size={self.group_size}, "
-                f"zero_point={self.zero_point})")
+                f"zero_point={self.zero_point}, "
+                f"modules_to_not_convert={self.modules_to_not_convert})")
 
     def get_name(self) -> str:
         return "awq"
@@ -61,11 +65,15 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
         weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
         group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
         zero_point = cls.get_from_keys(config, ["zero_point"])
-        return cls(weight_bits, group_size, zero_point)
+        modules_to_not_convert = cls.get_from_keys_or(
+            config, ["modules_to_not_convert"], None)
+        return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
 
     def get_quant_method(self, layer: torch.nn.Module,
-                         prefix: str) -> Optional["AWQLinearMethod"]:
+                         prefix: str) -> Optional["LinearMethodBase"]:
         if isinstance(layer, LinearBase):
+            if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
+                return UnquantizedLinearMethod()
             return AWQLinearMethod(self)
         return None
 
@@ -73,6 +81,10 @@ def get_scaled_act_names(self) -> List[str]:
         return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
 
 
+def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
+    return any(module_name in prefix for module_name in modules_to_not_convert)
+
+
 class AWQLinearMethod(LinearMethodBase):
     """Linear method for AWQ.
 
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
index abb18d31b5a8..ecc345f116c3 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -1,6 +1,10 @@
 from typing import Any, Dict, List, Optional, cast
 
 import torch
+from compressed_tensors.config import CompressionFormat
+from compressed_tensors.quantization import (QuantizationArgs,
+                                             QuantizationStrategy,
+                                             QuantizationType)
 from pydantic import BaseModel
 
 from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -16,8 +20,7 @@
     CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
     CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
 from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
-    CompressionFormat, QuantizationArgs, QuantizationStrategy,
-    QuantizationType, find_matched_target, is_activation_quantization_format,
+    find_matched_target, is_activation_quantization_format,
     should_ignore_layer)
 from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
 from vllm.platforms import current_platform
@@ -97,12 +100,21 @@ def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
                 target_scheme_map[target][
                     "weights"] = QuantizationArgs.parse_obj(
                         quant_config.get("weights"))
-                try:
-                    target_scheme_map[target][
-                        "input_activations"] = QuantizationArgs.parse_obj(
-                            quant_config.get("input_activations"))
-                except Exception:
-                    target_scheme_map[target]["input_activations"] = None
+
+                target_scheme_map[target]["input_activations"] = None
+                if is_activation_quantization_format(quant_format):
+                    input_activations = quant_config.get("input_activations")
+                    # The only case where we have activation quant supported
+                    # but no input_activations provided in the config
+                    # should be w8a16fp8 w8a16fp8 can also run for cases where
+                    # there is an input_quant but it is ignored
+                    if not input_activations:
+                        assert target_scheme_map[target][
+                            "weights"].type == QuantizationType.FLOAT
+                    else:
+                        target_scheme_map[target][
+                            "input_activations"] = QuantizationArgs.parse_obj(
+                                quant_config.get("input_activations"))
 
         return cls(target_scheme_map=target_scheme_map,
                    ignore=ignore,
@@ -241,8 +253,6 @@ def _get_scheme_from_parts(
                     group_size=weight_quant.group_size,
                     actorder=weight_quant.actorder)
 
-        # Detect If Activation Quantization.
-        # TODO @dsikka: clean-up conditions
         if is_activation_quantization_format(self.quant_format):
             if self._is_fp8_w8a8(weight_quant, input_quant):
                 is_fp8_w8a8_supported = self._check_scheme_supported(
@@ -253,16 +263,19 @@ def _get_scheme_from_parts(
                         is_static_input_scheme=(input_quant
                                                 and not input_quant.dynamic))
                 else:
+                    # note: input_quant will be present for converted models;
+                    # will be ignored during inference post loading
                     return CompressedTensorsW8A16Fp8(
                         strategy=weight_quant.strategy,
-                        is_static_input_scheme=(input_quant
-                                                and not input_quant.dynamic))
+                        is_static_input_scheme=not input_quant.dynamic)
 
+            # note: input_quant can be None
             if self._is_fp8_w8a16(weight_quant, input_quant):
+                is_static_input_scheme = (input_quant
+                                          and not input_quant.dynamic)
                 return CompressedTensorsW8A16Fp8(
                     strategy=weight_quant.strategy,
-                    is_static_input_scheme=(input_quant
-                                            and not input_quant.dynamic))
+                    is_static_input_scheme=is_static_input_scheme)
 
             if self._is_static_tensor_w8a8(weight_quant, input_quant):
                 return CompressedTensorsW8A8Int8(
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index af04d725159f..733eece4b5fa 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -3,14 +3,14 @@
 from typing import Callable, List, Optional
 
 import torch
+from compressed_tensors import CompressionFormat
+from compressed_tensors.quantization import QuantizationStrategy
 
 from vllm import _custom_ops as ops
 from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
                                                   FusedMoeWeightScaleSupported)
 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
     WNA16_SUPPORTED_BITS)
-from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
-    CompressionFormat, QuantizationStrategy)
 from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
     all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
 from vllm.model_executor.utils import set_weight_attrs
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
index 3d55d55cc390..1671a23d77c6 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
@@ -1,11 +1,10 @@
 from typing import Callable, List, Optional
 
 import torch
+from compressed_tensors.quantization import QuantizationStrategy
 
 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme)
-from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
-    QuantizationStrategy)
 from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
     apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
 from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
index 5931ec36c97d..7270b302ef96 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
@@ -1,12 +1,11 @@
 from typing import Callable, List, Optional
 
 import torch
+from compressed_tensors.quantization import QuantizationStrategy
 from torch.nn import Parameter
 
 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme)
-from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
-    QuantizationStrategy)
 from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
     apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
     requantize_with_max_scale)
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
index 245a35c8783a..15d9cdbcbb86 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
@@ -1,13 +1,12 @@
 from typing import Callable, List, Optional
 
 import torch
+from compressed_tensors.quantization import QuantizationStrategy
 from torch.nn import Parameter
 
 from vllm.logger import init_logger
 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme)
-from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
-    QuantizationStrategy)
 from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
     apply_int8_linear, convert_to_channelwise)
 from vllm.model_executor.parameter import (BasevLLMParameter,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
index cb65557be8f9..a51573801778 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
@@ -1,12 +1,11 @@
 from typing import Callable, List, Optional, Set
 
 import torch
+from compressed_tensors.quantization import ActivationOrdering
 
 from vllm.logger import init_logger
 from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
     CompressedTensorsScheme)
-from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
-    ActivationOrdering)
 from vllm.model_executor.layers.quantization.kernels import (
     MPLinearLayerConfig, choose_mp_linear_kernel)
 from vllm.model_executor.layers.quantization.utils.marlin_utils import (
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
index fc531b9d666e..a74eaef5efde 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py
@@ -1,111 +1,13 @@
 import re
-from enum import Enum
-from typing import Any, Dict, Iterable, Optional, Union
+from typing import Iterable, Optional
 
-from pydantic import BaseModel, Field, field_validator
+from compressed_tensors import CompressionFormat
 from torch.nn import Module
 
 from vllm.model_executor.layers.quantization.utils.quant_utils import (
     FUSED_LAYER_NAME_MAPPING)
 
 
-class CompressionFormat(Enum):
-    dense = "dense"
-    sparse_bitmask = "sparse-bitmask"
-    naive_quantized = "naive-quantized"
-    float_quantized = "float-quantized"
-    int_quantized = "int-quantized"
-    pack_quantized = "pack-quantized"
-    marlin_24 = "marlin-24"
-
-
-class QuantizationType(str, Enum):
-    """
-    Enum storing quantization type options
-    """
-
-    INT = "int"
-    FLOAT = "float"
-
-
-class QuantizationStrategy(str, Enum):
-    """
-    Enum storing quantization strategy options
-    """
-
-    TENSOR = "tensor"
-    CHANNEL = "channel"
-    GROUP = "group"
-    BLOCK = "block"
-    TOKEN = "token"
-
-
-class ActivationOrdering(str, Enum):
-    """
-    Enum storing strategies for activation ordering
-
-    Group: reorder groups and weight\n
-    Weight: only reorder weight, not groups. Slightly lower latency and
-    accuracy compared to group actorder\n
-    """
-
-    GROUP = "group"
-    WEIGHT = "weight"
-
-
-class QuantizationArgs(BaseModel):
-    """
-    User facing arguments used to define a quantization config 
-    for weights or activations
-
-    :param num_bits: quantization bit depth
-    :param type: dtype to quantized to, either int or float
-    :param symmetric: whether or not quantization scale is symmetric
-    :param strategy: string determining the scope of scale/zero-point to apply
-    :param group_size: group length to use for the group strategy
-    :param block_structure: 2d block structure to use for the block 
-    strategy, must be of the format "2x4", "8x16", etc.
-    :param dynamic: set True to perform dynamic quantization -
-        values will not be calibrated during calibration phase, 
-        instead during inference new quantization ranges will be 
-        observed with every sample. Defaults to False for static
-        quantization. Note that enabling dynamic quantization 
-        will change the default observer to a memoryless one
-    :param actorder: whether to apply group quantization in decreasing order of
-        activation. Defaults to None for arbitrary ordering
-    """
-
-    num_bits: int = 8
-    type: QuantizationType = QuantizationType.INT
-    symmetric: bool = True
-    group_size: Optional[int] = None
-    strategy: Optional[QuantizationStrategy] = None
-    block_structure: Optional[str] = None
-    dynamic: bool = False
-    actorder: Union[ActivationOrdering, bool, None] = None
-    observer: str = Field(
-        default="minmax",
-        description=("The class to use to compute the quantization param - "
-                     "scale and zero-point'"),
-    )
-    observer_kwargs: Dict[str, Any] = Field(
-        default_factory=dict,
-        description=
-        ("optional dict of kwargs to be passed directly to torch quantization "
-         "Observers constructor excluding quantization range or symmetry"),
-    )
-
-    @field_validator("actorder", mode="before")
-    def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
-        if isinstance(value, bool):
-            return ActivationOrdering.GROUP if value else None
-
-        if isinstance(value, str):
-            return ActivationOrdering(value.lower())
-
-        return value
-
-
 def is_activation_quantization_format(format: str) -> bool:
     _ACTIVATION_QUANTIZATION_FORMATS = [
         CompressionFormat.naive_quantized.value,
diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
index fe50c4930d04..b04612a9b00d 100644
--- a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
+++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py
@@ -42,6 +42,10 @@ def __init__(self,
         self.config = c
         self.w_q_name = w_q_param_name
         self.w_s_name = w_s_param_name
+        if c.zero_points:
+            assert w_zp_param_name is not None
+        if c.has_g_idx:
+            assert w_gidx_param_name is not None
         self.w_zp_name = w_zp_param_name
         self.w_gidx_name = w_gidx_param_name
 
diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py
index 47591c2aa644..94a3dc2584d6 100644
--- a/vllm/model_executor/layers/quantization/kernels/__init__.py
+++ b/vllm/model_executor/layers/quantization/kernels/__init__.py
@@ -1,6 +1,8 @@
-import os
 from typing import List, Optional, Type
 
+import vllm.envs as envs
+from vllm.model_executor.layers.quantization.kernels.exllama import (
+    ExllamaLinearKernel)
 from vllm.model_executor.layers.quantization.kernels.machete import (
     MacheteLinearKernel)
 from vllm.model_executor.layers.quantization.kernels.marlin import (
@@ -13,6 +15,7 @@
 _POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
     MacheteLinearKernel,
     MarlinLinearKernel,
+    ExllamaLinearKernel,
 ]
 
 
@@ -45,8 +48,7 @@ def choose_mp_linear_kernel(
 
     failure_reasons = []
     for kernel in _POSSIBLE_KERNELS:
-        if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
-            .split(","):
+        if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
             failure_reasons.append(
                 f' {kernel.__name__} disabled by environment variable')
             continue
diff --git a/vllm/model_executor/layers/quantization/kernels/exllama.py b/vllm/model_executor/layers/quantization/kernels/exllama.py
new file mode 100644
index 000000000000..1d85d62ec83e
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/kernels/exllama.py
@@ -0,0 +1,140 @@
+from typing import Optional, Tuple
+
+import torch
+
+from vllm import _custom_ops as ops
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+    pack_quantized_values_into_int32)
+from vllm.model_executor.parameter import (BasevLLMParameter,
+                                           permute_param_layout_)
+from vllm.scalar_type import scalar_types
+
+from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
+
+
+class ExllamaLinearKernel(MPLinearKernel):
+    SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
+    # In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
+    # currently untested so not added to the list
+
+    @classmethod
+    def get_min_capability(cls) -> int:
+        return 60
+
+    @classmethod
+    def can_implement(cls,
+                      c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
+        if c.has_g_idx and\
+            c.partition_weight_shape[0] != c.full_weight_shape[0]:
+            return False, "Act reordering currently not supported by Exllama, "\
+                          "when the input features are partitioned across "\
+                          "devices"
+
+        if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
+            return False, "Output features must be a multiple of the pack " \
+                            "factor (32 / num_bits) so that we can correctly " \
+                            "pack the zero points"
+
+        if c.act_type != torch.float16:
+            return False, "Exllama only supports float16 activations"
+
+        if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
+            return False, f"Quant type ({c.weight_type}) not supported by "\
+                           "Exllama, supported types are: "\
+                           f"{cls.SUPPORTED_QUANT_TYPES}"
+
+        if c.full_weight_shape[0] % c.group_size != 0:
+            return False, f"Group size ({c.group_size}) does not evenly divide"\
+                           " the number of input features "\
+                           f"({c.full_weight_shape[0]})"
+
+        return True, None
+
+    def process_weights_after_loading(self, layer: torch.nn.Module):
+        c = self.config
+
+        # For Exllama, we need to set a zero-point tensor if there is not one
+        if not c.zero_points:
+            self.w_zp_name = "qzeros"
+            device = getattr(layer, self.w_q_name).device
+            groups = c.partition_weight_shape[0] // c.group_size
+            out_features = c.partition_weight_shape[1]
+
+            if c.weight_type.has_bias():
+                # if the type has a bias we have to create a zeros tensor that
+                # contains the bias values repeated for each group (-1 due to
+                # a bug in the original GPTQ checkpoint format leading to
+                # exllama kernel adding 1 to the zero points during inference)
+                # Documentation of the bug can be found here:
+                #  https://garden.danieldk.eu/GPTQ-Checkpoint-Format
+                zeros = torch.full((groups, out_features),
+                                   c.weight_type.bias - 1,
+                                   dtype=torch.int32,
+                                   device=device)
+            else:
+                raise NotImplementedError(
+                    "A 0 zero-point is not supported by Exllama due to "
+                    "a bug in the original GPTQ checkpoint format leading to "
+                    "exllama kernel adding 1 to the zero points during "
+                    "inference")
+            zeros = pack_quantized_values_into_int32(zeros,
+                                                     c.weight_type,
+                                                     packed_dim=1)
+            setattr(layer, self.w_zp_name,
+                    torch.nn.Parameter(zeros, requires_grad=False))
+
+        if c.has_g_idx:
+
+            def transform_w_g_idx(x):
+                # Exllama wants the permutation array instead of the group
+                # indices
+                return torch.argsort(x).to(torch.int)
+
+            self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
+        else:
+            self.w_gidx_name = "g_idx"
+            empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
+                                                         dtype=torch.int,
+                                                         device=device),
+                                             requires_grad=False)
+            setattr(layer, self.w_gidx_name, empty_g_idx)
+
+        def transform_w_q(x):
+            assert isinstance(x, BasevLLMParameter)
+            assert self.w_gidx_name is not None
+            g_idx = getattr(layer, self.w_gidx_name)
+
+            permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
+            x_cont = x.data.contiguous()
+            ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
+            return x_cont
+
+        def transform_w_s(x):
+            assert isinstance(x, BasevLLMParameter)
+            permute_param_layout_(x, input_dim=0, output_dim=1)
+            x.data = x.data.contiguous()
+            return x.to(dtype=c.act_type)
+
+        # Repack weights and scales for Machete
+        self._transform_param(layer, self.w_q_name, transform_w_q)
+        self._transform_param(layer, self.w_s_name, transform_w_s)
+
+    def apply_weights(self,
+                      layer: torch.nn.Module,
+                      x: torch.Tensor,
+                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        c = self.config
+
+        x_2d = x.reshape(-1, x.shape[-1])
+        out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
+
+        w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
+
+        assert w_zp is not None, "Zero points are required by Exllama"
+        assert w_g_idx is not None, "Group index is required by Exllama"
+        output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
+                               c.weight_type.size_bits)
+
+        if bias is not None:
+            output.add_(bias)
+        return output.reshape(out_shape)
diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py
index fa39cb511528..e5696d08f30f 100644
--- a/vllm/model_executor/layers/quantization/kernels/machete.py
+++ b/vllm/model_executor/layers/quantization/kernels/machete.py
@@ -8,7 +8,7 @@
     MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
     query_machete_supported_quant_types)
 from vllm.model_executor.layers.quantization.utils.quant_utils import (
-    pack_weights_into_int32, unpack_weights_into_int32)
+    pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
 from vllm.model_executor.parameter import (BasevLLMParameter,
                                            permute_param_layout_)
 
@@ -71,13 +71,13 @@ def transform_w_q(x):
             assert isinstance(x, BasevLLMParameter)
             permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
             if c.has_g_idx:
-                x_unpacked = unpack_weights_into_int32(x.data,
-                                                       c.weight_type,
-                                                       packed_dim=0)
+                x_unpacked = unpack_quantized_values_into_int32(x.data,
+                                                                c.weight_type,
+                                                                packed_dim=0)
                 x_perm = x_unpacked[perm, :]
-                x.data = pack_weights_into_int32(x_perm,
-                                                 c.weight_type,
-                                                 packed_dim=0)
+                x.data = pack_quantized_values_into_int32(x_perm,
+                                                          c.weight_type,
+                                                          packed_dim=0)
             x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
                                            self.config.weight_type)
             return x
diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py
index 833d00073564..c217f5ca620a 100644
--- a/vllm/model_executor/layers/quantization/utils/quant_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py
@@ -20,9 +20,9 @@
 }
 
 
-def pack_weights_into_int32(w_q: torch.Tensor,
-                            wtype: ScalarType,
-                            packed_dim: int = 0):
+def pack_quantized_values_into_int32(w_q: torch.Tensor,
+                                     wtype: ScalarType,
+                                     packed_dim: int = 0):
     # move dim to pack to the end
     perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
     inv_perm = tuple(perm.index(i) for i in range(len(perm)))
@@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
     return res.permute(inv_perm)
 
 
-def unpack_weights_into_int32(w_q: torch.Tensor,
-                              wtype: ScalarType,
-                              packed_dim: int = 0):
+def unpack_quantized_values_into_int32(w_q: torch.Tensor,
+                                       wtype: ScalarType,
+                                       packed_dim: int = 0):
     # move dim to pack to the end
     perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
     inv_perm = tuple(perm.index(i) for i in range(len(perm)))
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index d4e9ed87ed54..2158ad333967 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -72,6 +72,7 @@ def _apply_rotary_emb(
         return torch.stack((o1, o2), dim=-1).flatten(-2)
 
 
+@CustomOp.register("rotary_embedding")
 class RotaryEmbedding(CustomOp):
     """Original rotary positional embedding."""
 
@@ -468,7 +469,7 @@ def __init__(
         self.long_factor = long_factor
 
         scale = self.max_position_embeddings / \
-            self.original_max_position_embeddings
+                self.original_max_position_embeddings
         if scale <= 1.0:
             scaling_factor = 1.0
         else:
@@ -920,13 +921,10 @@ def get_rope(
         rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
                                      is_neox_style, dtype)
     else:
-        scaling_type = rope_scaling[
-            "type"] if "type" in rope_scaling else rope_scaling["rope_type"]
-        # The correct one should be "longrope" but keep "su" here
-        # for backward compatible
-        if scaling_type not in {"su", "longrope"}:
-            scaling_factor = rope_scaling.get("factor", 1.0)
+        scaling_type = rope_scaling["rope_type"]
+
         if scaling_type == "llama3":
+            scaling_factor = rope_scaling["factor"]
             low_freq_factor = rope_scaling["low_freq_factor"]
             high_freq_factor = rope_scaling["high_freq_factor"]
             original_max_position = rope_scaling[
@@ -937,16 +935,39 @@ def get_rope(
                                                scaling_factor, low_freq_factor,
                                                high_freq_factor,
                                                original_max_position)
+        elif scaling_type == "default":
+            if "mrope_section" in rope_scaling:
+                rotary_emb = MRotaryEmbedding(
+                    head_size,
+                    rotary_dim,
+                    max_position,
+                    base,
+                    is_neox_style,
+                    dtype,
+                    mrope_section=rope_scaling["mrope_section"],
+                )
+            else:
+                rotary_emb = RotaryEmbedding(
+                    head_size,
+                    rotary_dim,
+                    max_position,
+                    base,
+                    is_neox_style,
+                    dtype,
+                )
         elif scaling_type == "linear":
+            scaling_factor = rope_scaling["factor"]
             rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
                                                       max_position, base,
                                                       is_neox_style,
                                                       scaling_factor, dtype)
         elif scaling_type == "dynamic":
+            scaling_factor = rope_scaling["factor"]
             rotary_emb = DynamicNTKScalingRotaryEmbedding(
                 head_size, rotary_dim, max_position, base, is_neox_style,
                 scaling_factor, dtype)
         elif scaling_type == "yarn":
+            scaling_factor = rope_scaling["factor"]
             original_max_position = rope_scaling[
                 "original_max_position_embeddings"]
             extra_kwargs = {
@@ -961,6 +982,7 @@ def get_rope(
                                                     scaling_factor, dtype,
                                                     **extra_kwargs)
         elif scaling_type == "deepseek_yarn":
+            scaling_factor = rope_scaling["factor"]
             original_max_position = rope_scaling[
                 "original_max_position_embeddings"]
             # assert max_position == original_max_position * scaling_factor
@@ -973,9 +995,7 @@ def get_rope(
             rotary_emb = DeepseekScalingRotaryEmbedding(
                 head_size, rotary_dim, original_max_position, base,
                 is_neox_style, scaling_factor, dtype, **extra_kwargs)
-        # The correct one should be "longrope" but keep "su" here
-        # for backward compatible
-        elif scaling_type == "su" or scaling_type == "longrope":
+        elif scaling_type == "longrope":
             short_factor = rope_scaling["short_factor"]
             long_factor = rope_scaling["long_factor"]
             original_max_position = rope_scaling[
@@ -989,16 +1009,6 @@ def get_rope(
                 head_size, rotary_dim, max_position, original_max_position,
                 base, is_neox_style, dtype, short_factor, long_factor,
                 **extra_kwargs)
-        elif scaling_type == "mrope":
-            rotary_emb = MRotaryEmbedding(
-                head_size,
-                rotary_dim,
-                max_position,
-                base,
-                is_neox_style,
-                dtype,
-                mrope_section=rope_scaling["mrope_section"],
-            )
         else:
             raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
     _ROPE_DICT[key] = rotary_emb
diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py
index 42a6a0e6b322..f86c6ec362eb 100644
--- a/vllm/model_executor/layers/sampler.py
+++ b/vllm/model_executor/layers/sampler.py
@@ -4,7 +4,7 @@
 from dataclasses import dataclass
 from importlib.util import find_spec
 from math import inf
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, Iterator, List, Optional, Tuple, Union
 
 import msgspec
 import torch
@@ -117,12 +117,15 @@ class SamplerOutput(
     # block/sync across workers, cpu-gpu sync time and sampling time.
     model_execute_time: Optional[float] = None
 
-    def __getitem__(self, idx: int):
+    def __getitem__(self, idx: int) -> CompletionSequenceGroupOutput:
         return self.outputs[idx]
 
     def __setitem__(self, idx: int, value):
         self.outputs[idx] = value
 
+    def __iter__(self) -> Iterator[CompletionSequenceGroupOutput]:
+        return iter(self.outputs)
+
     def __len__(self):
         return len(self.outputs)
 
diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py
index 00c82fb77186..a9f1e6e88d79 100644
--- a/vllm/model_executor/model_loader/neuron.py
+++ b/vllm/model_executor/model_loader/neuron.py
@@ -6,7 +6,6 @@
 
 import torch
 import torch.nn as nn
-import transformers
 from transformers import PretrainedConfig
 
 from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
@@ -108,39 +107,11 @@ def load_weights(self, model_name_or_path: str, **kwargs):
         neuronx_module = importlib.import_module(neuronx_module_path)
         neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
 
-        split_model_dir = f"{model_name_or_path}-split"
-        if _is_pretrained_neuron_checkpoint(model_name_or_path):
-            split_model_dir = model_name_or_path
-        elif not os.path.exists(f"{model_name_or_path}-split"):
-            hf_model_cls = getattr(transformers, hf_model_cls_name)
-            from transformers_neuronx.module import save_pretrained_split
-
-            hf_model = hf_model_cls.from_pretrained(model_name_or_path,
-                                                    low_cpu_mem_usage=True)
-            save_pretrained_split(hf_model, f"{model_name_or_path}-split")
-
-        self.model = neuronx_model_cls.from_pretrained(split_model_dir,
+        self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
                                                        **kwargs)
         self.model.to_neuron()
 
 
-def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
-    # Checking if the neuron checkpoint is saved in the old format.
-    if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
-        return True
-    # Checking if the neuron checkpoint is saved in the new format.
-    pretrained_split_files = ["config.json", "generation_config.json"]
-    pretrained_split_format = ".safetensors"
-    for file in pretrained_split_files:
-        file_path = os.path.join(model_name_or_path, file)
-        if not os.path.isfile(file_path):
-            return False
-    for file in os.listdir(model_name_or_path):
-        if file.endswith(pretrained_split_format):
-            return True
-    return False
-
-
 def _get_model_architecture(config: PretrainedConfig) -> str:
     architectures = getattr(config, "architectures", [])
     for arch in architectures:
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index 1e2857ee28cb..0c51314bc90d 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -499,8 +499,8 @@ def kv_cache_scales_loader(
         logger.error("File or directory '%s' not found.", filename)
     except json.JSONDecodeError:
         logger.error("Error decoding JSON in file '%s'.", filename)
-    except Exception as e:
-        logger.error("An error occurred while reading '%s': %s", filename, e)
+    except Exception:
+        logger.exception("An error occurred while reading '%s'.", filename)
     # This section is reached if and only if any of the excepts are hit
     # Return an empty iterable (list) => no KV cache scales are loaded
     # which ultimately defaults to 1.0 scales
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index 54ed548ba8bc..f2cfdf8ffd30 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -26,6 +26,7 @@
 from transformers import PretrainedConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size)
@@ -250,6 +251,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class BaiChuanModel(nn.Module):
 
     def __init__(self,
@@ -432,7 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
 
 
 class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
-    """Baichuan 13B and Baichuan2 7B/13B."""
+    """Baichuan 13B and Baichuan2 7B/13B.
+    NOTE: the class name has a lower case 'c'.
+    """
 
     def __init__(
         self,
@@ -450,7 +454,9 @@ def __init__(
 
 
 class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
-    """Baichuan 7B."""
+    """Baichuan 7B.
+    NOTE: the class name has an upper case 'C'.
+    """
 
     def __init__(
         self,
diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py
new file mode 100644
index 000000000000..4c0a0e303e65
--- /dev/null
+++ b/vllm/model_executor/models/bert.py
@@ -0,0 +1,419 @@
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import BertConfig
+
+from vllm.attention import Attention, AttentionMetadata, AttentionType
+from vllm.attention.backends.xformers import XFormersImpl
+from vllm.config import CacheConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+                                               QKVParallelLinear,
+                                               RowParallelLinear)
+from vllm.model_executor.layers.pooler import Pooler, PoolingType
+from vllm.model_executor.layers.quantization.base_config import (
+    QuantizationConfig)
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.pooling_metadata import PoolingMetadata
+from vllm.sequence import IntermediateTensors, PoolerOutput
+
+
+class BertEmbedding(nn.Module):
+
+    def __init__(self, config: BertConfig):
+
+        super().__init__()
+        self.size = config.hidden_size
+        self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
+                                                      config.hidden_size)
+        self.position_embeddings = VocabParallelEmbedding(
+            config.max_position_embeddings, config.hidden_size)
+        self.token_type_embeddings = VocabParallelEmbedding(
+            config.type_vocab_size, config.hidden_size)
+        self.LayerNorm = nn.LayerNorm(config.hidden_size,
+                                      eps=config.layer_norm_eps)
+        self.position_ids = nn.Parameter(
+            torch.empty((1, config.max_position_embeddings)), )
+
+        self.position_embedding_type = config.position_embedding_type
+        if self.position_embedding_type != "absolute":
+            raise ValueError("Only 'absolute' position_embedding_type" +
+                             " is supported")
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        input_shape = input_ids.size()
+
+        # Input embeddings.
+        inputs_embeds = self.word_embeddings(input_ids)
+
+        # Position embeddings.
+        position_embeddings = self.position_embeddings(position_ids)
+
+        # Token type embeddings. (TODO: move off hotpath?)
+        token_type_embeddings = self.token_type_embeddings(
+            torch.zeros(input_shape,
+                        dtype=torch.long,
+                        device=inputs_embeds.device))
+
+        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
+        embeddings = self.LayerNorm(embeddings)
+        return embeddings
+
+
+class BertEncoder(nn.Module):
+
+    def __init__(self,
+                 config: BertConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.layer = nn.ModuleList([
+            BertLayer(config=config,
+                      cache_config=cache_config,
+                      quant_config=quant_config,
+                      prefix=f"{prefix}.layer.{layer_idx}")
+            for layer_idx in range(config.num_hidden_layers)
+        ])
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        for i in range(len(self.layer)):
+            layer = self.layer[i]
+            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
+        return hidden_states
+
+
+class BertLayer(nn.Module):
+
+    def __init__(self,
+                 config: BertConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+
+        self.attention = BertAttention(
+            hidden_size=config.hidden_size,
+            num_attention_heads=config.num_attention_heads,
+            layer_norm_eps=config.layer_norm_eps,
+            cache_config=cache_config,
+            quant_config=quant_config,
+            prefix=f"{prefix}.attention")
+
+        self.intermediate = BertIntermediate(
+            hidden_size=config.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            quant_config=quant_config,
+            prefix=f"{prefix}.intermediate")
+
+        self.output = BertOutput(hidden_size=config.hidden_size,
+                                 intermediate_size=config.intermediate_size,
+                                 layer_norm_eps=config.layer_norm_eps,
+                                 quant_config=quant_config,
+                                 prefix=f"{prefix}.output")
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: Optional[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ):
+        attn_output = self.attention(hidden_states, kv_cache, attn_metadata)
+        intermediate_output = self.intermediate(attn_output)
+        output = self.output(intermediate_output, attn_output)
+        return output
+
+
+class BertAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_attention_heads: int,
+        layer_norm_eps: float,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ):
+        super().__init__()
+
+        self.self = BertSelfAttention(hidden_size=hidden_size,
+                                      num_attention_heads=num_attention_heads,
+                                      cache_config=cache_config,
+                                      quant_config=quant_config,
+                                      prefix=f"{prefix}.output")
+
+        self.output = BertSelfOutput(hidden_size=hidden_size,
+                                     layer_norm_eps=layer_norm_eps,
+                                     quant_config=quant_config,
+                                     prefix=f"{prefix}.output")
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        self_output = self.self(hidden_states, kv_cache, attn_metadata)
+        return self.output(self_output, hidden_states)
+
+
+class BertSelfAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_attention_heads: int,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ):
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+
+        self.total_num_heads = num_attention_heads
+        assert self.total_num_heads % tp_size == 0
+
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = self.total_num_heads
+        self.head_dim = self.hidden_size // self.total_num_heads
+        assert self.head_dim * self.total_num_heads == self.hidden_size
+
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size=self.hidden_size,
+            head_size=self.head_dim,
+            total_num_heads=self.total_num_heads,
+            total_num_kv_heads=self.total_num_kv_heads,
+            bias=True,
+            quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj")
+
+        self.attn = Attention(num_heads=self.num_heads,
+                              head_size=self.head_dim,
+                              scale=self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config,
+                              prefix=f"{prefix}.attn")
+
+        if not isinstance(self.attn.impl, XFormersImpl):
+            raise ValueError(
+                "Encoder-only models currently require XFORMERS attention "
+                "backend. Set VLLM_ATTENTION_BACKEND=XFORMERS to use BERT.")
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        output = self.attn(q,
+                           k,
+                           v,
+                           kv_cache,
+                           attn_metadata,
+                           attn_type=AttentionType.ENCODER_ONLY)
+        return output
+
+
+class BertSelfOutput(nn.Module):
+
+    def __init__(self,
+                 hidden_size: int,
+                 layer_norm_eps: float,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.dense = RowParallelLinear(input_size=hidden_size,
+                                       output_size=hidden_size,
+                                       bias=True,
+                                       quant_config=quant_config,
+                                       prefix=f"{prefix}.dense")
+        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+
+    def forward(self, hidden_states: torch.Tensor,
+                input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.dense(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertIntermediate(nn.Module):
+
+    def __init__(self,
+                 hidden_size: int,
+                 intermediate_size: int,
+                 hidden_act: str,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.dense = ColumnParallelLinear(input_size=hidden_size,
+                                          output_size=intermediate_size,
+                                          bias=True,
+                                          quant_config=quant_config,
+                                          prefix=f"{prefix}.dense")
+        self.intermediate_act_fn = get_act_fn(hidden_act)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.dense(hidden_states)
+        hidden_states = self.intermediate_act_fn(hidden_states)
+        return hidden_states
+
+
+class BertOutput(nn.Module):
+
+    def __init__(self,
+                 hidden_size: int,
+                 intermediate_size: int,
+                 layer_norm_eps: float,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+
+        self.dense = RowParallelLinear(input_size=intermediate_size,
+                                       output_size=hidden_size,
+                                       bias=True,
+                                       quant_config=quant_config,
+                                       prefix=f"{prefix}.dense")
+
+        self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+
+    def forward(self, hidden_states: torch.Tensor,
+                input_tensor: torch.Tensor) -> torch.Tensor:
+        hidden_states, _ = self.dense(hidden_states)
+        hidden_states = self.LayerNorm(hidden_states + input_tensor)
+        return hidden_states
+
+
+class BertModel(nn.Module):
+
+    def __init__(self,
+                 config: BertConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.embeddings = BertEmbedding(config)
+        self.encoder = BertEncoder(config,
+                                   cache_config,
+                                   quant_config,
+                                   prefix=f"{prefix}.encoder")
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        position_ids: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        if inputs_embeds is not None:
+            hidden_states = inputs_embeds
+        else:
+            hidden_states = self.embeddings(input_ids=input_ids,
+                                            position_ids=position_ids)
+
+        return self.encoder(hidden_states, kv_caches, attn_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "query", "q"),
+            ("qkv_proj", "key", "k"),
+            ("qkv_proj", "value", "v"),
+        ]
+
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "pooler" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+
+
+class BertEmbeddingModel(nn.Module):
+    """A model that uses Bert to provide embedding functionalities.
+
+   This class encapsulates the BertModel and provides an interface for
+   embedding operations and customized pooling functions.
+
+   Attributes:
+       model: An instance of BertModel used for forward operations.
+       _pooler: An instance of Pooler used for pooling operations.
+   """
+
+    def __init__(
+        self,
+        config: BertConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.model = BertModel(config, cache_config, quant_config)
+        self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor],
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        return self.model(input_ids=input_ids,
+                          position_ids=positions,
+                          kv_caches=kv_caches,
+                          inputs_embeds=inputs_embeds,
+                          intermediate_tensors=intermediate_tensors,
+                          attn_metadata=attn_metadata)
+
+    def pooler(
+        self,
+        hidden_states: torch.Tensor,
+        pooling_metadata: PoolingMetadata,
+    ) -> Optional[PoolerOutput]:
+        return self._pooler(hidden_states, pooling_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        self.model.load_weights(weights)
diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py
index 7c8e76461dd6..1f2d7384076e 100644
--- a/vllm/model_executor/models/blip.py
+++ b/vllm/model_executor/models/blip.py
@@ -10,7 +10,7 @@
 
 from vllm.config import ModelConfig
 from vllm.distributed import divide, get_tensor_model_parallel_world_size
-from vllm.inputs import LLMInputs
+from vllm.inputs import DecoderOnlyInputs, token_inputs
 from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                QKVParallelLinear,
@@ -63,7 +63,7 @@ def dummy_seq_data_for_blip(
     else:
         image_feature_size = image_feature_size_override
 
-    return SequenceData.from_token_counts(
+    return SequenceData.from_prompt_token_counts(
         (image_token_id, image_feature_size * num_images),
         (0, seq_len - image_feature_size * num_images),
     )
@@ -89,14 +89,14 @@ def dummy_image_for_blip(
 def input_processor_for_blip(
     model_config: ModelConfig,
     hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
-    llm_inputs: LLMInputs,
+    inputs: DecoderOnlyInputs,
     *,
     image_token_id: int,
     image_feature_size_override: Optional[int] = None,
 ):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
 
@@ -107,22 +107,22 @@ def input_processor_for_blip(
 
     new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
-        llm_inputs.get("prompt"),
-        llm_inputs["prompt_token_ids"],
+        inputs.get("prompt"),
+        inputs["prompt_token_ids"],
         placeholder_token_id=image_token_id,
         repeat_count=image_feature_size,
     )
 
     # NOTE: Create a defensive copy of the original inputs
-    return LLMInputs(prompt_token_ids=new_token_ids,
-                     prompt=new_prompt,
-                     multi_modal_data=multi_modal_data)
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
 
 
 # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
 class BlipVisionEmbeddings(nn.Module):
 
-    def __init__(self, config: BlipVisionConfig):
+    def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
         super().__init__()
 
         self.config = config
@@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):
 
     def __init__(
         self,
-        config: BlipVisionConfig,
+        config: Union[BlipVisionConfig, Blip2VisionConfig],
         quant_config: Optional[QuantizationConfig] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
         self.config = config
         self.embed_dim = config.hidden_size
@@ -189,11 +190,13 @@ def __init__(
             self.num_heads,
             bias=config.qkv_bias,
             quant_config=quant_config,
+            prefix=f"{prefix}.qkv",
         )
         self.projection = RowParallelLinear(
             self.embed_dim,
             self.embed_dim,
             quant_config=quant_config,
+            prefix=f"{prefix}.projection",
         )
 
         self.tp_size = get_tensor_model_parallel_world_size()
@@ -235,9 +238,12 @@ def forward(
 
 class BlipMLP(nn.Module):
 
-    def __init__(self,
-                 config: BlipVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None):
+    def __init__(
+        self,
+        config: BlipVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         self.config = config
@@ -246,11 +252,13 @@ def __init__(self,
         self.fc1 = ColumnParallelLinear(config.hidden_size,
                                         config.intermediate_size,
                                         bias=True,
-                                        quant_config=quant_config)
+                                        quant_config=quant_config,
+                                        prefix=f"{prefix}.fc1")
         self.fc2 = RowParallelLinear(config.intermediate_size,
                                      config.hidden_size,
                                      bias=True,
-                                     quant_config=quant_config)
+                                     quant_config=quant_config,
+                                     prefix=f"{prefix}.fc2")
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         hidden_states, _ = self.fc1(hidden_states)
@@ -262,24 +270,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
 class BlipEncoderLayer(nn.Module):
 
-    def __init__(self,
-                 config: BlipVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None):
+    def __init__(
+        self,
+        config: BlipVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         # fallback to sdpa attention if tp unavailable
         num_heads = config.num_attention_heads
         tp_size = get_tensor_model_parallel_world_size()
         if USE_XFORMERS_OPS and num_heads % tp_size == 0:
-            self.self_attn = BlipParallelAttention(config,
-                                                   quant_config=quant_config)
+            self.self_attn = BlipParallelAttention(
+                config,
+                quant_config=quant_config,
+                prefix=f"{prefix}.self_attn",
+            )
         else:
             # Blip doesn't have SDPA attention implemented in transformers
             # use eager attention instead for cpu backend
             self.self_attn = BlipAttention(config)
         self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                         eps=config.layer_norm_eps)
-        self.mlp = BlipMLP(config, quant_config=quant_config)
+        self.mlp = BlipMLP(config,
+                           quant_config=quant_config,
+                           prefix=f"{prefix}.mlp")
         self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                         eps=config.layer_norm_eps)
 
@@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
         config: BlipConfig
     """
 
-    def __init__(self,
-                 config: BlipVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 num_hidden_layers_override: Optional[int] = None):
+    def __init__(
+        self,
+        config: BlipVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        num_hidden_layers_override: Optional[int] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         self.config = config
@@ -321,8 +340,10 @@ def __init__(self,
             num_hidden_layers = num_hidden_layers_override
 
         self.layers = nn.ModuleList([
-            BlipEncoderLayer(config=config, quant_config=quant_config)
-            for _ in range(num_hidden_layers)
+            BlipEncoderLayer(config=config,
+                             quant_config=quant_config,
+                             prefix=f"{prefix}.layers.{layer_idx}")
+            for layer_idx in range(num_hidden_layers)
         ])
 
     def forward(self, inputs_embeds: torch.Tensor):
@@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
     config_class = BlipVisionConfig
     main_input_name = "pixel_values"
 
-    def __init__(self,
-                 config: BlipVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 num_hidden_layers_override: Optional[int] = None):
+    def __init__(
+        self,
+        config: BlipVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        num_hidden_layers_override: Optional[int] = None,
+        require_post_norm: Optional[bool] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         tp_size = get_tensor_model_parallel_world_size()
@@ -354,19 +380,24 @@ def __init__(self,
             config=config,
             quant_config=quant_config,
             num_hidden_layers_override=num_hidden_layers_override,
+            prefix=f"{prefix}.encoder",
         )
 
+        num_hidden_layers = config.num_hidden_layers
         if len(self.encoder.layers) > config.num_hidden_layers:
             raise ValueError(
-                f"The original encoder only has {config.num_hidden_layers} "
+                f"The original encoder only has {num_hidden_layers} "
                 f"layers, but you requested {len(self.encoder.layers)} layers."
             )
-        elif len(self.encoder.layers) == config.num_hidden_layers:
+
+        # If possible, skip post_layernorm to conserve memory
+        if require_post_norm is None:
+            require_post_norm = len(self.encoder.layers) == num_hidden_layers
+
+        if require_post_norm:
             self.post_layernorm = nn.LayerNorm(config.hidden_size,
                                                eps=config.layer_norm_eps)
         else:
-            # post_layernorm is unused when we extract intermediate features
-            # In this case, we can skip it to conserve memory
             self.post_layernorm = None
 
     def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index 3ab235754a40..cd2013e91514 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -9,7 +9,8 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@@ -421,7 +422,7 @@ def dummy_seq_data_for_blip2(
     else:
         image_feature_size = image_feature_size_override
 
-    return SequenceData.from_token_counts(
+    return SequenceData.from_prompt_token_counts(
         (image_token_id, image_feature_size * num_images),
         (0, seq_len - image_feature_size * num_images),
     )
@@ -449,10 +450,10 @@ def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
     raise NotImplementedError(msg)
 
 
-def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     hf_config = ctx.get_hf_config(Blip2Config)
     image_feature_size = get_blip2_image_feature_size(hf_config)
@@ -460,15 +461,15 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
     # The original model places image tokens at the front
     # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
     new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
-    new_token_ids += llm_inputs["prompt_token_ids"]
+    new_token_ids += inputs["prompt_token_ids"]
 
-    new_prompt = llm_inputs.get("prompt")
+    new_prompt = inputs.get("prompt")
     if new_prompt is not None:
         new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
 
-    return LLMInputs(prompt_token_ids=new_token_ids,
-                     prompt=new_prompt,
-                     multi_modal_data=multi_modal_data)
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
 
 
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
@@ -489,7 +490,7 @@ def __init__(self,
         self.multimodal_config = multimodal_config
 
         # TODO: Optionally initializes this for supporting embeddings.
-        self.vision_model = BlipVisionModel(config.vision_config)
+        self.vision_model = BlipVisionModel(config.vision_config, quant_config)
 
         self.query_tokens = nn.Parameter(
             torch.zeros(1, config.num_query_tokens,
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index b2c9e221690b..77ab7de6165f 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -24,6 +24,7 @@
 from transformers import BloomConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size)
@@ -218,6 +219,7 @@ def forward(
         return output
 
 
+@support_torch_compile
 class BloomModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index 03c7419f6f6a..aaf559ca386c 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -11,7 +11,8 @@
 from vllm.attention import Attention, AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.layernorm import RMSNorm
 from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -69,7 +70,7 @@ def dummy_seq_data_for_chameleon(
     else:
         image_feature_size = image_feature_size_override
 
-    return SequenceData.from_token_counts(
+    return SequenceData.from_prompt_token_counts(
         (image_token_id, image_feature_size * num_images),
         (0, seq_len - image_feature_size * num_images),
     )
@@ -106,7 +107,8 @@ def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
     return seq_data, mm_data
 
 
-def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
+def input_processor_for_chameleon(ctx: InputContext,
+                                  inputs: DecoderOnlyInputs):
 
     """
     Processing input prompt to insert required tokens for image placeholder.
@@ -114,16 +116,16 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
     See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58
     """ # noqa
 
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     model_config = ctx.model_config
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
     new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
-        llm_inputs.get("prompt"),
-        llm_inputs["prompt_token_ids"],
+        inputs.get("prompt"),
+        inputs["prompt_token_ids"],
         placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
         repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
         pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
@@ -137,9 +139,9 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
     new_token_ids += [CHAMELEON_SEP_TOKEN_ID]
 
     # NOTE: Create a defensive copy of the original inputs
-    return LLMInputs(prompt_token_ids=new_token_ids,
-                     prompt=new_prompt,
-                     multi_modal_data=multi_modal_data)
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
 
 
 class ChameleonLayerNorm(nn.LayerNorm):
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index f26c9f950dd3..ca90d10e9f9f 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -13,8 +13,9 @@
 
 from vllm.attention import Attention, AttentionMetadata
 from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
-from vllm.distributed import get_tensor_model_parallel_world_size
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.logger import init_logger
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.layernorm import RMSNorm
@@ -22,8 +23,7 @@
                                                QKVParallelLinear,
                                                RowParallelLinear)
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.quantization.base_config import (
-    QuantizationConfig)
+from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.rotary_embedding import get_rope
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
 from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -39,7 +39,9 @@
                            SequenceData)
 from vllm.transformers_utils.configs import ChatGLMConfig
 
-from .interfaces import SupportsLoRA, SupportsMultiModal
+from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
+from .utils import (is_pp_missing_parameter,
+                    make_empty_intermediate_tensors_factory, make_layers)
 
 logger = init_logger(__name__)
 
@@ -149,20 +151,24 @@ def find_all_positions(input_ids: List[int], target: int) -> List[int]:
     return [index for index, value in enumerate(input_ids) if value == target]
 
 
-def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
+def input_processor_for_glmv(ctx: InputContext, inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return inputs
+
     hf_config = ctx.get_hf_config(ChatGLMConfig)
     vision_config = getattr(hf_config, 'vision_config', None)
 
     if vision_config is None:
-        return llm_inputs
+        return inputs
     elif isinstance(vision_config, dict):
         image_placeholder_length = calculate_image_placeholder(vision_config)
     else:
         msg = f"Unsupported vision config: {type(vision_config)}"
         raise NotImplementedError(msg)
 
-    input_ids = llm_inputs.get("prompt_token_ids")
-    position_ids = llm_inputs.get("position_ids")
+    input_ids = inputs["prompt_token_ids"]
+
     tokenizer = cached_get_tokenizer(
         ctx.model_config.model,
         trust_remote_code=ctx.model_config.trust_remote_code)
@@ -171,20 +177,19 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
         raw_batch_data = tokenizer.apply_chat_template(
             conversation=[{
                 "role": "user",
-                "image": llm_inputs['multi_modal_data']["image"],
-                "content": llm_inputs['prompt']
+                "image": multi_modal_data["image"],
+                "content": inputs['prompt'],
             }],
             add_generation_prompt=True,
             tokenize=True,
             return_tensors="pt",
-            return_dict=True).data
+            return_dict=True,
+        ).data
     except Exception:
-        logger.error("Failed to process content (%s)", llm_inputs['prompt'])
+        logger.error("Failed to process content (%s)", inputs['prompt'])
         raise
     input_ids = raw_batch_data['input_ids'][0].tolist()
 
-    if position_ids is None:
-        position_ids = list(range(len(input_ids)))
     boi_token_id = hf_config.boi_token_id
     eoi_token_id = hf_config.eoi_token_id
     boi_positions = find_all_positions(input_ids, boi_token_id)
@@ -193,7 +198,6 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
     assert len(boi_positions) == len(eoi_positions)
 
     new_input_ids = []
-    new_position_ids = []
     final_processed_position = 0
     final_processed_position = 0
 
@@ -201,29 +205,28 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs):
         assert boi_position < eoi_position
         new_input_ids.extend(input_ids[final_processed_position:boi_position +
                                        1])
-        new_position_ids.extend(
-            list(range(final_processed_position, boi_position + 1)))
         new_input_ids.extend([input_ids[boi_position + 1]] *
                              image_placeholder_length)
-        new_position_ids.extend([boi_position + 1] * image_placeholder_length)
         final_processed_position = eoi_position
 
     new_input_ids.extend(input_ids[final_processed_position:])
-    new_position_ids.extend(
-        list(range(final_processed_position, len(input_ids))))
 
-    assert len(new_input_ids) == len(new_position_ids)
+    prompt = inputs.get("prompt")
+    if prompt is None:
+        prompt = tokenizer.decode(new_input_ids)
 
-    llm_inputs["prompt_token_ids"] = new_input_ids
-    llm_inputs["position_ids"] = new_position_ids
-    return llm_inputs
+    return token_inputs(
+        prompt_token_ids=new_input_ids,
+        prompt=prompt,
+        multi_modal_data=multi_modal_data,
+    )
 
 
 class GLMAttention(nn.Module):
 
     def __init__(
         self,
-        config,
+        config: ChatGLMConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
@@ -314,7 +317,7 @@ class GLMMLP(nn.Module):
 
     def __init__(
         self,
-        config,
+        config: ChatGLMConfig,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -357,7 +360,7 @@ class GLMBlock(nn.Module):
 
     def __init__(
         self,
-        config,
+        config: ChatGLMConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
@@ -428,9 +431,10 @@ class GLMTransformer(nn.Module):
 
     def __init__(
         self,
-        config,
+        config: ChatGLMConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
     ):
         super().__init__()
         self.post_layer_norm = config.post_layer_norm
@@ -439,10 +443,11 @@ def __init__(
         self.num_layers = config.num_layers
 
         # Transformer layers.
-        self.layers = nn.ModuleList([
-            GLMBlock(config, cache_config, quant_config)
-            for i in range(self.num_layers)
-        ])
+        self.start_layer, self.end_layer, self.layers = make_layers(
+            self.num_layers,
+            lambda prefix: GLMBlock(config, cache_config, quant_config),
+            prefix=f"{prefix}.layers",
+        )
 
         if self.post_layer_norm:
             layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
@@ -450,6 +455,10 @@ def __init__(
             self.final_layernorm = layer_norm_func(
                 config.hidden_size, eps=config.layernorm_epsilon)
 
+        self.make_empty_intermediate_tensors = (
+            make_empty_intermediate_tensors_factory(["hidden_states"],
+                                                    config.hidden_size))
+
     def forward(
         self,
         hidden_states: torch.Tensor,
@@ -457,16 +466,16 @@ def forward(
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        for i in range(self.num_layers):
+        for i in range(self.start_layer, self.end_layer):
             layer = self.layers[i]
             hidden_states = layer(
                 hidden_states=hidden_states,
                 position_ids=position_ids,
-                kv_cache=kv_caches[i],
+                kv_cache=kv_caches[i - self.start_layer],
                 attn_metadata=attn_metadata,
             )
         # Final layer norm.
-        if self.post_layer_norm:
+        if get_pp_group().is_last_rank and self.post_layer_norm:
             hidden_states = self.final_layernorm(hidden_states)
 
         return hidden_states
@@ -476,7 +485,7 @@ class ChatGLMModel(nn.Module):
 
     def __init__(
         self,
-        config,
+        config: ChatGLMConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
@@ -504,6 +513,9 @@ def __init__(
         else:
             self.vision = None
 
+        self.make_empty_intermediate_tensors = (
+            self.encoder.make_empty_intermediate_tensors)
+
     def _parse_and_validate_image_input(
             self, **kwargs: object) -> GLMImagePixelInputs:
 
@@ -529,24 +541,26 @@ def forward(
         intermediate_tensors: Optional[IntermediateTensors] = None,
         **kwargs: object,
     ) -> torch.Tensor:
-
-        inputs_embeds = self.embedding(input_ids)
-        image_input = self._parse_and_validate_image_input(**kwargs)
-
-        if image_input["pixel_values"] is not None:
-            pixel_values = image_input["pixel_values"].to(
-                dtype=inputs_embeds.dtype)
-            image_embeds = self.vision(pixel_values)
-
-            boi_token_id = self.config.boi_token_id
-            eoi_token_id = self.config.eoi_token_id
-
-            inputs_embeds = merge_glm_vision_embeddings(
-                input_ids=input_ids,
-                inputs_embeds=inputs_embeds,
-                vision_embeddings=image_embeds,
-                boi_token_id=boi_token_id,
-                eoi_token_id=eoi_token_id)
+        if intermediate_tensors is None:
+            inputs_embeds = self.embedding(input_ids)
+            image_input = self._parse_and_validate_image_input(**kwargs)
+
+            if image_input["pixel_values"] is not None:
+                pixel_values = image_input["pixel_values"].to(
+                    dtype=inputs_embeds.dtype)
+                image_embeds = self.vision(pixel_values)
+
+                boi_token_id = self.config.boi_token_id
+                eoi_token_id = self.config.eoi_token_id
+
+                inputs_embeds = merge_glm_vision_embeddings(
+                    input_ids=input_ids,
+                    inputs_embeds=inputs_embeds,
+                    vision_embeddings=image_embeds,
+                    boi_token_id=boi_token_id,
+                    eoi_token_id=eoi_token_id)
+        else:
+            inputs_embeds = intermediate_tensors["hidden_states"]
 
         # Run encoder.
         hidden_states = self.encoder(
@@ -555,6 +569,9 @@ def forward(
             kv_caches=kv_caches,
             attn_metadata=attn_metadata,
         )
+
+        if not get_pp_group().is_last_rank:
+            return IntermediateTensors({"hidden_states": hidden_states})
         return hidden_states
 
 
@@ -562,7 +579,8 @@ def forward(
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
-class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
+class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
+                         SupportsMultiModal):
     packed_modules_mapping = {
         "query_key_value": ["query_key_value"],
         "dense_h_to_4h": ["dense_h_to_4h"]
@@ -610,7 +628,8 @@ def forward(self,
                 intermediate_tensors: Optional[IntermediateTensors] = None,
                 **kwargs) -> torch.Tensor:
         hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         attn_metadata, **kwargs)
+                                         attn_metadata, intermediate_tensors,
+                                         **kwargs)
         return hidden_states
 
     def compute_logits(
@@ -656,6 +675,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             # Skip loading extra bias for GPTQ models.
             if name.endswith(".bias") and name not in params_dict:
                 continue
+            if is_pp_missing_parameter(name, self):
+                continue
             param = params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py
index edfb0c2b5e19..6b45cb384d4a 100644
--- a/vllm/model_executor/models/clip.py
+++ b/vllm/model_executor/models/clip.py
@@ -11,7 +11,7 @@
 
 from vllm.config import ModelConfig
 from vllm.distributed import divide, get_tensor_model_parallel_world_size
-from vllm.inputs import LLMInputs
+from vllm.inputs import DecoderOnlyInputs, token_inputs
 from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                QKVParallelLinear,
@@ -62,7 +62,7 @@ def dummy_seq_data_for_clip(
     else:
         image_feature_size = image_feature_size_override
 
-    return SequenceData.from_token_counts(
+    return SequenceData.from_prompt_token_counts(
         (image_token_id, image_feature_size * num_images),
         (0, seq_len - image_feature_size * num_images),
     )
@@ -106,14 +106,14 @@ def dummy_video_for_clip(
 def input_processor_for_clip(
     model_config: ModelConfig,
     hf_config: CLIPVisionConfig,
-    llm_inputs: LLMInputs,
+    inputs: DecoderOnlyInputs,
     *,
     image_token_id: int,
     image_feature_size_override: Optional[Union[int, List[int]]] = None,
 ):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
 
@@ -130,16 +130,16 @@ def input_processor_for_clip(
 
     new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
-        llm_inputs.get("prompt"),
-        llm_inputs["prompt_token_ids"],
+        inputs.get("prompt"),
+        inputs["prompt_token_ids"],
         placeholder_token_id=image_token_id,
         repeat_count=image_feature_size,
     )
 
     # NOTE: Create a defensive copy of the original inputs
-    return LLMInputs(prompt_token_ids=new_token_ids,
-                     prompt=new_prompt,
-                     multi_modal_data=multi_modal_data)
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
 
 
 # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
@@ -192,6 +192,7 @@ def __init__(
         self,
         config: CLIPVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
     ):
         super().__init__()
         self.config = config
@@ -211,12 +212,14 @@ def __init__(
             head_size=self.head_dim,
             total_num_heads=self.num_heads,
             quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj",
         )
 
         self.out_proj = RowParallelLinear(
             input_size=self.embed_dim,
             output_size=self.embed_dim,
             quant_config=quant_config,
+            prefix=f"{prefix}.out_proj",
         )
 
         self.tp_size = get_tensor_model_parallel_world_size()
@@ -259,20 +262,25 @@ def forward(
 
 class CLIPMLP(nn.Module):
 
-    def __init__(self,
-                 config: CLIPVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None):
+    def __init__(
+        self,
+        config: CLIPVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
         self.config = config
         self.activation_fn = get_act_fn(config.hidden_act)
         self.fc1 = ColumnParallelLinear(config.hidden_size,
                                         config.intermediate_size,
                                         bias=True,
-                                        quant_config=quant_config)
+                                        quant_config=quant_config,
+                                        prefix=f"{prefix}.fc1")
         self.fc2 = RowParallelLinear(config.intermediate_size,
                                      config.hidden_size,
                                      bias=True,
-                                     quant_config=quant_config)
+                                     quant_config=quant_config,
+                                     prefix=f"{prefix}.fc2")
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         hidden_states, _ = self.fc1(hidden_states)
@@ -284,21 +292,29 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
 class CLIPEncoderLayer(nn.Module):
 
-    def __init__(self,
-                 config: CLIPVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None):
+    def __init__(
+        self,
+        config: CLIPVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         num_heads = config.num_attention_heads
         tp_size = get_tensor_model_parallel_world_size()
         if USE_XFORMERS_OPS and num_heads % tp_size == 0:
-            self.self_attn = CLIPParallelAttention(config,
-                                                   quant_config=quant_config)
+            self.self_attn = CLIPParallelAttention(
+                config,
+                quant_config=quant_config,
+                prefix=f"{prefix}.self_attn",
+            )
         else:
             self.self_attn = CLIPSdpaAttention(config)
         self.layer_norm1 = nn.LayerNorm(config.hidden_size,
                                         eps=config.layer_norm_eps)
-        self.mlp = CLIPMLP(config, quant_config=quant_config)
+        self.mlp = CLIPMLP(config,
+                           quant_config=quant_config,
+                           prefix=f"{prefix}.mlp")
         self.layer_norm2 = nn.LayerNorm(config.hidden_size,
                                         eps=config.layer_norm_eps)
 
@@ -327,11 +343,15 @@ class CLIPEncoder(nn.Module):
         config: CLIPConfig
     """
 
-    def __init__(self,
-                 config: CLIPVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 num_hidden_layers_override: Optional[int] = None):
+    def __init__(
+        self,
+        config: CLIPVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        num_hidden_layers_override: Optional[int] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
 
         if num_hidden_layers_override is None:
@@ -339,8 +359,10 @@ def __init__(self,
         else:
             num_hidden_layers = num_hidden_layers_override
         self.layers = nn.ModuleList([
-            CLIPEncoderLayer(config=config, quant_config=quant_config)
-            for _ in range(num_hidden_layers)
+            CLIPEncoderLayer(config=config,
+                             quant_config=quant_config,
+                             prefix=f"{prefix}.layers.{layer_idx}")
+            for layer_idx in range(num_hidden_layers)
         ])
 
     def forward(self, inputs_embeds: torch.Tensor):
@@ -354,11 +376,17 @@ def forward(self, inputs_embeds: torch.Tensor):
 
 class CLIPVisionTransformer(nn.Module):
 
-    def __init__(self,
-                 config: CLIPVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 num_hidden_layers_override: Optional[int] = None):
+    def __init__(
+        self,
+        config: CLIPVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        num_hidden_layers_override: Optional[int] = None,
+        require_post_norm: Optional[bool] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
         embed_dim = config.hidden_size
 
@@ -370,19 +398,25 @@ def __init__(self,
         self.encoder = CLIPEncoder(
             config=config,
             quant_config=quant_config,
-            num_hidden_layers_override=num_hidden_layers_override)
+            num_hidden_layers_override=num_hidden_layers_override,
+            prefix=f"{prefix}.encoder",
+        )
 
+        num_hidden_layers = config.num_hidden_layers
         if len(self.encoder.layers) > config.num_hidden_layers:
             raise ValueError(
-                f"The original encoder only has {config.num_hidden_layers} "
+                f"The original encoder only has {num_hidden_layers} "
                 f"layers, but you requested {len(self.encoder.layers)} layers."
             )
-        elif len(self.encoder.layers) == config.num_hidden_layers:
+
+        # If possible, skip post_layernorm to conserve memory
+        if require_post_norm is None:
+            require_post_norm = len(self.encoder.layers) == num_hidden_layers
+
+        if require_post_norm:
             self.post_layernorm = nn.LayerNorm(embed_dim,
                                                eps=config.layer_norm_eps)
         else:
-            # post_layernorm is unused when we extract intermediate features
-            # In this case, we can skip it to conserve memory
             self.post_layernorm = None
 
     def forward(
@@ -405,10 +439,15 @@ class CLIPVisionModel(nn.Module):
     config_class = CLIPVisionConfig
     main_input_name = "pixel_values"
 
-    def __init__(self,
-                 config: CLIPVisionConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 num_hidden_layers_override: Optional[int] = None):
+    def __init__(
+        self,
+        config: CLIPVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        num_hidden_layers_override: Optional[int] = None,
+        require_post_norm: Optional[bool] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         tp_size = get_tensor_model_parallel_world_size()
@@ -418,7 +457,10 @@ def __init__(self,
         self.vision_model = CLIPVisionTransformer(
             config=config,
             quant_config=quant_config,
-            num_hidden_layers_override=num_hidden_layers_override)
+            num_hidden_layers_override=num_hidden_layers_override,
+            require_post_norm=require_post_norm,
+            prefix=f"{prefix}.vision_model",
+        )
 
     def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
         return self.vision_model(pixel_values)
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index 578cd2f04861..348e6d20f329 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -28,6 +28,7 @@
 from transformers import CohereConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import SiluAndMul
@@ -250,6 +251,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class CohereModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 702be7b7f5ed..38114836bfdb 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -242,7 +242,7 @@ def __init__(
                                         bias=False,
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.o_proj")
-        rope_scaling['type'] = 'deepseek_yarn'
+        rope_scaling["rope_type"] = 'deepseek_yarn'
         self.rotary_emb = get_rope(qk_rope_head_dim,
                                    rotary_dim=qk_rope_head_dim,
                                    max_position=max_position_embeddings,
diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py
index 13811d33768a..a87e1c022862 100644
--- a/vllm/model_executor/models/eagle.py
+++ b/vllm/model_executor/models/eagle.py
@@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
         self.model = model_cls(self.config.model, *args, **kwargs)
         self.fc = nn.Linear(config.model.hidden_size * 2,
                             config.model.hidden_size,
-                            bias=getattr(self.config, "bias", False))
+                            bias=getattr(self.config, "eagle_fc_bias", False))
 
         self.orig_vocab_size = config.vocab_size
         self.truncated_vocab_size = config.truncated_vocab_size
diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py
index dfb8fe55d2fb..4126ceb7117d 100644
--- a/vllm/model_executor/models/exaone.py
+++ b/vllm/model_executor/models/exaone.py
@@ -29,6 +29,7 @@
 from torch import nn
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size)
@@ -311,6 +312,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class ExaoneModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py
new file mode 100644
index 000000000000..6840ac8b9e30
--- /dev/null
+++ b/vllm/model_executor/models/florence2.py
@@ -0,0 +1,261 @@
+import math
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+
+from vllm.attention import AttentionMetadata
+from vllm.config import CacheConfig
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+    QuantizationConfig)
+from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.bart import (BartDecoder, BartEncoder,
+                                             BartParallelLMHead,
+                                             BartScaledWordEmbedding)
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .utils import AutoWeightsLoader
+
+
+class Florence2LanguageModel(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.config = config
+
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model)
+        self.encoder = BartEncoder(config,
+                                   cache_config=cache_config,
+                                   quant_config=quant_config)
+        self.decoder = BartDecoder(config,
+                                   cache_config=cache_config,
+                                   quant_config=quant_config)
+
+        if self.config.tie_word_embeddings:
+            self.encoder.embed_tokens.weight = self.shared.weight
+            self.decoder.embed_tokens.weight = self.shared.weight
+
+    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
+                encoder_input_ids: torch.Tensor,
+                encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        r"""
+        Args:
+            input_ids
+                Indices of *decoder* input sequence tokens in the vocabulary.
+                Padding will be ignored by default should you
+                provide it.
+            positions
+                Positions of *decoder* input sequence tokens.
+            encoder_input_ids
+                Indices of *encoder* input sequence tokens in the vocabulary.
+            encoder_positions:
+                Positions of *encoder* input sequence tokens.
+            kv_caches:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                vLLM Attention metadata structure
+        Returns:
+            Model output torch.Tensor
+        """
+
+        encoder_hidden_states = None
+
+        if encoder_input_ids.numel() > 0:
+            # Run encoder attention if a non-zero number of encoder tokens
+            # are provided as input
+            encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
+                                                 positions=encoder_positions,
+                                                 kv_caches=kv_caches,
+                                                 attn_metadata=attn_metadata)
+
+        # decoder outputs consists of
+        # (dec_features, past_key_value, dec_hidden, dec_attn)
+        decoder_outputs = self.decoder(
+            decoder_input_ids=input_ids,
+            decoder_positions=positions,
+            encoder_hidden_states=encoder_hidden_states,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata)
+
+        return decoder_outputs
+
+
+class Florence2LanguageForConditionalGeneration(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.config = config
+        self.model = Florence2LanguageModel(config,
+                                            cache_config=cache_config,
+                                            quant_config=quant_config)
+        embed_scale = math.sqrt(
+            config.d_model) if config.scale_embedding else 1.0
+
+        self.vocab_size = config.vocab_size
+        self.lm_head = BartParallelLMHead(self.vocab_size,
+                                          config.d_model,
+                                          embed_scale=embed_scale)
+
+        self.logits_processor = LogitsProcessor(self.vocab_size,
+                                                config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        encoder_input_ids: torch.Tensor,
+        encoder_positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        **kwargs,
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            input_ids
+                torch.Tensor of *decoder* input token ids.
+            positions
+                torch.Tensor of *decoder* position indices.
+            encoder_input_ids
+                torch.Tensor of *encoder* input token ids.
+            encoder_positions
+                torch.Tensor of *encoder* position indices
+            kv_caches:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                vLLM Attention metadata structure
+        Returns:
+            Output torch.Tensor
+        """
+        return self.model(input_ids, positions, encoder_input_ids,
+                          encoder_positions, kv_caches, attn_metadata)
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(self, logits: torch.Tensor,
+               sampling_metadata: SamplingMetadata) -> SamplerOutput:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+        ]
+
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+
+                param = params_dict[name.replace(weight_name, param_name)]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                if "final_logits_bias" in name:
+                    continue
+                if self.config.tie_word_embeddings and "embed_tokens" in name:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+
+
+class Florence2ForConditionalGeneration(nn.Module):
+
+    def __init__(self,
+                 config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+
+        # TODO(Isotr0py): Add vision backbone
+        self.language_model = Florence2LanguageForConditionalGeneration(
+            config=config.text_config,
+            cache_config=cache_config,
+            quant_config=quant_config)
+
+    @property
+    def sampler(self):
+        return self.language_model.sampler
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        *,
+        encoder_input_ids: torch.Tensor,
+        encoder_positions: torch.Tensor,
+        **kwargs,
+    ) -> torch.Tensor:
+        r"""
+        Args:
+            input_ids
+                torch.Tensor of *decoder* input token ids.
+            positions
+                torch.Tensor of *decoder* position indices.
+            encoder_input_ids
+                torch.Tensor of *encoder* input token ids.
+            encoder_positions
+                torch.Tensor of *encoder* position indices
+            kv_caches:
+                Layer-wise list of KV cache tensors
+            attn_metadata:
+                vLLM Attention metadata structure
+        Returns:
+            Output torch.Tensor
+        """
+        return self.language_model(input_ids, positions, encoder_input_ids,
+                                   encoder_positions, kv_caches, attn_metadata)
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
+        return self.language_model.compute_logits(hidden_states,
+                                                  sampling_metadata)
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> SamplerOutput:
+        return self.language_model.sample(logits, sampling_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        skip_prefixes = [
+            'image_projection', "vision_tower", "image_proj_norm",
+            "image_pos_embed", "visual_temporal_embed"
+        ]
+        loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
+        loader.load_weights(weights)
diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py
index 62a1b1f8cd4c..358d1dd288c4 100644
--- a/vllm/model_executor/models/fuyu.py
+++ b/vllm/model_executor/models/fuyu.py
@@ -27,7 +27,8 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.model_executor.layers.linear import ColumnParallelLinear
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import SamplerOutput
@@ -149,10 +150,10 @@ def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
     return model_image_input
 
 
-def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     model_config = ctx.model_config
     image_data = multi_modal_data["image"]
@@ -176,8 +177,8 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
         raise TypeError(f"Invalid image type: {type(image_data)}")
 
     # process prompts
-    prompt = llm_inputs.get("prompt")
-    prompt_token_ids = llm_inputs["prompt_token_ids"]
+    prompt = inputs.get("prompt")
+    prompt_token_ids = inputs["prompt_token_ids"]
     tokenizer = cached_get_tokenizer(model_config.model)
     # dim0 is batch_size, dim1 is subseq_size which will always be 1
     image_input_ids: List[List[
@@ -190,9 +191,9 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
     new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
         1:] + boa_token
 
-    return LLMInputs(prompt=new_prompt,
-                     prompt_token_ids=new_prompt_token_ids,
-                     multi_modal_data=new_multi_modal_data)
+    return token_inputs(prompt=new_prompt,
+                        prompt_token_ids=new_prompt_token_ids,
+                        multi_modal_data=new_multi_modal_data)
 
 
 def input_mapper_for_fuyu(ctx: InputContext, data: object):
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index 91e556db70a0..436bd45d53f3 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -22,6 +22,7 @@
 from transformers import GemmaConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.logger import init_logger
@@ -239,6 +240,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class GemmaModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py
index bcb03ef55ef9..d79248f93f5a 100644
--- a/vllm/model_executor/models/gemma2.py
+++ b/vllm/model_executor/models/gemma2.py
@@ -31,14 +31,16 @@
                                                QKVParallelLinear,
                                                RowParallelLinear)
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.pooler import Pooler, PoolingType
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.rotary_embedding import get_rope
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
 from vllm.model_executor.layers.vocab_parallel_embedding import (
     VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.pooling_metadata import PoolingMetadata
 from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.sequence import IntermediateTensors
+from vllm.sequence import IntermediateTensors, PoolerOutput
 
 from .interfaces import SupportsLoRA, SupportsPP
 from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
@@ -239,13 +241,7 @@ def forward(
         return hidden_states, residual
 
 
-@support_torch_compile(
-    dynamic_arg_dims={
-        "input_ids": 0,
-        "positions": 0,
-        "inputs_embeds": 0,
-        "intermediate_tensors": 0,
-    })
+@support_torch_compile
 class Gemma2Model(nn.Module):
 
     def __init__(
@@ -461,3 +457,50 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                            if self.config.tie_word_embeddings else None),
         )
         loader.load_weights(weights)
+
+
+class Gemma2EmbeddingModel(nn.Module, SupportsPP):
+    """
+    A model that uses Gemma2 with additional embedding functionalities.
+
+    This class encapsulates the Gemma2Model and provides an interface for
+    embedding operations and customized pooling functions.
+
+    Attributes:
+        model: An instance of Gemma2Model used for forward operations.
+        _pooler: An instance of Pooler used for pooling operations.
+    """
+
+    def __init__(
+        self,
+        **kwargs,
+    ) -> None:
+        super().__init__()
+
+        self.model = Gemma2Model(**kwargs)
+        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+
+        self.make_empty_intermediate_tensors = (
+            self.model.make_empty_intermediate_tensors)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor],
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        return self.model(input_ids, positions, kv_caches, attn_metadata,
+                          intermediate_tensors, inputs_embeds)
+
+    def pooler(
+        self,
+        hidden_states: torch.Tensor,
+        pooling_metadata: PoolingMetadata,
+    ) -> Optional[PoolerOutput]:
+        return self._pooler(hidden_states, pooling_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        self.model.load_weights(weights)
diff --git a/vllm/model_executor/models/gemma2_embedding.py b/vllm/model_executor/models/gemma2_embedding.py
deleted file mode 100644
index e8e10598c164..000000000000
--- a/vllm/model_executor/models/gemma2_embedding.py
+++ /dev/null
@@ -1,57 +0,0 @@
-from typing import Iterable, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-from vllm.attention import AttentionMetadata
-from vllm.model_executor.layers.pooler import Pooler, PoolingType
-from vllm.model_executor.pooling_metadata import PoolingMetadata
-from vllm.sequence import IntermediateTensors, PoolerOutput
-
-from .gemma2 import Gemma2Model
-from .interfaces import SupportsPP
-
-
-class Gemma2EmbeddingModel(nn.Module, SupportsPP):
-    """A model that uses Gemma2 with additional embedding functionalities.
-
-   This class encapsulates the Gemma2Model and provides an interface for
-   embedding operations and customized pooling functions.
-
-   Attributes:
-       model: An instance of Gemma2Model used for forward operations.
-       _pooler: An instance of Pooler used for pooling operations.
-   """
-
-    def __init__(
-        self,
-        **kwargs,
-    ) -> None:
-        super().__init__()
-        self.model = Gemma2Model(**kwargs)
-        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
-
-        self.make_empty_intermediate_tensors = (
-            self.model.make_empty_intermediate_tensors)
-
-    def forward(
-        self,
-        input_ids: Optional[torch.Tensor],
-        positions: torch.Tensor,
-        kv_caches: List[torch.Tensor],
-        attn_metadata: AttentionMetadata,
-        intermediate_tensors: Optional[IntermediateTensors] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-    ) -> Union[torch.Tensor, IntermediateTensors]:
-        return self.model(input_ids, positions, kv_caches, attn_metadata,
-                          intermediate_tensors, inputs_embeds)
-
-    def pooler(
-        self,
-        hidden_states: torch.Tensor,
-        pooling_metadata: PoolingMetadata,
-    ) -> Optional[PoolerOutput]:
-        return self._pooler(hidden_states, pooling_metadata)
-
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        self.model.load_weights(weights)
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 975502340e5f..3330d8402136 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -24,6 +24,7 @@
 from transformers import GPT2Config
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed.parallel_state import (
     get_pp_group, get_tensor_model_parallel_world_size)
@@ -182,6 +183,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class GPT2Model(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index 6c4a04667c5d..fb151c2bf4bf 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -25,17 +25,23 @@
 from transformers import GPTBigCodeConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
+# yapf conflicts with isort for this block
+# yapf: disable
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                QKVParallelLinear,
-                                               RowParallelLinear)
+                                               QuantizationConfigOverride,
+                                               RowParallelLinear,
+                                               TiedWeightLinearMethod)
+# yapf: enable
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
 from vllm.model_executor.layers.vocab_parallel_embedding import (
-    ParallelLMHead, VocabParallelEmbedding)
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.sequence import IntermediateTensors
@@ -187,6 +193,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class GPTBigCodeModel(nn.Module):
 
     def __init__(
@@ -205,9 +212,15 @@ def __init__(
         lora_vocab = (lora_config.lora_extra_vocab_size *
                       (lora_config.max_loras or 1)) if lora_config else 0
         self.vocab_size = config.vocab_size + lora_vocab
-        self.wte = VocabParallelEmbedding(self.vocab_size,
-                                          self.embed_dim,
-                                          org_num_embeddings=config.vocab_size)
+        self.wte = VocabParallelEmbedding(
+            self.vocab_size,
+            self.embed_dim,
+            org_num_embeddings=config.vocab_size,
+            padding_size=DEFAULT_VOCAB_PADDING_SIZE
+            # We need bigger padding if using lora for kernel
+            # compatibility
+            if not lora_config else lora_config.lora_vocab_padding_size,
+        )
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.start_layer, self.end_layer, self.h = make_layers(
             config.num_hidden_layers,
@@ -249,7 +262,7 @@ def forward(
 class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
     packed_modules_mapping = {"c_attn": ["c_attn"]}
 
-    supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
+    supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
 
     embedding_modules = {
         "wte": "input_embeddings",
@@ -273,16 +286,38 @@ def __init__(
         self.quant_config = quant_config
         self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
                                            lora_config)
-        if self.config.tie_word_embeddings:
-            self.lm_head = self.transformer.wte
-        else:
-            self.lm_head = ParallelLMHead(
-                self.transformer.vocab_size,
-                self.transformer.embed_dim,
-                org_num_embeddings=self.config.vocab_size)
+
         self.unpadded_vocab_size = config.vocab_size
         if lora_config:
             self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+
+        if self.config.tie_word_embeddings:
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size,
+                quant_config=QuantizationConfigOverride(
+                    TiedWeightLinearMethod),
+                params_dtype=self.transformer.wte.weight.dtype,
+            )
+            self.lm_head.register_parameter("weight",
+                                            self.transformer.wte.weight)
+        else:
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size,
+                quant_config=quant_config,
+            )
+
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size)
         self.sampler = Sampler()
@@ -321,7 +356,7 @@ def sample(
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
         for name, loaded_weight in weights:
-            if "lm_head.weight" in name:
+            if "lm_head.weight" in name and self.config.tie_word_embeddings:
                 continue
             if ".attn.bias" in name:
                 # Skip attention mask.
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index d40bf8c88ee1..0451d16b6c73 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -23,6 +23,7 @@
 from transformers import GPTJConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
@@ -174,6 +175,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class GPTJModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index 23a1ca06cc69..1bccef7a5f17 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -23,6 +23,7 @@
 from transformers import GPTNeoXConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
@@ -187,6 +188,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class GPTNeoXModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py
index dcf4f5b27704..5a397ed8ff6a 100644
--- a/vllm/model_executor/models/granite.py
+++ b/vllm/model_executor/models/granite.py
@@ -28,6 +28,7 @@
 from transformers import GraniteConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size)
@@ -254,6 +255,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class GraniteModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
index 3b0b6febaa48..43f4f29814e6 100644
--- a/vllm/model_executor/models/idefics2_vision_model.py
+++ b/vllm/model_executor/models/idefics2_vision_model.py
@@ -113,7 +113,8 @@ def __init__(
         self,
         config: Idefics2Config,
         quant_config: Optional[QuantizationConfig] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
         self.config = config
         self.embed_dim = config.hidden_size
@@ -130,12 +131,14 @@ def __init__(
             self.head_dim,
             self.num_heads,
             quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj",
         )
         self.out_proj = RowParallelLinear(
             self.embed_dim,
             self.embed_dim,
             bias=True,
             quant_config=quant_config,
+            prefix=f"{prefix}.out_proj",
         )
         self.tp_size = get_tensor_model_parallel_world_size()
         self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
@@ -178,7 +181,8 @@ def __init__(
         self,
         config: Idefics2Config,
         quant_config: Optional[QuantizationConfig] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
         self.config = config
         self.activation_fn = get_act_fn(config.hidden_act)
@@ -187,12 +191,14 @@ def __init__(
             config.intermediate_size,
             bias=True,
             quant_config=quant_config,
+            prefix=f"{prefix}.fc1",
         )
         self.fc2 = RowParallelLinear(
             config.intermediate_size,
             config.hidden_size,
             bias=True,
             quant_config=quant_config,
+            prefix=f"{prefix}.fc2",
         )
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -204,13 +210,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
 class Idefics2EncoderLayer(nn.Module):
 
-    def __init__(self, config: Idefics2Config):
+    def __init__(
+        self,
+        config: Idefics2Config,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
         self.embed_dim = config.hidden_size
-        self.self_attn = Idefics2VisionAttention(config)
+        self.self_attn = Idefics2VisionAttention(config,
+                                                 quant_config=quant_config,
+                                                 prefix=f"{prefix}.self_attn")
         self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                         eps=config.layer_norm_eps)
-        self.mlp = Idefics2VisionMLP(config)
+        self.mlp = Idefics2VisionMLP(config,
+                                     quant_config=quant_config,
+                                     prefix=f"{prefix}.mlp")
         self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                         eps=config.layer_norm_eps)
 
@@ -245,12 +260,20 @@ class Idefics2Encoder(nn.Module):
         config: Idefics2Config
     """
 
-    def __init__(self, config: Idefics2Config):
+    def __init__(
+        self,
+        config: Idefics2Config,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
         self.layers = nn.ModuleList([
-            Idefics2EncoderLayer(config)
-            for _ in range(config.num_hidden_layers)
+            Idefics2EncoderLayer(config,
+                                 quant_config=quant_config,
+                                 prefix=f"{prefix}.layers.{layer_idx}")
+            for layer_idx in range(config.num_hidden_layers)
         ])
 
     def forward(
@@ -275,12 +298,20 @@ def forward(
 
 class Idefics2VisionTransformer(nn.Module):
 
-    def __init__(self, config: Idefics2VisionConfig):
+    def __init__(
+        self,
+        config: Idefics2VisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         embed_dim = config.hidden_size
         self.config = config
         self.embeddings = Idefics2VisionEmbeddings(config)
-        self.encoder = Idefics2Encoder(config)
+        self.encoder = Idefics2Encoder(config,
+                                       quant_config=quant_config,
+                                       prefix=f"{prefix}.encoder")
         self.post_layernorm = nn.LayerNorm(embed_dim,
                                            eps=config.layer_norm_eps)
 
diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py
index 35be1cec3d43..9761635d2a6c 100644
--- a/vllm/model_executor/models/intern_vit.py
+++ b/vllm/model_executor/models/intern_vit.py
@@ -97,6 +97,37 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
         return embeddings
 
 
+class InternVisionPatchModel(nn.Module):
+
+    def __init__(self, config: PretrainedConfig):
+        super().__init__()
+        self.config = config
+        self.embeddings = InternVisionEmbeddings(config)
+
+    def get_input_embeddings(self):
+        return self.embeddings
+
+    def forward(
+        self,
+        pixel_values: Optional[torch.Tensor] = None,
+        pixel_embeds: Optional[torch.Tensor] = None,
+    ) -> torch.FloatTensor:
+        if pixel_values is None and pixel_embeds is None:
+            raise ValueError(
+                'You have to specify pixel_values or pixel_embeds')
+
+        if pixel_embeds is not None:
+            hidden_states = pixel_embeds
+        elif pixel_values is not None:
+            if pixel_values.ndim == 4:
+                hidden_states = self.embeddings(pixel_values)
+            else:
+                raise ValueError(
+                    f'wrong pixel_values size: {pixel_values.shape}')
+
+        return hidden_states
+
+
 class InternParallelAttention(nn.Module):
     """Multi-headed attention from 'Attention Is All You Need' paper"""
 
@@ -106,6 +137,7 @@ def __init__(
         quant_config: Optional[QuantizationConfig] = None,
         *,
         num_dummy_heads: int = 0,
+        prefix: str = "",
     ) -> None:
         super().__init__()
 
@@ -134,6 +166,7 @@ def __init__(
             num_dummy_heads + self.num_heads,
             bias=config.qkv_bias,
             quant_config=quant_config,
+            prefix=f"{prefix}.qkv",
         )
 
         self.qk_normalization = config.qk_normalization
@@ -150,6 +183,7 @@ def __init__(
             self.dummy_dim,
             self.embed_dim,
             quant_config=quant_config,
+            prefix=f"{prefix}.proj",
         )
 
     def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
@@ -253,20 +287,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 class InternMLP(nn.Module):
 
-    def __init__(self,
-                 config: PretrainedConfig,
-                 quant_config: Optional[QuantizationConfig] = None):
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
         self.activation_fn = get_act_fn(config.hidden_act)
         self.fc1 = ColumnParallelLinear(config.hidden_size,
                                         config.intermediate_size,
                                         bias=True,
-                                        quant_config=quant_config)
+                                        quant_config=quant_config,
+                                        prefix=f"{prefix}.fc1")
         self.fc2 = RowParallelLinear(config.intermediate_size,
                                      config.hidden_size,
                                      bias=True,
-                                     quant_config=quant_config)
+                                     quant_config=quant_config,
+                                     prefix=f"{prefix}.fc2")
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         hidden_states, _ = self.fc1(hidden_states)
@@ -284,6 +324,7 @@ def __init__(
         quant_config: Optional[QuantizationConfig] = None,
         *,
         num_dummy_heads: int = 0,
+        prefix: str = "",
     ) -> None:
         super().__init__()
 
@@ -293,9 +334,12 @@ def __init__(
 
         self.attn = self._init_attn(config,
                                     quant_config,
-                                    num_dummy_heads=num_dummy_heads)
+                                    num_dummy_heads=num_dummy_heads,
+                                    prefix=f"{prefix}.attn")
 
-        self.mlp = InternMLP(config, quant_config=quant_config)
+        self.mlp = InternMLP(config,
+                             quant_config=quant_config,
+                             prefix=f"{prefix}.mlp")
         self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
                                              eps=config.layer_norm_eps)
         self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
@@ -312,6 +356,7 @@ def _init_attn(
         quant_config: Optional[QuantizationConfig],
         *,
         num_dummy_heads: int,
+        prefix: str = "",
     ):
         # fallback to sdpa attention if tp unavailable
         tp_size = get_tensor_model_parallel_world_size()
@@ -320,7 +365,8 @@ def _init_attn(
         if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0:
             return InternParallelAttention(config,
                                            quant_config=quant_config,
-                                           num_dummy_heads=num_dummy_heads)
+                                           num_dummy_heads=num_dummy_heads,
+                                           prefix=prefix)
 
         return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
 
@@ -346,6 +392,7 @@ def __init__(
         *,
         num_hidden_layers_override: Optional[int] = None,
         num_dummy_heads: int = 0,
+        prefix: str = "",
     ):
         super().__init__()
 
@@ -359,8 +406,9 @@ def __init__(
         self.layers = nn.ModuleList([
             InternVisionEncoderLayer(config,
                                      quant_config,
-                                     num_dummy_heads=num_dummy_heads)
-            for _ in range(num_hidden_layers)
+                                     num_dummy_heads=num_dummy_heads,
+                                     prefix=f"{prefix}.layers.{layer_idx}")
+            for layer_idx in range(num_hidden_layers)
         ])
 
     def forward(self, inputs_embeds: torch.Tensor):
@@ -381,7 +429,8 @@ def __init__(
         *,
         num_hidden_layers_override: Optional[int] = None,
         num_dummy_heads: int = 0,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         self.config = config
@@ -392,6 +441,7 @@ def __init__(
             quant_config=quant_config,
             num_hidden_layers_override=num_hidden_layers_override,
             num_dummy_heads=num_dummy_heads,
+            prefix=f"{prefix}.encoder",
         )
 
     def get_input_embeddings(self):
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index f6cde44e9d83..9a77e48626ca 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -7,6 +7,7 @@
 from transformers import PretrainedConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size,
@@ -230,6 +231,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class InternLM2Model(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py
new file mode 100644
index 000000000000..6effd70b75da
--- /dev/null
+++ b/vllm/model_executor/models/internlm2_ve.py
@@ -0,0 +1,166 @@
+# -*- coding: utf-8 -*-
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from vllm.attention import AttentionMetadata
+from vllm.config import CacheConfig
+from vllm.distributed import get_pp_group
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.models.internlm2 import (InternLM2Attention,
+                                                  InternLM2ForCausalLM,
+                                                  InternLM2MLP, InternLM2Model)
+from vllm.sequence import IntermediateTensors
+
+from .utils import make_layers
+
+
+class InternLM2VEDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          8192)
+        self.attention = InternLM2Attention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
+            quant_config=quant_config,
+        )
+        self.feed_forward = InternLM2MLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            quant_config=quant_config,
+        )
+        self.feed_forward_ve = InternLM2MLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            quant_config=quant_config,
+        )
+        self.attention_norm = RMSNorm(config.hidden_size,
+                                      eps=config.rms_norm_eps)
+        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+        visual_token_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.attention_norm(hidden_states)
+        else:
+            hidden_states, residual = self.attention_norm(
+                hidden_states, residual)
+        hidden_states = self.attention(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.ffn_norm(hidden_states, residual)
+        if visual_token_mask is not None and visual_token_mask.any():
+            visual_token_mask = visual_token_mask.repeat(
+                1, self.hidden_size).bool()
+            text_token_mask = ~visual_token_mask
+            hidden_states[visual_token_mask] = self.feed_forward_ve(
+                hidden_states[visual_token_mask].reshape(
+                    -1, self.hidden_size)).flatten()
+            if text_token_mask.any():
+                hidden_states[text_token_mask] = self.feed_forward(
+                    hidden_states[text_token_mask].reshape(
+                        -1, self.hidden_size)).flatten()
+        else:
+            hidden_states = self.feed_forward(hidden_states)
+        return hidden_states, residual
+
+
+class InternLM2VEModel(InternLM2Model):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__(config, cache_config, quant_config)
+        self.start_layer, self.end_layer, self.layers = make_layers(
+            config.num_hidden_layers,
+            lambda prefix: InternLM2VEDecoderLayer(config, cache_config,
+                                                   quant_config),
+            prefix=f"{prefix}.layers")
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+        visual_token_mask: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        if get_pp_group().is_first_rank:
+            if inputs_embeds is not None:
+                hidden_states = inputs_embeds
+            else:
+                hidden_states = self.tok_embeddings(input_ids)
+            residual = None
+        else:
+            assert intermediate_tensors is not None
+            hidden_states = intermediate_tensors["hidden_states"]
+            residual = intermediate_tensors["residual"]
+        for i in range(self.start_layer, self.end_layer):
+            layer = self.layers[i]
+            hidden_states, residual = layer(
+                positions,
+                hidden_states,
+                kv_caches[i - self.start_layer],
+                attn_metadata,
+                residual,
+                visual_token_mask=visual_token_mask,
+            )
+        if not get_pp_group().is_last_rank:
+            return IntermediateTensors({
+                "hidden_states": hidden_states,
+                "residual": residual
+            })
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class InternLM2VEForCausalLM(InternLM2ForCausalLM):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__(config, cache_config, quant_config)
+        self.model = InternLM2VEModel(config, cache_config, quant_config)
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index 9024831df543..3ae37d9fe5d8 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -17,10 +17,13 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
-from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
+from vllm.model_executor.layers.quantization import (AWQConfig,
+                                                     QuantizationConfig)
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
-from vllm.model_executor.models.intern_vit import InternVisionModel
+from vllm.model_executor.models.intern_vit import (InternVisionModel,
+                                                   InternVisionPatchModel)
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.multimodal import MULTIMODAL_REGISTRY
 from vllm.multimodal.base import MultiModalInputs
@@ -276,13 +279,13 @@ def _expand_image_prompt(
     def input_processor(
         self,
         ctx: InputContext,
-        llm_inputs: LLMInputs,
+        inputs: DecoderOnlyInputs,
         *,
         max_dynamic_patch: Optional[int] = None,
-    ) -> LLMInputs:
-        multi_modal_data = llm_inputs.get("multi_modal_data")
+    ) -> DecoderOnlyInputs:
+        multi_modal_data = inputs.get("multi_modal_data")
         if multi_modal_data is None or "image" not in multi_modal_data:
-            return llm_inputs
+            return inputs
 
         model_config = ctx.model_config
         hf_config = ctx.get_hf_config()
@@ -311,8 +314,8 @@ def input_processor(
             model_config.tokenizer,
             trust_remote_code=model_config.trust_remote_code)
 
-        prompt = llm_inputs.get("prompt")
-        prompt_token_ids = llm_inputs["prompt_token_ids"]
+        prompt = inputs.get("prompt")
+        prompt_token_ids = inputs["prompt_token_ids"]
         if prompt is None:
             prompt = tokenizer.decode(prompt_token_ids)
 
@@ -320,9 +323,9 @@ def input_processor(
                                                num_patches)
         new_prompt_token_ids = tokenizer.encode(new_prompt)
 
-        return LLMInputs(prompt=prompt,
-                         prompt_token_ids=new_prompt_token_ids,
-                         multi_modal_data=multi_modal_data)
+        return token_inputs(prompt=prompt,
+                            prompt_token_ids=new_prompt_token_ids,
+                            multi_modal_data=multi_modal_data)
 
     def input_mapper(
         self,
@@ -342,6 +345,8 @@ def input_mapper(
         elif is_list_of(data, Image.Image):
             # we can't stack here because images may have different num_patches
             data = [image_pixel_values_mapper(img) for img in data]
+        else:
+            return MultiModalInputs({"image_embeds": data})
         model_config = ctx.model_config
         tokenizer = cached_get_tokenizer(
             model_config.tokenizer,
@@ -414,23 +419,24 @@ def __init__(self,
 
         self.config = config
         self.multimodal_config = multimodal_config
+        self._patch_quant_config(config, quant_config)
 
         image_size = config.force_image_size or config.vision_config.image_size
         patch_size = config.vision_config.patch_size
         self.patch_size = patch_size
-        self.select_layer = config.select_layer
         self.num_image_token = int(
             (image_size // patch_size)**2 * (config.downsample_ratio**2))
         self.downsample_ratio = config.downsample_ratio
         self.ps_version = config.ps_version
 
-        vision_feature_layer = self.select_layer
-        if vision_feature_layer < 0:
-            num_hidden_layers = config.vision_config.num_hidden_layers \
-                + vision_feature_layer + 1
-        else:
-            num_hidden_layers = vision_feature_layer + 1
-        self.vision_model = self._init_vision_model(config, num_hidden_layers)
+        self.llm_arch_name = config.text_config.architectures[0]
+        self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
+        self.vision_model = self._init_vision_model(
+            config,
+            quant_config=quant_config,
+            is_mono=self.is_mono,
+            prefix="vision_model",
+        )
 
         self.language_model = init_vllm_registered_model(
             config.text_config, cache_config, quant_config)
@@ -441,6 +447,18 @@ def __init__(self,
         self.make_empty_intermediate_tensors = (
             self.language_model.make_empty_intermediate_tensors)
 
+    def _patch_quant_config(self, config: PretrainedConfig,
+                            quant_config: QuantizationConfig):
+        # the awq models from OpenGVLab missing `modules_to_not_convert`
+        # patch the quant_config to add `modules_to_not_convert` back
+        if isinstance(quant_config, AWQConfig):
+            text_config = config.text_config
+            llm_quant_config = getattr(text_config, "quantization_config",
+                                       None)
+            if (not quant_config.modules_to_not_convert) and \
+                (llm_quant_config is not None):
+                quant_config.modules_to_not_convert.append("vision_model")
+
     @cached_property
     def sampler(self):
         if hasattr(self.language_model, "sampler"):
@@ -448,10 +466,30 @@ def sampler(self):
 
         return Sampler()
 
-    def _init_vision_model(self, config: PretrainedConfig,
-                           num_hidden_layers: int):
-        return InternVisionModel(config.vision_config,
-                                 num_hidden_layers_override=num_hidden_layers)
+    def _init_vision_model(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig],
+        *,
+        is_mono: bool,
+        prefix: str,
+    ):
+        if not is_mono:
+            vision_feature_layer = config.select_layer
+            if vision_feature_layer < 0:
+                num_hidden_layers = config.vision_config.num_hidden_layers \
+                    + vision_feature_layer + 1
+            else:
+                num_hidden_layers = vision_feature_layer + 1
+
+            return InternVisionModel(
+                config.vision_config,
+                quant_config=quant_config,
+                num_hidden_layers_override=num_hidden_layers,
+                prefix=prefix,
+            )
+        else:
+            return InternVisionPatchModel(config.vision_config)
 
     def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
         vit_hidden_size = config.vision_config.hidden_size
@@ -559,6 +597,14 @@ def _process_image_input(
 
         return image_embeds
 
+    def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
+        if self.is_mono:
+            visual_token_mask = (
+                input_ids == self.img_context_token_id).reshape(-1, 1)
+        else:
+            visual_token_mask = None
+        return visual_token_mask
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -571,6 +617,7 @@ def forward(
         if intermediate_tensors is not None:
             input_ids = None
             inputs_embeds = None
+            visual_token_mask = None
         else:
             image_input = self._parse_and_validate_image_input(**kwargs)
             if image_input is not None:
@@ -580,16 +627,24 @@ def forward(
                 inputs_embeds = merge_multimodal_embeddings(
                     input_ids, inputs_embeds, vision_embeddings,
                     self.img_context_token_id)
+                visual_token_mask = self._get_visual_token_mask(input_ids)
                 input_ids = None
             else:
                 inputs_embeds = None
-
-        hidden_states = self.language_model.model(input_ids,
-                                                  positions,
-                                                  kv_caches,
-                                                  attn_metadata,
-                                                  intermediate_tensors,
-                                                  inputs_embeds=inputs_embeds)
+                visual_token_mask = None
+
+        forward_kwargs = {
+            "input_ids": input_ids,
+            "positions": positions,
+            "kv_caches": kv_caches,
+            "attn_metadata": attn_metadata,
+            "intermediate_tensors": intermediate_tensors,
+            "inputs_embeds": inputs_embeds,
+        }
+        if self.is_mono:
+            forward_kwargs.update({"visual_token_mask": visual_token_mask})
+
+        hidden_states = self.language_model.model(**forward_kwargs)
         return hidden_states
 
     def compute_logits(
diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py
index c5e5393442e3..b947f24a693b 100644
--- a/vllm/model_executor/models/jais.py
+++ b/vllm/model_executor/models/jais.py
@@ -1,6 +1,6 @@
 # coding=utf-8
 # Adapted from
-# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
+# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
 # Copyright 2023 The vLLM team.
 # Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
 # reserved.
@@ -26,6 +26,7 @@
 from torch import nn
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size)
@@ -212,6 +213,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class JAISModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py
index ac251b88e872..fddd39fb8c85 100644
--- a/vllm/model_executor/models/jamba.py
+++ b/vllm/model_executor/models/jamba.py
@@ -1,6 +1,5 @@
 # coding=utf-8
 """Inference-only Jamba model."""
-from dataclasses import dataclass
 from typing import Iterable, List, Optional, Tuple
 
 import torch
@@ -29,7 +28,8 @@
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import (
     composed_weight_loader, default_weight_loader, sharded_weight_loader)
-from vllm.model_executor.models.mamba_cache import MambaCacheManager
+from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
+                                                    MambaCacheParams)
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.model_executor.utils import set_weight_attrs
 from vllm.sequence import IntermediateTensors
@@ -41,13 +41,6 @@
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 
-@dataclass
-class MambaCacheParams:
-    is_prompt: bool = False
-    conv_state: torch.Tensor = torch.Tensor()
-    ssm_state: torch.Tensor = torch.Tensor()
-
-
 # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
 class JambaMambaMixer(nn.Module):
     """
@@ -60,10 +53,9 @@ class JambaMambaMixer(nn.Module):
     **selective** state spaces)
     """
 
-    def __init__(self, config: JambaConfig, layer_idx):
+    def __init__(self, config: JambaConfig):
         super().__init__()
         self.config = config
-        self.layer_idx = layer_idx
         self.hidden_size = config.hidden_size
         self.ssm_state_size = config.mamba_d_state
         self.conv_kernel_size = config.mamba_d_conv
@@ -129,8 +121,8 @@ def __init__(self, config: JambaConfig, layer_idx):
                                    eps=config.rms_norm_eps)
 
     def forward(self, hidden_states: torch.Tensor,
-                attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
-                ssm_state: torch.Tensor):
+                attn_metadata: AttentionMetadata,
+                mamba_cache_params: MambaCacheParams):
 
         # 1. Gated MLP's linear projection
         projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -153,17 +145,18 @@ def forward(self, hidden_states: torch.Tensor,
                 conv_weights,
                 self.conv1d.bias,
                 activation=self.activation,
-                conv_states=conv_state,
+                conv_states=mamba_cache_params.conv_state,
                 has_initial_state=attn_metadata.context_lens_tensor > 0,
+                cache_indices=mamba_cache_params.state_indices_tensor,
                 query_start_loc=attn_metadata.query_start_loc)
         else:
             hidden_states = causal_conv1d_update(
                 hidden_states.transpose(0, 1),
-                conv_state,
+                mamba_cache_params.conv_state,
                 conv_weights,
                 self.conv1d.bias,
                 self.activation,
-            )
+                conv_state_indices=mamba_cache_params.state_indices_tensor)
             hidden_states = hidden_states.transpose(0, 1)
 
         # 3. State Space Model sequence transformation
@@ -188,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor,
             and attn_metadata.context_lens_tensor is not None:
             scan_outputs = selective_scan_fn(
                 hidden_states,
-                ssm_state,
+                mamba_cache_params.ssm_state,
                 discrete_time_step,
                 self.A,
                 B.transpose(-2, -1),
@@ -197,11 +190,12 @@ def forward(self, hidden_states: torch.Tensor,
                 gate,
                 time_proj_bias,
                 delta_softplus=True,
+                cache_indices=mamba_cache_params.state_indices_tensor,
                 has_initial_state=attn_metadata.context_lens_tensor > 0,
                 query_start_loc=attn_metadata.query_start_loc)
         else:
             scan_outputs = selective_state_update(
-                ssm_state,
+                mamba_cache_params.ssm_state,
                 hidden_states.transpose(0, 1),
                 discrete_time_step.transpose(0, 1),
                 self.A,
@@ -211,7 +205,7 @@ def forward(self, hidden_states: torch.Tensor,
                 gate.transpose(0, 1),
                 time_proj_bias,
                 dt_softplus=True,
-            )
+                state_batch_indices=mamba_cache_params.state_indices_tensor)
             scan_outputs = scan_outputs.transpose(0, 1)
 
         # 4. Final linear projection
@@ -292,7 +286,7 @@ def __init__(self,
         super().__init__()
         self.layer_idx = layer_idx
         self.config = config
-        self.mamba = JambaMambaMixer(config, layer_idx)
+        self.mamba = JambaMambaMixer(config)
 
         num_experts = config.layers_num_experts[layer_idx]
         ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
@@ -307,8 +301,7 @@ def forward(
         hidden_states: torch.Tensor,
         attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
-        conv_state: torch.Tensor,
-        ssm_state: torch.Tensor,
+        mamba_cache_params: MambaCacheParams,
         **kwargs,
     ):
         if residual is None:
@@ -318,8 +311,8 @@ def forward(
             hidden_states, residual = self.input_layernorm(
                 hidden_states, residual)
 
-        hidden_states = self.mamba(hidden_states, attn_metadata, conv_state,
-                                   ssm_state)
+        hidden_states = self.mamba(hidden_states, attn_metadata,
+                                   mamba_cache_params)
         # Fully Connected
         hidden_states, residual = self.pre_ff_layernorm(
             hidden_states, residual)
@@ -476,17 +469,14 @@ def forward(
         positions: torch.Tensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
-        conv_state: torch.Tensor,
-        ssm_state: torch.Tensor,
+        mamba_cache_params: MambaCacheParams,
     ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
         residual = None
-
         for i in range(len(self.layers)):
             layer = self.layers[i]
             kv_cache = None
-            current_ssm_state = None
-            current_conv_state = None
+            layer_mamba_cache_params = None
             if isinstance(layer, JambaAttentionDecoderLayer):
                 kv_cache = kv_caches[(i - self.config.attn_layer_offset) //
                                      self.config.attn_layer_period]
@@ -494,8 +484,8 @@ def forward(
                 current_state_layer = i - (1 +
                                            (i - self.config.attn_layer_offset)
                                            // self.config.attn_layer_period)
-                current_ssm_state = ssm_state[current_state_layer]
-                current_conv_state = conv_state[current_state_layer]
+                layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
+                    current_state_layer)
 
             hidden_states, residual = layer(
                 positions=positions,
@@ -503,9 +493,7 @@ def forward(
                 kv_cache=kv_cache,
                 attn_metadata=attn_metadata,
                 residual=residual,
-                conv_state=current_conv_state,
-                ssm_state=current_ssm_state,
-            )
+                mamba_cache_params=layer_mamba_cache_params)
         hidden_states, _ = self.final_layernorm(hidden_states, residual)
         return hidden_states
 
@@ -588,13 +576,16 @@ def forward(self,
             self.mamba_cache = MambaCacheManager(
                 self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
                 *self._get_mamba_cache_shape())
-
-        mamba_cache_tensors = self.mamba_cache.current_run_tensors(
-            input_ids, attn_metadata, **kwargs)
-
+        (
+            mamba_cache_tensors,
+            state_indices_tensor,
+        ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
+                                                 **kwargs)
+        mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
+                                              mamba_cache_tensors[1],
+                                              state_indices_tensor)
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, mamba_cache_tensors[0],
-                                   mamba_cache_tensors[1])
+                                   attn_metadata, mamba_cache_params)
         return hidden_states
 
     def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index ad5cfcc44022..c346e3e808e3 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -38,6 +38,7 @@
                                                QKVParallelLinear,
                                                RowParallelLinear)
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.pooler import Pooler, PoolingType
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
     get_compressed_tensors_cache_scale)
@@ -47,8 +48,9 @@
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import (
     default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.pooling_metadata import PoolingMetadata
 from vllm.model_executor.sampling_metadata import SamplingMetadata
-from vllm.sequence import IntermediateTensors
+from vllm.sequence import IntermediateTensors, PoolerOutput
 from vllm.utils import is_hip
 
 from .interfaces import SupportsLoRA, SupportsPP
@@ -266,13 +268,7 @@ def forward(
         return hidden_states, residual
 
 
-@support_torch_compile(
-    dynamic_arg_dims={
-        "input_ids": 0,
-        "positions": 0,
-        "inputs_embeds": 0,
-        "intermediate_tensors": 0,
-    })
+@support_torch_compile
 class LlamaModel(nn.Module):
 
     def __init__(
@@ -615,3 +611,52 @@ def permute(w: torch.Tensor, n_heads: int):
                 name = name.replace(item, mapping[item])
 
         return name, loaded_weight
+
+
+class LlamaEmbeddingModel(nn.Module, SupportsPP):
+    """
+    A model that uses Llama with additional embedding functionalities.
+
+    This class encapsulates the LlamaModel and provides an interface for
+    embedding operations and customized pooling functions.
+
+    Attributes:
+        model: An instance of LlamaModel used for forward operations.
+        _pooler: An instance of Pooler used for pooling operations.
+    """
+
+    def __init__(
+        self,
+        **kwargs,
+    ) -> None:
+        super().__init__()
+
+        self.model = LlamaModel(**kwargs)
+        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+        self.make_empty_intermediate_tensors = (
+            self.model.make_empty_intermediate_tensors)
+
+    def forward(
+        self,
+        input_ids: Optional[torch.Tensor],
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        return self.model(input_ids, positions, kv_caches, attn_metadata,
+                          intermediate_tensors, inputs_embeds)
+
+    def pooler(
+        self,
+        hidden_states: torch.Tensor,
+        pooling_metadata: PoolingMetadata,
+    ) -> Optional[PoolerOutput]:
+        return self._pooler(hidden_states, pooling_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        self.model.load_weights(weights)
+
+    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
+        self.model.load_kv_cache_scales(quantization_param_path)
diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py
deleted file mode 100644
index 13574e84d7aa..000000000000
--- a/vllm/model_executor/models/llama_embedding.py
+++ /dev/null
@@ -1,59 +0,0 @@
-from typing import Iterable, List, Optional, Tuple, Union
-
-import torch
-from torch import nn
-
-from vllm.attention import AttentionMetadata
-from vllm.model_executor.layers.pooler import Pooler, PoolingType
-from vllm.model_executor.pooling_metadata import PoolingMetadata
-from vllm.sequence import IntermediateTensors, PoolerOutput
-
-from .interfaces import SupportsPP
-from .llama import LlamaModel
-
-
-class LlamaEmbeddingModel(nn.Module, SupportsPP):
-    """A model that uses Llama with additional embedding functionalities.
-
-   This class encapsulates the LlamaModel and provides an interface for
-   embedding operations and customized pooling functions.
-
-   Attributes:
-       model: An instance of LlamaModel used for forward operations.
-       _pooler: An instance of Pooler used for pooling operations.
-   """
-
-    def __init__(
-        self,
-        **kwargs,
-    ) -> None:
-        super().__init__()
-        self.model = LlamaModel(**kwargs)
-        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
-        self.make_empty_intermediate_tensors = (
-            self.model.make_empty_intermediate_tensors)
-
-    def forward(
-        self,
-        input_ids: Optional[torch.Tensor],
-        positions: torch.Tensor,
-        kv_caches: List[torch.Tensor],
-        attn_metadata: AttentionMetadata,
-        intermediate_tensors: Optional[IntermediateTensors] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-    ) -> Union[torch.Tensor, IntermediateTensors]:
-        return self.model(input_ids, positions, kv_caches, attn_metadata,
-                          intermediate_tensors, inputs_embeds)
-
-    def pooler(
-        self,
-        hidden_states: torch.Tensor,
-        pooling_metadata: PoolingMetadata,
-    ) -> Optional[PoolerOutput]:
-        return self._pooler(hidden_states, pooling_metadata)
-
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        self.model.load_weights(weights)
-
-    def load_kv_cache_scales(self, quantization_param_path: str) -> None:
-        self.model.load_kv_cache_scales(quantization_param_path)
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index 864b9ff66a84..b005d83c17f9 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -1,15 +1,16 @@
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
-                    TypedDict, Union)
+from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
+                    Tuple, TypedDict, Union)
 
 import torch
 import torch.nn as nn
 from PIL import Image
-from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig
+from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
+                          PretrainedConfig, SiglipVisionConfig)
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
 from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@@ -22,6 +23,10 @@
                    dummy_seq_data_for_clip, get_max_clip_image_tokens,
                    input_processor_for_clip)
 from .interfaces import SupportsMultiModal, SupportsPP
+from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
+                      dummy_seq_data_for_pixtral_hf,
+                      get_max_pixtral_hf_image_tokens,
+                      input_processor_for_pixtral_hf)
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                      input_processor_for_siglip)
@@ -31,8 +36,13 @@
 
 class LlavaImagePixelInputs(TypedDict):
     type: Literal["pixel_values"]
-    data: torch.Tensor
-    """Shape: `(batch_size * num_images, num_channels, height, width)`"""
+    data: Union[torch.Tensor, List[torch.Tensor]]
+    """
+    Shape: `(batch_size * num_images, num_channels, height, width)`
+
+    Note that `height` or `width` may be different per batch and image,
+    in which case the data is passed as a list instead of a batched tensor.
+    """
 
 
 class LlavaImageEmbeddingInputs(TypedDict):
@@ -77,6 +87,8 @@ def get_max_llava_image_tokens(ctx: InputContext):
         num_image_tokens = get_max_clip_image_tokens(vision_config)
     elif isinstance(vision_config, SiglipVisionConfig):
         num_image_tokens = get_max_siglip_image_tokens(vision_config)
+    elif isinstance(vision_config, PixtralVisionConfig):
+        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
     else:
         msg = f"Unsupported vision config: {type(vision_config)}"
         raise NotImplementedError(msg)
@@ -120,15 +132,26 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int,
 
         mm_data = dummy_image_for_siglip(vision_config, num_images)
         return seq_data, mm_data
+    elif isinstance(vision_config, PixtralVisionConfig):
+        seq_data = dummy_seq_data_for_pixtral_hf(
+            vision_config,
+            seq_len,
+            num_images,
+            image_token_id=hf_config.image_token_index,
+            image_feature_size_override=image_feature_size,
+        )
+
+        mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
+        return seq_data, mm_data
 
     msg = f"Unsupported vision config: {type(vision_config)}"
     raise NotImplementedError(msg)
 
 
-def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     model_config = ctx.model_config
     hf_config = ctx.get_hf_config(LlavaConfig)
@@ -151,7 +174,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
         return input_processor_for_clip(
             model_config,
             vision_config,
-            llm_inputs,
+            inputs,
             image_token_id=hf_config.image_token_index,
             image_feature_size_override=image_feature_size,
         )
@@ -159,16 +182,35 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
         return input_processor_for_siglip(
             model_config,
             vision_config,
-            llm_inputs,
+            inputs,
             image_token_id=hf_config.image_token_index,
             image_feature_size_override=image_feature_size,
         )
+    elif isinstance(vision_config, PixtralVisionConfig):
+        # We ignore image_feature_size_override since we have non-uniform
+        # image sizes for Pixtral
+        return input_processor_for_pixtral_hf(
+            model_config,
+            vision_config,
+            inputs,
+            image_token_id=hf_config.image_token_index,
+        )
 
     msg = f"Unsupported vision config: {type(vision_config)}"
     raise NotImplementedError(msg)
 
 
-def _init_vision_tower(hf_config: LlavaConfig):
+class LlavaLikeConfig(Protocol):
+    vision_config: PretrainedConfig
+    vision_feature_layer: int
+
+
+def init_vision_tower_for_llava(
+    hf_config: LlavaLikeConfig,
+    quant_config: Optional[QuantizationConfig],
+    *,
+    require_post_norm: Optional[bool] = None,
+):
     vision_config = hf_config.vision_config
 
     # Initialize the vision tower only up to the required feature layer
@@ -182,12 +224,23 @@ def _init_vision_tower(hf_config: LlavaConfig):
     if isinstance(vision_config, CLIPVisionConfig):
         return CLIPVisionModel(
             vision_config,
+            quant_config,
             num_hidden_layers_override=num_hidden_layers,
+            require_post_norm=require_post_norm,
         )
     elif isinstance(vision_config, SiglipVisionConfig):
         return SiglipVisionModel(
             vision_config,
+            quant_config,
             num_hidden_layers_override=num_hidden_layers,
+            require_post_norm=require_post_norm,
+        )
+    elif isinstance(vision_config, PixtralVisionConfig):
+        return PixtralHFVisionModel(
+            vision_config,
+            quant_config,
+            num_hidden_layers_override=num_hidden_layers,
+            require_post_norm=require_post_norm,
         )
 
     msg = f"Unsupported vision config: {type(vision_config)}"
@@ -210,8 +263,18 @@ def __init__(self,
         self.config = config
         self.multimodal_config = multimodal_config
 
+        # NOTE: These are special cases for Pixtral-12B in the HF-format
+        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
+        if (config.text_config.architectures is None
+                and config.text_config.model_type == "mistral"):
+            config.text_config.architectures = ["MistralForCausalLM"]
+        if (config.projector_hidden_act is None
+                and config.vision_config.hidden_act == "gelu"):
+            config.projector_hidden_act = "gelu"
+
         # TODO: Optionally initializes this for supporting embeddings.
-        self.vision_tower = _init_vision_tower(config)
+        self.vision_tower = init_vision_tower_for_llava(
+            config, quant_config, require_post_norm=False)
         self.multi_modal_projector = LlavaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
             text_hidden_size=config.text_config.hidden_size,
@@ -243,9 +306,38 @@ def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
 
         return data
 
+    def _validate_image_sizes(self, images: List[torch.Tensor],
+                              sizes: List[torch.Tensor]) -> List[torch.Tensor]:
+        if not isinstance(sizes, list):
+            sizes = [sizes]
+
+        total_images = sum(size.numel() // 2 for size in sizes)
+        if total_images != len(images):
+            raise ValueError("Mismatch in number of images. "
+                             f"Expected {total_images}, got {len(images)}")
+        img_idx = 0
+        for size in sizes:
+            # Flatten the size tensor to a list of (height, width) pairs
+            size = size.view(-1, 2).tolist()
+            for expected_h, expected_w in size:
+                if img_idx >= len(images):
+                    raise ValueError("Ran out of images before sizes. "
+                                     f"{img_idx} >= {len(images)}")
+                img = images[img_idx]
+                if img.shape[-2:] != (expected_h, expected_w):
+                    raise ValueError(
+                        "Image size mismatch. Expected "
+                        f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
+                if img.shape[-3] != 3:
+                    raise ValueError("Image channel mismatch. Expected 3, "
+                                     f"got {img.shape[-3]}")
+                img_idx += 1
+        return images
+
     def _parse_and_validate_image_input(
             self, **kwargs: object) -> Optional[LlavaImageInputs]:
         pixel_values = kwargs.pop("pixel_values", None)
+        image_sizes = kwargs.pop("image_sizes", None)
         image_embeds = kwargs.pop("image_embeds", None)
 
         if pixel_values is None and image_embeds is None:
@@ -256,6 +348,34 @@ def _parse_and_validate_image_input(
                 raise ValueError("Incorrect type of pixel values. "
                                  f"Got type: {type(pixel_values)}")
 
+            # Case for models like PixtralHF that have dynamic image sizes
+            # so we need to produce a list of tensors
+            if image_sizes is not None:
+                images = pixel_values
+
+                def flatten_to_3d_tensors(item):
+                    if isinstance(item, torch.Tensor):
+                        if item.dim() >= 3:
+                            return [t for t in item.view(-1, *item.shape[-3:])]
+                        else:
+                            raise ValueError(
+                                f"Unexpected tensor dimension: {item.dim()}")
+                    elif isinstance(item, list):
+                        return [
+                            t for subitem in item
+                            for t in flatten_to_3d_tensors(subitem)
+                        ]
+                    else:
+                        raise ValueError(f"Unexpected type: {type(item)}")
+
+                # Restructure the batched images into a list of lists of images
+                images = flatten_to_3d_tensors(pixel_values)
+
+                return LlavaImagePixelInputs(
+                    type="pixel_values",
+                    data=self._validate_image_sizes(images, image_sizes),
+                )
+
             return LlavaImagePixelInputs(
                 type="pixel_values",
                 data=self._validate_pixel_values(
@@ -286,7 +406,8 @@ def _select_image_features(self, image_features: torch.Tensor, *,
 
     def _image_pixels_to_features(
         self,
-        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
+        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
+                            PixtralHFVisionModel],
         pixel_values: torch.Tensor,
     ) -> torch.Tensor:
 
diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py
index 766f6a4cc83f..9466e72ecc63 100644
--- a/vllm/model_executor/models/llava_next.py
+++ b/vllm/model_executor/models/llava_next.py
@@ -12,24 +12,26 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext
+from vllm.model_executor.layers.pooler import Pooler, PoolingType
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.pooling_metadata import PoolingMetadata
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.sequence import IntermediateTensors
+from vllm.sequence import IntermediateTensors, PoolerOutput
 from vllm.utils import is_list_of
 
 from .clip import (CLIPVisionModel, dummy_image_for_clip,
                    dummy_seq_data_for_clip, get_clip_image_feature_size,
                    get_clip_patch_grid_length, input_processor_for_clip)
 from .interfaces import SupportsMultiModal, SupportsPP
-from .llava import LlavaMultiModalProjector
+from .llava import LlavaMultiModalProjector, init_vision_tower_for_llava
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                      get_siglip_patch_grid_length, input_processor_for_siglip)
-from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
-                    merge_multimodal_embeddings)
+from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
+                    init_vllm_registered_model)
 
 # Result in the max possible feature size (2x2 grid of 336x336px tiles)
 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
@@ -201,10 +203,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int,
     raise NotImplementedError(msg)
 
 
-def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+def input_processor_for_llava_next(ctx: InputContext,
+                                   inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     model_config = ctx.model_config
     hf_config = ctx.get_hf_config(LlavaNextConfig)
@@ -239,7 +242,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
         return input_processor_for_clip(
             model_config,
             vision_config,
-            llm_inputs,
+            inputs,
             image_token_id=hf_config.image_token_index,
             image_feature_size_override=image_feature_size,
         )
@@ -247,7 +250,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
         return input_processor_for_siglip(
             model_config,
             vision_config,
-            llm_inputs,
+            inputs,
             image_token_id=hf_config.image_token_index,
             image_feature_size_override=image_feature_size,
         )
@@ -256,32 +259,6 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
     raise NotImplementedError(msg)
 
 
-def _init_vision_tower(hf_config: LlavaNextConfig):
-    vision_config = hf_config.vision_config
-
-    # Initialize the vision tower only up to the required feature layer
-    vision_feature_layer = hf_config.vision_feature_layer
-    if vision_feature_layer < 0:
-        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
-            + vision_feature_layer + 1
-    else:
-        num_hidden_layers = vision_feature_layer + 1
-
-    if isinstance(vision_config, CLIPVisionConfig):
-        return CLIPVisionModel(
-            vision_config,
-            num_hidden_layers_override=num_hidden_layers,
-        )
-    elif isinstance(vision_config, SiglipVisionConfig):
-        return SiglipVisionModel(
-            vision_config,
-            num_hidden_layers_override=num_hidden_layers,
-        )
-
-    msg = f"Unsupported vision config: {type(vision_config)}"
-    raise NotImplementedError(msg)
-
-
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@@ -300,7 +277,8 @@ def __init__(self,
         self.multimodal_config = multimodal_config
 
         # TODO: Optionally initializes this for supporting embeddings.
-        self.vision_tower = _init_vision_tower(config)
+        self.vision_tower = init_vision_tower_for_llava(
+            config, quant_config, require_post_norm=False)
         self.image_newline = nn.Parameter(
             torch.empty(config.text_config.hidden_size))
         self.multi_modal_projector = LlavaMultiModalProjector(
@@ -311,6 +289,10 @@ def __init__(self,
         self.language_model = init_vllm_registered_model(
             config.text_config, cache_config, quant_config)
 
+        # The same model class supports both language generation and embedding
+        # because the architecture name is the same
+        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+
         self.make_empty_intermediate_tensors = (
             self.language_model.make_empty_intermediate_tensors)
 
@@ -604,14 +586,12 @@ def forward(
             image_input = self._parse_and_validate_image_input(**kwargs)
 
             if image_input is not None:
-                vision_embeddings = self._process_image_input(image_input)
-                inputs_embeds = self.language_model.model.get_input_embeddings(
-                    input_ids)
-
-                inputs_embeds = merge_multimodal_embeddings(
-                    input_ids, inputs_embeds, vision_embeddings,
-                    self.config.image_token_index)
-
+                inputs_embeds = embed_multimodal(
+                    input_ids,
+                    self.config.image_token_index,
+                    self.language_model.model.get_input_embeddings,
+                    lambda _: self._process_image_input(image_input),
+                )
                 input_ids = None
             else:
                 inputs_embeds = None
@@ -640,6 +620,13 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
+    def pooler(
+        self,
+        hidden_states: torch.Tensor,
+        pooling_metadata: PoolingMetadata,
+    ) -> Optional[PoolerOutput]:
+        return self._pooler(hidden_states, pooling_metadata)
+
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         loader = AutoWeightsLoader(self)
         loader.load_weights(weights)
diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py
index e10c1f9e6e04..43eec43d5664 100644
--- a/vllm/model_executor/models/llava_next_video.py
+++ b/vllm/model_executor/models/llava_next_video.py
@@ -11,7 +11,8 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@@ -25,6 +26,7 @@
 
 from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
 from .interfaces import SupportsMultiModal, SupportsPP
+from .llava import init_vision_tower_for_llava
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip)
 from .utils import (AutoWeightsLoader, init_vllm_registered_model,
@@ -139,10 +141,10 @@ def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int,
 
 
 def input_processor_for_llava_next_video(ctx: InputContext,
-                                         llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+                                         inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "video" not in multi_modal_data:
-        return llm_inputs
+        return inputs
     video_data = multi_modal_data["video"]
 
     model_config = ctx.model_config
@@ -160,15 +162,15 @@ def input_processor_for_llava_next_video(ctx: InputContext,
 
         new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
             tokenizer,
-            llm_inputs.get("prompt"),
-            llm_inputs["prompt_token_ids"],
+            inputs.get("prompt"),
+            inputs["prompt_token_ids"],
             placeholder_token_id=hf_config.video_token_index,
             repeat_count=video_feature_size,
         )
 
-        return LLMInputs(prompt_token_ids=new_token_ids,
-                         prompt=new_prompt,
-                         multi_modal_data=multi_modal_data)
+        return token_inputs(prompt_token_ids=new_token_ids,
+                            prompt=new_prompt,
+                            multi_modal_data=multi_modal_data)
 
     elif is_list_of(video_data, np.ndarray):
         raise NotImplementedError(
@@ -178,32 +180,6 @@ def input_processor_for_llava_next_video(ctx: InputContext,
     raise NotImplementedError(msg)
 
 
-def _init_vision_tower(hf_config: LlavaNextVideoConfig):
-    vision_config = hf_config.vision_config
-
-    # Initialize the vision tower only up to the required feature layer
-    vision_feature_layer = hf_config.vision_feature_layer
-    if vision_feature_layer < 0:
-        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
-            + vision_feature_layer + 1
-    else:
-        num_hidden_layers = vision_feature_layer + 1
-
-    if isinstance(vision_config, CLIPVisionConfig):
-        return CLIPVisionModel(
-            vision_config,
-            num_hidden_layers_override=num_hidden_layers,
-        )
-    elif isinstance(vision_config, SiglipVisionConfig):
-        return SiglipVisionModel(
-            vision_config,
-            num_hidden_layers_override=num_hidden_layers,
-        )
-
-    msg = f"Unsupported vision config: {type(vision_config)}"
-    raise NotImplementedError(msg)
-
-
 # adopted from transformers modeling_llava_next_video.py
 class LlavaNextVideoPooler(nn.Module):
 
@@ -280,7 +256,8 @@ def __init__(self,
         self.multimodal_config = multimodal_config
 
         # Initialize the vision tower only up to the required feature layer
-        self.vision_tower = _init_vision_tower(config)
+        self.vision_tower = init_vision_tower_for_llava(
+            config, quant_config, require_post_norm=False)
         self.vision_resampler = LlavaNextVideoPooler(config)
         self.multi_modal_projector = LlavaNextMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index 46e97e78d482..47e62409072e 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -15,8 +15,8 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
-from vllm.logger import init_logger
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@@ -31,14 +31,13 @@
                    dummy_video_for_clip, get_clip_image_feature_size,
                    get_clip_patch_grid_length, input_processor_for_clip)
 from .interfaces import SupportsMultiModal, SupportsPP
+from .llava import init_vision_tower_for_llava
 from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
                      dummy_video_for_siglip, get_siglip_image_feature_size,
                      get_siglip_patch_grid_length, input_processor_for_siglip)
 from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
                     merge_multimodal_embeddings)
 
-logger = init_logger(__name__)
-
 # Result in the max possible feature size (2x2 grid of 336x336px tiles)
 MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448
 
@@ -252,10 +251,10 @@ def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int,
 
 
 def input_processor_when_multimodal_input_image(ctx: InputContext,
-                                                llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+                                                inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     model_config = ctx.model_config
     hf_config = ctx.get_hf_config(LlavaOnevisionConfig)
@@ -290,7 +289,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
         return input_processor_for_clip(
             model_config,
             vision_config,
-            llm_inputs,
+            inputs,
             image_token_id=hf_config.image_token_index,
             image_feature_size_override=image_feature_size,
         )
@@ -298,7 +297,7 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
         return input_processor_for_siglip(
             model_config,
             vision_config,
-            llm_inputs,
+            inputs,
             image_token_id=hf_config.image_token_index,
             image_feature_size_override=image_feature_size,
         )
@@ -308,10 +307,10 @@ def input_processor_when_multimodal_input_image(ctx: InputContext,
 
 
 def input_processor_when_multimodal_input_video(ctx: InputContext,
-                                                llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+                                                inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "video" not in multi_modal_data:
-        return llm_inputs
+        return inputs
     video_data = multi_modal_data["video"]
 
     model_config = ctx.model_config
@@ -326,15 +325,15 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
 
         new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
             tokenizer,
-            llm_inputs.get("prompt"),
-            llm_inputs["prompt_token_ids"],
+            inputs.get("prompt"),
+            inputs["prompt_token_ids"],
             placeholder_token_id=hf_config.video_token_index,
             repeat_count=video_feature_size,
         )
 
-        return LLMInputs(prompt_token_ids=new_token_ids,
-                         prompt=new_prompt,
-                         multi_modal_data=multi_modal_data)
+        return token_inputs(prompt_token_ids=new_token_ids,
+                            prompt=new_prompt,
+                            multi_modal_data=multi_modal_data)
 
     elif is_list_of(video_data, np.ndarray):
         raise NotImplementedError(
@@ -345,46 +344,20 @@ def input_processor_when_multimodal_input_video(ctx: InputContext,
 
 
 def input_processor_for_llava_onevision(ctx: InputContext,
-                                        llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+                                        inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or ("video" not in multi_modal_data
                                     and "image" not in multi_modal_data):
-        return llm_inputs
+        return inputs
     if "image" in multi_modal_data:
-        return input_processor_when_multimodal_input_image(ctx, llm_inputs)
+        return input_processor_when_multimodal_input_image(ctx, inputs)
     if "video" in multi_modal_data:
-        return input_processor_when_multimodal_input_video(ctx, llm_inputs)
+        return input_processor_when_multimodal_input_video(ctx, inputs)
 
     msg = "Unsupported multi data type"
     raise NotImplementedError(msg)
 
 
-def _init_vision_tower(hf_config: LlavaOnevisionConfig):
-    vision_config = hf_config.vision_config
-
-    # Initialize the vision tower only up to the required feature layer
-    vision_feature_layer = hf_config.vision_feature_layer
-    if vision_feature_layer < 0:
-        num_hidden_layers = hf_config.vision_config.num_hidden_layers \
-            + vision_feature_layer + 1
-    else:
-        num_hidden_layers = vision_feature_layer + 1
-
-    if isinstance(vision_config, CLIPVisionConfig):
-        return CLIPVisionModel(
-            vision_config,
-            num_hidden_layers_override=num_hidden_layers,
-        )
-    elif isinstance(vision_config, SiglipVisionConfig):
-        return SiglipVisionModel(
-            vision_config,
-            num_hidden_layers_override=num_hidden_layers,
-        )
-
-    msg = f"Unsupported vision config: {type(vision_config)}"
-    raise NotImplementedError(msg)
-
-
 class LlavaOnevisionMultiModalProjector(nn.Module):
 
     def __init__(self, config: LlavaOnevisionConfig):
@@ -427,7 +400,8 @@ def __init__(self,
         self.multimodal_config = multimodal_config
 
         # Initialize the vision tower only up to the required feature layer
-        self.vision_tower = _init_vision_tower(config)
+        self.vision_tower = init_vision_tower_for_llava(
+            config, quant_config, require_post_norm=False)
         self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
         self.language_model = init_vllm_registered_model(
             config.text_config, cache_config, quant_config)
diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py
index 1112a2181135..9f4f391a6682 100644
--- a/vllm/model_executor/models/mamba.py
+++ b/vllm/model_executor/models/mamba.py
@@ -1,6 +1,5 @@
 # coding=utf-8
 """PyTorch MAMBA model."""
-from dataclasses import dataclass
 from typing import Iterable, List, Optional, Tuple
 
 import torch
@@ -10,7 +9,6 @@
 from vllm.attention.backends.abstract import AttentionMetadata
 from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
 from vllm.distributed import get_tensor_model_parallel_world_size
-from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.layernorm import RMSNorm
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                MergedColumnParallelLinear,
@@ -24,12 +22,13 @@
     QuantizationConfig)
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
 from vllm.model_executor.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding)
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import (
     composed_weight_loader, default_weight_loader, sharded_weight_loader)
 from vllm.model_executor.models.interfaces import (HasInnerState,
                                                    IsAttentionFree)
-from vllm.model_executor.models.mamba_cache import MambaCacheManager
+from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
+                                                    MambaCacheParams)
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.model_executor.utils import set_weight_attrs
 from vllm.sequence import IntermediateTensors
@@ -39,13 +38,6 @@
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 
-@dataclass
-class MambaCacheParams:
-    is_prompt: bool = False
-    conv_state: torch.Tensor = torch.Tensor()
-    ssm_state: torch.Tensor = torch.Tensor()
-
-
 # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
 class MambaMixer(nn.Module):
     """
@@ -67,7 +59,7 @@ def __init__(self, config: MambaConfig, layer_idx):
         self.conv_kernel_size = config.conv_kernel
         self.intermediate_size = config.intermediate_size
         self.time_step_rank = int(config.time_step_rank)
-
+        self.is_falcon_mamba = config.model_type == "falcon_mamba"
         self.conv1d = ColumnParallelLinear(
             input_size=self.conv_kernel_size,
             output_size=self.intermediate_size,
@@ -117,10 +109,17 @@ def __init__(self, config: MambaConfig, layer_idx):
             input_is_parallel=True,
         )
         self.activation = config.hidden_act
+        if self.is_falcon_mamba:
+            self.dt_layernorm = RMSNorm(self.time_step_rank,
+                                        eps=config.mixer_rms_eps)
+            self.b_layernorm = RMSNorm(self.ssm_state_size,
+                                       eps=config.mixer_rms_eps)
+            self.c_layernorm = RMSNorm(self.ssm_state_size,
+                                       eps=config.mixer_rms_eps)
 
     def forward(self, hidden_states: torch.Tensor,
-                attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
-                ssm_state: torch.Tensor):
+                attn_metadata: AttentionMetadata,
+                mamba_cache_params: MambaCacheParams):
 
         # 1. Gated MLP's linear projection
         projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -143,17 +142,18 @@ def forward(self, hidden_states: torch.Tensor,
                 conv_weights,
                 self.conv1d.bias,
                 activation=self.activation,
-                conv_states=conv_state,
+                conv_states=mamba_cache_params.conv_state,
                 has_initial_state=attn_metadata.context_lens_tensor > 0,
+                cache_indices=mamba_cache_params.state_indices_tensor,
                 query_start_loc=attn_metadata.query_start_loc)
         else:
             hidden_states = causal_conv1d_update(
                 hidden_states.transpose(0, 1),
-                conv_state,
+                mamba_cache_params.conv_state,
                 conv_weights,
                 self.conv1d.bias,
                 self.activation,
-            )
+                conv_state_indices=mamba_cache_params.state_indices_tensor)
             hidden_states = hidden_states.transpose(0, 1)
 
         # 3. State Space Model sequence transformation
@@ -165,8 +165,12 @@ def forward(self, hidden_states: torch.Tensor,
             [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
             dim=-1,
         )
-
-        # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
+        # Note that Jamba and FalconMamba normalizes B, C, and time_step here
+        # but Mamba doesn't.
+        if self.is_falcon_mamba:
+            time_step = self.dt_layernorm(time_step.contiguous())
+            B = self.b_layernorm(B.contiguous())
+            C = self.c_layernorm(C.contiguous())
 
         discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
         # 3.c perform the recurrence y ← SSM(A, B, C)(x)
@@ -177,7 +181,7 @@ def forward(self, hidden_states: torch.Tensor,
             and attn_metadata.context_lens_tensor is not None:
             scan_outputs = selective_scan_fn(
                 hidden_states,
-                ssm_state,
+                mamba_cache_params.ssm_state,
                 discrete_time_step,
                 self.A,
                 B.transpose(-2, -1),
@@ -186,11 +190,12 @@ def forward(self, hidden_states: torch.Tensor,
                 gate,
                 time_proj_bias,
                 delta_softplus=True,
+                cache_indices=mamba_cache_params.state_indices_tensor,
                 has_initial_state=attn_metadata.context_lens_tensor > 0,
                 query_start_loc=attn_metadata.query_start_loc)
         else:
             scan_outputs = selective_state_update(
-                ssm_state,
+                mamba_cache_params.ssm_state,
                 hidden_states.transpose(0, 1),
                 discrete_time_step.transpose(0, 1),
                 self.A,
@@ -200,7 +205,7 @@ def forward(self, hidden_states: torch.Tensor,
                 gate.transpose(0, 1),
                 time_proj_bias,
                 dt_softplus=True,
-            )
+                state_batch_indices=mamba_cache_params.state_indices_tensor)
             scan_outputs = scan_outputs.transpose(0, 1)
 
         # 4. Final linear projection
@@ -209,37 +214,6 @@ def forward(self, hidden_states: torch.Tensor,
         return contextualized_states
 
 
-class MambaMLP(nn.Module):
-
-    def __init__(
-        self,
-        config: MambaConfig,
-        quant_config: Optional[QuantizationConfig] = None,
-    ) -> None:
-        super().__init__()
-        hidden_size = config.hidden_size
-        intermediate_size = config.intermediate_size
-        hidden_act = config.hidden_act
-        self.gate_up_proj = MergedColumnParallelLinear(
-            hidden_size, [intermediate_size] * 2,
-            bias=False,
-            quant_config=quant_config)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           quant_config=quant_config)
-        if hidden_act != "silu":
-            raise ValueError(f"Unsupported activation: {hidden_act}. "
-                             "Only silu is supported for now.")
-        self.act_fn = SiluAndMul()
-
-    def forward(self, x):
-        gate_up, _ = self.gate_up_proj(x)
-        x = self.act_fn(gate_up)
-        x, _ = self.down_proj(x)
-        return x
-
-
 class MambaDecoderLayer(nn.Module):
 
     def __init__(self,
@@ -250,20 +224,16 @@ def __init__(self,
         super().__init__()
         self.layer_idx = layer_idx
         self.config = config
+        self.is_falcon_mamba = config.model_type == "falcon_mamba"
         self.mixer = MambaMixer(config, layer_idx)
-
-        self.feed_forward = MambaMLP(config, quant_config=quant_config)
         self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
-        self.pre_ff_layernorm = RMSNorm(config.hidden_size,
-                                        eps=config.layer_norm_epsilon)
 
     def forward(
         self,
         hidden_states: torch.Tensor,
         attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
-        conv_state: torch.Tensor,
-        ssm_state: torch.Tensor,
+        mamba_cache_params: MambaCacheParams,
         **kwargs,
     ):
         if residual is None:
@@ -272,12 +242,8 @@ def forward(
         else:
             hidden_states, residual = self.norm(hidden_states, residual)
 
-        hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
-                                   ssm_state)
-        # Fully Connected
-        hidden_states, residual = self.pre_ff_layernorm(
-            hidden_states, residual)
-        hidden_states = self.feed_forward(hidden_states)
+        hidden_states = self.mixer(hidden_states, attn_metadata,
+                                   mamba_cache_params)
         return hidden_states, residual
 
 
@@ -319,53 +285,27 @@ def forward(
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
-        kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
-        conv_state: torch.Tensor,
-        ssm_state: torch.Tensor,
+        mamba_cache_params: MambaCacheParams,
     ) -> torch.Tensor:
+
         hidden_states = self.embeddings(input_ids)
         residual = None
 
         for i in range(len(self.layers)):
             layer = self.layers[i]
-            current_ssm_state = ssm_state[i]
-            current_conv_state = conv_state[i]
-
             hidden_states, residual = layer(
                 positions=positions,
                 hidden_states=hidden_states,
                 attn_metadata=attn_metadata,
                 residual=residual,
-                conv_state=current_conv_state,
-                ssm_state=current_ssm_state,
-            )
+                mamba_cache_params=mamba_cache_params.at_layer_idx(i))
         hidden_states, _ = self.norm_f(hidden_states, residual)
 
         return hidden_states
 
 
 class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
-    packed_modules_mapping = {
-        "qkv_proj": [
-            "q_proj",
-            "k_proj",
-            "v_proj",
-        ],
-    }
-
-    # LoRA specific attributes
-    supported_lora_modules = [
-        "qkv_proj",
-        "o_proj",
-        "embed_tokens",
-        "lm_head",
-    ]
-    embedding_modules = {
-        "embeddings": "input_embeddings",
-        "lm_head": "output_embeddings",
-    }
-    embedding_padding_modules = ["lm_head"]
 
     def __init__(
         self,
@@ -388,8 +328,18 @@ def __init__(
         self.unpadded_vocab_size = config.vocab_size
         if lora_config:
             self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
-
-        self.lm_head = self.backbone.embeddings
+        if config.tie_word_embeddings:
+            self.lm_head = self.backbone.embeddings
+        else:
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+                padding_size=DEFAULT_VOCAB_PADDING_SIZE
+                # We need bigger padding if using lora for kernel
+                # compatibility
+                if not lora_config else lora_config.lora_vocab_padding_size,
+            )
 
         # Used to track and store by the Mamba cache between steps.
         self.mamba_cache: Optional[MambaCacheManager] = None
@@ -413,12 +363,18 @@ def forward(self,
                 self.lm_head.weight.dtype, self.config.num_hidden_layers,
                 max_batch_size, *self._get_mamba_cache_shape())
 
-        mamba_cache_tensors = self.mamba_cache.current_run_tensors(
-            input_ids, attn_metadata, **kwargs)
+        (
+            mamba_cache_tensors,
+            state_indices_tensor,
+        ) = self.mamba_cache.current_run_tensors(input_ids, attn_metadata,
+                                                 **kwargs)
 
-        hidden_states = self.backbone(input_ids, positions, kv_caches,
-                                      attn_metadata, mamba_cache_tensors[0],
-                                      mamba_cache_tensors[1])
+        mamba_cache_params = MambaCacheParams(mamba_cache_tensors[0],
+                                              mamba_cache_tensors[1],
+                                              state_indices_tensor)
+
+        hidden_states = self.backbone(input_ids, positions, attn_metadata,
+                                      mamba_cache_params)
 
         return hidden_states
 
@@ -457,43 +413,15 @@ def sample(
         return next_tokens
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        stacked_params_mapping = [
-            # (param_name, shard_name, shard_id)
-            ("qkv_proj", "q_proj", "q"),
-            ("qkv_proj", "k_proj", "k"),
-            ("qkv_proj", "v_proj", "v"),
-            ("gate_up_proj", "gate_proj", 0),
-            ("gate_up_proj", "up_proj", 1),
-        ]
-
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in weights:
-            if "rotary_emb.inv_freq" in name:
-                continue
-
             if "A_log" in name:
                 name = name.replace("A_log", "A")
+            # Skip loading extra bias for GPTQ models.
+            if name.endswith(".bias") and name not in params_dict:
+                continue
 
-            if ".self_attn." in name:
-                name = name.replace(".self_attn", "")
-
-            for param_name, weight_name, shard_id in stacked_params_mapping:
-                if weight_name not in name:
-                    continue
-                name = name.replace(weight_name, param_name)
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-                param = params_dict[name]
-                weight_loader = param.weight_loader
-                weight_loader(param, loaded_weight, shard_id)
-                break
-            else:
-                # Skip loading extra bias for GPTQ models.
-                if name.endswith(".bias") and name not in params_dict:
-                    continue
-
-                param = params_dict[name]
-                weight_loader = getattr(param, "weight_loader",
-                                        default_weight_loader)
-                weight_loader(param, loaded_weight)
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py
index 8d1ba3737d4a..79393421f3ae 100644
--- a/vllm/model_executor/models/mamba_cache.py
+++ b/vllm/model_executor/models/mamba_cache.py
@@ -1,8 +1,22 @@
-from typing import Dict, List, Optional
+from dataclasses import dataclass
+from typing import Dict, List
 
 import torch
 
 from vllm.attention.backends.abstract import AttentionMetadata
+from vllm.attention.backends.utils import PAD_SLOT_ID
+
+
+@dataclass
+class MambaCacheParams:
+    conv_state: torch.Tensor = torch.Tensor()
+    ssm_state: torch.Tensor = torch.Tensor()
+    state_indices_tensor: torch.Tensor = torch.Tensor()
+
+    def at_layer_idx(self, layer_idx):
+        return MambaCacheParams(self.conv_state[layer_idx],
+                                self.ssm_state[layer_idx],
+                                self.state_indices_tensor)
 
 
 class MambaCacheManager:
@@ -24,6 +38,7 @@ def __init__(self, dtype, num_mamba_layers, max_batch_size,
         # Maps between the request id and a dict that maps between the seq_id
         # and its index inside the self.mamba_cache
         self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
+        self.free_cache_indices = list(range(max_batch_size))
 
     def current_run_tensors(self, input_ids: torch.Tensor,
                             attn_metadata: AttentionMetadata, **kwargs):
@@ -36,30 +51,43 @@ def current_run_tensors(self, input_ids: torch.Tensor,
             finished_requests_ids = kwargs["finished_requests_ids"]
 
             self._release_finished_requests(finished_requests_ids)
-            mamba_cache_tensors = self._prepare_current_run_mamba_cache(
+            state_indices = self._prepare_current_run_mamba_cache(
                 request_ids_to_seq_ids, finished_requests_ids)
 
+            state_indices_tensor = torch.as_tensor(state_indices,
+                                                   dtype=torch.int32,
+                                                   device="cuda")
+            mamba_cache_tensors = self.mamba_cache
+
         else:
             # CUDA graph capturing runs
-            mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
+            (mamba_cache_tensors,
+             state_indices_tensor) = kwargs["seqlen_agnostic_capture_inputs"]
 
-        return mamba_cache_tensors
+        return (mamba_cache_tensors, state_indices_tensor)
 
     def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
         """
-        Copy the relevant Mamba cache into the CUDA graph input buffer
-        that was provided during the capture runs
-        (JambaForCausalLM.mamba_gc_cache_buffer).
+        Copy the relevant state_indices into the CUDA graph input buffer 
         """
         assert all(
             key in kwargs
             for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
         finished_requests_ids = kwargs["finished_requests_ids"]
         request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
+        assert "seqlen_agnostic_capture_inputs" in input_buffers
+        _, input_state_indices_buffer = input_buffers[
+            "seqlen_agnostic_capture_inputs"]
 
         self._release_finished_requests(finished_requests_ids)
-        self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
-                                              finished_requests_ids)
+        state_indices = self._prepare_current_run_mamba_cache(
+            request_ids_to_seq_ids, finished_requests_ids)
+        cuda_graph_pad_len = input_state_indices_buffer.shape[0] - len(
+            state_indices)
+        state_indices.extend([PAD_SLOT_ID] * cuda_graph_pad_len)
+
+        input_state_indices_buffer.copy_(
+            torch.as_tensor(state_indices, dtype=torch.int32, device="cuda"))
 
     def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
         """
@@ -67,13 +95,10 @@ def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
         The buffer is used to maintain the Mamba Cache during the CUDA graph
         replay runs.
         """
-        return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
-
-    def _swap_mamba_cache(self, from_index: int, to_index: int):
-        assert len(self.mamba_cache) > 0
-        for cache_t in self.mamba_cache:
-            cache_t[:, [to_index,from_index]] = \
-             cache_t[:, [from_index,to_index]]
+        state_indices_tensor = torch.as_tensor([PAD_SLOT_ID] * batch_size,
+                                               dtype=torch.int32,
+                                               device="cuda")
+        return (self.mamba_cache, state_indices_tensor)
 
     def _copy_mamba_cache(self, from_index: int, to_index: int):
         assert len(self.mamba_cache) > 0
@@ -81,142 +106,53 @@ def _copy_mamba_cache(self, from_index: int, to_index: int):
             cache_t[:, to_index].copy_(cache_t[:, from_index],
                                        non_blocking=True)
 
-    def _move_out_if_already_occupied(self, index: int,
-                                      all_occupied_indices: List[int]):
-        if index in all_occupied_indices:
-            first_free_index = self._first_free_index_in_mamba_cache()
-            # In case occupied, move the occupied to a new empty block
-            self._move_cache_index_and_mappings(from_index=index,
-                                                to_index=first_free_index)
-
-    def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
-                                                       seq_id: int,
-                                                       destination_index: int):
+    def _assign_seq_id_to_cache_index(self, cur_rid: str, seq_id: int,
+                                      finished_requests_ids) -> int:
         """
         Assign (req_id,seq_id) pair to a `destination_index` index, if
         already occupied, move the occupying index to a free index.
         """
-        all_occupied_indices = self._get_all_occupied_indices()
-        if cur_rid not in self.mamba_cache_indices_mapping:
-            self._move_out_if_already_occupied(
-                index=destination_index,
-                all_occupied_indices=all_occupied_indices)
+        if cur_rid in finished_requests_ids:
+            # set as pad, do not allocate destination index
+            return PAD_SLOT_ID
+        elif cur_rid not in self.mamba_cache_indices_mapping:
+            destination_index = self.free_cache_indices.pop()
             self.mamba_cache_indices_mapping[cur_rid] = {
                 seq_id: destination_index
             }
+            return destination_index
         elif seq_id not in (seq_ids2indices :=
                             self.mamba_cache_indices_mapping[cur_rid]):
             # parallel sampling , where n > 1, assume prefill have
-            # already happened now we only need to copy the already
+            # already happened, so we copy the
             # existing cache into the siblings seq_ids caches
-            self._move_out_if_already_occupied(
-                index=destination_index,
-                all_occupied_indices=all_occupied_indices)
-            index_exists = list(seq_ids2indices.values())[0]
+            index_exists = next(iter(seq_ids2indices.values()))
             # case of decoding n>1, copy prefill cache to decoding indices
+            destination_index = self.free_cache_indices.pop()
             self._copy_mamba_cache(from_index=index_exists,
                                    to_index=destination_index)
             self.mamba_cache_indices_mapping[cur_rid][
                 seq_id] = destination_index
+            return destination_index
         else:
             # already exists
-            cache_index_already_exists = self.mamba_cache_indices_mapping[
-                cur_rid][seq_id]
-            if cache_index_already_exists != destination_index:
-                # In case the seq id already exists but not in
-                # the right destination, swap it with what's occupying it
-                self._swap_pair_indices_and_mappings(
-                    from_index=cache_index_already_exists,
-                    to_index=destination_index)
+            return self.mamba_cache_indices_mapping[cur_rid][seq_id]
 
     def _prepare_current_run_mamba_cache(
             self, request_ids_to_seq_ids: Dict[str, list[int]],
-            finished_requests_ids: List[str]):
-        running_indices = []
-        request_ids_to_seq_ids_flatten = [
-            (req_id, seq_id)
+            finished_requests_ids: List[str]) -> List[int]:
+        return [
+            self._assign_seq_id_to_cache_index(req_id, seq_id,
+                                               finished_requests_ids)
             for req_id, seq_ids in request_ids_to_seq_ids.items()
             for seq_id in seq_ids
         ]
-        batch_size = len(request_ids_to_seq_ids_flatten)
-        for dest_index, (request_id,
-                         seq_id) in enumerate(request_ids_to_seq_ids_flatten):
-            if request_id in finished_requests_ids:
-                # Do not allocate cache index for requests that run
-                # and finish right after
-                continue
-            self._assign_seq_id_to_mamba_cache_in_specific_dest(
-                request_id, seq_id, dest_index)
-            running_indices.append(dest_index)
-
-        self._clean_up_first_bs_blocks(batch_size, running_indices)
-        conv_state = self.mamba_cache[0][:, :batch_size]
-        temporal_state = self.mamba_cache[1][:, :batch_size]
-
-        return (conv_state, temporal_state)
-
-    def _get_all_occupied_indices(self):
-        return [
-            cache_idx
-            for seq_ids2indices in self.mamba_cache_indices_mapping.values()
-            for cache_idx in seq_ids2indices.values()
-        ]
-
-    def _clean_up_first_bs_blocks(self, batch_size: int,
-                                  indices_for_current_run: List[int]):
-        # move out all of the occupied but currently not running blocks
-        # outside of the first n blocks
-        destination_indices = range(batch_size)
-        max_possible_batch_size = self.mamba_cache[0].shape[1]
-        for destination_index in destination_indices:
-            if destination_index in self._get_all_occupied_indices() and  \
-               destination_index not in indices_for_current_run:
-                # move not running indices outside of the batch
-                all_other_indices = list(
-                    range(batch_size, max_possible_batch_size))
-                first_avail_index = self._first_free_index_in_mamba_cache(
-                    all_other_indices)
-                self._swap_indices(from_index=destination_index,
-                                   to_index=first_avail_index)
-
-    def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
-        self._copy_mamba_cache(from_index=from_index, to_index=to_index)
-        self._update_mapping_index(from_index=from_index, to_index=to_index)
-
-    def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
-        self._swap_mamba_cache(from_index=from_index, to_index=to_index)
-        self._swap_mapping_index(from_index=from_index, to_index=to_index)
-
-    def _swap_mapping_index(self, from_index: int, to_index: int):
-        for seq_ids2index in self.mamba_cache_indices_mapping.values():
-            for seq_id, index in seq_ids2index.items():
-                if from_index == index:
-                    seq_ids2index.update({seq_id: to_index})
-                elif to_index == index:
-                    seq_ids2index.update({seq_id: from_index})
-
-    def _update_mapping_index(self, from_index: int, to_index: int):
-        for seq_ids2index in self.mamba_cache_indices_mapping.values():
-            for seq_id, index in seq_ids2index.items():
-                if from_index == index:
-                    seq_ids2index.update({seq_id: to_index})
-                    return
 
     def _release_finished_requests(self,
                                    finished_seq_groups_req_ids: List[str]):
         for req_id in finished_seq_groups_req_ids:
             if req_id in self.mamba_cache_indices_mapping:
+                for seq_id in self.mamba_cache_indices_mapping[req_id]:
+                    self.free_cache_indices.append(
+                        self.mamba_cache_indices_mapping[req_id][seq_id])
                 self.mamba_cache_indices_mapping.pop(req_id)
-
-    def _first_free_index_in_mamba_cache(
-            self, indices_range: Optional[List[int]] = None) -> int:
-        assert self.mamba_cache is not None
-        if indices_range is None:
-            max_possible_batch_size = self.mamba_cache[0].shape[1]
-            indices_range = list(range(max_possible_batch_size))
-        all_occupied_indices = self._get_all_occupied_indices()
-        for i in indices_range:
-            if i not in all_occupied_indices:
-                return i
-        raise Exception("Couldn't find a free spot in the mamba cache! This"
-                        "should never happen")
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 41c2877194bb..03fb036020f2 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -29,11 +29,12 @@
 from transformers import PretrainedConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size,
                               tensor_model_parallel_all_reduce)
-from vllm.model_executor.layers.activation import SiluAndMul
+from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
 from vllm.model_executor.layers.fused_moe import fused_moe
 from vllm.model_executor.layers.layernorm import RMSNorm
 from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
@@ -152,6 +153,7 @@ def __init__(
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
+        hidden_act_param: float,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -163,10 +165,13 @@ def __init__(
                                            hidden_size,
                                            bias=False,
                                            quant_config=quant_config)
-        if hidden_act != "silu":
+        if hidden_act == "silu":
+            self.act_fn = SiluAndMul()
+        elif hidden_act == "fatrelu":
+            self.act_fn = FatreluAndMul(threshold=hidden_act_param)
+        else:
             raise ValueError(f"Unsupported activation: {hidden_act}. "
-                             "Only silu is supported for now.")
-        self.act_fn = SiluAndMul()
+                             "Only silu and fatrelu are supported for now.")
 
     def forward(self, x):
         gate_up, _ = self.gate_up_proj(x)
@@ -304,6 +309,7 @@ def _init_ffn_block(self):
                 hidden_size=self.hidden_size,
                 intermediate_size=self.config.intermediate_size,
                 hidden_act=self.config.hidden_act,
+                hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
                 quant_config=self.quant_config,
             )
         else:
@@ -343,6 +349,7 @@ def forward(
         return hidden_states, None
 
 
+@support_torch_compile
 class MiniCPMModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index 9ee4dd0f0623..2ec51dc4647f 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -36,7 +36,8 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2,
@@ -256,7 +257,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
 
 
 def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int):
-    return SequenceData.from_token_counts((0, seq_len))
+    return SequenceData.from_prompt_token_counts((0, seq_len))
 
 
 def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig,
@@ -279,10 +280,10 @@ def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int,
     return seq_data, mm_data
 
 
-def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+def input_processor_for_minicpmv(ctx: InputContext, inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
     model_config = ctx.model_config
     version = get_version_by_config(model_config.hf_config)
     tokenizer = cached_get_tokenizer(
@@ -297,8 +298,8 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
         return image_processor. \
             get_slice_image_placeholder(image_size, num_image)
 
-    prompt = llm_inputs.get("prompt")
-    token_ids = llm_inputs.get("prompt_token_ids")
+    prompt = inputs.get("prompt")
+    token_ids = inputs.get("prompt_token_ids")
     if prompt is None:
         prompt = tokenizer.decode(token_ids)
 
@@ -332,12 +333,11 @@ def get_placeholder(image_size: Tuple[int, int], num_image: int):
         _build_image_input(ctx, image) for image in images
     ]
 
-    llm_inputs = LLMInputs(
+    return token_inputs(
         prompt_token_ids=new_token_ids,
         prompt=new_prompt,
         multi_modal_data=multi_modal_data,
     )
-    return llm_inputs
 
 
 def input_mapper_for_minicpmv(ctx: InputContext, data: object):
@@ -395,7 +395,7 @@ def __init__(
 
         self.version = get_version_by_config(self.config)
         self.llm = self.init_llm(config, cache_config, quant_config)
-        self.vpm = self.init_vision_module()
+        self.vpm = self.init_vision_module(config, quant_config)
         param_dtype = torch.get_default_dtype()
         self.vpm.to(dtype=param_dtype)
         self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
@@ -647,7 +647,11 @@ def init_llm(
     ) -> nn.Module:
         raise NotImplementedError
 
-    def init_vision_module(self) -> nn.Module:
+    def init_vision_module(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig],
+    ) -> nn.Module:
         raise NotImplementedError
 
     def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
@@ -693,7 +697,11 @@ def init_llm(
                                        quant_config=quant_config),
                           name="model")
 
-    def init_vision_module(self) -> nn.Module:
+    def init_vision_module(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig],
+    ) -> nn.Module:
         # TODO :refactor this vision model
         try:
             import timm
@@ -817,8 +825,13 @@ def init_llm(
                                      quant_config=quant_config),
                           name="model")
 
-    def init_vision_module(self) -> nn.Module:
-        model = Idefics2VisionTransformer(self.config.vision_config)
+    def init_vision_module(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig],
+    ) -> nn.Module:
+        model = Idefics2VisionTransformer(config.vision_config,
+                                          quant_config=quant_config)
         if self.config.drop_vision_last_layer:
             model.encoder.layers = model.encoder.layers[:-1]
         return model
@@ -929,9 +942,13 @@ def init_llm(
                                      quant_config=quant_config),
                           name="model")
 
-    def init_vision_module(self) -> nn.Module:
-
-        model = Idefics2VisionTransformer(self.config.vision_config)
+    def init_vision_module(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig],
+    ) -> nn.Module:
+        model = Idefics2VisionTransformer(config.vision_config,
+                                          quant_config=quant_config)
         if self.config.drop_vision_last_layer:
             model.encoder.layers = model.encoder.layers[:-1]
         return model
diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py
index 45d6ad3c0efa..44ef49729c96 100644
--- a/vllm/model_executor/models/mllama.py
+++ b/vllm/model_executor/models/mllama.py
@@ -14,10 +14,10 @@
 # limitations under the License.
 """PyTorch Mllama model."""
 import math
-from array import array
 from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                     TypedDict, Union)
 
+import numpy as np
 import torch
 import torch.nn.functional as F
 import torch.utils.checkpoint
@@ -28,12 +28,16 @@
                                            CausalLMOutputWithPast)
 from transformers.models.mllama.image_processing_mllama import (
     get_optimal_tiled_canvas)
+from transformers.models.mllama.processing_mllama import (
+    get_cross_attention_token_mask)
 
 import vllm.distributed.parallel_state as ps
 from vllm.attention import Attention, AttentionMetadata, AttentionType
+from vllm.attention.ops.paged_attn import PagedAttention
 from vllm.config import CacheConfig, MultiModalConfig
 from vllm.distributed import get_tensor_model_parallel_world_size
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
+                         EncoderDecoderInputs, InputContext)
 from vllm.logger import init_logger
 from vllm.model_executor.layers.layernorm import RMSNorm
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -47,7 +51,7 @@
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.multimodal import MULTIMODAL_REGISTRY
-from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
+from vllm.sequence import SequenceData
 
 from .clip import CLIPMLP
 from .interfaces import SupportsMultiModal
@@ -72,31 +76,45 @@ class MllamaImagePixelInputs(TypedDict):
 # TODO: support LlamaImageEmbeddingInputs
 
 
-def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
+def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int:
+    num_images = 0
+    for token_id in prompt_token_ids[::-1]:
+        if token_id == MLLAMA_IMAGE_TOKEN_ID:
+            num_images += 1
+        elif num_images > 0:
+            break
+    return num_images
+
+
+def input_processor_for_mllama(ctx: InputContext,
+                               inputs: Union[DecoderOnlyInputs,
+                                             EncoderDecoderInputs]):
     # move encoder_prompt to prompt
-    if llm_inputs.get("prompt") is None:
-        llm_inputs["prompt"] = llm_inputs["encoder_prompt"]
-        llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"]
+    if inputs.get("prompt") is None:
+        inputs["prompt"] = inputs["encoder_prompt"]
+        inputs["prompt_token_ids"] = inputs["encoder_prompt_token_ids"]
 
     # process multi-modal data
-    assert "decoder_multi_modal_data" not in llm_inputs, \
-        "multi-modal data should be put in encoder message of mllama"
-    multi_modal_data = llm_inputs.get("encoder_multi_modal_data")
+    multi_modal_data = inputs.get("encoder_multi_modal_data")
 
     if multi_modal_data is None or "image" not in multi_modal_data \
         or multi_modal_data["image"] is None:
         # text-only
-        llm_inputs["encoder_prompt"] = ""
-        llm_inputs["encoder_prompt_token_ids"] = []
-        llm_inputs["encoder_multi_modal_data"] = {}
-        return llm_inputs
+        inputs["encoder_prompt"] = ""
+        inputs["encoder_prompt_token_ids"] = []
+        inputs["encoder_multi_modal_data"] = {}
+        return inputs
 
-    # get num_tiles
     if isinstance(multi_modal_data['image'], Image.Image):
         multi_modal_data['image'] = [multi_modal_data['image']]
+    # Since only the last group of consecutive images
+    # are attended by the decoded tokens, we only need to
+    # get the number of tiles for those images.
+    num_decode_images = _get_num_image_in_last_group(
+        inputs["prompt_token_ids"])
     hf_config = ctx.model_config.hf_config
     num_tiles = 0
-    for image in multi_modal_data["image"]:
+    for image in multi_modal_data["image"][::-1]:
         width, height = image.size
         tile_size = hf_config.vision_config.image_size
         canvas_height, canvas_width = get_optimal_tiled_canvas(
@@ -108,17 +126,21 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
         num_tiles_height = canvas_height // tile_size
         num_tiles_width = canvas_width // tile_size
         num_tiles += num_tiles_height * num_tiles_width
+        num_decode_images -= 1
+        if num_decode_images == 0:
+            break
 
-    # set encoder prompt based on num_tiles
+    # Set encoder prompt length based on the number of tiles.
+    # This tells the block manager to allocate correct number
+    # of slots for encoder tokens.
     assert hf_config.vision_config.image_size % 14 == 0, \
         "chunk size should be multiple of 14"
     token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
     num_tokens = num_tiles * token_per_chunk
-    llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
-    llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID
-                                              ] * num_tokens
+    inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens
+    inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID] * num_tokens
 
-    return llm_inputs
+    return inputs
 
 
 def get_max_mllama_image_tokens(ctx: InputContext) -> int:
@@ -131,17 +153,18 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int):
     # <|image|> * num_images + 0 * (seq_len - num_images)
     assert seq_len >= num_images, \
         "seq_len should be greater than or equal to num_images"
-    token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
-                      [MLLAMA_IMAGE_TOKEN_ID]) * num_images
-    token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images)
-    return SequenceData(token_ids)
+
+    return SequenceData.from_prompt_token_counts(
+        (MLLAMA_IMAGE_TOKEN_ID, num_images),
+        (0, seq_len - num_images),
+    )
 
 
 def dummy_encoder_seq_data(ctx: InputContext, num_images: int):
     num_tokens = get_max_mllama_image_tokens(ctx) * num_images
-    token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
-                      [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens
-    return SequenceData(token_ids)
+
+    return SequenceData.from_prompt_token_counts(
+        (MLLAMA_IMAGE_TOKEN_ID, num_tokens))
 
 
 def dummy_image(num_images: int, ):
@@ -356,9 +379,13 @@ def forward(
 
 class MllamaVisionEncoderLayer(nn.Module):
 
-    def __init__(self,
-                 config: config_mllama.MllamaVisionConfig,
-                 is_gated: bool = False):
+    def __init__(
+        self,
+        config: config_mllama.MllamaVisionConfig,
+        quant_config: Optional[QuantizationConfig],
+        prefix: str = "",
+        is_gated: bool = False,
+    ) -> None:
         super().__init__()
 
         self.hidden_size = config.hidden_size
@@ -367,7 +394,9 @@ def __init__(self,
         self.intermediate_size = config.intermediate_size
 
         self.self_attn = MllamaVisionSdpaAttention(config)
-        self.mlp = CLIPMLP(config)
+        self.mlp = CLIPMLP(config,
+                           quant_config=quant_config,
+                           prefix=f"{prefix}.mlp")
 
         self.input_layernorm = nn.LayerNorm(self.hidden_size,
                                             eps=config.norm_eps)
@@ -404,16 +433,23 @@ def forward(
 
 class MllamaVisionEncoder(nn.Module):
 
-    def __init__(self,
-                 config: config_mllama.MllamaVisionConfig,
-                 num_layers=32,
-                 is_gated=False,
-                 output_hidden_states=None):
+    def __init__(
+        self,
+        config: config_mllama.MllamaVisionConfig,
+        quant_config: Optional[QuantizationConfig],
+        num_layers: int = 32,
+        is_gated: bool = False,
+        output_hidden_states=None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
         self.config = config
         self.layers = nn.ModuleList([
-            MllamaVisionEncoderLayer(config, is_gated)
-            for _ in range(num_layers)
+            MllamaVisionEncoderLayer(config,
+                                     quant_config=quant_config,
+                                     is_gated=is_gated,
+                                     prefix=f"{prefix}.layers.{layer_idx}")
+            for layer_idx in range(num_layers)
         ])
         self.output_hidden_states = output_hidden_states or []
 
@@ -440,8 +476,14 @@ def forward(
 
 class MllamaVisionModel(nn.Module):
 
-    def __init__(self, config: config_mllama.MllamaVisionConfig):
+    def __init__(
+        self,
+        config: config_mllama.MllamaVisionConfig,
+        quant_config: Optional[QuantizationConfig],
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.image_size = config.image_size
         self.patch_size = config.patch_size
         self.max_num_tiles = config.max_num_tiles
@@ -477,12 +519,19 @@ def __init__(self, config: config_mllama.MllamaVisionConfig):
         # encoders
         self.transformer = MllamaVisionEncoder(
             config,
+            quant_config,
             config.num_hidden_layers,
             is_gated=False,
-            output_hidden_states=config.intermediate_layers_indices)
-        self.global_transformer = MllamaVisionEncoder(config,
-                                                      config.num_global_layers,
-                                                      is_gated=True)
+            output_hidden_states=config.intermediate_layers_indices,
+            prefix=f"{prefix}.transformer",
+        )
+        self.global_transformer = MllamaVisionEncoder(
+            config,
+            quant_config,
+            config.num_global_layers,
+            is_gated=True,
+            prefix=f"{prefix}.global_transformer",
+        )
 
     def apply_class_embedding(self,
                               hidden_state: torch.Tensor) -> torch.Tensor:
@@ -625,6 +674,7 @@ def __init__(
         config: Optional[config_mllama.MllamaTextConfig] = None,
         layer_idx: Optional[int] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
     ):
         super().__init__()
         self.config = config
@@ -650,6 +700,7 @@ def __init__(
             self.num_key_value_heads,
             bias=False,
             quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj",
         )
         self.o_proj = RowParallelLinear(
             self.num_heads * self.head_dim,
@@ -657,6 +708,7 @@ def __init__(
             bias=False,
             input_is_parallel=True,
             quant_config=quant_config,
+            prefix=f"{prefix}.o_proj",
         )
         # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
         # use huggingface's instead
@@ -669,12 +721,14 @@ def __init__(
             self.head_dim,
             self.scaling,
             self.num_local_key_value_heads,
+            prefix=f"{prefix}.attn",
         )
 
     def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.Tensor],
+        kv_range_for_decode: Optional[List[Tuple[int, int]]],
         cross_attention_states: Optional[torch.Tensor],
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
@@ -697,29 +751,93 @@ def forward(
         q = q.view(-1, self.num_local_heads, self.head_dim)
         q = self.q_norm(q)
 
-        output = self.attn(q,
-                           k,
-                           v,
-                           kv_cache,
-                           attn_metadata,
-                           attn_type=AttentionType.ENCODER_DECODER)
+        if attention_mask is not None:
+            output = self.attention_with_mask(q, k, v, kv_cache,
+                                              attention_mask,
+                                              kv_range_for_decode,
+                                              attn_metadata)
+        else:
+            output = self.attn(q,
+                               k,
+                               v,
+                               kv_cache,
+                               attn_metadata,
+                               attn_type=AttentionType.ENCODER_DECODER)
         out, _ = self.o_proj(output)
         return out
 
+    def attention_with_mask(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attention_mask: torch.Tensor,
+        kv_range_for_decode: List[Tuple[int, int]],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        # Skip writing kv-cache for the initial profiling run.
+        if len(kv_cache.shape) == 3:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_local_key_value_heads, self.head_dim)
+            cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
+            cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
+            PagedAttention.write_to_paged_cache(
+                cached_k, cached_v, key_cache, value_cache,
+                attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
+        # We have to call torch.sdpa for prefill when using a
+        # custom cross-attention mask. Because the mask is not a
+        # standard causal mask, neither a block diagonal mask which
+        # can be optimized by xformers.BlockDiagonalMask.
+        # The mask is specially calculated for supporting multi
+        # images and interleaved images.
+        q_len = q.shape[0]
+        kv_len = k.shape[0]
+        q = q.transpose(0, 1).view(self.num_local_key_value_heads,
+                                   self.num_key_value_groups, q_len,
+                                   self.head_dim).contiguous()
+        k = k.transpose(0,
+                        1)[:,
+                           None, :, :].expand(self.num_local_key_value_heads,
+                                              self.num_key_value_groups,
+                                              kv_len,
+                                              self.head_dim).contiguous()
+        v = v.transpose(0,
+                        1)[:,
+                           None, :, :].expand(self.num_local_key_value_heads,
+                                              self.num_key_value_groups,
+                                              kv_len,
+                                              self.head_dim).contiguous()
+        attention_mask = attention_mask.view(1, 1, q_len, kv_len)
+        output = F.scaled_dot_product_attention(q,
+                                                k,
+                                                v,
+                                                attn_mask=attention_mask,
+                                                is_causal=False)
+        output = output.permute(2, 0, 1, 3).reshape(
+            q_len, self.num_local_heads * self.head_dim)
+        return output
+
 
 class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
     """Cross-attention transformer block with tanh-gated attention
     and feedforward."""
 
-    def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
-                 quant_config: Optional[QuantizationConfig]) \
-        -> None:
+    def __init__(
+        self,
+        config: config_mllama.MllamaTextConfig,
+        layer_idx: int,
+        quant_config: Optional[QuantizationConfig],
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.layer_idx = layer_idx
         self.cross_attn = MllamaTextCrossAttention(
             config=config,
             layer_idx=layer_idx,
             quant_config=quant_config,
+            prefix=f"{prefix}.cross_attn",
         )
 
         self.input_layernorm = RMSNorm(config.hidden_size,
@@ -731,6 +849,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
             quant_config=quant_config,
+            prefix=f"{prefix}.mlp",
         )
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                 eps=config.rms_norm_eps)
@@ -741,6 +860,7 @@ def forward(
         hidden_states: torch.Tensor,
         cross_attention_states: torch.Tensor,
         cross_attention_mask: torch.Tensor,
+        kv_range_for_decode: Optional[List[Tuple[int, int]]],
         full_text_row_masked_out_mask: torch.Tensor,
         kv_cache: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
@@ -751,6 +871,7 @@ def forward(
         hidden_states = self.cross_attn(
             hidden_states=hidden_states,
             attention_mask=cross_attention_mask,
+            kv_range_for_decode=kv_range_for_decode,
             cross_attention_states=cross_attention_states,
             kv_cache=kv_cache,
             attn_metadata=attn_metadata,
@@ -772,10 +893,15 @@ class MllamaTextModel(nn.Module):
     config_class = config_mllama.MllamaTextConfig
     base_model_prefix = "model"
 
-    def __init__(self, config: config_mllama.MllamaTextConfig,
-                 cache_config: Optional[CacheConfig],
-                 quant_config: Optional[QuantizationConfig]):
+    def __init__(
+        self,
+        config: config_mllama.MllamaTextConfig,
+        cache_config: Optional[CacheConfig],
+        quant_config: Optional[QuantizationConfig],
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.padding_idx = config.pad_token_id
         self.vocab_size = config.vocab_size
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8,
@@ -787,13 +913,20 @@ def __init__(self, config: config_mllama.MllamaTextConfig,
             if layer_idx in self.cross_attention_layers:
                 layers.append(
                     MllamaCrossAttentionDecoderLayer(
-                        config, layer_idx, quant_config=quant_config))
+                        config,
+                        layer_idx,
+                        quant_config=quant_config,
+                        prefix=f"{prefix}.layers.{layer_idx}",
+                    ))
             else:
                 # TODO: force LlamaDecoderLayer to config.attention_bias=False
                 layers.append(
-                    LlamaDecoderLayer(config,
-                                      cache_config=cache_config,
-                                      quant_config=quant_config))
+                    LlamaDecoderLayer(
+                        config,
+                        cache_config=cache_config,
+                        quant_config=quant_config,
+                        prefix=f"{prefix}.layers.{layer_idx}",
+                    ))
 
         self.layers = nn.ModuleList(layers)
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -804,6 +937,7 @@ def forward(
         positions: Optional[torch.LongTensor],
         cross_attention_states: Optional[torch.LongTensor],
         cross_attention_mask: Optional[torch.LongTensor],
+        kv_range_for_decode: Optional[List[Tuple[int, int]]],
         full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
                                                       torch.Tensor]],
         kv_caches: List[torch.Tensor],
@@ -820,6 +954,7 @@ def forward(
                         hidden_states=hidden_states,
                         cross_attention_states=cross_attention_states,
                         cross_attention_mask=cross_attention_mask,
+                        kv_range_for_decode=kv_range_for_decode,
                         full_text_row_masked_out_mask=
                         full_text_row_masked_out_mask,
                         kv_cache=kv_caches[idx],
@@ -848,12 +983,19 @@ class MllamaForCausalLM(nn.Module):
         "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"
     ]
 
-    def __init__(self, config: config_mllama.MllamaTextConfig,
-                 cache_config: Optional[CacheConfig],
-                 quant_config: Optional[QuantizationConfig]):
+    def __init__(
+        self,
+        config: config_mllama.MllamaTextConfig,
+        cache_config: Optional[CacheConfig],
+        quant_config: Optional[QuantizationConfig],
+        prefix: str = "",
+    ) -> None:
         super().__init__()
         self.vocab_size = config.vocab_size
-        self.model = MllamaTextModel(config, cache_config, quant_config)
+        self.model = MllamaTextModel(config,
+                                     cache_config,
+                                     quant_config,
+                                     prefix=f"{prefix}.model")
         self.lm_head = ParallelLMHead(
             config.vocab_size,
             config.hidden_size,
@@ -868,6 +1010,7 @@ def forward(
         positions: Optional[torch.LongTensor],
         cross_attention_states: Optional[torch.LongTensor],
         cross_attention_mask: Optional[torch.LongTensor],
+        kv_range_for_decode: Optional[List[Tuple[int, int]]],
         full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
                                                       torch.Tensor]],
         kv_caches: List[torch.Tensor],
@@ -879,6 +1022,7 @@ def forward(
             positions=positions,
             cross_attention_states=cross_attention_states,
             cross_attention_mask=cross_attention_mask,
+            kv_range_for_decode=kv_range_for_decode,
             full_text_row_masked_out_mask=full_text_row_masked_out_mask,
             kv_caches=kv_caches,
             attn_metadata=attn_metadata,
@@ -908,11 +1052,14 @@ def __init__(self,
             config.pad_token_id if config.pad_token_id is not None else -1
         self.image_size = config.vision_config.image_size
 
-        self.vision_model = MllamaVisionModel(config.vision_config)
+        self.vision_model = MllamaVisionModel(config.vision_config,
+                                              quant_config,
+                                              prefix="vision_model")
         self.language_model = MllamaForCausalLM(
             config.text_config,
             cache_config=cache_config,
             quant_config=quant_config,
+            prefix="language_model",
         )
         self.multi_modal_projector = nn.Linear(
             config.vision_config.vision_output_dim,
@@ -1026,36 +1173,102 @@ def _parse_and_validate_image_input(self, **kwargs: object):
         raise AssertionError("This line should be unreachable.")
 
     def flat_encoder_result(self, cross_attention_states: torch.Tensor,
-                            attn_metadata: AttentionMetadata):
+                            attn_metadata: AttentionMetadata,
+                            actual_encoder_seq_lens: List[int]):
 
         cross_attention_states_flat = torch.zeros(
-            sum(attn_metadata.encoder_seq_lens),
+            sum(actual_encoder_seq_lens),
             cross_attention_states.shape[-1],
             device=cross_attention_states.device,
             dtype=cross_attention_states.dtype)
         start_pos = 0
-        for seq_len, vision_token_in_batch in zip(
-                attn_metadata.encoder_seq_lens, cross_attention_states):
+        for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens,
+                                                  cross_attention_states):
             end_pos = start_pos + seq_len
             cross_attention_states_flat[
                 start_pos:end_pos] = vision_token_in_batch[:seq_len]
             start_pos = end_pos
         cross_attention_states = cross_attention_states_flat
+        return cross_attention_states
+
+    def get_cross_attention_states(
+        self,
+        image_inputs: MllamaImagePixelInputs,
+        attn_metadata: AttentionMetadata,
+        actual_encoder_seq_lens: List[int],
+    ) -> Tuple[torch.Tensor]:
+        # NOTE: llama's reference implementation runs vision model on CPU
+        pixel_values = image_inputs['data']
+        aspect_ratio_ids = image_inputs['aspect_ratio_ids']
+        aspect_ratio_mask = image_inputs['aspect_ratio_mask']
+        cross_attention_states = self.vision_model(pixel_values,
+                                                   aspect_ratio_ids,
+                                                   aspect_ratio_mask)
+        cross_attention_states = self.multi_modal_projector(
+            cross_attention_states)
+
+        bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
+        cross_attention_states = cross_attention_states.view(
+            bsz, -1, image_token_dim)
+
+        cross_attention_states = self.flat_encoder_result(
+            cross_attention_states, attn_metadata, actual_encoder_seq_lens)
+
+        return cross_attention_states
+
+    def get_cross_attention_mask(
+        self,
+        input_ids: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        num_tiles: List[List[int]],
+        num_tokens_per_tile: int,
+        dtype: torch.dtype,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        token_ids = input_ids.tolist()
+        start = 0
+        batch_token_ids = []
+        for seq_len in attn_metadata.seq_lens:
+            batch_token_ids.append(token_ids[start:start + seq_len])
+            start += seq_len
+        sparse_mask = [
+            get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID)
+            for t in batch_token_ids
+        ]
+
+        # Skip generating cross-attention mask if all samples
+        # are text-only or have only 1 leading image.
+        if skip_attention_mask(sparse_mask):
+            return None, None
 
+        dense_mask, tile_range_for_decode = \
+            convert_sparse_cross_attention_mask_to_dense(
+                sparse_mask, num_tiles, attn_metadata.seq_lens)
+        cross_attention_mask = \
+            convert_dense_cross_attention_mask_to_tensor(
+                dense_mask, num_tokens_per_tile, input_ids.device, dtype)
+        kv_range_for_decode = [[
+            t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
+        ] for t in tile_range_for_decode]
+
+        return cross_attention_mask, kv_range_for_decode
+
+    def get_full_text_row_masked_out_mask(
+        self,
+        attn_metadata: AttentionMetadata,
+        device: torch.device,
+    ) -> torch.Tensor:
         full_text_row_masked_out_mask = torch.ones(
             (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
         start_pos = 0
-        for seq_len, encoder_seq_len in zip(
-                attn_metadata.seq_lens_tensor.cpu(),
-                attn_metadata.encoder_seq_lens):
+        for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
+                                            attn_metadata.encoder_seq_lens):
             if encoder_seq_len == 0:
                 full_text_row_masked_out_mask[start_pos:start_pos +
                                               seq_len] = False
             start_pos += seq_len
         full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
-            cross_attention_states.device)
-
-        return cross_attention_states, full_text_row_masked_out_mask
+            device)
+        return full_text_row_masked_out_mask
 
     def forward(
         self,
@@ -1069,39 +1282,54 @@ def forward(
             attn_metadata.num_decode_tokens > 0:
             raise ValueError("Chunk prefill not supported")
         image_inputs = self._parse_and_validate_image_input(**kwargs)
+        cross_attention_states = None
+        cross_attention_mask = None
+        kv_range_for_decode = None
+
+        # For 1) text-only prefill and decode, 2) image-present decode.
         if image_inputs is None:
-            cross_attention_mask = None
             full_text_row_masked_out_mask = (
                 attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
                     input_ids.device)
-            cross_attention_states = None
             skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
+
+        # For image-present prefill.
         else:
-            # NOTE: llama's reference implementation runs vision model on CPU
-            pixel_values = image_inputs['data']
-            aspect_ratio_ids = image_inputs['aspect_ratio_ids']
-            aspect_ratio_mask = image_inputs['aspect_ratio_mask']
-            cross_attention_states = self.vision_model(pixel_values,
-                                                       aspect_ratio_ids,
-                                                       aspect_ratio_mask)
-            cross_attention_states = self.multi_modal_projector(
-                cross_attention_states)
-
-            bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
-            cross_attention_states = cross_attention_states.view(
-                bsz, -1, image_token_dim)
-
-            cross_attention_states, full_text_row_masked_out_mask = \
-                self.flat_encoder_result(cross_attention_states, attn_metadata)
             skip_cross_attention = False
-            # TODO: support multi-image by this mask
-            cross_attention_mask = None
+
+            # Get the actual number of encoder tokens for each sample.
+            # Because attn_metadata.encoder_seq_lens only counts the last
+            # group of images for each sample, which is used to cheat the
+            # block manager to allocate blocks for those images only.
+            # See input_processor_for_mllama() for more details.
+            num_tiles_tensor = kwargs.pop("num_tiles")
+            num_tiles = [t[0].tolist() for t in num_tiles_tensor]
+            num_tokens_per_tile = (self.image_size // 14)**2 + 1
+            actual_encoder_seq_lens = [
+                sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
+            ]
+            for actual_len, last_group_len in zip(
+                    actual_encoder_seq_lens, attn_metadata.encoder_seq_lens):
+                assert actual_len >= last_group_len
+
+            cross_attention_states = self.get_cross_attention_states(
+                image_inputs, attn_metadata, actual_encoder_seq_lens)
+
+            full_text_row_masked_out_mask = \
+                self.get_full_text_row_masked_out_mask(
+                    attn_metadata, input_ids.device)
+
+            cross_attention_mask, kv_range_for_decode = \
+                self.get_cross_attention_mask(
+                    input_ids, attn_metadata, num_tiles,
+                    num_tokens_per_tile, cross_attention_states.dtype)
 
         outputs = self.language_model(
             input_ids=input_ids,
             positions=positions,
             cross_attention_states=cross_attention_states,
             cross_attention_mask=cross_attention_mask,
+            kv_range_for_decode=kv_range_for_decode,
             full_text_row_masked_out_mask=full_text_row_masked_out_mask,
             kv_caches=kv_caches,
             attn_metadata=attn_metadata,
@@ -1140,3 +1368,76 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+
+
+def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
+    for mask in sparse_mask:
+        # Skip text-only samples.
+        if len(mask) == 0:
+            continue
+        # If the sample contains more than 1 images,
+        # we can't skip mask.
+        if len(mask) != 1:
+            return False
+        # If the sample contains only 1 image,
+        # but the image is not the leading one,
+        # we can't skip mask.
+        if mask[0][0] != 0 or mask[0][1] != -1:
+            return False
+    return True
+
+
+def convert_sparse_cross_attention_mask_to_dense(
+    sparse_mask: List[List[List[int]]],
+    num_tiles: List[List[int]],
+    lengths: List[int],
+) -> Tuple[np.ndarray, List[Tuple[int, int]]]:
+    total_length = sum(lengths)
+    total_tiles = sum([sum(tiles) for tiles in num_tiles])
+    dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
+    # A list of ranges, range[i] = [start, end] means
+    # if the i-th sample has N tiles in total, the tiles[start, end]
+    # will be used for cross-attention decoding.
+    tile_range_for_decode = []
+
+    seq_start = 0
+    tile_start = 0
+    for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
+        ts, td = -1, 0
+        for mask, tile in zip(masks, tiles):
+            if len(mask) != 2:
+                continue
+            start, end = mask
+            end = min(end, length)
+            if end == -1:
+                end = length
+            if end == length:
+                if ts == -1:
+                    ts = tile_start
+                td += tile
+            dense_mask[seq_start + start:seq_start + end,
+                       tile_start:tile_start + tile] = 1
+            tile_start += tile
+        tile_range_for_decode.append((ts, ts + td))
+        seq_start += length
+
+    return dense_mask, tile_range_for_decode
+
+
+def convert_dense_cross_attention_mask_to_tensor(
+    cross_attention_token_mask: np.ndarray,
+    num_tokens_per_tile: int,
+    device: torch.device,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
+    mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)
+
+    mask = 1.0 - mask
+    mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)
+
+    ninf = torch.finfo(dtype).min
+    full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
+    mask *= full_text_mask
+    # (num_prompt_tokens, num_encoder_tokens)
+    return mask
diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py
index ccfee165368e..3c34227767e0 100644
--- a/vllm/model_executor/models/molmo.py
+++ b/vllm/model_executor/models/molmo.py
@@ -1,4 +1,3 @@
-import logging
 import math
 import re
 from array import array
@@ -14,16 +13,15 @@
 from torch.nn import functional as F
 from transformers import PretrainedConfig
 
-import vllm.envs as envs
 from vllm.attention import Attention, AttentionMetadata
-from vllm.attention.selector import (_Backend, backend_name_to_enum,
-                                     get_global_forced_attn_backend)
+from vllm.attention.selector import _Backend
 from vllm.config import CacheConfig, MultiModalConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size,
                               split_tensor_along_last_dim,
                               tensor_model_parallel_all_gather)
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.model_executor import SamplingMetadata
 from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
 from vllm.model_executor.layers.layernorm import RMSNorm
@@ -32,22 +30,21 @@
                                                QKVParallelLinear,
                                                RowParallelLinear)
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
-from vllm.model_executor.layers.quantization.base_config import (
-    QuantizationConfig)
+from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.rotary_embedding import get_rope
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
 from vllm.model_executor.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
-from vllm.model_executor.models.interfaces import SupportsMultiModal
-from vllm.model_executor.models.utils import make_layers
 from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
-from vllm.platforms import current_platform
+from vllm.multimodal.utils import cached_get_tokenizer
 from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
                            SequenceData)
 from vllm.transformers_utils.processor import get_processor
 
-log = logging.getLogger(__name__)
+from .interfaces import SupportsMultiModal, SupportsPP
+from .utils import (get_vit_attn_backend,
+                    make_empty_intermediate_tensors_factory, make_layers)
 
 # TODO: hard-coded for now. Consider making it configurable.
 VIT_LAYERS = [-2, -9]
@@ -189,35 +186,12 @@ def __init__(
         )
 
         # Detect attention implementation.
-        selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
-        if selected_backend is None:
-            backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
-            if backend_by_env_var is not None:
-                selected_backend = backend_name_to_enum(backend_by_env_var)
-        if selected_backend is None:
-            # For Volta and Turing GPUs, use xformers instead.
-            device_available = current_platform.get_device_capability()[0] >= 8
-            if device_available:
-                from transformers.utils import is_flash_attn_2_available
-                if is_flash_attn_2_available():
-                    self._use_flash_attn = True
-                else:
-                    log.warning(
-                        "Current Molmo implementation has a bug with "
-                        "`vllm-flash-attn` inside vision module, so we use "
-                        "xformers backend instead. You can run `pip install "
-                        "flash-attn to use flash-attention backend.")
-                    self._use_flash_attn = False
-            else:
-                self._use_flash_attn = False
-        else:
-            if selected_backend == _Backend.FLASH_ATTN:
-                self._use_flash_attn = True
-            elif selected_backend == _Backend.XFORMERS:
-                self._use_flash_attn = False
-            else:
-                raise RuntimeError(
-                    f"Molmo does not support {selected_backend} backend now.")
+        self.attn_backend: _Backend = get_vit_attn_backend()
+        if self.attn_backend not in {
+                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
+        }:
+            raise RuntimeError(
+                f"Molmo does not support {self.attn_backend} backend now.")
 
     def forward(self,
                 inputs_q: torch.Tensor,
@@ -239,10 +213,15 @@ def forward(self,
         xk = xk.view(*kv_shape)
         xv = xv.view(*kv_shape)
 
-        if self._use_flash_attn:
+        if self.attn_backend == _Backend.FLASH_ATTN:
             from flash_attn import flash_attn_func
             output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False)
-        else:
+        elif self.attn_backend == _Backend.TORCH_SDPA:
+            xq, xk, xv = (rearrange(x, "b s h d -> b h s d")
+                          for x in (xq, xk, xv))
+            output = F.scaled_dot_product_attention(xq, xk, xv)
+            output = rearrange(output, "b h s d -> b s h d ")
+        elif self.attn_backend == _Backend.XFORMERS:
             from xformers import ops as xops
             output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0)
 
@@ -765,6 +744,10 @@ def __init__(
         assert config.layer_norm_type == "rms"
         self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps)
 
+        self.make_empty_intermediate_tensors = (
+            make_empty_intermediate_tensors_factory(
+                ["hidden_states", "residual"], config.hidden_size))
+
     def forward(
         self,
         input_ids: torch.Tensor,
@@ -945,14 +928,20 @@ def pad_images(
     return images, image_input_idx, image_masks
 
 
-def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
-    prompt = llm_inputs["prompt"]
-    multi_modal_data = llm_inputs.get("multi_modal_data")
-    image = multi_modal_data.get("image")
+def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs):
+    prompt = inputs.get("prompt")
+    multi_modal_data = inputs.get("multi_modal_data")
+    image = None if multi_modal_data is None else multi_modal_data.get("image")
+
     processor = cached_get_processor(ctx.model_config.model,
                                      trust_remote_code=True,
                                      revision=ctx.model_config.code_revision)
 
+    model_config = ctx.model_config
+    tokenizer = cached_get_tokenizer(
+        model_config.tokenizer,
+        trust_remote_code=model_config.trust_remote_code)
+
     # NOTE: message formatting for raw text prompt is only applied for
     # offline inference; for online inference, the prompt is always in
     # instruction format and tokenized.
@@ -962,9 +951,7 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
     elif prompt is not None:
         out = processor.process(prompt, image)
     else:
-        out = processor.process(None,
-                                image,
-                                tokens=llm_inputs["prompt_token_ids"])
+        out = processor.process(None, image, tokens=inputs["prompt_token_ids"])
 
     image_processor = processor.image_processor
     max_total_crops = 1 + image_processor.max_crops
@@ -1017,9 +1004,13 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
 
     multi_modal_data = dict(image=image_data)
 
-    return LLMInputs(
+    prompt = inputs.get("prompt")
+    if prompt is None:
+        prompt = tokenizer.decode(out["input_ids"])
+
+    return token_inputs(
         prompt_token_ids=out["input_ids"],
-        prompt=llm_inputs["prompt"],
+        prompt=prompt,
         multi_modal_data=multi_modal_data,
     )
 
@@ -1028,7 +1019,7 @@ def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs):
 @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens)
 @INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo)
 @INPUT_REGISTRY.register_input_processor(input_processor_for_molmo)
-class MolmoForCausalLM(nn.Module, SupportsMultiModal):
+class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
 
     def __init__(
         self,
@@ -1060,6 +1051,9 @@ def __init__(
                                                 or config.vocab_size)
         self.sampler = Sampler()
 
+        self.make_empty_intermediate_tensors = (
+            self.model.make_empty_intermediate_tensors)
+
     def _parse_and_validate_image_input(
         self,
         **kwargs: object,
@@ -1143,31 +1137,36 @@ def forward(
         positions: torch.LongTensor,
         kv_caches: List[torch.Tensor],
         attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
         **kwargs: object,
     ) -> SamplerOutput:
+        if intermediate_tensors is not None:
+            input_ids = None
+            inputs_embeds = None
+        else:
+            image_input = self._parse_and_validate_image_input(**kwargs)
 
-        image_input = self._parse_and_validate_image_input(**kwargs)
-
-        if image_input is not None:
-            inputs_embeds = self.model.embed_tokens(input_ids)
-            image_features = self._process_image_input(image_input)
+            if image_input is not None:
+                inputs_embeds = self.model.embed_tokens(input_ids)
+                image_features = self._process_image_input(image_input)
 
-            inputs_embeds = self._merge_multimodal_embeddings(
-                inputs_embeds,
-                image_features,
-                image_input["image_input_idx"],
-                image_input["seq_len"],
-            )
+                inputs_embeds = self._merge_multimodal_embeddings(
+                    inputs_embeds,
+                    image_features,
+                    image_input["image_input_idx"],
+                    image_input["seq_len"],
+                )
 
-            input_ids = None
-        else:
-            inputs_embeds = None
+                input_ids = None
+            else:
+                inputs_embeds = None
 
         hidden_states = self.model(
             input_ids=input_ids,
             positions=positions,
             kv_caches=kv_caches,
             attn_metadata=attn_metadata,
+            intermediate_tensors=intermediate_tensors,
             inputs_embeds=inputs_embeds,
         )
 
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index e3d3937b13fa..ee802030a5ef 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -7,6 +7,7 @@
 import torch.nn as nn
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size)
@@ -204,6 +205,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class MPTModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py
index 14515e16e34a..72a09129fed6 100644
--- a/vllm/model_executor/models/nemotron.py
+++ b/vllm/model_executor/models/nemotron.py
@@ -27,6 +27,7 @@
 from torch import nn
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
@@ -290,6 +291,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class NemotronModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py
index a52e3cb6039b..df4fd0a3256e 100644
--- a/vllm/model_executor/models/nvlm_d.py
+++ b/vllm/model_executor/models/nvlm_d.py
@@ -4,10 +4,13 @@
 # Copyright (c) 2024 NVIDIA
 # Licensed under Apache 2.0 License [see LICENSE for details]
 # --------------------------------------------------------
+from typing import Optional
+
 import torch.nn as nn
 from transformers import PretrainedConfig
 
 from vllm.inputs import INPUT_REGISTRY
+from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.multimodal import MULTIMODAL_REGISTRY
 
 from .intern_vit import InternVisionModel
@@ -55,10 +58,31 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
             nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False),
         )
 
-    def _init_vision_model(self, config: PretrainedConfig,
-                           num_hidden_layers: int):
-        # We added additional dummy heads to the original num of heads to make
-        # the number of heads divisible by 8.
-        return InternVisionModel(config.vision_config,
-                                 num_hidden_layers_override=num_hidden_layers,
-                                 num_dummy_heads=7)
+    def _init_vision_model(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig],
+        *,
+        is_mono: bool,
+        prefix: str,
+    ):
+        if not is_mono:
+            vision_feature_layer = config.select_layer
+            if vision_feature_layer < 0:
+                num_hidden_layers = config.vision_config.num_hidden_layers \
+                    + vision_feature_layer + 1
+            else:
+                num_hidden_layers = vision_feature_layer + 1
+
+            # We added additional dummy heads to the original num of heads to
+            # make the number of heads divisible by 8.
+            return InternVisionModel(
+                config.vision_config,
+                quant_config=quant_config,
+                num_hidden_layers_override=num_hidden_layers,
+                num_dummy_heads=7,
+                prefix=prefix,
+            )
+        else:
+            msg = "Monolith mode is not applicable to NVLM_D"
+            raise NotImplementedError(msg)
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 5ca7c66f5407..90ab8abcb84b 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -28,6 +28,7 @@
 from transformers import OlmoConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import SiluAndMul
@@ -221,6 +222,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class OlmoModel(nn.Module):
 
     def __init__(self,
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 3bcdb0d87fd5..37c3fa919124 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -24,6 +24,7 @@
 from transformers import OPTConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
@@ -279,6 +280,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class OPTModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index 0913193f73a4..055407587c59 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -11,6 +11,7 @@
 from transformers import PretrainedConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import SiluAndMul
@@ -184,7 +185,6 @@ def forward(
         hidden_states: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
-        residual: Optional[torch.Tensor],
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         residual = hidden_states
@@ -203,9 +203,10 @@ def forward(
         hidden_states = self.post_attention_layernorm(hidden_states)
         hidden_states = self.mlp(hidden_states)
         hidden_states = residual + hidden_states
-        return hidden_states, None
+        return hidden_states
 
 
+@support_torch_compile
 class OrionModel(nn.Module):
 
     def __init__(
@@ -233,8 +234,9 @@ def __init__(
             prefix=f"{prefix}.layers")
         self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.make_empty_intermediate_tensors = (
-            make_empty_intermediate_tensors_factory(
-                ["hidden_states", "residual"], config.hidden_size))
+            make_empty_intermediate_tensors_factory([
+                "hidden_states",
+            ], config.hidden_size))
 
     def forward(
         self,
@@ -246,24 +248,20 @@ def forward(
     ) -> Union[torch.Tensor, IntermediateTensors]:
         if get_pp_group().is_first_rank:
             hidden_states = self.embed_tokens(input_ids)
-            residual = None
         else:
-            assert intermediate_tensors
+            assert intermediate_tensors is not None
             hidden_states = intermediate_tensors["hidden_states"]
-            residual = intermediate_tensors["residual"]
         for i in range(self.start_layer, self.end_layer):
             layer = self.layers[i]
-            hidden_states, residual = layer(
+            hidden_states = layer(
                 positions,
                 hidden_states,
                 kv_caches[i - self.start_layer],
                 attn_metadata,
-                residual,
             )
         if not get_pp_group().is_last_rank:
             return IntermediateTensors({
                 "hidden_states": hidden_states,
-                "residual": residual
             })
         hidden_states = self.norm(hidden_states)
         return hidden_states
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
index 99d000ea13a2..7a62a098a452 100644
--- a/vllm/model_executor/models/paligemma.py
+++ b/vllm/model_executor/models/paligemma.py
@@ -7,7 +7,8 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.logger import init_logger
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import SamplerOutput
@@ -68,7 +69,8 @@ def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
     return seq_data, mm_data
 
 
-def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
+def input_processor_for_paligemma(ctx: InputContext,
+                                  inputs: DecoderOnlyInputs):
 
     """
     The correct prompt format needs to be:
@@ -77,9 +79,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
     See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
     """ # noqa
 
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     model_config = ctx.model_config
     hf_config = ctx.get_hf_config(PaliGemmaConfig)
@@ -91,8 +93,8 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
     image_token_str_pad = image_token_str * image_feature_size
     image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
 
-    orig_prompt = llm_inputs.get("prompt")
-    orig_prompt_ids = llm_inputs.get("prompt_token_ids")
+    orig_prompt = inputs.get("prompt")
+    orig_prompt_ids = inputs.get("prompt_token_ids")
 
     if orig_prompt is not None and image_token_str in orig_prompt:
         logger.warning(
@@ -106,9 +108,9 @@ def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
     new_token_ids = image_token_ids_pad + orig_prompt_ids + [108]  #newline
 
     # NOTE: Create a defensive copy of the original inputs
-    return LLMInputs(prompt_token_ids=new_token_ids,
-                     prompt=new_prompt,
-                     multi_modal_data=multi_modal_data)
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
 
 
 class PaliGemmaMultiModalProjector(nn.Module):
@@ -140,7 +142,8 @@ def __init__(self,
         self.config = config
         self.multimodal_config = multimodal_config
 
-        self.vision_tower = SiglipVisionModel(config.vision_config)
+        self.vision_tower = SiglipVisionModel(config.vision_config,
+                                              quant_config)
         self.multi_modal_projector = PaliGemmaMultiModalProjector(
             vision_hidden_size=config.vision_config.hidden_size,
             projection_dim=config.vision_config.projection_dim)
diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py
index b625d19f6447..fc9ef15db26c 100644
--- a/vllm/model_executor/models/persimmon.py
+++ b/vllm/model_executor/models/persimmon.py
@@ -27,6 +27,7 @@
 from transformers import PersimmonConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
@@ -209,6 +210,7 @@ def forward(
         return outputs
 
 
+@support_torch_compile
 class PersimmonModel(nn.Module):
 
     def __init__(self,
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index 0918f21a40e2..ec20cb249ba9 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -102,8 +102,9 @@ def __init__(self,
         # pylint: disable=C0301
         # Refer to:
         # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
-        rope_theta = 10000
-        max_position_embeddings = getattr(config, "n_positions", 2048)
+        rope_theta = getattr(config, "rope_theta", 10000.0)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          2048)
         self.rotary_emb = get_rope(
             self.head_size,
             rotary_dim=rotary_dim,
diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py
index 4cfeb3bb3496..3a7afc606bb9 100644
--- a/vllm/model_executor/models/phi3_small.py
+++ b/vllm/model_executor/models/phi3_small.py
@@ -179,7 +179,7 @@ def __init__(
                 rope_scaling["factor"] = self.rope_position_scale
         else:
             rope_scaling = {
-                "type": "linear",
+                "rope_type": "linear",
                 "factor": self.rope_position_scale,
             }
 
diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py
index 00a04dac8878..855a9b17585a 100644
--- a/vllm/model_executor/models/phi3v.py
+++ b/vllm/model_executor/models/phi3v.py
@@ -27,16 +27,21 @@
 
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.logger import init_logger
+from vllm.model_executor.layers.pooler import Pooler, PoolingType
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding)
 from vllm.model_executor.models.clip import CLIPVisionModel
 from vllm.model_executor.models.llama import LlamaForCausalLM
+from vllm.model_executor.pooling_metadata import PoolingMetadata
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.multimodal import MULTIMODAL_REGISTRY
 from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
-from vllm.sequence import IntermediateTensors
+from vllm.sequence import IntermediateTensors, PoolerOutput
 from vllm.utils import is_list_of
 
 from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
@@ -65,7 +70,8 @@
                                                      projection_dim=768)
 
 
-def _init_img_processor(hf_config: PretrainedConfig):
+def _init_img_processor(hf_config: PretrainedConfig,
+                        quant_config: Optional[QuantizationConfig]):
     clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
     layer_idx = hf_config.img_processor.get('layer_idx', -2)
 
@@ -77,7 +83,10 @@ def _init_img_processor(hf_config: PretrainedConfig):
         num_hidden_layers = layer_idx + 1
 
     img_processor = CLIPVisionModel(
-        clip_config, num_hidden_layers_override=num_hidden_layers)
+        clip_config,
+        quant_config,
+        num_hidden_layers_override=num_hidden_layers,
+    )
 
     return img_processor
 
@@ -143,14 +152,15 @@ def get_img_features(self,
 class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
     """Phi3 Image embedding with HD transform."""
 
-    def __init__(self, config: PretrainedConfig) -> None:
+    def __init__(self, config: PretrainedConfig,
+                 quant_config: Optional[QuantizationConfig]) -> None:
         super().__init__()
 
         # n_embed or hidden_size
         hidden_size = config.n_embd if hasattr(
             config, 'n_embd') else config.hidden_size
 
-        self.img_processor = _init_img_processor(config)
+        self.img_processor = _init_img_processor(config, quant_config)
 
         image_dim_out = config.img_processor['image_dim_out']
         self.num_img_tokens = config.img_processor['num_img_tokens']
@@ -289,10 +299,6 @@ def add_image_newline(self, image_features_hd):
             dim=2).reshape(num_images, -1, hid_dim)
         return image_features_hd_newline
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
-
 
 # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
 def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
@@ -385,32 +391,37 @@ def dummy_data_for_phi3v(ctx: InputContext,
     return seq_data, mm_data
 
 
-# Reserve this function to also handle placeholders for additional images
-# [ref: PR #5820]
 @lru_cache
-def _get_image_placeholder_token_ids(model_config: ModelConfig,
-                                     idx: int) -> List[int]:
+def _get_image_placeholder_token_id_candidates(
+    model_config: ModelConfig,
+    idx: int,
+) -> List[List[int]]:
     assert idx > 0
 
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
 
+    # This is used when the image token is at the start of the string
+    start_candidate = tokenizer.encode(f"<|image_{idx}|>",
+                                       add_special_tokens=False)
+
+    # This is used when the image token is in the middle of the string
     # We need to get the token for "<", not "▁<"
     # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
     a_token_id, = tokenizer.encode("a", add_special_tokens=False)
-    a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
-        f"a<|image_{idx}|>", add_special_tokens=False)
+    a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
+                                                      add_special_tokens=False)
     assert a_token_id == a_token_id_
 
-    return image_placeholder_token_ids
+    return [start_candidate, middle_candidate]
 
 
 def input_processor_for_phi3v(ctx: InputContext,
-                              llm_inputs: LLMInputs,
+                              inputs: DecoderOnlyInputs,
                               *,
                               num_crops: Optional[int] = None):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     model_config = ctx.model_config
     hf_config = ctx.get_hf_image_processor_config()
@@ -442,7 +453,7 @@ def input_processor_for_phi3v(ctx: InputContext,
     else:
         raise TypeError(f"Invalid image type: {type(image_data)}")
 
-    prompt = llm_inputs.get("prompt")
+    prompt = inputs.get("prompt")
     if prompt is None:
         # for async server request, we assume prompt and its token_ids is always
         # in correct format. And num_image_tags == len(image_data) always True.
@@ -459,18 +470,20 @@ def input_processor_for_phi3v(ctx: InputContext,
                 image_data), "The count of image_placeholder not match image's"
         new_prompt = prompt
 
-    prompt_token_ids = llm_inputs["prompt_token_ids"].copy()
+    prompt_token_ids = inputs["prompt_token_ids"].copy()
 
-    # masked place_holder with image token id
+    # masked placeholder with image token id
     for idx in image_idx:
-        image_token_ids = _get_image_placeholder_token_ids(model_config,
-                                                           idx=idx)
-        for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
-            if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
-                prompt_token_ids[i:i + len(image_token_ids)] = [
-                    _IMAGE_TOKEN_ID
-                ] * len(image_token_ids)
-                break
+        candidates = _get_image_placeholder_token_id_candidates(model_config,
+                                                                idx=idx)
+
+        for candidate in candidates:
+            for i in range(len(prompt_token_ids) - len(candidate) + 1):
+                if prompt_token_ids[i:i + len(candidate)] == candidate:
+                    prompt_token_ids[i:i +
+                                     len(candidate)] = ([_IMAGE_TOKEN_ID] *
+                                                        len(candidate))
+                    break
 
     # merge consecutive tag ids
     merged_token_ids: List[int] = []
@@ -497,10 +510,9 @@ def input_processor_for_phi3v(ctx: InputContext,
             new_token_ids.append(token_id)
 
     # NOTE: Create a defensive copy of the original inputs
-    llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
-                           prompt=new_prompt,
-                           multi_modal_data=multi_modal_data)
-    return llm_inputs
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
 
 
 @MULTIMODAL_REGISTRY.register_image_input_mapper()
@@ -520,12 +532,23 @@ def __init__(self,
         self.multimodal_config = multimodal_config
         self.image_token_id = _IMAGE_TOKEN_ID
 
-        # TODO: Optionally initializes this for supporting embeddings.
-        self.vision_embed_tokens = Phi3HDImageEmbedding(config)
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+            quant_config=quant_config,
+        )
+
+        # TODO: Optionally initializes this for supporting input embeddings.
+        self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
 
         self.language_model = LlamaForCausalLM(config, cache_config,
                                                quant_config)
 
+        # The same model class supports both language generation and embedding
+        # because the architecture name is the same
+        self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
+
         self.make_empty_intermediate_tensors = (
             self.language_model.make_empty_intermediate_tensors)
 
@@ -649,8 +672,7 @@ def forward(self,
 
             if image_input is not None:
                 vision_embeddings = self._process_image_input(image_input)
-                inputs_embeds = self.language_model.model.get_input_embeddings(
-                    input_ids)
+                inputs_embeds = self.embed_tokens(input_ids)
                 inputs_embeds = merge_multimodal_embeddings(
                     input_ids, inputs_embeds, vision_embeddings,
                     self.image_token_id)
@@ -682,13 +704,27 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
+    def pooler(
+        self,
+        hidden_states: torch.Tensor,
+        pooling_metadata: PoolingMetadata,
+    ) -> Optional[PoolerOutput]:
+        return self._pooler(hidden_states, pooling_metadata)
+
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         hf_to_vllm_mapper = WeightsMapper(
             orig_to_new_prefix={
+                "model.vision_embed_tokens.wte": "embed_tokens",
                 "model.vision_embed_tokens.": "vision_embed_tokens.",
                 "lm_head.": "language_model.lm_head.",
                 "model.": "language_model.model.",
             })
 
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights, mapper=hf_to_vllm_mapper)
+        autoloaded_weights = loader.load_weights(weights,
+                                                 mapper=hf_to_vllm_mapper)
+
+        # The HF config doesn't specify whether these are tied,
+        # so we detect it this way
+        if "embed_tokens" not in autoloaded_weights:
+            self.embed_tokens = self.language_model.model.embed_tokens
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index c8957dcae6b1..18dbee94e10b 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -3,18 +3,25 @@
 from itertools import tee
 from typing import Iterable, List, Mapping, Optional, Tuple, Union
 
+import numpy
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
 from mistral_common.protocol.instruct.messages import ImageChunk
 from PIL import Image
-from transformers import PretrainedConfig
+from transformers import PixtralVisionConfig, PretrainedConfig
+from transformers.models.pixtral.image_processing_pixtral import (
+    _num_image_tokens)
+from transformers.models.pixtral.modeling_pixtral import (
+    PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
 from xformers.ops.fmha import memory_efficient_attention
 from xformers.ops.fmha.attn_bias import BlockDiagonalMask
 
 from vllm.attention import AttentionMetadata
-from vllm.config import CacheConfig, MultiModalConfig
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
+from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.layernorm import RMSNorm
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
@@ -25,6 +32,8 @@
 from vllm.multimodal.base import MultiModalInputs
 from vllm.multimodal.utils import cached_get_tokenizer
 from vllm.sequence import IntermediateTensors, SequenceData
+from vllm.transformers_utils.processor import cached_get_processor
+from vllm.utils import is_list_of
 
 from .interfaces import SupportsMultiModal, SupportsPP
 from .utils import init_vllm_registered_model
@@ -62,7 +71,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
     image_feature_size = (size**2) // (patch_size**2)
 
     num_image_tokens = image_feature_size * num_images
-    seq_data = SequenceData.from_token_counts(
+    seq_data = SequenceData.from_prompt_token_counts(
         (image_token_id, num_image_tokens),
         (0, seq_len - num_image_tokens),
     )
@@ -102,8 +111,8 @@ def input_mapper_for_pixtral(ctx: InputContext,
     return MultiModalInputs({"images": images})
 
 
-def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+def input_processor_for_pixtral(ctx: InputContext, inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is not None and "image" in multi_modal_data:
         tokenizer = cached_get_tokenizer(
             ctx.model_config.tokenizer,
@@ -112,15 +121,15 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
         mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
         image_token_id = mm_encoder.special_ids.img
 
-        if image_token_id not in llm_inputs['prompt_token_ids']:
+        if image_token_id not in inputs['prompt_token_ids']:
             raise ValueError(
-                (f"You've passed {llm_inputs=} without {image_token_id=}"
+                (f"You've passed {inputs=} without {image_token_id=}"
                  " Make sure to process your input via mistral_common's"
                  " tokenizer or pass a chat completion request. For more"
                  " For more info, see: "
                  "https://github.com/vllm-project/vllm/issues/8411."))
 
-    return llm_inputs
+    return inputs
 
 
 @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral)
@@ -576,3 +585,448 @@ def __init__(self, args: VisionEncoderArgs, dim: int):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         return self.w_out(self.gelu(self.w_in(x)))
+
+
+#### HF Transformers version of Pixtral ####
+# Based off https://github.com/huggingface/transformers/blob/d7950bff82b18c823193d17d72188c5e46d06c83/src/transformers/models/pixtral/modeling_pixtral.py
+# This model follows the Llava family, meaning image embeddings are placed
+# instead of the `[IMG]` token placeholders.
+# The model uses [`PixtralVisionModel`] for its vision encoder,
+# and [`MistralForCausalLM`] for its language decoder.
+
+
+def get_pixtral_hf_patch_grid_length(*, image_size: int,
+                                     patch_size: int) -> int:
+    # Since interpolation is applied, the image size need not be divisible
+    # assert image_size % patch_size == 0
+    return image_size // patch_size
+
+
+def get_pixtral_hf_num_patches(*, image_size: int, patch_size: int) -> int:
+    grid_length = get_pixtral_hf_patch_grid_length(image_size=image_size,
+                                                   patch_size=patch_size)
+    return grid_length * grid_length
+
+
+def get_max_pixtral_hf_image_feature_size(
+        hf_config: PixtralVisionConfig) -> int:
+    return get_pixtral_hf_num_patches(image_size=hf_config.image_size,
+                                      patch_size=hf_config.patch_size)
+
+
+def get_max_pixtral_hf_image_tokens(hf_config: PixtralVisionConfig) -> int:
+    return get_max_pixtral_hf_image_feature_size(hf_config)
+
+
+def dummy_seq_data_for_pixtral_hf(
+    hf_config: PixtralVisionConfig,
+    seq_len: int,
+    num_images: int,
+    *,
+    image_token_id: int,
+    image_feature_size_override: Optional[int] = None,
+):
+    if image_feature_size_override is None:
+        image_feature_size = get_max_pixtral_hf_image_feature_size(hf_config)
+    else:
+        image_feature_size = image_feature_size_override
+
+    return SequenceData.from_prompt_token_counts(
+        (image_token_id, image_feature_size * num_images),
+        (0, seq_len - image_feature_size * num_images),
+    )
+
+
+def dummy_image_for_pixtral_hf(
+    hf_config: PixtralVisionConfig,
+    num_images: int,
+    *,
+    image_width_override: Optional[int] = None,
+    image_height_override: Optional[int] = None,
+):
+    width = height = hf_config.image_size
+    if image_width_override is not None:
+        width = image_width_override
+    if image_height_override is not None:
+        height = image_height_override
+
+    image = Image.new("RGB", (width, height), color=0)
+    return {"image": image if num_images == 1 else [image] * num_images}
+
+
+def get_pixtral_hf_image_feature_size(hf_config: PixtralVisionConfig,
+                                      image_width: int,
+                                      image_height: int) -> Tuple[int, int]:
+    # Adapted from transformers.models.pixtral.image_processing_pixtral.get_resize_output_image_size # noqa: E501
+    # https://github.com/huggingface/transformers/blob/2bd4d5897dc73e8b172832070a6f9e567a0df017/src/transformers/models/pixtral/image_processing_pixtral.py#L180 # noqa: E501
+    max_width, max_height = hf_config.image_size, hf_config.image_size
+    patch_width, patch_height = hf_config.patch_size, hf_config.patch_size
+
+    ratio = max(image_width / max_width, image_height / max_height)
+
+    if ratio > 1:
+        image_width = int(numpy.ceil(image_width / ratio))
+        image_height = int(numpy.ceil(image_height / ratio))
+
+    num_height_tokens, num_width_tokens = _num_image_tokens(
+        (image_height, image_width), (patch_height, patch_width))
+
+    return num_width_tokens, num_height_tokens
+
+
+def input_processor_for_pixtral_hf(
+    model_config: ModelConfig,
+    hf_config: PixtralVisionConfig,
+    inputs: DecoderOnlyInputs,
+    *,
+    image_token_id: int,
+    image_feature_size_override: Optional[Union[int, List[int]]] = None,
+) -> DecoderOnlyInputs:
+    assert image_feature_size_override is None, (
+        "image_feature_size_override is not supported for Pixtral")
+
+    multi_modal_data = inputs.get("multi_modal_data")
+    if multi_modal_data is None or "image" not in multi_modal_data:
+        return inputs
+
+    processor = cached_get_processor(model_config.model)
+
+    image_data = multi_modal_data["image"]
+    if isinstance(image_data, Image.Image):
+        image_data = [image_data]
+    elif not is_list_of(image_data, Image.Image):
+        raise TypeError(f"Invalid image type: {type(image_data)}")
+
+    new_prompt = inputs.get("prompt")
+    new_token_ids = inputs["prompt_token_ids"]
+
+    image_token = processor.image_token
+    image_break_token = processor.image_break_token
+    image_end_token = processor.image_end_token
+
+    # Update new_prompt if present
+    if new_prompt:
+        parts = new_prompt.split(image_token)
+        assert len(parts) - 1 == len(image_data)
+        new_parts = [parts[0]]  # Start with the part before any image tokens
+
+        for image, next_part in zip(image_data, parts[1:]):
+            w, h = image.size
+            (num_width_tokens,
+             num_height_tokens) = get_pixtral_hf_image_feature_size(
+                 hf_config, image_width=w, image_height=h)
+
+            replace_tokens = [image_token] * num_width_tokens + [
+                image_break_token
+            ]
+            replace_tokens = replace_tokens * num_height_tokens
+            replace_tokens[-1] = image_end_token
+
+            new_parts.append("".join(replace_tokens))
+            new_parts.append(next_part)
+
+        new_prompt = "".join(new_parts)
+
+    # Update new_token_ids
+    convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
+    image_token_id = convert_tokens_to_ids(image_token)
+    image_break_id = convert_tokens_to_ids(image_break_token)
+    image_end_id = convert_tokens_to_ids(image_end_token)
+    placeholder_token_id = -999
+    # Find all image token indices at once
+    placeholder_indices = [
+        idx for idx, token_id in enumerate(new_token_ids)
+        if token_id == image_token_id
+    ]
+    assert len(placeholder_indices) == len(image_data)
+    replace_tokens_list = []
+    for placeholder_idx, image in zip(placeholder_indices, image_data):
+        new_token_ids[placeholder_idx] = placeholder_token_id
+
+        w, h = image.size
+        (num_width_tokens,
+         num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
+                                                                image_width=w,
+                                                                image_height=h)
+
+        replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
+        replace_tokens = replace_tokens * num_height_tokens
+        replace_tokens[-1] = image_end_id
+        replace_tokens_list.append(replace_tokens)
+
+    # Backward iteration for replacement without affecting known indices
+    for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
+                                               reversed(replace_tokens_list)):
+        new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens
+
+    # NOTE: Create a defensive copy of the original inputs
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
+
+
+class PixtralHFMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: PixtralVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+
+        assert config.intermediate_size is not None
+        # TODO: Use quant_config and prefix after optimizing this
+        self.gate_proj = nn.Linear(config.hidden_size,
+                                   config.intermediate_size,
+                                   bias=False)
+        self.up_proj = nn.Linear(config.hidden_size,
+                                 config.intermediate_size,
+                                 bias=False)
+        self.down_proj = nn.Linear(config.intermediate_size,
+                                   config.hidden_size,
+                                   bias=False)
+        self.act = get_act_fn(config.hidden_act)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
+
+
+class PixtralHFAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: PixtralVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+
+        self.config = config
+        assert not config.hidden_size % config.num_attention_heads
+        self.n_heads = config.num_attention_heads
+        self.head_dim = config.hidden_size // config.num_attention_heads
+
+        self.scale = self.head_dim**-0.5
+
+        # TODO: Use quant_config and prefix after optimizing this
+        self.q_proj = nn.Linear(config.hidden_size,
+                                config.hidden_size,
+                                bias=False)
+        self.k_proj = nn.Linear(config.hidden_size,
+                                config.hidden_size,
+                                bias=False)
+        self.v_proj = nn.Linear(config.hidden_size,
+                                config.hidden_size,
+                                bias=False)
+        self.o_proj = nn.Linear(config.hidden_size,
+                                config.hidden_size,
+                                bias=False)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: BlockDiagonalMask,
+        position_embeddings: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        batch, patches, _ = hidden_states.size()
+
+        q = self.q_proj(hidden_states)
+        k = self.k_proj(hidden_states)
+        v = self.v_proj(hidden_states)
+
+        # Transpose q and k to apply HF's Rotary Position Embedding
+        q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
+        k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
+        cos, sin = position_embeddings
+        q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
+
+        # Transpose q and k back for attention
+        q = q.transpose(1, 2).contiguous()
+        k = k.transpose(1, 2).contiguous()
+        v = v.reshape(batch, patches, self.n_heads, self.head_dim)
+
+        out = memory_efficient_attention(q, k, v, attn_bias=attention_mask)
+        out = out.reshape(batch, patches, self.n_heads * self.head_dim)
+
+        return self.o_proj(out)
+
+
+class PixtralHFTransformerBlock(nn.Module):
+
+    def __init__(
+        self,
+        config: PixtralVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+
+        self.attention_norm = RMSNorm(config.hidden_size, eps=1e-5)
+        self.attention = PixtralHFAttention(config,
+                                            quant_config=quant_config,
+                                            prefix=f"{prefix}.attention")
+        self.feed_forward = PixtralHFMLP(config,
+                                         quant_config=quant_config,
+                                         prefix=f"{prefix}.feed_forward")
+        self.ffn_norm = RMSNorm(config.hidden_size, eps=1e-5)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attention_mask: BlockDiagonalMask,
+        position_embeddings: torch.Tensor,
+    ) -> torch.Tensor:
+        r = self.attention.forward(self.attention_norm(hidden_states),
+                                   attention_mask=attention_mask,
+                                   position_embeddings=position_embeddings)
+        h = hidden_states + r
+        r = self.feed_forward.forward(self.ffn_norm(h))
+        out = h + r
+        return out
+
+
+class PixtralHFTransformer(nn.Module):
+
+    def __init__(
+        self,
+        config: PixtralVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        num_hidden_layers_override: Optional[int] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+
+        if num_hidden_layers_override is None:
+            num_hidden_layers = config.num_hidden_layers
+        else:
+            num_hidden_layers = num_hidden_layers_override
+
+        self.layers = nn.ModuleList([
+            PixtralHFTransformerBlock(config=config,
+                                      quant_config=quant_config,
+                                      prefix=f"{prefix}.layers.{layer_idx}")
+            for layer_idx in range(num_hidden_layers)
+        ])
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_mask: BlockDiagonalMask,
+        position_embeddings: torch.Tensor,
+    ) -> torch.Tensor:
+        for layer in self.layers:
+            x = layer(x, attention_mask, position_embeddings)
+        return x
+
+
+class PixtralHFVisionModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PixtralVisionConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        *,
+        num_hidden_layers_override: Optional[int] = None,
+        require_post_norm: Optional[bool] = None,
+        prefix: str = "",
+    ) -> None:
+        super().__init__()
+
+        self.config = config
+        self.patch_conv = nn.Conv2d(
+            in_channels=config.num_channels,
+            out_channels=config.hidden_size,
+            kernel_size=config.patch_size,
+            stride=config.patch_size,
+            bias=False,
+        )
+        self.ln_pre = RMSNorm(config.hidden_size, eps=1e-5)
+        self.transformer = PixtralHFTransformer(
+            config,
+            quant_config,
+            num_hidden_layers_override=num_hidden_layers_override,
+            prefix=f"{prefix}.transformer",
+        )
+
+        num_hidden_layers = config.num_hidden_layers
+        if len(self.transformer.layers) > config.num_hidden_layers:
+            raise ValueError(
+                f"The original encoder only has {num_hidden_layers} "
+                f"layers, but you requested {len(self.transformer.layers)} "
+                "layers.")
+
+        if require_post_norm is True:
+            msg = "PixtralHFVisionModel does not have post-layernorm"
+            raise ValueError(msg)
+
+        self.dtype = next(self.parameters()).dtype
+        self.device = next(self.parameters()).device
+        self.patch_positional_embedding = PixtralRotaryEmbedding(
+            config, self.device)
+
+    def forward(
+        self,
+        pixel_values: List[torch.Tensor],
+    ) -> torch.Tensor:
+        """
+        Args:
+            pixel_values: Each image to be processed will be a separate tensor
+                in pixel_values. This means it will be a list of tensors
+                because multiple requests batched can have multiple images,
+                each with their own shape potentially
+
+        Returns:
+            image_features: tensor of token features for
+                all tokens of all images of shape (N_toks, D)
+        """
+        # pass images through initial convolution independently
+        patch_embeds_list = [
+            self.patch_conv(img.unsqueeze(0).to(self.dtype))
+            for img in pixel_values
+        ]
+
+        # flatten to a single sequence
+        patch_embeds = torch.cat(
+            [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1)
+        patch_embeds = self.ln_pre(patch_embeds)
+
+        # positional embeddings
+        position_ids = position_ids_in_meshgrid(
+            patch_embeds_list,
+            max_width=self.config.image_size // self.config.patch_size).to(
+                self.device)
+
+        position_embedding = self.patch_positional_embedding(
+            patch_embeds, position_ids)
+        attention_mask = BlockDiagonalMask.from_seqlens(
+            [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
+        out = self.transformer(patch_embeds, attention_mask,
+                               position_embedding)
+
+        return out
+
+    # (TODO) Add prefix argument for filtering out weights to be loaded
+    #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = []
+        params_dict = dict(self.named_parameters())
+
+        for name, loaded_weight in weights:
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+
+                param = params_dict[name.replace(weight_name, param_name)]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index fd8a27eec3b9..cd3f7c1b6c4d 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -22,7 +22,8 @@
 from vllm.attention import Attention, AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.logger import init_logger
 from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
 from vllm.model_executor.layers.layernorm import RMSNorm
@@ -652,30 +653,30 @@ def get_image_text(image_num: int, padding: bool) -> str:
 
 
 def input_processor_for_qwen(ctx: InputContext,
-                             llm_inputs: LLMInputs) -> LLMInputs:
+                             inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
     """Processes the inputs, which may or may not be multimodal.
     Multimodal inputs will only be processed if the model has a "visual"
     component in its model config, otherwise they'll be ignored.
 
     Args:
         ctx: Context of the loaded model.
-        llm_inputs: LLM inputs which may have a multi_modal_data attribute.
+        inputs: LLM inputs which may have a multi_modal_data attribute.
 
     Returns:
         If the model is language only or not multimodal inputs were provided,
-        returns llm_inputs unmodified. Otherwise, processes the multimodal
+        returns inputs unmodified. Otherwise, processes the multimodal
         images / image embeddings and adds the fixed-length image placeholders.
     """
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+    multi_modal_data = inputs.get("multi_modal_data")
 
     # Only process images if we have multimodal data and a visual config
     hf_config = ctx.get_hf_config()
     if (multi_modal_data is None or "image" not in multi_modal_data
             or not hasattr(hf_config, "visual")):
-        return llm_inputs
+        return inputs
 
-    prompt = llm_inputs.get("prompt")
-    prompt_token_ids = llm_inputs["prompt_token_ids"]
+    prompt = inputs.get("prompt")
+    prompt_token_ids = inputs["prompt_token_ids"]
     model_config = ctx.model_config
     tokenizer = cached_get_tokenizer(
         model_config.tokenizer,
@@ -713,9 +714,9 @@ def input_processor_for_qwen(ctx: InputContext,
 
     new_prompt_token_ids = tokenizer.encode(new_prompt)
 
-    return LLMInputs(prompt=new_prompt,
-                     prompt_token_ids=new_prompt_token_ids,
-                     multi_modal_data=multi_modal_data)
+    return token_inputs(prompt=new_prompt,
+                        prompt_token_ids=new_prompt_token_ids,
+                        multi_modal_data=multi_modal_data)
 
 
 def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
@@ -822,7 +823,7 @@ def dummy_data_for_qwen(
     # The presence of a visual config indicates this is a multimodal model.
     # If we don't have it, the model is considered an LLM for warmup purposes.
     if not hasattr(hf_config, "visual"):
-        seq_data = SequenceData.from_token_counts((0, seq_len))
+        seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
         mm_data = None
         return seq_data, mm_data
 
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index eb9a9aa9364c..23eb1482ffef 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -365,6 +365,28 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
     embedding_modules = {}
     embedding_padding_modules = []
 
+    # BitandBytes specific attributes
+    default_bitsandbytes_target_modules = [
+        ".gate_proj.",
+        ".down_proj.",
+        ".up_proj.",
+        ".q_proj.",
+        ".k_proj.",
+        ".v_proj.",
+        ".o_proj.",
+    ]
+
+    # in TP, these weights are partitioned along the column dimension (dim=-1)
+    column_parallel_weights_modules = [".down_proj.", ".o_proj."]
+    bitsandbytes_stacked_params_mapping = {
+        # shard_name, weight_name, index
+        "q_proj": ("qkv_proj", 0),
+        "k_proj": ("qkv_proj", 1),
+        "v_proj": ("qkv_proj", 2),
+        "gate_proj": ("gate_up_proj", 0),
+        "up_proj": ("gate_up_proj", 1),
+    }
+
     def __init__(
         self,
         config: Qwen2Config,
diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py
new file mode 100644
index 000000000000..3d049eeb920b
--- /dev/null
+++ b/vllm/model_executor/models/qwen2_audio.py
@@ -0,0 +1,462 @@
+# coding=utf-8
+# Copyright 2024 The Qwen team.
+# Copyright 2023 The vLLM team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
+from functools import lru_cache
+from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
+
+import librosa
+import numpy as np
+import torch
+import torch.nn as nn
+from transformers import Qwen2AudioConfig, Qwen2AudioEncoder
+
+from vllm.attention import AttentionMetadata
+from vllm.config import CacheConfig, MultiModalConfig
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+    QuantizationConfig)
+from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
+from vllm.model_executor.model_loader.weight_utils import (
+    default_weight_loader, maybe_remap_kv_scale_name)
+from vllm.model_executor.models.qwen2 import Qwen2Model
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
+from vllm.sequence import IntermediateTensors, SequenceData
+
+from .interfaces import SupportsMultiModal, SupportsPP
+
+logger = init_logger(__name__)
+
+_KEYS_TO_MODIFY_MAPPING = {
+    "language_model.lm_head": "lm_head",
+    "language_model.model": "language_model",
+}
+
+
+# # === Audio Inputs === #
+class Qwen2AudioInputs(TypedDict):
+    input_features: torch.Tensor
+    """Shape: 
+    `(num_audios, num_mel_bins, 3000)`
+    """
+
+    feature_attention_mask: torch.Tensor
+    """Shape: `(num_audios, 3000)`
+    """
+
+
+# === Audio Encoder === #
+
+
+class Qwen2AudioMultiModalProjector(nn.Module):
+
+    def __init__(self, audio_hidden_size: int, text_hidden_size: int):
+        super().__init__()
+        self.linear = nn.Linear(audio_hidden_size, text_hidden_size, bias=True)
+
+    def forward(self, audio_features):
+        hidden_states = self.linear(audio_features)
+        return hidden_states
+
+
+def dummy_data_for_qwen2_audio(ctx: InputContext, seq_len: int,
+                               mm_counts: Mapping[str, int]):
+    num_audios = mm_counts["audio"]
+    max_llm_audio_tokens = get_max_qwen2_audio_audio_tokens(ctx) * num_audios
+    if seq_len - max_llm_audio_tokens - 2 < 0:
+        raise RuntimeError(
+            f"Qwen2-Audio cannot process {num_audios} audios in a prompt, "
+            "please increase max_model_len or reduce audio limit by "
+            "--limit-mm-per-prompt.")
+
+    audio_token_index = ctx.model_config.hf_config.audio_token_index
+
+    dummy_seqdata = SequenceData.from_prompt_token_counts(
+        (audio_token_index, max_llm_audio_tokens),
+        (0, seq_len - max_llm_audio_tokens),
+    )
+    dummy_audio = np.full((max_llm_audio_tokens * 2 * 2 * 160, ), 0.)
+    return dummy_seqdata, {"audio": [(dummy_audio, 16000)] * num_audios}
+
+
+def get_processor(
+    processor_name: str,
+    *args,
+    trust_remote_code: bool = False,
+    **kwargs,
+):
+    """Gets a processor for the given model name via HuggingFace.
+
+    Derived from `vllm.transformers_utils.image_processor.get_image_processor`.
+    """
+    # don't put this import at the top level
+    # it will call torch.cuda.device_count()
+    from transformers import AutoProcessor
+
+    try:
+        processor = AutoProcessor.from_pretrained(
+            processor_name,
+            *args,
+            trust_remote_code=trust_remote_code,
+            **kwargs)
+    except ValueError as e:
+        # If the error pertains to the processor class not existing or not
+        # currently being imported, suggest using the --trust-remote-code flag.
+        # Unlike AutoTokenizer, AutoProcessor does not separate such errors
+        if not trust_remote_code:
+            err_msg = (
+                "Failed to load the processor. If the processor is "
+                "a custom processor not yet available in the HuggingFace "
+                "transformers library, consider setting "
+                "`trust_remote_code=True` in LLM or using the "
+                "`--trust-remote-code` flag in the CLI.")
+            raise RuntimeError(err_msg) from e
+        else:
+            raise e
+
+    return processor
+
+
+cached_get_processor = lru_cache(get_processor)
+
+
+def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
+    """
+    Computes the output length of the convolutional layers
+    and the output length of the audio encoder
+    """
+    input_lengths = (input_lengths - 1) // 2 + 1
+    output_lengths = (input_lengths - 2) // 2 + 1
+    return input_lengths, output_lengths
+
+
+def get_max_qwen2_audio_audio_tokens(ctx: InputContext) -> int:
+    max_source_position = (
+        ctx.model_config.hf_config.audio_config.max_source_positions)
+    output_lengths = (max_source_position - 2) // 2 + 1
+    return output_lengths
+
+
+def input_processor_for_qwen2_audio(
+        ctx: InputContext, inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
+    multi_modal_data = inputs.get("multi_modal_data")
+    if multi_modal_data is None or "audio" not in multi_modal_data:
+        return inputs
+
+    audios = multi_modal_data["audio"]
+    if not isinstance(audios, list):
+        audios = [audios]
+
+    if len(audios) == 0:
+        return inputs
+
+    processor = cached_get_processor(ctx.model_config.model)
+    resampled_audios = [
+        librosa.resample(audio,
+                         orig_sr=sampling_rate,
+                         target_sr=processor.feature_extractor.sampling_rate)
+        for audio, sampling_rate in audios
+    ]
+    audio_input_lengths = np.array(
+        [min(3000, _.shape[0] // 160 + 1) for _ in resampled_audios])
+
+    audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
+        audio_input_lengths)
+
+    audio_token_index = ctx.model_config.hf_config.audio_token_index
+
+    input_ids = inputs['prompt_token_ids']
+
+    new_input_ids = []
+    audio_num = input_ids.count(audio_token_index)
+    assert len(audio_input_lengths) == audio_num, \
+        (f'The text input contains {audio_num} audio tokens, '
+         f'but {len(audio_input_lengths)} audios provided')
+    start = 0
+    for audio_idx in range(audio_num):
+        end = input_ids.index(audio_token_index, start)
+        new_input_ids.extend(input_ids[start:end])  # text part
+
+        new_input_ids.extend([audio_token_index] *
+                             audio_output_lengths[audio_idx])
+        start = end + 1
+    new_input_ids.extend(input_ids[start:])
+
+    return token_inputs(
+        prompt_token_ids=new_input_ids,
+        prompt=inputs['prompt'],
+        multi_modal_data=multi_modal_data,
+    )
+
+
+def input_mapper_for_qwen2_audio(
+    ctx: InputContext,
+    multi_modal_data: Union[np.ndarray, List[np.ndarray]],
+) -> MultiModalInputs:
+    """Input mapper for Qwen2-Audio."""
+    if not isinstance(multi_modal_data, list):
+        multi_modal_data = [multi_modal_data]
+
+    if len(multi_modal_data) == 0:
+        return MultiModalInputs()
+
+    processor = cached_get_processor(ctx.model_config.model)
+    audio_feature_extractor = processor.feature_extractor
+    if audio_feature_extractor is None:
+        raise RuntimeError(
+            "No HuggingFace audio_feature_extractor is available "
+            "to process the audio object")
+
+    try:
+        resampled_audios = [
+            librosa.resample(
+                audio,
+                orig_sr=sampling_rate,
+                target_sr=processor.feature_extractor.sampling_rate)
+            for audio, sampling_rate in multi_modal_data
+        ]
+        batch_data = audio_feature_extractor(resampled_audios,
+                                             sampling_rate=16000,
+                                             return_attention_mask=True,
+                                             padding="max_length",
+                                             return_tensors="pt").data
+        batch_data["feature_attention_mask"] = batch_data.pop("attention_mask")
+    except Exception:
+        logger.error("Failed to process audio (%s)", multi_modal_data)
+        raise
+
+    return MultiModalInputs(batch_data)
+
+
+@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_audio)
+@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_audio)
+@MULTIMODAL_REGISTRY.register_input_mapper("audio",
+                                           input_mapper_for_qwen2_audio)
+@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
+    "audio", get_max_qwen2_audio_audio_tokens)
+class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal,
+                                         SupportsPP):
+
+    def __init__(self,
+                 config: Qwen2AudioConfig,
+                 multimodal_config: MultiModalConfig,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
+        super().__init__()
+
+        self.config = config
+        self.multimodal_config = multimodal_config
+
+        self.audio_tower = Qwen2AudioEncoder(config.audio_config)
+        self.multi_modal_projector = Qwen2AudioMultiModalProjector(
+            config.audio_config.d_model, config.text_config.hidden_size)
+
+        self.quant_config = quant_config
+
+        self.language_model = Qwen2Model(config.text_config, cache_config,
+                                         quant_config)
+        self.unpadded_vocab_size = config.text_config.vocab_size
+        if config.text_config.tie_word_embeddings:
+            self.lm_head = self.language_model.embed_tokens
+        else:
+            self.lm_head = ParallelLMHead(config.text_config.vocab_size,
+                                          config.text_config.hidden_size,
+                                          quant_config=quant_config)
+        logit_scale = getattr(config, "logit_scale", 1.0)
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.text_config.vocab_size,
+                                                logit_scale)
+        self.sampler = Sampler()
+
+        self.make_empty_intermediate_tensors = (
+            self.language_model.make_empty_intermediate_tensors)
+
+    def _validate_and_reshape_mm_tensor(self,
+                                        mm_input: Union[torch.Tensor,
+                                                        List[torch.Tensor]],
+                                        name: str) -> torch.Tensor:
+        if not isinstance(mm_input, (torch.Tensor, list)):
+            raise ValueError(f"Incorrect type of {name}. "
+                             f"Got type: {type(mm_input)}")
+        if isinstance(mm_input, torch.Tensor):
+            return torch.concat(list(mm_input))
+        else:
+            return torch.concat(mm_input)
+
+    def _parse_and_validate_audio_input(
+            self, **kwargs: object) -> Optional[Qwen2AudioInputs]:
+        input_features = kwargs.pop('input_features', None)
+        feature_attention_mask = kwargs.pop('feature_attention_mask', None)
+        if input_features is None:
+            return None
+        input_features = self._validate_and_reshape_mm_tensor(
+            input_features, 'input_features')
+        feature_attention_mask = self._validate_and_reshape_mm_tensor(
+            feature_attention_mask, 'feature_attention_mask')
+        if not isinstance(input_features, (torch.Tensor, list)):
+            raise ValueError("Incorrect type of audio input features. "
+                             f"Got type: {type(input_features)}")
+        return Qwen2AudioInputs(input_features=input_features,
+                                feature_attention_mask=feature_attention_mask)
+
+    def _process_audio_input(self,
+                             audio_input: Qwen2AudioInputs) -> torch.Tensor:
+
+        input_features = audio_input["input_features"]
+        feature_attention_mask = audio_input["feature_attention_mask"]
+
+        audio_feat_lengths, audio_output_lengths = (
+            self.audio_tower._get_feat_extract_output_lengths(
+                feature_attention_mask.sum(-1)))
+
+        batch_size, _, max_mel_seq_len = input_features.shape
+        max_seq_len = (max_mel_seq_len - 2) // 2 + 1
+        # Create a sequence tensor of shape (batch_size, max_seq_len)
+        seq_range = (torch.arange(
+            0,
+            max_seq_len,
+            dtype=audio_feat_lengths.dtype,
+            device=audio_feat_lengths.device).unsqueeze(0).expand(
+                batch_size, max_seq_len))
+        lengths_expand = audio_feat_lengths.unsqueeze(-1).expand(
+            batch_size, max_seq_len)
+        # Create mask
+        padding_mask = seq_range >= lengths_expand
+
+        audio_attention_mask_ = padding_mask.view(
+            batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
+                                                  max_seq_len)
+        audio_attention_mask = audio_attention_mask_.to(
+            dtype=self.audio_tower.conv1.weight.dtype,
+            device=self.audio_tower.conv1.weight.device)
+        audio_attention_mask[audio_attention_mask_] = float("-inf")
+
+        audio_outputs = self.audio_tower(input_features,
+                                         attention_mask=audio_attention_mask)
+        selected_audio_feature = audio_outputs.last_hidden_state
+        audio_features = self.multi_modal_projector(selected_audio_feature)
+        num_audios, max_audio_tokens, embed_dim = audio_features.shape
+        audio_features_mask = torch.arange(max_audio_tokens).expand(
+            num_audios, max_audio_tokens
+        ).to(audio_output_lengths.device) < audio_output_lengths.unsqueeze(1)
+        masked_audio_features = audio_features[audio_features_mask].view(
+            -1, embed_dim)
+
+        return masked_audio_features
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        **kwargs: object,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        if intermediate_tensors is not None:
+            input_ids = None
+            inputs_embeds = None
+        else:
+            audio_input = self._parse_and_validate_audio_input(**kwargs)
+
+            if audio_input is None:
+                inputs_embeds = None
+            else:
+                inputs_embeds = self.language_model.embed_tokens(input_ids)
+                masked_audio_features = self._process_audio_input(audio_input)
+                # merge llm embeddings and audio features
+                mask = (input_ids == self.config.audio_token_index)
+                inputs_embeds[mask, :] = masked_audio_features
+
+                input_ids = None
+
+        hidden_states = self.language_model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+            intermediate_tensors=intermediate_tensors,
+            inputs_embeds=inputs_embeds,
+        )
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+        params_dict = dict(self.named_parameters(remove_duplicate=False))
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if (self.config.text_config.tie_word_embeddings
+                    and "lm_head.weight" in name):
+                continue
+            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
+                if key_to_modify in name:
+                    name = name.replace(key_to_modify, new_key)
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name or 'audio' in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                # Remapping the name of FP8 kv-scale.
+                name = maybe_remap_kv_scale_name(name, params_dict)
+                if name is None:
+                    continue
+
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py
index 7dcf52a56e98..ee0eeb9db380 100644
--- a/vllm/model_executor/models/qwen2_rm.py
+++ b/vllm/model_executor/models/qwen2_rm.py
@@ -119,5 +119,6 @@ def pooler(
         return self._pooler(hidden_states, pooling_metadata)
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        loader = AutoWeightsLoader(self)
+        loader = AutoWeightsLoader(self,
+                                   ignore_unexpected_prefixes=["lm_head."])
         loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 4a39b3fbe5a4..4e60fe70b25f 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -22,7 +22,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
-from functools import lru_cache, partial
+from functools import partial
 from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
                     Tuple, Type, TypedDict, Union)
 
@@ -34,17 +34,18 @@
 from transformers.image_utils import (get_image_size,
                                       infer_channel_dimension_format,
                                       to_numpy_array)
+from transformers.models.qwen2_vl.configuration_qwen2_vl import (
+    Qwen2VLConfig, Qwen2VLVisionConfig)
 from transformers.models.qwen2_vl.image_processing_qwen2_vl import (
     make_batched_images, make_batched_videos, smart_resize)
 
-import vllm.envs as envs
 from vllm.attention import AttentionMetadata
-from vllm.attention.selector import (_Backend, backend_name_to_enum,
-                                     get_global_forced_attn_backend)
+from vllm.attention.selector import _Backend
 from vllm.config import CacheConfig, MultiModalConfig
 from vllm.distributed import get_pp_group, parallel_state
 from vllm.distributed import utils as dist_utils
-from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
+                         token_inputs)
 from vllm.logger import init_logger
 from vllm.model_executor import SamplingMetadata
 from vllm.model_executor.layers.activation import QuickGELU
@@ -60,15 +61,14 @@
                              MultiModalInputs)
 from vllm.multimodal.base import MultiModalData
 from vllm.multimodal.image import cached_get_image_processor
-from vllm.platforms import current_platform
+from vllm.multimodal.utils import cached_get_tokenizer
 from vllm.sequence import IntermediateTensors, SequenceData
-from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
-                                                     Qwen2VLVisionConfig)
-from vllm.transformers_utils.processor import get_processor
-from vllm.utils import is_cpu
+from vllm.transformers_utils.config import uses_mrope
+from vllm.transformers_utils.processor import cached_get_processor
 
 from .interfaces import SupportsMultiModal, SupportsPP
-from .utils import (PPMissingLayer, is_pp_missing_parameter,
+from .utils import (PPMissingLayer, get_vit_attn_backend,
+                    is_pp_missing_parameter,
                     make_empty_intermediate_tensors_factory)
 
 logger = init_logger(__name__)
@@ -79,7 +79,7 @@
 class Qwen2VLImagePixelInputs(TypedDict):
     type: Literal["pixel_values"]
     data: torch.Tensor
-    """Shape: 
+    """Shape:
     `(num_patches, num_channels * patch_size * patch_size)`
     """
 
@@ -103,14 +103,14 @@ class Qwen2VLImageEmbeddingInputs(TypedDict):
 
 class Qwen2VLVideoInputs(TypedDict):
     pixel_values_videos: torch.Tensor
-    """Shape: 
-    `(num_patches, 
+    """Shape:
+    `(num_patches,
       num_channels * temporal_patch_size * patch_size * patch_size)`
     """
 
     video_grid_thw: torch.Tensor
     """Shape: `(num_videos, 3)`
-    
+
     This should be in `(grid_t, grid_h, grid_w)` format.
     """
 
@@ -213,37 +213,12 @@ def __init__(
                                       quant_config=quant_config)
 
         # Detect attention implementation.
-        selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
-        if selected_backend is None:
-            backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
-            if backend_by_env_var is not None:
-                selected_backend = backend_name_to_enum(backend_by_env_var)
-        if selected_backend is None:
-            # For Volta and Turing GPUs, use xformers instead.
-            device_available = current_platform.has_device_capability(80)
-            if device_available:
-                from transformers.utils import is_flash_attn_2_available
-
-                if is_flash_attn_2_available():
-                    self._use_flash_attn = True
-                else:
-                    logger.warning(
-                        "Current Qwen2-VL implementation has a bug with "
-                        "`vllm-flash-attn` inside vision module, so we use "
-                        "xformers backend instead. You can run `pip install "
-                        "flash-attn to use flash-attention backend.")
-                    self._use_flash_attn = False
-            else:
-                self._use_flash_attn = False
-        else:
-            if selected_backend == _Backend.FLASH_ATTN:
-                self._use_flash_attn = True
-            elif selected_backend == _Backend.XFORMERS:
-                self._use_flash_attn = False
-            else:
-                raise RuntimeError(
-                    f"Qwen2-VL does not support {selected_backend} backend now."
-                )
+        self.attn_backend: _Backend = get_vit_attn_backend()
+        if self.attn_backend not in {
+                _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
+        }:
+            raise RuntimeError(
+                f"Qwen2-VL does not support {self.attn_backend} backend now.")
 
     def forward(
         self,
@@ -272,7 +247,7 @@ def forward(
             q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
             k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
 
-        if self._use_flash_attn:
+        if self.attn_backend == _Backend.FLASH_ATTN:
             # from vllm_flash_attn.flash_attn_interface import (
             #   flash_attn_varlen_func)
             from flash_attn import flash_attn_varlen_func
@@ -293,7 +268,7 @@ def forward(
             context_layer = rearrange(output,
                                       "(b s) ... -> b s ...",
                                       b=batch_size)
-        elif is_cpu():
+        elif self.attn_backend == _Backend.TORCH_SDPA:
             seq_length = q.size(1)
             q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]]
             attention_mask = torch.zeros([1, seq_length, seq_length],
@@ -308,7 +283,7 @@ def forward(
                                                     attention_mask,
                                                     dropout_p=0.0)
             context_layer = rearrange(output, "b h s d -> b s h d ")
-        else:
+        elif self.attn_backend == _Backend.XFORMERS:
             from xformers import ops as xops
             from xformers.ops.fmha.attn_bias import BlockDiagonalMask
 
@@ -570,13 +545,14 @@ def forward(
 
 # === Vision input helpers === #
 
-cached_get_processor = lru_cache(get_processor)
-
 
 def mm_input_mapper_for_qwen2_vl(
     ctx: InputContext,
     data: MultiModalData[object],
     data_type_key: str,
+    *,
+    min_pixels: Optional[int] = None,
+    max_pixels: Optional[int] = None,
 ) -> MultiModalInputs:
     """Input mapper for Qwen2-VL."""
     if data_type_key == "image" and isinstance(data, dict):
@@ -585,8 +561,19 @@ def mm_input_mapper_for_qwen2_vl(
             "image_grid_thw": data.get("image_grid_thw"),
         })
     model_config = ctx.model_config
+    # Handle mm processor kwargs; we pass these at creation time
+    # because preprocess() in transformers doesn't expose them
+    mm_processor_kwargs = {}
+    if min_pixels:
+        mm_processor_kwargs["min_pixels"] = min_pixels
+    if max_pixels:
+        mm_processor_kwargs["max_pixels"] = max_pixels
+
     image_processor = cached_get_image_processor(
-        model_config.model, trust_remote_code=model_config.trust_remote_code)
+        model_config.model,
+        trust_remote_code=model_config.trust_remote_code,
+        **mm_processor_kwargs,
+    )
     if image_processor is None:
         raise RuntimeError("No HuggingFace processor is available "
                            "to process the image object")
@@ -659,25 +646,36 @@ def _get_max_image_info(
     image_processor,
     data_type_key: str = "image",
     mm_count: int = 1,
+    min_pixels: Optional[int] = None,
+    max_pixels: Optional[int] = None,
 ):
+    # Limit min / max pixels unless they're explicitly provided
+    if min_pixels is None:
+        min_pixels = max(image_processor.min_pixels, 28 * 28)
+    if max_pixels is None:
+        max_pixels = min(image_processor.max_pixels, 1280 * 28 * 28)
+
     return _get_vision_info(
         image_processor,
         height=9999999,
         width=9999999,
-
-        # Limit min / max pixels.
-        min_pixels=max(image_processor.min_pixels, 28 * 28),
-        max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28),
+        min_pixels=min_pixels,
+        max_pixels=max_pixels,
         data_type_key=data_type_key,
         mm_count=mm_count,
     )
 
 
-def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
+def get_max_qwen2_vl_mm_tokens(ctx: InputContext,
+                               data_type_key: str,
+                               *,
+                               min_pixels=None,
+                               max_pixels=None) -> int:
     image_processor = cached_get_image_processor(ctx.model_config.model)
     max_resized_height, max_resized_width, max_llm_image_tokens = \
         _get_max_image_info(image_processor, data_type_key=data_type_key,
-                            mm_count=1)
+                            mm_count=1, min_pixels=min_pixels,
+                            max_pixels=max_pixels)
     return max_llm_image_tokens
 
 
@@ -688,14 +686,20 @@ def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int:
 
 
 def dummy_data_for_qwen2_vl(
-    ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
+    ctx: InputContext,
+    seq_len: int,
+    mm_counts: Mapping[str, int],
+    *,
+    min_pixels: Optional[int] = None,
+    max_pixels: Optional[int] = None
 ) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
     image_processor = cached_get_image_processor(ctx.model_config.model)
 
     num_images = mm_counts["image"]
     max_resized_height, max_resized_width, max_llm_image_tokens = \
         _get_max_image_info(image_processor, data_type_key="image",
-                            mm_count=num_images)
+                            mm_count=num_images, min_pixels=min_pixels,
+                            max_pixels=max_pixels)
     if seq_len - max_llm_image_tokens - 2 < 0:
         raise RuntimeError(
             f"Qwen2-VL cannot process {num_images} images in a prompt, "
@@ -706,16 +710,17 @@ def dummy_data_for_qwen2_vl(
     num_videos = mm_counts["video"]
     max_resized_height, max_resized_width, max_llm_video_tokens = \
         _get_max_image_info(image_processor, data_type_key="video",
-                            mm_count=num_videos)
+                            mm_count=num_videos, min_pixels=min_pixels,
+                            max_pixels=max_pixels)
     if seq_len - max_llm_video_tokens - 2 < 0:
         raise RuntimeError(
-            f"Qwen2-VL cannot process {num_images} videos in a prompt, "
+            f"Qwen2-VL cannot process {num_videos} videos in a prompt, "
             "please increase max_model_len or reduce video limit by "
             "--limit-mm-per-prompt.")
 
     hf_config = ctx.get_hf_config(Qwen2VLConfig)
 
-    dummy_seqdata = SequenceData.from_token_counts(
+    dummy_seqdata = SequenceData.from_prompt_token_counts(
         (hf_config.vision_start_token_id, 1),
         (hf_config.image_token_id, max_llm_image_tokens),
         (hf_config.vision_end_token_id, 1),
@@ -734,6 +739,8 @@ def _get_llm_num_vision_tokens(
     mm_inputs: list,
     data_type_key: str,
     image_processor,
+    min_pixels: int,
+    max_pixels: int,
 ):
     """Get number of vision tokens of multimodal inputs.
 
@@ -743,12 +750,13 @@ def _get_llm_num_vision_tokens(
     image = to_numpy_array(mm_inputs[0])
     input_data_format = infer_channel_dimension_format(image)
     height, width = get_image_size(image, channel_dim=input_data_format)
+
     _, _, llm_num_vision_tokens = _get_vision_info(
         image_processor,
         height=height,
         width=width,
-        min_pixels=image_processor.min_pixels,
-        max_pixels=image_processor.max_pixels,
+        min_pixels=min_pixels,
+        max_pixels=max_pixels,
         do_resize=image_processor.do_resize,
         data_type_key=data_type_key,
         mm_count=len(mm_inputs),
@@ -758,7 +766,8 @@ def _get_llm_num_vision_tokens(
 
 def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
                        data_type_key: str, image_processor: Any,
-                       prompt_token_ids: List[int]) -> List[int]:
+                       prompt_token_ids: List[int], min_pixels: Optional[int],
+                       max_pixels: Optional[int]) -> List[int]:
     """
     Expand pad tokens for multi-modal inputs (e.g., images or videos).
 
@@ -769,6 +778,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
         data_type_key (str): The type of the multi-modal input.
         image_processor (Any): The image processor used to process the inputs.
         prompt_token_ids (List[int]): The list of token IDs in the prompt.
+        min_pixels (int): min pixels to used for img processing
+        max_pixels (int): max pixels to be used for img processing
 
     Returns:
         List[int]: The list of token IDs for the multi-modal inputs.
@@ -785,6 +796,8 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
             [data] if data_type_key == "image" else data,
             data_type_key=data_type_key,
             image_processor=image_processor,
+            min_pixels=min_pixels,
+            max_pixels=max_pixels,
         )
         if cnt == 0:
             end_idx = indices[cnt]
@@ -798,17 +811,27 @@ def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable,
     return prompt_token_ids_with_data
 
 
-def input_processor_for_qwen2_vl(ctx: InputContext,
-                                 llm_inputs: LLMInputs) -> LLMInputs:
-    multi_modal_data = llm_inputs.get("multi_modal_data", None)
+def input_processor_for_qwen2_vl(
+    ctx: InputContext,
+    inputs: DecoderOnlyInputs,
+    *,
+    min_pixels: Optional[int] = None,
+    max_pixels: Optional[int] = None,
+) -> DecoderOnlyInputs:
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None:
-        return llm_inputs
+        return inputs
 
     image_inputs = multi_modal_data.get("image", None)
     video_inputs = multi_modal_data.get("video", None)
 
     processor = cached_get_processor(ctx.model_config.model)
     image_processor = processor.image_processor
+    # Apply processor kwarg overrides for image processor options
+    min_pixels = min_pixels if min_pixels else image_processor.min_pixels
+    max_pixels = max_pixels if max_pixels else image_processor.max_pixels
+
+    model_config = ctx.model_config
     hf_config = ctx.get_hf_config(Qwen2VLConfig)
 
     # To avoid redundant processing of vision objects (resize, rescale, etc.),
@@ -816,7 +839,7 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
     # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`.
     #
     # The following code is equivalent to:
-    #    prompt = llm_inputs["prompt"]
+    #    prompt = inputs["prompt"]
     #    inputs = processor(text=[prompt],
     #                       images=image_inputs,
     #                       videos=video_inputs,
@@ -824,14 +847,11 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
     #                       return_tensors="pt")
     #    prompt_token_ids = inputs["input_ids"][0].tolist()
 
-    prompt_token_ids = llm_inputs.get("prompt_token_ids", None)
-    if prompt_token_ids is None:
-        prompt = llm_inputs["prompt"]
-        prompt_token_ids = processor.tokenizer(
-            prompt,
-            padding=True,
-            return_tensors=None,
-        )["input_ids"]
+    tokenizer = cached_get_tokenizer(
+        model_config.tokenizer,
+        trust_remote_code=model_config.trust_remote_code)
+
+    prompt_token_ids = inputs["prompt_token_ids"]
 
     # Expand image pad tokens.
 
@@ -856,20 +876,30 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
         else:
             prompt_token_ids = _expand_pad_tokens(image_inputs,
                                                   hf_config.image_token_id,
-                                                  make_batched_images, "image",
+                                                  make_batched_images,
+                                                  "image",
                                                   image_processor,
-                                                  prompt_token_ids)
+                                                  prompt_token_ids,
+                                                  min_pixels=min_pixels,
+                                                  max_pixels=max_pixels)
 
     if video_inputs is not None:
         prompt_token_ids = _expand_pad_tokens(video_inputs,
                                               hf_config.video_token_id,
-                                              make_batched_videos, "video",
+                                              make_batched_videos,
+                                              "video",
                                               image_processor,
-                                              prompt_token_ids)
+                                              prompt_token_ids,
+                                              min_pixels=min_pixels,
+                                              max_pixels=max_pixels)
+
+    prompt = inputs.get("prompt")
+    if prompt is None:
+        prompt = tokenizer.decode(prompt_token_ids)
 
-    return LLMInputs(
+    return token_inputs(
         prompt_token_ids=prompt_token_ids,
-        prompt=llm_inputs["prompt"],
+        prompt=prompt,
         multi_modal_data=multi_modal_data,
     )
 
@@ -1061,8 +1091,7 @@ def forward(
             if image_input is None and video_input is None:
                 inputs_embeds = None
             else:
-                rope_scaling = getattr(self.config, "rope_scaling", {})
-                if rope_scaling.get("type", None) == "mrope":
+                if uses_mrope(self.config):
                     assert positions.ndim == 2 and positions.size(0) == 3, (
                         "multimodal section rotary embedding requires "
                         f"(3, seq_len) positions, but got {positions.size()}")
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index b06d3d612dbc..717615988a90 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -26,8 +26,10 @@
     "AquilaModel": ("llama", "LlamaForCausalLM"),
     "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
     "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
-    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
-    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
+    # baichuan-7b, upper case 'C' in the class name
+    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),
+    # baichuan-13b, lower case 'c' in the class name
+    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
     "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
     # ChatGLMModel supports multimodal
     "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
@@ -47,12 +49,14 @@
     "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
     "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
     "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
+    "InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
     "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
     "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
     "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
     # For decapoda-research/llama-*
     "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
     "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
+    "FalconMambaForCausalLM": ("mamba", "MambaForCausalLM"),
     "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
     "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
     "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
@@ -83,12 +87,18 @@
     # [Encoder-decoder]
     "BartModel": ("bart", "BartForConditionalGeneration"),
     "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
+    "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
 }
 
 _EMBEDDING_MODELS = {
-    "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
+    # [Text-only]
+    "BertModel": ("bert", "BertEmbeddingModel"),
+    "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
+    "MistralModel": ("llama", "LlamaEmbeddingModel"),
     "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
-    "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"),
+    # [Multimodal]
+    "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"),  # noqa: E501
+    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
 }
 
 _MULTIMODAL_MODELS = {
@@ -111,6 +121,7 @@
     "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"),  # noqa: E501
     "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
     "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"),  # noqa: E501
+    "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"),  # noqa: E501
     "UltravoxModel": ("ultravox", "UltravoxModel"),
     # [Encoder-decoder]
     "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"),  # noqa: E501
diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
index 743a81f8f9e9..91277b0ccd14 100644
--- a/vllm/model_executor/models/siglip.py
+++ b/vllm/model_executor/models/siglip.py
@@ -13,7 +13,7 @@
 
 from vllm.config import ModelConfig
 from vllm.distributed import divide, get_tensor_model_parallel_world_size
-from vllm.inputs import LLMInputs
+from vllm.inputs import DecoderOnlyInputs, token_inputs
 from vllm.model_executor.layers.activation import get_act_fn
 from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                                QKVParallelLinear,
@@ -67,7 +67,7 @@ def dummy_seq_data_for_siglip(
     else:
         image_feature_size = image_feature_size_override
 
-    return SequenceData.from_token_counts(
+    return SequenceData.from_prompt_token_counts(
         (image_token_id, image_feature_size * num_images),
         (0, seq_len - image_feature_size * num_images),
     )
@@ -111,14 +111,14 @@ def dummy_video_for_siglip(
 def input_processor_for_siglip(
     model_config: ModelConfig,
     hf_config: SiglipVisionConfig,
-    llm_inputs: LLMInputs,
+    inputs: DecoderOnlyInputs,
     *,
     image_token_id: int,
     image_feature_size_override: Optional[Union[int, List[int]]] = None,
 ):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "image" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     tokenizer = cached_get_tokenizer(model_config.tokenizer)
 
@@ -135,14 +135,14 @@ def input_processor_for_siglip(
 
     new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
-        llm_inputs.get("prompt"),
-        llm_inputs["prompt_token_ids"],
+        inputs.get("prompt"),
+        inputs["prompt_token_ids"],
         placeholder_token_id=image_token_id,
         repeat_count=image_feature_size,
     )
 
     # NOTE: Create a defensive copy of the original inputs
-    return LLMInputs(
+    return token_inputs(
         prompt_token_ids=new_token_ids,
         prompt=new_prompt,
         multi_modal_data=multi_modal_data,
@@ -248,8 +248,10 @@ def __init__(
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
         self.embed_dim = config.hidden_size
         self.num_heads = config.num_attention_heads
@@ -266,12 +268,14 @@ def __init__(
             head_size=self.head_dim,
             total_num_heads=self.num_heads,
             quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj",
         )
 
         self.out_proj = RowParallelLinear(
             input_size=self.embed_dim,
             output_size=self.embed_dim,
             quant_config=quant_config,
+            prefix=f"{prefix}.out_proj",
         )
 
         self.tp_size = get_tensor_model_parallel_world_size()
@@ -314,8 +318,10 @@ def __init__(
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
         self.activation_fn = get_act_fn(config.hidden_act)
 
@@ -326,11 +332,13 @@ def __init__(
             config.hidden_size,
             config.intermediate_size,
             quant_config=quant_config if quantizable else None,
+            prefix=f"{prefix}.fc1",
         )
         self.fc2 = RowParallelLinear(
             config.intermediate_size,
             config.hidden_size,
             quant_config=quant_config if quantizable else None,
+            prefix=f"{prefix}.fc2",
         )
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -346,15 +354,20 @@ def __init__(
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.embed_dim = config.hidden_size
 
         num_heads = config.num_attention_heads
         tp_size = get_tensor_model_parallel_world_size()
         if USE_XFORMERS_OPS and num_heads % tp_size == 0:
-            self.self_attn = SiglipParallelAttention(config,
-                                                     quant_config=quant_config)
+            self.self_attn = SiglipParallelAttention(
+                config,
+                quant_config=quant_config,
+                prefix=f"{prefix}.self_attn",
+            )
         else:
             self.self_attn = SiglipSdpaAttention(config)
 
@@ -363,6 +376,7 @@ def __init__(
         self.mlp = SiglipMLP(
             config,
             quant_config=quant_config,
+            prefix=f"{prefix}.mlp",
         )
         self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                         eps=config.layer_norm_eps)
@@ -392,8 +406,10 @@ def __init__(
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
         num_hidden_layers_override: Optional[int] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
 
         if num_hidden_layers_override is None:
@@ -402,8 +418,10 @@ def __init__(
             num_hidden_layers = num_hidden_layers_override
 
         self.layers = nn.ModuleList([
-            SiglipEncoderLayer(config, quant_config=quant_config)
-            for _ in range(num_hidden_layers)
+            SiglipEncoderLayer(config,
+                               quant_config=quant_config,
+                               prefix=f"{prefix}.layers.{layer_idx}")
+            for layer_idx in range(num_hidden_layers)
         ])
 
     def forward(
@@ -424,7 +442,8 @@ def __init__(
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
-    ):
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
@@ -433,7 +452,9 @@ def __init__(
             config.hidden_size, config.num_attention_heads, batch_first=True)
         self.layernorm = nn.LayerNorm(config.hidden_size,
                                       eps=config.layer_norm_eps)
-        self.mlp = SiglipMLP(config=config, quant_config=quant_config)
+        self.mlp = SiglipMLP(config=config,
+                             quant_config=quant_config,
+                             prefix=f"{prefix}.mlp")
 
     def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
         batch_size = hidden_state.shape[0]
@@ -454,9 +475,13 @@ def __init__(
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
+        *,
         num_hidden_layers_override: Optional[int] = None,
-    ):
+        require_post_norm: Optional[bool] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
+
         self.config = config
         embed_dim = config.hidden_size
 
@@ -465,26 +490,34 @@ def __init__(
             config,
             quant_config=quant_config,
             num_hidden_layers_override=num_hidden_layers_override,
+            prefix=f"{prefix}.encoder",
         )
 
+        num_hidden_layers = config.num_hidden_layers
         if len(self.encoder.layers) > config.num_hidden_layers:
             raise ValueError(
-                f"The original encoder only has {config.num_hidden_layers} "
+                f"The original encoder only has {num_hidden_layers} "
                 f"layers, but you requested {len(self.encoder.layers)} layers."
             )
-        elif len(self.encoder.layers) == config.num_hidden_layers:
+
+        # If possible, skip post_layernorm to conserve memory
+        if require_post_norm is None:
+            require_post_norm = len(self.encoder.layers) == num_hidden_layers
+
+        if require_post_norm:
             self.post_layernorm = nn.LayerNorm(embed_dim,
                                                eps=config.layer_norm_eps)
         else:
-            # post_layernorm is unused when we extract intermediate features
-            # In this case, we can skip it to conserve memory
             self.post_layernorm = None
 
         self.use_head = (True if not hasattr(config, "vision_use_head") else
                          config.vision_use_head)
         if self.use_head:
             self.head = SiglipMultiheadAttentionPoolingHead(
-                config=config, quant_config=quant_config)
+                config=config,
+                quant_config=quant_config,
+                prefix=f"{prefix}.head",
+            )
 
     def forward(
         self,
@@ -517,8 +550,11 @@ def __init__(
         self,
         config: SiglipVisionConfig,
         quant_config: Optional[QuantizationConfig] = None,
+        *,
         num_hidden_layers_override: Optional[int] = None,
-    ):
+        require_post_norm: Optional[bool] = None,
+        prefix: str = "",
+    ) -> None:
         super().__init__()
 
         num_heads = config.num_attention_heads
@@ -529,6 +565,8 @@ def __init__(
             config,
             quant_config,
             num_hidden_layers_override=num_hidden_layers_override,
+            require_post_norm=require_post_norm,
+            prefix=f"{prefix}.vision_model",
         )
 
     def get_input_embeddings(self) -> nn.Module:
diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py
index b9298ed03114..5a3dd3c02b85 100644
--- a/vllm/model_executor/models/solar.py
+++ b/vllm/model_executor/models/solar.py
@@ -29,6 +29,7 @@
 from transformers import PretrainedConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
                               get_tensor_model_parallel_world_size)
@@ -263,6 +264,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class SolarModel(nn.Module):
 
     def __init__(
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 81dd7c4daa5e..8f0644bca3e2 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -25,6 +25,7 @@
 from transformers import Starcoder2Config
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import get_act_fn
@@ -193,6 +194,7 @@ def forward(
         return hidden_states
 
 
+@support_torch_compile
 class Starcoder2Model(nn.Module):
 
     def __init__(self,
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index e162e3af008e..5f33b872beec 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -18,7 +18,7 @@
 from vllm.attention import AttentionMetadata
 from vllm.config import CacheConfig, MultiModalConfig
 from vllm.inputs import INPUT_REGISTRY
-from vllm.inputs.data import LLMInputs
+from vllm.inputs.data import DecoderOnlyInputs, token_inputs
 from vllm.inputs.registry import InputContext
 from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
 from vllm.model_executor.layers.layernorm import RMSNorm
@@ -117,6 +117,9 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
     if not isinstance(data, list):
         data = [data]
 
+    if len(data) == 0:
+        return MultiModalInputs()
+
     # If the audio inputs are embeddings, no need for preprocessing
     if is_list_of(data, torch.Tensor, check="all"):
         return MultiModalInputs({"audio_embeds": data})
@@ -156,10 +159,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
     return MultiModalInputs({"audio_features": audio_features})
 
 
-def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
-    multi_modal_data = llm_inputs.get("multi_modal_data")
+def input_processor_for_ultravox(ctx: InputContext, inputs: DecoderOnlyInputs):
+    multi_modal_data = inputs.get("multi_modal_data")
     if multi_modal_data is None or "audio" not in multi_modal_data:
-        return llm_inputs
+        return inputs
 
     feature_extractor = whisper_feature_extractor(ctx)
     audios = multi_modal_data["audio"]
@@ -196,16 +199,16 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
 
     new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
         tokenizer,
-        llm_inputs.get("prompt"),
-        llm_inputs["prompt_token_ids"],
+        inputs.get("prompt"),
+        inputs["prompt_token_ids"],
         placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN,
         repeat_count=audio_token_counts,
     )
 
     # NOTE: Create a defensive copy of the original inputs
-    return LLMInputs(prompt_token_ids=new_token_ids,
-                     prompt=new_prompt,
-                     multi_modal_data=multi_modal_data)
+    return token_inputs(prompt_token_ids=new_token_ids,
+                        prompt=new_prompt,
+                        multi_modal_data=multi_modal_data)
 
 
 class StackAudioFrames(nn.Module):
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index 89b64ba2fd43..6995f5805c5e 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -1,23 +1,30 @@
 import itertools
 from dataclasses import dataclass, field
-from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
-                    Protocol, Tuple, Union, overload)
+from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
+                    Optional, Protocol, Tuple, Union, overload)
 
 import torch
 import torch.nn as nn
 from torch.func import functional_call
 from transformers import PretrainedConfig
 
+import vllm.envs as envs
+from vllm.attention.selector import (_Backend, backend_name_to_enum,
+                                     get_global_forced_attn_backend)
 from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
                          SchedulerConfig)
+from vllm.logger import init_logger
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.model_loader.loader import build_model
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
 from vllm.model_executor.models import ModelRegistry
 from vllm.multimodal.base import NestedTensors
+from vllm.platforms import current_platform
 from vllm.sequence import IntermediateTensors
 from vllm.utils import is_pin_memory_available
 
+logger = init_logger(__name__)
+
 WeightsMapping = Mapping[str, Optional[str]]
 """If a key maps to a value of `None`, the corresponding weight is ignored."""
 
@@ -72,6 +79,9 @@ class AutoWeightsLoader:
 
     Similarly, the weight loading logic for individual parameters can be
     overridden by defining a ``weight_loader`` method.
+
+    Detailed weight loading information can be viewed by setting the
+    environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
     """
 
     def __init__(
@@ -124,31 +134,40 @@ def _load_param(
         base_prefix: str,
         param: nn.Parameter,
         weights: Iterable[Tuple[str, torch.Tensor]],
-    ) -> None:
+    ) -> Iterable[str]:
         for weight_name, weight_data in weights:
             weight_qualname = self._get_qualname(base_prefix, weight_name)
 
             if self._can_skip(weight_qualname):
+                logger.debug("Skipping weight %s", weight_qualname)
+
                 continue
 
             if weight_name != "":
-                if not self._can_ignore_unexpected(weight_qualname):
-                    raise ValueError(
-                        f"Attempted to load nested weight '{weight_qualname}' "
-                        f"into a single parameter '{base_prefix}'")
+                if self._can_ignore_unexpected(weight_qualname):
+                    logger.debug("Ignoring weight %s", weight_qualname)
 
-                continue
+                    continue
+
+                raise ValueError(
+                    f"Attempted to load nested weight '{weight_qualname}' "
+                    f"into a single parameter '{base_prefix}'")
 
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, weight_data)
 
+            logger.debug("Loaded weight %s with shape %s", weight_qualname,
+                         param.shape)
+
+            yield weight_qualname
+
     def _load_module(
         self,
         base_prefix: str,
         module: nn.Module,
         weights: Iterable[Tuple[str, torch.Tensor]],
-    ) -> None:
+    ) -> Iterable[str]:
         if isinstance(module, PPMissingLayer):
             return
 
@@ -166,30 +185,53 @@ def _load_module(
         for child_prefix, child_weights in self._groupby_prefix(weights):
             prefix = self._get_qualname(base_prefix, child_prefix)
 
-            if self._can_skip(prefix):
-                continue
-
             if child_prefix in child_modules:
-                self._load_module(prefix, child_modules[child_prefix],
-                                  child_weights)
+                if self._can_skip(prefix + "."):
+                    logger.debug("Skipping module %s", prefix)
+
+                    continue
+
+                yield from self._load_module(prefix,
+                                             child_modules[child_prefix],
+                                             child_weights)
             elif child_prefix in child_params:
-                self._load_param(prefix, child_params[child_prefix],
-                                 child_weights)
+                if self._can_skip(prefix):
+                    logger.debug("Skipping param %s", prefix)
+
+                    continue
+
+                yield from self._load_param(prefix, child_params[child_prefix],
+                                            child_weights)
             else:
-                if not self._can_ignore_unexpected(prefix):
-                    msg = f"There is no module or parameter named '{prefix}'"
-                    raise ValueError(msg)
+                can_skip_module = self._can_skip(prefix + ".")
+                can_skip_param = self._can_skip(prefix)
+                if can_skip_module or can_skip_param:
+                    logger.debug("Skipping missing %s", prefix)
+
+                    continue
+
+                can_ignore_module = self._can_ignore_unexpected(prefix + ".")
+                can_ignore_param = self._can_ignore_unexpected(prefix)
+                if can_ignore_module or can_ignore_param:
+                    logger.debug("Ignoring missing %s", prefix)
+
+                    continue
+
+                msg = (f"There is no module or parameter named '{prefix}' "
+                       f"in {type(self.module).__name__}")
+                raise ValueError(msg)
 
     def load_weights(
         self,
         weights: Iterable[Tuple[str, torch.Tensor]],
         *,
         mapper: Optional[WeightsMapper] = None,
-    ) -> None:
+    ) -> List[str]:
         if mapper is not None:
             weights = mapper.apply(weights)
 
-        self._load_module("", self.module, weights)
+        autoloaded_weights = list(self._load_module("", self.module, weights))
+        return autoloaded_weights
 
 
 def init_vllm_registered_model(
@@ -282,10 +324,11 @@ def _embedding_count_expression(embeddings: NestedTensors) -> str:
         _embedding_count_expression(inner) for inner in embeddings)
 
 
-def merge_multimodal_embeddings(input_ids: torch.Tensor,
-                                inputs_embeds: torch.Tensor,
-                                multimodal_embeddings: NestedTensors,
-                                placeholder_token_id: int) -> torch.Tensor:
+def _merge_multimodal_embeddings(
+    inputs_embeds: torch.Tensor,
+    is_multimodal: torch.Tensor,
+    multimodal_embeddings: NestedTensors,
+) -> torch.Tensor:
     """
     Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
     positions in ``inputs_embeds`` corresponding to placeholder tokens in
@@ -294,8 +337,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
     Note:
         This updates ``inputs_embeds`` in place.
     """
-    mask = (input_ids == placeholder_token_id)
-    num_expected_tokens = mask.sum().item()
+    num_expected_tokens = is_multimodal.sum().item()
     assert isinstance(num_expected_tokens, int)
 
     flattened = _flatten_embeddings(multimodal_embeddings)
@@ -305,10 +347,70 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
             f"Attempted to assign {expr} = {flattened.shape[0]} "
             f"multimodal tokens to {num_expected_tokens} placeholders")
 
-    inputs_embeds[mask] = flattened
+    inputs_embeds[is_multimodal] = flattened
     return inputs_embeds
 
 
+def embed_multimodal(
+    input_ids: torch.Tensor,
+    multimodal_token_id: int,
+    get_text_embeds: Callable[[torch.Tensor], torch.Tensor],
+    get_multimodal_embeds: Callable[[torch.Tensor], Union[torch.Tensor,
+                                                          List[torch.Tensor]]],
+) -> torch.Tensor:
+    """
+    Embed token IDs and multimodal inputs and combine their embeddings.
+
+    ``multimodal_token_id`` is used to determine whether a token ID should
+    be embedded using ``get_text_embeds`` or ``get_multimodal_embeds``.
+
+    Compared to ``merge_multimodal_embeddings`, this avoids running
+    ``get_text_embeds`` on ``input_ids[input_ids == multimodal_token_id]``
+    which causes issues when the placeholder token ID exceeds the
+    vocabulary size of the language model.
+    """
+    is_multimodal = input_ids == multimodal_token_id
+    is_text = ~is_multimodal
+
+    text_embeds = get_text_embeds(input_ids[is_text])
+    multimodal_embeds = get_multimodal_embeds(input_ids[is_multimodal])
+
+    merged_embeds = torch.empty(
+        (input_ids.shape[0], text_embeds.shape[1]),
+        dtype=text_embeds.dtype,
+        device=text_embeds.device,
+    )
+
+    merged_embeds[is_text] = text_embeds
+
+    return _merge_multimodal_embeddings(
+        merged_embeds,
+        is_multimodal,
+        multimodal_embeds,
+    )
+
+
+def merge_multimodal_embeddings(
+    input_ids: torch.Tensor,
+    inputs_embeds: torch.Tensor,
+    multimodal_embeddings: NestedTensors,
+    placeholder_token_id: int,
+) -> torch.Tensor:
+    """
+    Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
+    positions in ``inputs_embeds`` corresponding to placeholder tokens in
+    ``input_ids``.
+
+    Note:
+        This updates ``inputs_embeds`` in place.
+    """
+    return _merge_multimodal_embeddings(
+        inputs_embeds,
+        (input_ids == placeholder_token_id),
+        multimodal_embeddings,
+    )
+
+
 class LayerFn(Protocol):
 
     def __call__(self, prefix: str) -> torch.nn.Module:
@@ -462,7 +564,7 @@ def make_empty_intermediate_tensors(
 
 class LLMWrapper(nn.Module):
     """
-    To align with the key names of LoRA trained with PEFT, we need to add an 
+    To align with the key names of LoRA trained with PEFT, we need to add an
     additional layer to the llm's implementation.
     """
 
@@ -482,3 +584,29 @@ def __getattr__(self, key: str):
     def __call__(self, *args: Any, **kwargs: Any) -> Any:
         llm = super().__getattr__(self.model_name)
         return llm(*args, **kwargs)
+
+
+def get_vit_attn_backend() -> _Backend:
+    selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
+    if selected_backend is None:
+        backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
+        if backend_by_env_var is not None:
+            selected_backend = backend_name_to_enum(backend_by_env_var)
+    if selected_backend is None:
+        # For Volta and Turing GPUs, use xformers instead.
+        device_available = current_platform.has_device_capability(80)
+        if device_available:
+            from transformers.utils import is_flash_attn_2_available
+            if is_flash_attn_2_available():
+                selected_backend = _Backend.FLASH_ATTN
+            else:
+                logger.warning(
+                    "Current `vllm-flash-attn` has a bug inside vision module, "
+                    "so we use xformers backend instead. You can run "
+                    "`pip install flash-attn` to use flash-attention backend.")
+                selected_backend = _Backend.XFORMERS
+        elif current_platform.is_cpu():
+            selected_backend = _Backend.TORCH_SDPA
+        else:
+            selected_backend = _Backend.XFORMERS
+    return selected_backend
diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py
index 3bded82033c0..036789642d3c 100644
--- a/vllm/model_executor/models/xverse.py
+++ b/vllm/model_executor/models/xverse.py
@@ -27,6 +27,7 @@
 from transformers import PretrainedConfig
 
 from vllm.attention import Attention, AttentionMetadata
+from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, LoRAConfig
 from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
 from vllm.model_executor.layers.activation import SiluAndMul
@@ -220,6 +221,7 @@ def forward(
         return hidden_states, residual
 
 
+@support_torch_compile
 class XverseModel(nn.Module):
 
     def __init__(
@@ -266,6 +268,7 @@ def forward(
             residual = None
         else:
             hidden_states = intermediate_tensors["hidden_states"]
+            residual = intermediate_tensors["residual"]
         for i in range(self.start_layer, self.end_layer):
             layer = self.layers[i]
             hidden_states, residual = layer(
diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py
index d7eec818cbba..c27b1cf6ac7b 100644
--- a/vllm/model_executor/utils.py
+++ b/vllm/model_executor/utils.py
@@ -3,6 +3,7 @@
 
 import torch
 
+from vllm.platforms import current_platform
 from vllm.utils import seed_everything
 
 
@@ -28,4 +29,25 @@ def set_weight_attrs(
     for key, value in weight_attrs.items():
         assert not hasattr(
             weight, key), (f"Overwriting existing tensor attribute: {key}")
+
+        # NOTE(woosuk): During weight loading, we often do something like:
+        # narrowed_tensor = param.data.narrow(0, offset, len)
+        # narrowed_tensor.copy_(real_weight)
+        # expecting narrowed_tensor and param.data to share the same storage.
+        # However, on TPUs, narrowed_tensor will lazily propagate to the base
+        # tensor, which is param.data, leading to the redundant memory usage.
+        # This sometimes causes OOM errors during model loading. To avoid this,
+        # we sync the param tensor after its weight loader is called.
+        # TODO(woosuk): Remove this hack once we have a better solution.
+        if current_platform.is_tpu() and key == "weight_loader":
+            value = _make_synced_weight_loader(value)
         setattr(weight, key, value)
+
+
+def _make_synced_weight_loader(original_weight_loader):
+
+    def _synced_weight_loader(param, *args, **kwargs):
+        original_weight_loader(param, *args, **kwargs)
+        torch._sync(param)
+
+    return _synced_weight_loader
diff --git a/vllm/outputs.py b/vllm/outputs.py
index 07650241cb63..951976310e7a 100644
--- a/vllm/outputs.py
+++ b/vllm/outputs.py
@@ -1,13 +1,13 @@
 import time
 from dataclasses import dataclass
-from typing import List, Optional
+from typing import Dict, List, Optional
 from typing import Sequence as GenericSequence
 from typing import Union
 
 from vllm.lora.request import LoRARequest
 from vllm.sampling_params import RequestOutputKind
 from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
-                           SequenceGroup, SequenceStatus)
+                           SequenceGroup, SequenceGroupBase, SequenceStatus)
 
 
 @dataclass
@@ -114,14 +114,28 @@ def __init__(
         self.encoder_prompt_token_ids = encoder_prompt_token_ids
 
     @classmethod
-    def from_seq_group(cls, seq_group: SequenceGroup,
-                       use_cache: bool) -> Optional["RequestOutput"]:
+    def from_seq_group(
+        cls, seq_group: SequenceGroup, use_cache: bool,
+        seq_id_to_seq_group: Dict[str, SequenceGroupBase]
+    ) -> Optional["RequestOutput"]:
+        finished = seq_group.is_finished()
+
+        if seq_group.request_id in seq_id_to_seq_group:
+            group: SequenceGroupBase = seq_id_to_seq_group[
+                seq_group.request_id]
+            if finished:
+                group.finish_seq(seq_group)
+            assembled_seq_group = group.maybe_assemble_group(seq_group)
+            if assembled_seq_group is None:
+                return None
+            return cls.from_seq_group(assembled_seq_group, use_cache,
+                                      seq_id_to_seq_group)
+
         sampling_params = seq_group.sampling_params
         if sampling_params is None:
             raise ValueError(
                 "Sampling parameters are missing for a CompletionRequest.")
 
-        finished = seq_group.is_finished()
         if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
                 not finished):
             return None
@@ -136,15 +150,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
                 outputs=[],
                 finished=False)
 
-        seqs = seq_group.get_seqs()
-        if len(seqs) == 1:
-            top_n_seqs = seqs
-        else:
-            # Get the top-n sequences.
-            n = sampling_params._real_n or sampling_params.n
-            sorting_key = lambda seq: seq.get_cumulative_logprob()
-            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
-            top_n_seqs = sorted_seqs[:n]
+        top_n_seqs = seq_group.get_seqs()
 
         # Create the outputs.
         # NOTE: We need omit logprobs here explicitly because the sequence
@@ -208,7 +214,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
 
             else:
                 output = CompletionOutput(
-                    seqs.index(seq), output_text, [output_token_ids]
+                    top_n_seqs.index(seq), output_text, [output_token_ids]
                     if isinstance(output_token_ids, int) else output_token_ids,
                     seq.get_cumulative_logprob() if include_logprobs else None,
                     output_logprobs,
@@ -309,10 +315,13 @@ def __repr__(self):
 class RequestOutputFactory:
 
     @staticmethod
-    def create(seq_group: SequenceGroup, use_cache: bool = False):
+    def create(seq_group: SequenceGroup,
+               seq_id_to_seq_group: Dict[str, SequenceGroupBase],
+               use_cache: bool = False):
         # Determine the type based on a condition, for example:
         if hasattr(seq_group,
                    'embeddings') and seq_group.embeddings is not None:
             return EmbeddingRequestOutput.from_seq_group(seq_group)
         else:
-            return RequestOutput.from_seq_group(seq_group, use_cache)
+            return RequestOutput.from_seq_group(seq_group, use_cache,
+                                                seq_id_to_seq_group)
diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py
index c648862b2d75..58912158139b 100644
--- a/vllm/platforms/__init__.py
+++ b/vllm/platforms/__init__.py
@@ -58,6 +58,13 @@
 except Exception:
     pass
 
+is_neuron = False
+try:
+    import transformers_neuronx  # noqa: F401
+    is_neuron = True
+except ImportError:
+    pass
+
 if is_tpu:
     # people might install pytorch built with cuda but run on tpu
     # so we need to check tpu first
@@ -75,6 +82,9 @@
 elif is_cpu:
     from .cpu import CpuPlatform
     current_platform = CpuPlatform()
+elif is_neuron:
+    from .neuron import NeuronPlatform
+    current_platform = NeuronPlatform()
 else:
     current_platform = UnspecifiedPlatform()
 
diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py
index fa487e2f917d..30bbf5107475 100644
--- a/vllm/platforms/cuda.py
+++ b/vllm/platforms/cuda.py
@@ -137,10 +137,9 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
                             pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
                         if p2p_status != pynvml.NVML_P2P_STATUS_OK:
                             return False
-                    except pynvml.NVMLError as error:
-                        logger.error(
+                    except pynvml.NVMLError:
+                        logger.exception(
                             "NVLink detection failed. This is normal if your"
-                            " machine has no NVLink equipped.",
-                            exc_info=error)
+                            " machine has no NVLink equipped.")
                         return False
         return True
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index 00742a290e42..d36367f2bc9c 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
     TPU = enum.auto()
     XPU = enum.auto()
     CPU = enum.auto()
+    NEURON = enum.auto()
     UNSPECIFIED = enum.auto()
 
 
@@ -48,6 +49,9 @@ def is_xpu(self) -> bool:
     def is_cpu(self) -> bool:
         return self._enum == PlatformEnum.CPU
 
+    def is_neuron(self) -> bool:
+        return self._enum == PlatformEnum.NEURON
+
     def is_cuda_alike(self) -> bool:
         """Stateless version of :func:`torch.cuda.is_available`."""
         return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py
new file mode 100644
index 000000000000..07d8398eda52
--- /dev/null
+++ b/vllm/platforms/neuron.py
@@ -0,0 +1,9 @@
+from .interface import Platform, PlatformEnum
+
+
+class NeuronPlatform(Platform):
+    _enum = PlatformEnum.NEURON
+
+    @classmethod
+    def get_device_name(cls, device_id: int = 0) -> str:
+        return "neuron"
diff --git a/vllm/profiler/__init__.py b/vllm/profiler/__init__.py
new file mode 100644
index 000000000000..3e25f5cc283f
--- /dev/null
+++ b/vllm/profiler/__init__.py
@@ -0,0 +1,5 @@
+from .layerwise_profile import layerwise_profile
+
+__all__ = [
+    "layerwise_profile",
+]
diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py
new file mode 100644
index 000000000000..9d9f427e807f
--- /dev/null
+++ b/vllm/profiler/layerwise_profile.py
@@ -0,0 +1,354 @@
+import copy
+from collections import defaultdict
+from dataclasses import asdict, dataclass, field
+from typing import Callable, Dict, List, Optional, Tuple, TypeAlias, Union
+
+import pandas as pd
+from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
+from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
+from torch.autograd.profiler import FunctionEvent
+from torch.profiler import ProfilerActivity, profile
+
+from vllm.profiler.utils import (TablePrinter, event_has_module,
+                                 event_is_torch_op, event_module_repr,
+                                 event_torch_op_stack_trace, indent_string)
+
+
+@dataclass
+class _ModuleTreeNode:
+    event: _ProfilerEvent
+    parent: Optional['_ModuleTreeNode'] = None
+    children: List['_ModuleTreeNode'] = field(default_factory=list)
+    trace: str = ""
+
+    @property
+    def is_leaf(self):
+        return (self.event.children is None or len(self.event.children) == 0)
+
+    @property
+    def is_torch_op(self):
+        return event_is_torch_op(self.event)
+
+    @property
+    def is_cuda(self):
+        return (self.event.tag == _EventType.Kineto
+                and self.event.typed[1].device_type == DeviceType.CUDA)
+
+
+@dataclass
+class SummaryStatsEntry:
+    name: str
+    cuda_time_us: float
+    pct_cuda_time: float
+    invocations: int
+
+
+@dataclass
+class ModelStatsEntry:
+    name: str
+    cpu_time_us: float
+    cuda_time_us: float
+    pct_cuda_time: float
+    trace: str
+
+
+StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry]
+
+
+@dataclass
+class _StatsTreeNode:
+    entry: StatsEntry
+    children: List[StatsEntry]
+    parent: Optional[StatsEntry]
+
+
+@dataclass
+class LayerwiseProfileResults(profile):
+    _kineto_results: _ProfilerResult
+    _kineto_event_correlation_map: Dict[int,
+                                        List[_KinetoEvent]] = field(init=False)
+    _event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False)
+    _module_tree: List[_ModuleTreeNode] = field(init=False)
+    _model_stats_tree: List[_StatsTreeNode] = field(init=False)
+    _summary_stats_tree: List[_StatsTreeNode] = field(init=False)
+
+    def __post_init__(self):
+        self._build_correlation_map()
+        self._build_module_tree()
+        self._build_stats_trees()
+
+    def print_model_table(self, column_widths: Dict[str, int] = None):
+        _column_widths = dict(name=60,
+                              cpu_time_us=12,
+                              cuda_time_us=12,
+                              pct_cuda_time=12,
+                              trace=60)
+        if column_widths:
+            _column_widths.update(**column_widths)
+        filtered_model_table = [
+            (depth, row)
+            for depth, row in self._flatten_stats_tree(self._model_stats_tree)
+            if row.cuda_time_us > 0 or row.cpu_time_us > 0
+        ]
+        TablePrinter(ModelStatsEntry, _column_widths).print_table(
+            self._indent_row_names_based_on_depth(
+                filtered_model_table,
+                indent_style=lambda indent: "|" + "-" * indent + " "))
+
+    def print_summary_table(self, column_widths: Dict[str, int] = None):
+        _column_widths = dict(name=80,
+                              cuda_time_us=12,
+                              pct_cuda_time=12,
+                              invocations=15)
+        if column_widths:
+            _column_widths.update(**column_widths)
+        filtered_summary_table = [(depth, row)
+                                  for depth, row in self._flatten_stats_tree(
+                                      self._summary_stats_tree)
+                                  if row.cuda_time_us > 0]
+        TablePrinter(SummaryStatsEntry, _column_widths).print_table(
+            self._indent_row_names_based_on_depth(
+                filtered_summary_table,
+                indent_style=lambda indent: "|" + "-" * indent + " "))
+
+    def export_model_stats_table_csv(self, filename: str):
+        df = pd.DataFrame([
+            asdict(row)
+            for _, row in self._flatten_stats_tree(self._model_stats_tree)
+        ])
+        df.to_csv(filename)
+
+    def export_summary_stats_table_csv(self, filename: str):
+        df = pd.DataFrame([
+            asdict(row)
+            for _, row in self._flatten_stats_tree(self._summary_stats_tree)
+        ])
+        df.to_csv(filename)
+
+    def convert_stats_to_dict(self) -> str:
+        return {
+            "summary_stats":
+            self._convert_stats_tree_to_dict(self._summary_stats_tree),
+            "model_stats":
+            self._convert_stats_tree_to_dict(self._model_stats_tree)
+        }
+
+    @staticmethod
+    def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int,
+                                                                 StatsEntry]],
+                                         indent_style: Union[Callable[[int],
+                                                                      str],
+                                                             str] = " "):
+        indented_rows = []
+        for depth, row in depths_rows:
+            if row.cuda_time_us == 0:
+                continue
+            indented_row = copy.deepcopy(row)
+            indented_row.name = indent_string(indented_row.name, depth,
+                                              indent_style)
+            indented_rows.append(indented_row)
+        return indented_rows
+
+    def _build_correlation_map(self):
+        self._kineto_event_correlation_map = defaultdict(list)
+        for event in self._kineto_results.events():
+            self._kineto_event_correlation_map[event.correlation_id()].append(
+                event)
+
+    def _build_module_tree(self):
+        self._module_tree = []
+        event_tree = self._kineto_results.experimental_event_tree()
+
+        def _df_traversal(event: _ProfilerEvent,
+                          curr_node: Optional[_ModuleTreeNode] = None):
+
+            # For the tensor parallel case for now only look at task 1
+            if event.start_tid != 1:
+                return
+
+            if event_has_module(event):
+                node = _ModuleTreeNode(event=event, parent=curr_node)
+                if curr_node:
+                    curr_node.children.append(node)
+                else:
+                    self._module_tree.append(node)
+                curr_node = node
+
+            is_leaf = (event.children is None or len(event.children) == 0)
+            if is_leaf and curr_node:
+                node = _ModuleTreeNode(
+                    event=event,
+                    parent=curr_node,
+                    trace=event_torch_op_stack_trace(
+                        event, until=lambda x: event_has_module(x)))
+                curr_node.children.append(node)
+                curr_node = node
+
+            for child in event.children:
+                _df_traversal(child, curr_node)
+
+        for root in event_tree:
+            _df_traversal(root)
+
+    def _get_kineto_gpu_event(self, node: _ModuleTreeNode):
+        if node.event.tag != _EventType.Kineto:
+            return None
+        correlated_kineto_events = self._kineto_event_correlation_map.get(
+            node.event.correlation_id, [])
+        iterator = (x for x in correlated_kineto_events
+                    if x.device_type() == DeviceType.CUDA
+                    and x.name() == node.event.name)
+        return next(iterator, None)
+
+    def _cumulative_cuda_time(self, node: _ModuleTreeNode):
+        'Return cuda time in microseconds'
+
+        def _cumulative_cuda_time_recursive(node: _ModuleTreeNode):
+            if node.is_leaf and (gpu_kineto_event :=
+                                 self._get_kineto_gpu_event(node)):
+                return gpu_kineto_event.duration_ns() / 1000.0
+            else:
+                cumulative_cuda_time = 0
+                for child in node.children:
+                    cumulative_cuda_time += _cumulative_cuda_time_recursive(
+                        child)
+                return cumulative_cuda_time
+
+        return _cumulative_cuda_time_recursive(node)
+
+    def _total_cuda_time(self):
+        return sum(
+            [self._cumulative_cuda_time(root) for root in self._module_tree])
+
+    def _build_stats_trees(self):
+        summary_dict: Dict[str, self.StatsTreeNode] = {}
+        total_cuda_time = self._total_cuda_time()
+
+        def pct_cuda_time(cuda_time_us):
+            return (cuda_time_us / total_cuda_time) * 100
+
+        def build_summary_stats_tree_df(
+            node: _ModuleTreeNode,
+            parent: Optional[_StatsTreeNode] = None,
+            summary_trace: Tuple[str] = ()):
+
+            if event_has_module(node.event):
+                name = event_module_repr(node.event)
+                cuda_time_us = self._cumulative_cuda_time(node)
+            elif (gpu_kineto_event := self._get_kineto_gpu_event(node)):
+                name = gpu_kineto_event.name()
+                cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0
+            else:
+                return None
+
+            summary_trace = summary_trace + (name, )
+            if summary_trace in summary_dict:
+                entry = summary_dict[summary_trace].entry
+                entry.cuda_time_us += cuda_time_us
+                entry.invocations += 1
+                entry.pct_cuda_time = pct_cuda_time(entry.cuda_time_us)
+            else:
+                new_node = _StatsTreeNode(entry=SummaryStatsEntry(
+                    name=name,
+                    cuda_time_us=cuda_time_us,
+                    pct_cuda_time=pct_cuda_time(cuda_time_us),
+                    invocations=1),
+                                          children=[],
+                                          parent=parent)
+                if parent:
+                    parent.children.append(new_node)
+                summary_dict[summary_trace] = new_node
+
+            for child in node.children:
+                build_summary_stats_tree_df(child, summary_dict[summary_trace],
+                                            summary_trace)
+
+            return summary_dict[summary_trace]
+
+        self._summary_stats_tree = []
+        for root in self._module_tree:
+            self._summary_stats_tree.append(build_summary_stats_tree_df(root))
+
+        def build_model_stats_tree_df(node: _ModuleTreeNode,
+                                      parent: Optional[_StatsTreeNode] = None):
+            if event_has_module(node.event, ):
+                name = event_module_repr(node.event)
+                cuda_time_us = self._cumulative_cuda_time(node)
+                cpu_time_us = node.event.duration_time_ns / 1000
+                trace = ""
+            elif (gpu_kineto_event := self._get_kineto_gpu_event(node)):
+                name = gpu_kineto_event.name()
+                cuda_time_us = gpu_kineto_event.duration_ns() / 1000.0
+                cpu_time_us = 0
+                trace = node.trace
+            else:
+                return None
+
+            new_node = _StatsTreeNode(entry=ModelStatsEntry(
+                name=name,
+                cpu_time_us=cpu_time_us,
+                cuda_time_us=cuda_time_us,
+                pct_cuda_time=pct_cuda_time(cuda_time_us),
+                trace=trace),
+                                      parent=parent,
+                                      children=[])
+            if parent:
+                parent.children.append(new_node)
+
+            for child in node.children:
+                build_model_stats_tree_df(child, new_node)
+
+            return new_node
+
+        self._model_stats_tree = []
+        for root in self._module_tree:
+            self._model_stats_tree.append(build_model_stats_tree_df(root))
+
+    def _flatten_stats_tree(
+            self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]:
+        entries: List[Tuple[int, StatsEntry]] = []
+
+        def df_traversal(node: _StatsTreeNode, depth=0):
+            entries.append((depth, node.entry))
+            for child in node.children:
+                df_traversal(child, depth=depth + 1)
+
+        for root in tree:
+            df_traversal(root)
+
+        return entries
+
+    def _convert_stats_tree_to_dict(self,
+                                    tree: List[_StatsTreeNode]) -> List[Dict]:
+        root_dicts: List[Dict] = []
+
+        def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]):
+            curr_json_list.append({
+                "entry": asdict(node.entry),
+                "children": []
+            })
+            for child in node.children:
+                df_traversal(child, curr_json_list[-1]["children"])
+
+        for root in tree:
+            df_traversal(root, root_dicts)
+
+        return root_dicts
+
+
+class layerwise_profile(profile):
+
+    def __init__(self):
+        super().__init__(
+            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
+            record_shapes=True,
+            with_stack=True,
+            with_modules=True,
+            experimental_config=_ExperimentalConfig(verbose=True))
+
+    def __enter__(self):
+        return super().__enter__()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        super().__exit__(exc_type, exc_val, exc_tb)
+        self.results = LayerwiseProfileResults(self.profiler.kineto_results)
diff --git a/vllm/profiler/utils.py b/vllm/profiler/utils.py
new file mode 100644
index 000000000000..033035e43432
--- /dev/null
+++ b/vllm/profiler/utils.py
@@ -0,0 +1,145 @@
+import dataclasses
+from typing import Callable, Dict, List, Type, Union
+
+from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
+
+#
+# String / Print Manipulation
+#
+
+
+def trim_string_front(string, width):
+    if len(string) > width:
+        offset = len(string) - width + 3
+        string = string[offset:]
+        if len(string) > 3:
+            string = "..." + string[3:]
+    return string
+
+
+def trim_string_back(string, width):
+    if len(string) > width:
+        offset = len(string) - width + 3
+        string = string[:-offset]
+        if len(string) > 3:
+            string = string + "..."
+    return string
+
+
+class TablePrinter:
+
+    def __init__(self, row_cls: Type[dataclasses.dataclass],
+                 column_widths: Dict[str, int]):
+        self.row_cls = row_cls
+        self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
+        self.column_widths = column_widths
+        assert set(self.column_widths.keys()) == set(self.fieldnames)
+
+    def print_table(self, rows: List[dataclasses.dataclass]):
+        self._print_header()
+        self._print_line()
+        for row in rows:
+            self._print_row(row)
+
+    def _print_header(self):
+        for i, f in enumerate(self.fieldnames):
+            last = (i == len(self.fieldnames) - 1)
+            col_width = self.column_widths[f]
+            print(trim_string_back(f, col_width).ljust(col_width),
+                  end=" | " if not last else "\n")
+
+    def _print_row(self, row):
+        assert isinstance(row, self.row_cls)
+
+        for i, f in enumerate(self.fieldnames):
+            last = (i == len(self.fieldnames) - 1)
+            col_width = self.column_widths[f]
+            val = getattr(row, f)
+
+            val_str = ""
+            if isinstance(val, str):
+                val_str = trim_string_back(val, col_width).ljust(col_width)
+            elif type(val) in [float, int]:
+                val_str = f"{float(val):>.2f}".rjust(col_width)
+            else:
+                val_str = f"{val}".rjust(col_width)
+            print(val_str, end=" | " if not last else "\n")
+
+    def _print_line(self):
+        total_col_width = 0
+        for column_width in self.column_widths.values():
+            total_col_width += column_width
+        print("=" * (total_col_width + 3 * (len(self.column_widths) - 1)))
+
+
+def indent_string(string: str,
+                  indent: int,
+                  indent_style: Union[Callable[[int], str], str] = " ") -> str:
+    if indent:
+        if isinstance(indent_style, str):
+            return indent_style * indent + string
+        else:
+            return indent_style(indent) + string
+    else:
+        return string
+
+
+#
+# _ProfilerEvent utils
+#
+
+
+def event_has_module(event: _ProfilerEvent) -> bool:
+    event_type, typed_event = event.typed
+    if event_type == _EventType.PyCall:
+        return typed_event.module is not None
+    return False
+
+
+def event_is_torch_op(event: _ProfilerEvent) -> bool:
+    return event.tag == _EventType.TorchOp
+
+
+def event_arg_repr(arg) -> str:
+    if arg is None or type(arg) in [float, int, bool, str]:
+        return f"{arg}"
+    elif isinstance(arg, list):
+        return f"[{', '.join([event_arg_repr(x) for x in arg])}]"
+    elif isinstance(arg, tuple):
+        return f"({', '.join([event_arg_repr(x) for x in arg])})"
+    else:
+        assert isinstance(arg,
+                          _TensorMetadata), f"Unsupported type: {type(arg)}"
+        sizes_str = ', '.join([str(x) for x in arg.sizes])
+        return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]"
+
+
+def event_torch_op_repr(event: _ProfilerEvent) -> str:
+    assert event.tag == _EventType.TorchOp
+    args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs])
+    return f"{event.name}({args_str})".replace("aten::", "")
+
+
+def event_module_repr(event: _ProfilerEvent) -> str:
+    assert event_has_module(event)
+    module = event.typed[1].module
+    if module.parameters and len(module.parameters) > 0:
+        args_str = ', '.join(
+            [f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters])
+        return f"{module.cls_name}({args_str})"
+    else:
+        return module.cls_name
+
+
+def event_torch_op_stack_trace(curr_event: _ProfilerEvent,
+                               until: Callable[[_ProfilerEvent], bool]) -> str:
+    trace = ""
+    curr_event = curr_event.parent
+    while curr_event and not until(curr_event):
+        if event_is_torch_op(curr_event):
+            if len(trace) > 0:
+                trace += " <- "
+            trace += event_torch_op_repr(curr_event)
+        curr_event = curr_event.parent
+
+    return trace
diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py
index 4f2ae75e65f3..9993cec13d64 100644
--- a/vllm/sampling_params.py
+++ b/vllm/sampling_params.py
@@ -49,14 +49,17 @@ class GuidedDecodingParams:
 
     @staticmethod
     def from_optional(
-        json: Optional[Union[Dict, BaseModel, str]],
+        json: Optional[Union[Dict, BaseModel, str]] = None,
         regex: Optional[str] = None,
         choice: Optional[List[str]] = None,
         grammar: Optional[str] = None,
         json_object: Optional[bool] = None,
         backend: Optional[str] = None,
         whitespace_pattern: Optional[str] = None,
-    ) -> "GuidedDecodingParams":
+    ) -> Optional["GuidedDecodingParams"]:
+        if all(arg is None
+               for arg in (json, regex, choice, grammar, json_object)):
+            return None
         # Extract json schemas from pydantic models
         if isinstance(json, (BaseModel, type(BaseModel))):
             json = json.model_json_schema()
diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py
index eb491dd1554a..9d711b0debcd 100644
--- a/vllm/scalar_type.py
+++ b/vllm/scalar_type.py
@@ -1,4 +1,298 @@
-from ._core_ext import NanRepr, ScalarType
+import functools
+import struct
+from dataclasses import dataclass
+from enum import Enum
+from typing import Optional, Union
+
+
+# Mirrors enum in `core/scalar_type.hpp`
+class NanRepr(Enum):
+    NONE = 0  # nans are not supported
+    IEEE_754 = 1  # nans are: Exp all 1s, mantissa not all 0s
+    EXTD_RANGE_MAX_MIN = 2  # nans are: Exp all 1s, mantissa all 1s
+
+
+# This ScalarType class is a parallel implementation of the C++ ScalarType
+# class found in csrc/core/scalar_type.hpp.  These two classes should be kept
+# in sync until the inductor fully supports custom C++ classes.
+@dataclass(frozen=True)
+class ScalarType:
+    """
+    ScalarType can represent a wide range of floating point and integer
+    types, in particular it can be used to represent sub-byte data types
+    (something that torch.dtype currently does not support). It is also
+    capable of  representing types with a bias, i.e.:
+      `stored_value = value + bias`,
+    this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
+    of 8). The implementation for this class can be found in
+    csrc/core/scalar_type.hpp, these type signatures should be kept in sync
+    with that file.
+    """
+
+    exponent: int
+    """
+    Number of bits in the exponent if this is a floating point type
+    (zero if this an integer type)
+    """
+
+    mantissa: int
+    """
+    Number of bits in the mantissa if this is a floating point type,
+    or the number bits representing an integer excluding the sign bit if
+    this an integer type.
+    """
+
+    signed: bool
+    "If the type is signed (i.e. has a sign bit)"
+
+    bias: int
+    """
+    bias used to encode the values in this scalar type
+    (value = stored_value - bias, default 0) for example if we store the
+    type as an unsigned integer with a bias of 128 then the value 0 will be
+    stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
+    """
+
+    _finite_values_only: bool = False
+    """
+    Private: if infs are supported, used `has_infs()` instead.
+    """
+
+    nan_repr: NanRepr = NanRepr.IEEE_754
+    """
+    How NaNs are represent in this scalar type, returns NanRepr value.
+    (not applicable for integer types)
+    """
+
+    def _floating_point_max_int(self) -> int:
+        assert (
+            self.mantissa <= 52 and self.exponent <= 11
+        ), f"Cannot represent max/min as a double for type {self.__str__()}"
+
+        max_mantissa = (1 << self.mantissa) - 1
+        if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
+            max_mantissa = max_mantissa - 1
+
+        max_exponent = (1 << self.exponent) - 2
+        if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
+                or self.nan_repr == NanRepr.NONE):
+            assert (
+                self.exponent < 11
+            ), f"Cannot represent max/min as a double for type {self.__str__()}"
+            max_exponent = max_exponent + 1
+
+        # adjust the exponent to match that of a double
+        # for now we assume the exponent bias is the standard 2^(e-1) -1, (where
+        # e is the exponent bits), there is some precedent for non-standard
+        # biases, example `float8_e4m3b11fnuz` here:
+        # https://github.com/jax-ml/ml_dtypes but to avoid premature over
+        # complication we are just assuming the standard exponent bias until
+        # there is a need to support non-standard biases
+        exponent_bias = (1 << (self.exponent - 1)) - 1
+        exponent_bias_double = (1 << 10) - 1  # double e = 11
+
+        max_exponent_double = (max_exponent - exponent_bias +
+                               exponent_bias_double)
+
+        # shift the mantissa and exponent into the proper positions for an
+        # IEEE double and bitwise-or them together.
+        return (max_mantissa <<
+                (52 - self.mantissa)) | (max_exponent_double << 52)
+
+    def _floating_point_max(self) -> float:
+        double_raw = self._floating_point_max_int()
+        return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
+
+    def _raw_max(self) -> Union[int, float]:
+        if self.is_floating_point():
+            return self._floating_point_max()
+        else:
+            assert (self.size_bits < 64 or self.size_bits == 64
+                    and self.is_signed()), "Cannot represent max as an int"
+            return (1 << self.mantissa) - 1
+
+    def _raw_min(self) -> Union[int, float]:
+        if self.is_floating_point():
+            assert self.is_signed(
+            ), "We currently assume all floating point types are signed"
+            sign_bit_double = 1 << 63
+
+            max_raw = self._floating_point_max_int()
+            min_raw = max_raw | sign_bit_double
+            return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
+        else:
+            assert (not self.is_signed() or
+                    self.size_bits <= 64), "Cannot represent min as a int64_t"
+
+            if self.is_signed():
+                return -(1 << (self.size_bits - 1))
+            else:
+                return 0
+
+    @functools.cached_property
+    def id(self) -> int:
+        """
+        Convert the ScalarType to an int which can be passed to pytorch custom
+        ops. This layout of the int must be kept in sync with the C++
+        ScalarType's from_id method.
+        """
+        val = 0
+        offset = 0
+
+        def or_and_advance(member, bit_width):
+            nonlocal val
+            nonlocal offset
+            bit_mask = (1 << bit_width) - 1
+            val = val | (int(member) & bit_mask) << offset
+            offset = offset + bit_width
+
+        or_and_advance(self.exponent, 8)
+        or_and_advance(self.mantissa, 8)
+        or_and_advance(self.signed, 1)
+        or_and_advance(self.bias, 32)
+        or_and_advance(self._finite_values_only, 1)
+        or_and_advance(self.nan_repr.value, 8)
+
+        assert offset <= 64, \
+            f"ScalarType fields too big {offset} to fit into an int64"
+
+        return val
+
+    @property
+    def size_bits(self) -> int:
+        return self.exponent + self.mantissa + int(self.signed)
+
+    def min(self) -> Union[int, float]:
+        """
+        Min representable value for this scalar type.
+        (accounting for bias if there is one)
+        """
+        return self._raw_min() - self.bias
+
+    def max(self) -> Union[int, float]:
+        """
+        Max representable value for this scalar type.
+        (accounting for bias if there is one)
+        """
+        return self._raw_max() - self.bias
+
+    def is_signed(self) -> bool:
+        """
+        If the type is signed (i.e. has a sign bit), same as `signed`
+        added for consistency with:
+        https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
+        """
+        return self.signed
+
+    def is_floating_point(self) -> bool:
+        "If the type is a floating point type"
+        return self.exponent != 0
+
+    def is_integer(self) -> bool:
+        "If the type is an integer type"
+        return self.exponent == 0
+
+    def has_bias(self) -> bool:
+        "If the type has a non-zero bias"
+        return self.bias != 0
+
+    def has_infs(self) -> bool:
+        "If the type is floating point and supports infinity"
+        return not self._finite_values_only
+
+    def has_nans(self) -> bool:
+        return self.nan_repr != NanRepr.NONE.value
+
+    def is_ieee_754(self) -> bool:
+        """
+        If the type is a floating point type that follows IEEE 754
+        conventions
+        """
+        return self.nan_repr == NanRepr.IEEE_754.value and \
+            not self._finite_values_only
+
+    def __str__(self) -> str:
+        """
+        naming generally follows: https://github.com/jax-ml/ml_dtypes
+        for floating point types (leading f) the scheme is:
+        `float_em[flags]`
+        flags:
+          - no-flags: means it follows IEEE 754 conventions
+          - f: means finite values only (no infinities)
+          - n: means nans are supported (non-standard encoding)
+        for integer types the scheme is:
+          `[u]int[b]`
+          - if bias is not present it means its zero
+        """
+        if self.is_floating_point():
+            ret = "float" + str(self.size_bits) + "_e" + str(
+                self.exponent) + "m" + str(self.mantissa)
+
+            if not self.is_ieee_754():
+                if self._finite_values_only:
+                    ret = ret + "f"
+                if self.nan_repr != NanRepr.NONE:
+                    ret = ret + "n"
+
+            return ret
+        else:
+            ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
+            if self.has_bias():
+                ret = ret + "b" + str(self.bias)
+            return ret
+
+    def __repr__(self) -> str:
+        return "ScalarType." + self.__str__()
+
+    # __len__ needs to be defined (and has to throw TypeError) for pytorch's
+    # opcheck to work.
+    def __len__(self) -> int:
+        raise TypeError
+
+    #
+    # Convenience Constructors
+    #
+
+    @classmethod
+    def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
+        "Create a signed integer scalar type (size_bits includes sign-bit)."
+        ret = cls(0, size_bits - 1, True, bias if bias else 0)
+        ret.id  # noqa B018: make sure the id is cached
+        return ret
+
+    @classmethod
+    def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
+        """Create a unsigned integer scalar type."""
+        ret = cls(0, size_bits, False, bias if bias else 0)
+        ret.id  # noqa B018: make sure the id is cached
+        return ret
+
+    @classmethod
+    def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
+        """
+        Create a standard floating point type
+        (i.e. follows IEEE 754 conventions).
+        """
+        assert (mantissa > 0 and exponent > 0)
+        ret = cls(exponent, mantissa, True, 0)
+        ret.id  # noqa B018: make sure the id is cached
+        return ret
+
+    @classmethod
+    def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
+               nan_repr: NanRepr) -> 'ScalarType':
+        """
+        Create a non-standard floating point type
+        (i.e. does not follow IEEE 754 conventions).
+        """
+        assert (mantissa > 0 and exponent > 0)
+        assert (nan_repr != NanRepr.IEEE_754), (
+            "use `float_IEEE754` constructor for floating point types that "
+            "follow IEEE 754 conventions")
+        ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
+        ret.id  # noqa B018: make sure the id is cached
+        return ret
+
 
 # naming generally follows: https://github.com/jax-ml/ml_dtypes
 # for floating point types (leading f) the scheme is:
@@ -17,16 +311,17 @@ class scalar_types:
     uint4 = ScalarType.uint(4, None)
     int8 = ScalarType.int_(8, None)
     uint8 = ScalarType.uint(8, None)
-    float8_e4m3fn = ScalarType.float_(4, 3, True,
-                                      NanRepr.EXTD_RANGE_MAX_MIN.value)
+    float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
     float8_e5m2 = ScalarType.float_IEEE754(5, 2)
     float16_e8m7 = ScalarType.float_IEEE754(8, 7)
     float16_e5m10 = ScalarType.float_IEEE754(5, 10)
 
     # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
-    float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
+    float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
 
     # "gptq" types
+    uint2b2 = ScalarType.uint(2, 2)
+    uint3b4 = ScalarType.uint(3, 4)
     uint4b8 = ScalarType.uint(4, 8)
     uint8b128 = ScalarType.uint(8, 128)
 
diff --git a/vllm/sequence.py b/vllm/sequence.py
index 3bb35ea955c8..fc936fbab0ea 100644
--- a/vllm/sequence.py
+++ b/vllm/sequence.py
@@ -4,7 +4,7 @@
 from abc import ABC, abstractmethod
 from array import array
 from collections import defaultdict
-from dataclasses import dataclass
+from dataclasses import dataclass, field
 from functools import cached_property, reduce
 from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
 from typing import Sequence as GenericSequence
@@ -13,15 +13,15 @@
 import msgspec
 import torch
 
-from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs
-from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs
+from vllm.inputs.parse import is_encoder_decoder_inputs
 from vllm.lora.request import LoRARequest
 from vllm.pooling_params import PoolingParams
 from vllm.prompt_adapter.request import PromptAdapterRequest
-from vllm.sampling_params import SamplingParams
+from vllm.sampling_params import RequestOutputKind, SamplingParams
 from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
 
 if TYPE_CHECKING:
+    from vllm.inputs import SingletonInputs
     from vllm.multimodal.base import MultiModalDataDict
 
 VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@@ -29,6 +29,11 @@
 VLLM_INVALID_TOKEN_ID = -1
 
 
+def array_full(token_id: int, count: int):
+    """:class:`array` equivalent of :func:`numpy.full`."""
+    return array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
+
+
 # We use dataclass for now because it is used for
 # openai server output, and msgspec is not serializable.
 # TODO(sang): Fix it.
@@ -173,22 +178,34 @@ class SequenceData(msgspec.Struct,
     _mrope_position_delta: Optional[int] = None
 
     @staticmethod
-    def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData":
+    def from_prompt_token_counts(
+            *token_counts: Tuple[int, int]) -> "SequenceData":
+        """
+        Construct a :class:`SequenceData` instance by concatenating
+        prompt token sequences.
+
+        Each tuple represents one token sequence, expressed in the form
+        :code:`(token_id, count)`.
+        """
         if len(token_counts) == 0:
             return SequenceData.from_seqs([])
 
-        arrs = [
-            array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
-            for token_id, count in token_counts
-        ]
+        prompt_token_ids_arr = reduce(
+            array.__iadd__,
+            (array_full(token_id, count) for token_id, count in token_counts),
+        )
 
-        return SequenceData(reduce(array.__add__, arrs))
+        return SequenceData(prompt_token_ids_arr)
 
     @staticmethod
     def from_seqs(
         prompt_token_ids: GenericSequence[int],
         output_token_ids: Optional[GenericSequence[int]] = None,
     ) -> "SequenceData":
+        """
+        Construct a :class:`SequenceData` instance from prompt and output
+        token sequences.
+        """
         prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                                      prompt_token_ids)
 
@@ -362,14 +379,14 @@ def __repr__(self) -> str:
 class Sequence:
     """Stores the data, status, and block information of a sequence.
 
-    The sequence is constructed from the LLMInputs instance passed
-    in through the `inputs` constructor argument.
+    The sequence is constructed from the :code:`SingletonInputs` instance
+    passed in through the :code:`inputs` constructor argument.
 
-    For encoder/decoder models, LLMInputs encapsulates both a
+    For encoder/decoder models, SingletonInputs encapsulates both a
     decoder and encoder prompt, creating an ambiguity about which
     prompt to construct the sequence from. The `from_decoder_prompt`
     constructor argument signals whether to construct the Sequence
-    from the LLMInputs decoder prompt, or encoder prompt.
+    from the SingletonInputs decoder prompt, or encoder prompt.
 
     Args:
         seq_id: The ID of the sequence.
@@ -379,16 +396,16 @@ class Sequence:
         eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
         lora_request: LoRA request.
         prompt_adapter_request: Prompt Adapter request.
-        from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
-                             (True) or encoder prompt (False.) Must be True
-                             for decoder-only model.
+        from_decoder_prompt: Construct Sequence from SingletonInputs decoder
+                             prompt (True) or encoder prompt (False.) Must be
+                             True for decoder-only model.
 
     """
 
     def __init__(
         self,
         seq_id: int,
-        inputs: "LLMInputs",
+        inputs: "SingletonInputs",
         block_size: int,
         eos_token_id: Optional[int] = None,
         lora_request: Optional[LoRARequest] = None,
@@ -404,19 +421,19 @@ def __init__(
         self.from_decoder_prompt = from_decoder_prompt
 
         # For decoder-only models, a Sequence is constructed
-        # from an LLMInputs instance (the `inputs` arg.)
+        # from an DecoderOnlyInputs instance (the `inputs` arg.)
         #
         # For encoder/decoder models the same `inputs`
         # instance could be utilized to construct either an
         # encoder sequence or a decoder sequence, because
-        # `LLMInputs` has both decoder- and encoder-oriented
+        # `DecoderOnlyInputs` has both decoder- and encoder-oriented
         # member variables (i.e. it encapsulates both an encoder
         # and a decoder prompt.) The decision of which type of sequence
         # to generate is determined by the `from_decoder_prompt` argument.
         #
         # When constructing a encoder sequence
         # (`from_decoder_prompt` False) it matters that
-        # the `LLMInputs` instance stored in `inputs` is valid
+        # the `DecoderOnlyInputs` instance stored in `inputs` is valid
         # in the sense that its encoder-related member variables are
         # populated; below, an exception is raised if this is
         # not the case.
@@ -424,8 +441,7 @@ def __init__(
         # When constructing a decoder sequence (`from_decoder_prompt` True)
         # it does not matter whether `inputs` has its encoder-related
         # member variables populated.
-        if not (from_decoder_prompt
-                or is_valid_encoder_decoder_llm_inputs(inputs)):
+        if not (from_decoder_prompt or is_encoder_decoder_inputs(inputs)):
             raise ValueError("Cannot extract encoder input prompt from "
                              f"invalid input {inputs}; did you forget the "
                              "encoder input prompt fields?")
@@ -471,15 +487,19 @@ def prompt_token_ids(self) -> List[int]:
 
     @property
     def multi_modal_data(self) -> "MultiModalDataDict":
-        if self.inputs.get("multi_modal_data") and self.inputs.get(
-                "encoder_multi_modal_data"):
+        inputs = self.inputs
+
+        if (inputs.get("multi_modal_data")
+                and inputs.get("encoder_multi_modal_data")):
             raise ValueError(
                 "Multi-modal data in both encoder and decoder is not supported."
             )
-        inputs = self.inputs
-        return self.inputs.get("multi_modal_data") or (cast(
-            EncoderDecoderLLMInputs,
-            inputs).get("encoder_multi_modal_data")) or {}
+
+        return cast(
+            "MultiModalDataDict",
+            (inputs.get("multi_modal_data")
+             or inputs.get("encoder_multi_modal_data") or {}),
+        )
 
     @property
     def mm_processor_kwargs(self) -> Dict[str, Any]:
@@ -532,6 +552,9 @@ def get_output_token_ids_to_return(
             # (which is what we have most of the time)
             return self.data._cached_all_token_ids[-1]
 
+        if num_new_tokens == 0:
+            return []
+
         return self.data._cached_all_token_ids[-num_new_tokens:]
 
     def hash_of_block(self, logical_idx: int) -> int:
@@ -658,6 +681,7 @@ def __init__(
     ) -> None:
         self.request_id = request_id
         self.seqs = seqs
+        self.first_seq = seqs[0]
         self.arrival_time = arrival_time
         self.is_single_seq = len(seqs) == 1
         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@@ -682,15 +706,11 @@ def __init__(
 
     @property
     def prompt(self) -> Optional[str]:
-        # All sequences in the group should have the same prompt.
-        # We use the prompt of an arbitrary sequence.
-        return self.seqs[0].prompt
+        return self.first_seq.prompt
 
     @property
     def prompt_token_ids(self) -> List[int]:
-        # All sequences in the group should have the same prompt.
-        # We use the prompt of an arbitrary sequence.
-        return self.seqs[0].prompt_token_ids
+        return self.first_seq.prompt_token_ids
 
     @property
     def encoder_prompt(self) -> Optional[str]:
@@ -710,17 +730,11 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]:
 
     @property
     def multi_modal_data(self) -> "MultiModalDataDict":
-        # All sequences in the group should have the same multi-modal data.
-        # We use the multi-modal data of an arbitrary sequence.
-        return self.seqs[0].multi_modal_data
+        return self.first_seq.multi_modal_data
 
     @property
     def mm_processor_kwargs(self) -> Dict[str, Any]:
-        # As with multi-modal data, all sequences in the group should have the
-        # same processor kwargs (i.e., mm_processor_kwargs are optionally
-        # provided per request; note that are independent of whether the model
-        # decoder-only or an encoder-decoder).
-        return self.seqs[0].mm_processor_kwargs
+        return self.first_seq.mm_processor_kwargs
 
     @property
     def lora_int_id(self) -> int:
@@ -765,7 +779,7 @@ def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int,
             assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill
             self.init_multi_step(num_steps=num_lookahead_slots + 1)
 
-    def get_last_latency(self, now: float) -> Optional[float]:
+    def get_last_latency(self, now: float) -> float:
         """Sets the last token time for Request level timings."""
         # If still in prefill phase, raise Error.
         if self.is_prefill():
@@ -785,7 +799,7 @@ def maybe_set_first_token_time(self, time: float) -> None:
         #   in TPOT, rather than recalculating TTFT (since from the )
         #   POV of the user, there is simply a long generation delay.
         if (self.metrics.first_token_time is None
-                and self.seqs[0].get_output_len() == 1):
+                and self.first_seq.get_output_len() == 1):
             self.metrics.first_token_time = time
 
     def maybe_set_first_scheduled_time(self, time: float) -> None:
@@ -802,18 +816,7 @@ def set_finished_time(self, time: Optional[float]) -> None:
     def get_max_num_running_seqs(self) -> int:
         """The maximum number of sequences running in parallel in the remaining
         lifetime of the request."""
-        if self.sampling_params:
-            n = self.sampling_params.n
-            assert isinstance(n, int)
-            if n > self.num_seqs():
-                # At prompt stage, the sequence group is not yet filled up
-                # and only have one sequence running. However, in the
-                # generation stage, we will have `n` sequences
-                # running.
-                return n
-        # At sampling stages, return the number of actual sequences
-        # that are not finished yet.
-        return self.num_unfinished_seqs()
+        return 0 if self.first_seq.is_finished() else 1
 
     def get_seqs(
         self,
@@ -822,10 +825,7 @@ def get_seqs(
         if status is None:
             return self.seqs
 
-        if self.is_single_seq:
-            return self.seqs if self.seqs[0].status == status else []
-
-        return [seq for seq in self.seqs if seq.status == status]
+        return self.seqs if self.first_seq.status == status else []
 
     def is_encoder_decoder(self) -> bool:
         return self.encoder_seq is not None
@@ -833,29 +833,20 @@ def is_encoder_decoder(self) -> bool:
     def get_encoder_seq(self) -> Optional[Sequence]:
         return self.encoder_seq
 
-    def get_unfinished_seqs(self) -> List[Sequence]:
-        if self.is_single_seq:
-            return self.seqs if not self.seqs[0].is_finished() else []
-
-        return [seq for seq in self.seqs if not seq.is_finished()]
-
     def get_finished_seqs(self) -> List[Sequence]:
-        if self.is_single_seq:
-            return self.seqs if self.seqs[0].is_finished() else []
-
-        return [seq for seq in self.seqs if seq.is_finished()]
+        return self.seqs if self.first_seq.is_finished() else []
 
     def update_num_computed_tokens(self, num_new_computed_tokens: int):
         """Update number of tokens computed so far."""
-        for seq in self.seqs:
-            if not seq.is_finished():
-                seq.data.update_num_computed_tokens(num_new_computed_tokens)
+        seq = self.first_seq
+        if not seq.is_finished():
+            seq.data.update_num_computed_tokens(num_new_computed_tokens)
 
     def get_num_uncomputed_tokens(self) -> int:
         num_uncomputed_tokens = 0
-        for seq in self.seqs:
-            if not seq.is_finished():
-                num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
+        seq = self.first_seq
+        if not seq.is_finished():
+            num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
         return num_uncomputed_tokens
 
     def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
@@ -869,46 +860,14 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
 
         return len(self.get_seqs(status))
 
-    def num_unfinished_seqs(self) -> int:
-        if self.is_single_seq:
-            return 1 if not self.seqs[0].is_finished() else 0
-
-        return len(self.get_unfinished_seqs())
-
     def num_finished_seqs(self) -> int:
-        if self.is_single_seq:
-            return 1 if self.seqs[0].is_finished() else 0
-
-        return len(self.get_finished_seqs())
-
-    def find(self, seq_id: int) -> Sequence:
-        if seq_id not in self.seqs_dict:
-            raise ValueError(f"Sequence {seq_id} not found.")
-        return self.seqs_dict[seq_id]
-
-    def add(self, seq: Sequence) -> None:
-        if seq.seq_id in self.seqs_dict:
-            raise ValueError(f"Sequence {seq.seq_id} already exists.")
-        self.seqs_dict[seq.seq_id] = seq
-        self.seqs.append(seq)
-        self.is_single_seq = len(self.seqs) == 1
-
-    def remove(self, seq_id: int) -> None:
-        seq = self.seqs_dict.pop(seq_id, None)
-        if seq is None:
-            raise ValueError(f"Sequence {seq_id} not found.")
-        self.seqs.remove(seq)
-        self.is_single_seq = len(self.seqs) == 1
+        return 1 if self.first_seq.is_finished() else 0
 
     def is_finished(self) -> bool:
-        if self.is_single_seq:
-            return self.seqs[0].is_finished()
-
-        return all(seq.is_finished() for seq in self.seqs)
+        return self.first_seq.is_finished()
 
     def is_prefill(self) -> bool:
-        # Every sequence should be in the same stage.
-        return self.seqs[0].is_prefill()
+        return self.first_seq.is_prefill()
 
     def __repr__(self) -> str:
         return (f"SequenceGroup(request_id={self.request_id}, "
@@ -1175,7 +1134,7 @@ class PoolerOutput(
 
     spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
 
-    def __getitem__(self, idx: int):
+    def __getitem__(self, idx: int) -> EmbeddingSequenceGroupOutput:
         return self.outputs[idx]
 
     def __setitem__(self, idx: int, value):
@@ -1378,3 +1337,121 @@ def clone(
             last_sampled_token_ids=self.last_sampled_token_ids.clone()
             if self.last_sampled_token_ids is not None else None,
             async_callback=self.async_callback)
+
+
+@dataclass
+class SequenceGroupBase:
+    group_id: str  # the original request id before splitting
+
+    assembled_seq_group: Optional[SequenceGroup] = None
+
+    # seq id to a unique index inside this group
+    seq_id_to_index: Dict[str, int] = field(default_factory=dict)
+
+    # seq ids to be finished
+    to_be_finished: Dict[str, SequenceGroup] = field(default_factory=dict)
+
+    # seq id to finished sequences
+    finished_reqs: Dict[str, SequenceGroup] = field(default_factory=dict)
+
+    streaming: bool = False
+
+    output_produced: bool = False
+
+    @staticmethod
+    def add_request(request_id: str, engine, params, *args, **kwargs):
+        """When we are ready to add a request with request_id and params
+        into the engine, we can split the request into multiple requests.
+        """
+        raise NotImplementedError
+
+    def finish_seq(self, seq: SequenceGroup):
+        """The sequence `seq` finishes, we should record the information.
+        """
+        del self.to_be_finished[seq.request_id]
+        self.finished_reqs[seq.request_id] = seq
+
+    def maybe_assemble_group(
+            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
+        """Assemble the sequence group, for producing the final
+        output, or adding request in the engine again.
+        """
+        raise NotImplementedError
+
+
+class ParallelSampleSequenceGroup(SequenceGroupBase):
+
+    @staticmethod
+    def add_request(request_id: str, engine, params, **kwargs):
+        original_params = params
+        params = copy.deepcopy(original_params)
+        params.n = 1
+        group = ParallelSampleSequenceGroup(request_id)
+        seqs = []
+        for i in range(original_params.n):
+            request_id_i = f"{request_id}_parallel_sample_{i}"
+            group.seq_id_to_index[request_id_i] = i
+            seq_group = engine._add_processed_request(
+                request_id_i,
+                params=params,
+                **kwargs,
+            )  # type: ignore
+            assert seq_group is not None
+            engine.seq_id_to_seq_group[request_id_i] = group
+            group.to_be_finished[request_id_i] = seq_group
+            seqs.append(seq_group.seqs[0])
+
+        # for parallel sampling, the `assembled_seq_group` is always
+        # available, since we have all the sequences ready, and they
+        # will not change.
+        group.assembled_seq_group = SequenceGroup(
+            request_id=request_id,
+            seqs=seqs,
+            arrival_time=seq_group.arrival_time,
+            sampling_params=original_params,
+            lora_request=seq_group.lora_request,
+            embeddings=seq_group.embeddings,
+            pooling_params=seq_group.pooling_params,
+            encoder_seq=seq_group.encoder_seq,
+            trace_headers=seq_group.trace_headers,
+            prompt_adapter_request=seq_group.prompt_adapter_request,
+            priority=seq_group.priority,
+        )
+
+        group.streaming = params.output_kind == RequestOutputKind.DELTA
+        group.output_produced = False
+
+    def maybe_assemble_group(
+            self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:
+
+        # in the streaming mode, we will return the assembled sequence
+        # for the first sequence, and then return None for the rest of
+        # sequences
+        if self.streaming:
+            if self.seq_id_to_index[seq_group.request_id] == 0:
+                return self.assembled_seq_group
+            return None
+
+        # in the non-streaming mode, we will return the assembled sequence
+        # once after all sequences finish, and then return None for the
+        # rest of the time
+
+        if len(self.to_be_finished) > 0:
+            return None
+
+        assert self.assembled_seq_group is not None
+        params = self.assembled_seq_group.sampling_params
+        assert isinstance(params, SamplingParams)
+        if not self.output_produced:
+            self.output_produced = True
+            if params._real_n is not None:
+                # Get the top-n sequences.
+                n = params._real_n or params.n
+                seqs = self.assembled_seq_group.seqs
+                sorting_key = lambda seq: seq.get_cumulative_logprob()
+                sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
+                top_n_seqs = sorted_seqs[:n]
+                self.assembled_seq_group.seqs = top_n_seqs
+            return self.assembled_seq_group
+        if self.output_produced:
+            return None
diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py
index aaf6ec5f508c..3aa999fcb9eb 100644
--- a/vllm/spec_decode/draft_model_runner.py
+++ b/vllm/spec_decode/draft_model_runner.py
@@ -179,7 +179,7 @@ def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
                 return False
 
         # TODO: Add support for other attn backends
-        if self.attn_backend.get_name() != "flash-attn":
+        if self.attn_backend.get_name() != "FLASH_ATTN":
             return False
 
         # TODO: Add support for LORA
diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py
index 36e5e1774aa0..a777e5c3f22a 100644
--- a/vllm/spec_decode/ngram_worker.py
+++ b/vllm/spec_decode/ngram_worker.py
@@ -67,9 +67,16 @@ def sampler_output(
                 execute_model_req.seq_group_metadata_list):
             seq_data = next(iter(seq_group_metadata.seq_data.values()))
 
+            seq_len = seq_data.get_len()
+            # When seq_len is less than 3072 (3K), we use CPU to perform
+            # the ngram match. Otherwise, we use the device specified in
+            # the model config (normally GPU). 3072 is a rough threshold
+            # based on profiling on H100, and it can be adjusted based
+            # on the actual performance on different hardware.
+            cur_device = "cpu" if seq_len < 3072 else self.device
             input_ids = torch.as_tensor(seq_data.get_token_ids(),
                                         dtype=torch.long,
-                                        device=self.device)
+                                        device=cur_device)
             input_length = seq_data.get_len()
 
             for ngram_size in range(
@@ -91,17 +98,15 @@ def sampler_output(
                 # first_match includes "values" (bool), indicating whether
                 # the match is found, and "indices", indicating the index
                 # of the first match.
-                # Note that "first_match.values.item()" triggers GPU-CPU
-                # sync so it is a bit inefficient, but we have not found
-                # a better way to do this.
                 first_match = matches.max(dim=-1)
                 if first_match.values.item():
                     proposal_start_idx = first_match.indices.add_(ngram_size)
                     spec_indices = (
                         proposal_start_idx).repeat(sample_len) + torch.arange(
-                            sample_len, device=self.device)
+                            sample_len, device=cur_device)
                     spec_indices.clamp_(max=input_ids.shape[-1] - 1)
-                    res = input_ids.gather(dim=-1, index=spec_indices)
+                    res = input_ids.gather(dim=-1,
+                                           index=spec_indices).to(self.device)
                     token_id_list.append(res)
                     token_prob_list.append(
                         torch.nn.functional.one_hot(
diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py
index 50d2767a0375..316db43502d3 100644
--- a/vllm/spec_decode/spec_decode_worker.py
+++ b/vllm/spec_decode/spec_decode_worker.py
@@ -184,7 +184,7 @@ def create_worker(
 
         if not disable_mqa_scorer:
             if scorer_worker.model_runner.attn_backend.get_name(
-            ) != "flash-attn":
+            ) != "FLASH_ATTN":
                 disable_mqa_scorer = True
                 logger.info(
                     "[Speculative Decoding] Disabling MQA scorer as the "
diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py
index b33449c42ecf..9bd2531d7a15 100644
--- a/vllm/transformers_utils/config.py
+++ b/vllm/transformers_utils/config.py
@@ -23,8 +23,8 @@
                                              MedusaConfig, MllamaConfig,
                                              MLPSpeculatorConfig, MPTConfig,
                                              NemotronConfig, NVLM_D_Config,
-                                             Qwen2VLConfig, RWConfig,
-                                             SolarConfig, UltravoxConfig)
+                                             RWConfig, SolarConfig,
+                                             UltravoxConfig)
 # yapf: enable
 from vllm.transformers_utils.utils import check_gguf_file
 
@@ -57,7 +57,6 @@
     "NVLM_D": NVLM_D_Config,
     "solar": SolarConfig,
     "ultravox": UltravoxConfig,
-    "qwen2_vl": Qwen2VLConfig,
     **_CONFIG_REGISTRY_OVERRIDE_HF
 }
 
@@ -91,6 +90,43 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
         return False
 
 
+def patch_rope_scaling(config: PretrainedConfig) -> None:
+    """Provide backwards compatibility for RoPE."""
+    text_config = getattr(config, "text_config", None)
+    if text_config is not None:
+        patch_rope_scaling(text_config)
+
+    rope_scaling = getattr(config, "rope_scaling", None)
+    if rope_scaling is not None:
+        patch_rope_scaling_dict(rope_scaling)
+
+
+def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None:
+    if "rope_type" not in rope_scaling and "type" in rope_scaling:
+        rope_scaling["rope_type"] = rope_scaling["type"]
+        logger.info("Replacing legacy 'type' key with 'rope_type'")
+
+    if "rope_type" not in rope_scaling:
+        raise ValueError("rope_scaling should have a 'rope_type' key")
+
+    if rope_scaling["rope_type"] == "su":
+        rope_scaling["rope_type"] = "longrope"
+        logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
+    elif rope_scaling["rope_type"] == "mrope":
+        assert "mrope_section" in rope_scaling
+        rope_scaling["rope_type"] = "default"
+        logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
+
+
+def uses_mrope(config: PretrainedConfig) -> bool:
+    """Detect if the model with this config uses M-ROPE."""
+    rope_scaling = getattr(config, "rope_scaling", None)
+    if rope_scaling is None:
+        return False
+
+    return "mrope_section" in rope_scaling
+
+
 def get_config(
     model: Union[str, Path],
     trust_remote_code: bool,
@@ -191,9 +227,73 @@ def get_config(
             )
             config.update({key: value})
 
+    patch_rope_scaling(config)
+
     return config
 
 
+def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
+    """Try to register HF model configuration class to serialize by value
+
+        With trust_remote_code, the config class is typically an instance of a
+        custom class imported from the HF modules cache. The class will not be
+        importable in spawned workers by default (and won't exist at all on
+        other nodes), which breaks serialization of the config.
+
+        In this function we tell the cloudpickle serialization library to pass
+        instances of these generated classes by value instead of by reference,
+        i.e. the class definition is serialized along with its data so that the
+        class module does not need to be importable on the receiving end. This
+        registration only works if the modules cache has already been
+        initialized.
+
+
+        See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
+    """
+    if not trust_remote_code:
+        return
+
+    try:
+        import transformers_modules
+    except ImportError:
+        logger.debug("Could not import transformers_modules used for remote"
+                     " code. If remote code is not needed remove"
+                     " `--trust-remote-code`.")
+        return
+
+    try:
+        import cloudpickle
+        cloudpickle.register_pickle_by_value(transformers_modules)
+
+        # ray vendors its own version of cloudpickle
+        from vllm.executor.ray_utils import ray
+        if ray:
+            ray.cloudpickle.register_pickle_by_value(transformers_modules)
+
+        # multiprocessing uses pickle to serialize arguments when using spawn
+        # Here we get pickle to use cloudpickle to serialize ModelConfig objects
+        # that contain instances of the custom config class to avoid
+        # serialization problems if the generated module (and model) has a `.`
+        # in its name
+        import multiprocessing
+        import pickle
+
+        from vllm.config import ModelConfig
+
+        def _reduce_modelconfig(mc: ModelConfig):
+            return (pickle.loads, (cloudpickle.dumps(mc), ))
+
+        multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
+
+    except Exception as e:
+        logger.warning(
+            "Unable to register remote classes used by"
+            " trust_remote_code with by-value serialization. This may"
+            " lead to a later error. If remote code is not needed"
+            " remove `--trust-remote-code`",
+            exc_info=e)
+
+
 def load_params_config(model, revision) -> PretrainedConfig:
     # This function loads a params.json config which
     # should be used when loading models in mistral format
diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py
index 8d6385d42d00..f0d79197a82c 100644
--- a/vllm/transformers_utils/configs/__init__.py
+++ b/vllm/transformers_utils/configs/__init__.py
@@ -14,8 +14,6 @@
 from vllm.transformers_utils.configs.mpt import MPTConfig
 from vllm.transformers_utils.configs.nemotron import NemotronConfig
 from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config
-from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
-                                                     Qwen2VLVisionConfig)
 from vllm.transformers_utils.configs.solar import SolarConfig
 from vllm.transformers_utils.configs.ultravox import UltravoxConfig
 
@@ -35,6 +33,4 @@
     "NVLM_D_Config",
     "SolarConfig",
     "UltravoxConfig",
-    "Qwen2VLConfig",
-    "Qwen2VLVisionConfig",
 ]
diff --git a/vllm/transformers_utils/configs/qwen2vl.py b/vllm/transformers_utils/configs/qwen2vl.py
deleted file mode 100644
index 92dd962790bc..000000000000
--- a/vllm/transformers_utils/configs/qwen2vl.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# coding=utf-8
-# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""Qwen2VL model configuration"""
-
-import os
-from typing import Union
-
-from transformers import PretrainedConfig
-
-
-class Qwen2VLVisionConfig(PretrainedConfig):
-    model_type = "qwen2_vl"
-
-    def __init__(
-        self,
-        depth=32,
-        embed_dim=1280,
-        hidden_size=3584,
-        hidden_act="quick_gelu",
-        mlp_ratio=4,
-        num_heads=16,
-        in_channels=3,
-        patch_size=14,
-        spatial_merge_size=2,
-        temporal_patch_size=2,
-        **kwargs,
-    ):
-        super().__init__(**kwargs)
-
-        self.depth = depth
-        self.embed_dim = embed_dim
-        self.hidden_size = hidden_size
-        self.hidden_act = hidden_act
-        self.mlp_ratio = mlp_ratio
-        self.num_heads = num_heads
-        self.in_channels = in_channels
-        self.patch_size = patch_size
-        self.spatial_merge_size = spatial_merge_size
-        self.temporal_patch_size = temporal_patch_size
-
-    @classmethod
-    def from_pretrained(cls, pretrained_model_name_or_path: Union[str,
-                                                                  os.PathLike],
-                        **kwargs) -> "PretrainedConfig":
-        cls._set_token_in_kwargs(kwargs)
-
-        config_dict, kwargs = cls.get_config_dict(
-            pretrained_model_name_or_path, **kwargs)
-
-        if config_dict.get("model_type") == "qwen2_vl":
-            config_dict = config_dict["vision_config"]
-
-        return cls.from_dict(config_dict, **kwargs)
-
-
-class Qwen2VLConfig(PretrainedConfig):
-
-    def __init__(
-        self,
-        vocab_size=152064,
-        hidden_size=8192,
-        intermediate_size=29568,
-        num_hidden_layers=80,
-        num_attention_heads=64,
-        num_key_value_heads=8,
-        hidden_act="silu",
-        max_position_embeddings=32768,
-        initializer_range=0.02,
-        rms_norm_eps=1e-05,
-        use_cache=True,
-        tie_word_embeddings=False,
-        rope_theta=1000000.0,
-        use_sliding_window=False,
-        sliding_window=4096,
-        max_window_layers=80,
-        attention_dropout=0.0,
-        vision_config=None,
-        rope_scaling=None,
-        **kwargs,
-    ):
-        if isinstance(vision_config, dict):
-            self.vision_config = Qwen2VLVisionConfig(**vision_config)
-        elif vision_config is None:
-            self.vision_config = Qwen2VLVisionConfig()
-
-        self.vocab_size = vocab_size
-        self.max_position_embeddings = max_position_embeddings
-        self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size
-        self.num_hidden_layers = num_hidden_layers
-        self.num_attention_heads = num_attention_heads
-        self.use_sliding_window = use_sliding_window
-        self.sliding_window = sliding_window
-        self.max_window_layers = max_window_layers
-
-        # for backward compatibility
-        if num_key_value_heads is None:
-            num_key_value_heads = num_attention_heads
-
-        self.num_key_value_heads = num_key_value_heads
-        self.hidden_act = hidden_act
-        self.initializer_range = initializer_range
-        self.rms_norm_eps = rms_norm_eps
-        self.use_cache = use_cache
-        self.rope_theta = rope_theta
-        self.attention_dropout = attention_dropout
-        self.rope_scaling = rope_scaling
-
-        # NOTE: the following section from original transformers config
-        # for Qwen2-VL is commented out to address rope config loading issue
-        #
-        # if self.rope_scaling is not None and "type" in self.rope_scaling:
-        #     if self.rope_scaling["type"] == "mrope":
-        #         self.rope_scaling["type"] = "default"
-        #     self.rope_scaling["rope_type"] = self.rope_scaling["type"]
-        # rope_config_validation(self)
-
-        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py
index 2b418f3603a0..7c8423d2b0a3 100644
--- a/vllm/transformers_utils/detokenizer.py
+++ b/vllm/transformers_utils/detokenizer.py
@@ -1,8 +1,10 @@
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, List, Optional
 
 from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
                            Sequence, SequenceGroup)
 
+from .detokenizer_utils import (convert_prompt_ids_to_tokens,
+                                detokenize_incrementally)
 from .tokenizer import AnyTokenizer
 from .tokenizer_group import BaseTokenizerGroup
 
@@ -88,7 +90,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
             prefix_offset = next_iter_prefix_offset
             read_offset = next_iter_read_offset
             if prev_tokens is None:
-                prev_tokens = next_iter_tokens
+                prev_tokens = next_iter_tokens.copy()
             else:
                 prev_tokens.extend(next_iter_tokens)
 
@@ -161,167 +163,3 @@ def decode_sequence_inplace(self, seq: Sequence,
         seq.output_text += new_decoded_token_text
 
         return len(new_decoded_token_text)
-
-
-def _replace_none_with_empty(tokens: List[Optional[str]]):
-    for i, token in enumerate(tokens):
-        if token is None:
-            tokens[i] = ""
-
-
-def _convert_tokens_to_string_with_added_encoders(
-    tokenizer: AnyTokenizer,
-    output_tokens: List[str],
-    skip_special_tokens: bool,
-    spaces_between_special_tokens: bool,
-) -> str:
-    # Adapted from
-    # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
-    # NOTE(woosuk): The following code is slow because it runs a for loop over
-    # the output_tokens. In Python, running a for loop over a list can be slow
-    # even when the loop body is very simple.
-    sub_texts: List[str] = []
-    current_sub_text: List[str] = []
-    all_special_tokens = set(tokenizer.all_special_tokens)
-    for token in output_tokens:
-        if skip_special_tokens and token in all_special_tokens:
-            continue
-        if token in tokenizer.get_added_vocab():
-            if current_sub_text:
-                sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
-                sub_texts.append(sub_text)
-                current_sub_text = []
-            sub_texts.append(token)
-        else:
-            current_sub_text.append(token)
-    if current_sub_text:
-        sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
-        sub_texts.append(sub_text)
-    if spaces_between_special_tokens:
-        return " ".join(sub_texts)
-    else:
-        return "".join(sub_texts)
-
-
-# 5 is an arbitrary value that should work for all
-# tokenizers (bigger = more conservative).
-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
-
-
-def convert_prompt_ids_to_tokens(
-    tokenizer: AnyTokenizer,
-    prompt_ids: List[int],
-    skip_special_tokens: bool = False,
-) -> Tuple[List[str], int, int]:
-    """Converts the prompt ids to tokens and returns the tokens and offsets
-    for incremental detokenization.
-
-    Note that not all tokens are converted to strings. Only the tokens that
-    are necessary for incremental detokenization are converted to strings.
-    """
-    # We do not need to convert the whole prompt to tokens.
-    # Offset a little more in case we have special tokens.
-    new_tokens = tokenizer.convert_ids_to_tokens(
-        prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
-        skip_special_tokens=skip_special_tokens)
-    read_offset = len(new_tokens)
-    prefix_offset = max(
-        read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
-    # This is required to guard against out-of-vocab prompt token ids
-    _replace_none_with_empty(new_tokens)  # type: ignore[arg-type]
-    return new_tokens, prefix_offset, read_offset
-
-
-# Based on
-# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
-# under Apache 2.0 license
-def detokenize_incrementally(
-    tokenizer: AnyTokenizer,
-    all_input_ids: List[int],
-    prev_tokens: Optional[List[str]],
-    prefix_offset: int,
-    read_offset: int,
-    skip_special_tokens: bool = False,
-    spaces_between_special_tokens: bool = True,
-) -> Tuple[List[str], str, int, int]:
-    """Detokenizes the input ids incrementally and returns the new tokens
-    and the new text.
-
-    If `prev_tokens` is None, this function will convert the input ids to
-    tokens and return the tokens and the new text. Otherwise, it will return the
-    new tokens and the new text.
-
-    This function will also return the new prefix offset and the new read
-    offset to be used in the next iteration.
-
-    The offsets are necessary to defeat cleanup algorithms in the decode which
-    decide to add a space or not depending on the surrounding ids.
-
-    Args:
-        tokenizer: The tokenizer to use.
-        all_input_ids: The input ids. The last id is the new token id.
-        prev_tokens: The previous tokens. If None, this function will convert
-            the input ids to tokens and return the tokens and the new text.
-        prefix_offset: The prefix offset.
-        read_offset: The read offset.
-        skip_special_tokens: Whether to skip special tokens.
-        spaces_between_special_tokens: Whether to add spaces between special
-            tokens.
-    """
-    new_token_id = all_input_ids[-1]
-    # This is the first iteration for this sequence
-    is_first_iter = prev_tokens is None
-    if is_first_iter:
-        (prev_tokens, prefix_offset,
-         read_offset) = convert_prompt_ids_to_tokens(
-             tokenizer,
-             all_input_ids[:-1],
-             skip_special_tokens=skip_special_tokens)
-    assert prev_tokens is not None
-
-    # If the new token id is out of bounds, return an empty string.
-    if 0 <= new_token_id < len(tokenizer):
-        # Put new_token_id in a list so skip_special_tokens is respected
-        new_tokens = tokenizer.convert_ids_to_tokens(
-            [new_token_id], skip_special_tokens=skip_special_tokens)
-        if isinstance(new_tokens, str):
-            new_tokens = [new_tokens]
-    else:
-        new_tokens = [""]
-    output_tokens = prev_tokens + new_tokens
-
-    # If this is the first iteration, return all tokens.
-    if is_first_iter:
-        new_tokens = output_tokens
-
-    # The prefix text is necessary only to defeat cleanup algorithms in
-    # the decode which decide to add a space or not depending on the
-    # surrounding ids.
-    if tokenizer.is_fast or not tokenizer.get_added_vocab():
-        prefix_text = tokenizer.convert_tokens_to_string(
-            output_tokens[prefix_offset:read_offset])
-        new_text = tokenizer.convert_tokens_to_string(
-            output_tokens[prefix_offset:])
-    else:
-        prefix_text = _convert_tokens_to_string_with_added_encoders(
-            tokenizer,
-            output_tokens[prefix_offset:read_offset],
-            skip_special_tokens=skip_special_tokens,
-            spaces_between_special_tokens=spaces_between_special_tokens,
-        )
-        new_text = _convert_tokens_to_string_with_added_encoders(
-            tokenizer,
-            output_tokens[prefix_offset:],
-            skip_special_tokens=skip_special_tokens,
-            spaces_between_special_tokens=spaces_between_special_tokens,
-        )
-
-    if len(new_text) <= len(prefix_text) or new_text.endswith("ļæ½"):
-        # utf-8 char at the end means it's a potential unfinished byte sequence
-        # from byte fallback tokenization.
-        # If it's in the middle, it's probably a real invalid id generated
-        # by the model
-        return new_tokens, "", prefix_offset, read_offset
-
-    new_text = new_text[len(prefix_text):]
-    return new_tokens, new_text, read_offset, len(output_tokens)
diff --git a/vllm/transformers_utils/detokenizer_utils.py b/vllm/transformers_utils/detokenizer_utils.py
new file mode 100644
index 000000000000..37ff8a236e79
--- /dev/null
+++ b/vllm/transformers_utils/detokenizer_utils.py
@@ -0,0 +1,167 @@
+from typing import List, Optional, Tuple
+
+from .tokenizer import AnyTokenizer
+
+
+def _replace_none_with_empty(tokens: List[Optional[str]]):
+    for i, token in enumerate(tokens):
+        if token is None:
+            tokens[i] = ""
+
+
+def _convert_tokens_to_string_with_added_encoders(
+    tokenizer: AnyTokenizer,
+    output_tokens: List[str],
+    skip_special_tokens: bool,
+    spaces_between_special_tokens: bool,
+) -> str:
+    # Adapted from
+    # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921
+    # NOTE(woosuk): The following code is slow because it runs a for loop over
+    # the output_tokens. In Python, running a for loop over a list can be slow
+    # even when the loop body is very simple.
+    sub_texts: List[str] = []
+    current_sub_text: List[str] = []
+    all_special_tokens = set(tokenizer.all_special_tokens)
+    for token in output_tokens:
+        if skip_special_tokens and token in all_special_tokens:
+            continue
+        if token in tokenizer.get_added_vocab():
+            if current_sub_text:
+                sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
+                sub_texts.append(sub_text)
+                current_sub_text = []
+            sub_texts.append(token)
+        else:
+            current_sub_text.append(token)
+    if current_sub_text:
+        sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
+        sub_texts.append(sub_text)
+    if spaces_between_special_tokens:
+        return " ".join(sub_texts)
+    else:
+        return "".join(sub_texts)
+
+
+# 5 is an arbitrary value that should work for all
+# tokenizers (bigger = more conservative).
+INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5
+
+
+def convert_prompt_ids_to_tokens(
+    tokenizer: AnyTokenizer,
+    prompt_ids: List[int],
+    skip_special_tokens: bool = False,
+) -> Tuple[List[str], int, int]:
+    """Converts the prompt ids to tokens and returns the tokens and offsets
+    for incremental detokenization.
+
+    Note that not all tokens are converted to strings. Only the tokens that
+    are necessary for incremental detokenization are converted to strings.
+    """
+    # We do not need to convert the whole prompt to tokens.
+    # Offset a little more in case we have special tokens.
+    new_tokens = tokenizer.convert_ids_to_tokens(
+        prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:],
+        skip_special_tokens=skip_special_tokens)
+    read_offset = len(new_tokens)
+    prefix_offset = max(
+        read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0)
+    # This is required to guard against out-of-vocab prompt token ids
+    _replace_none_with_empty(new_tokens)  # type: ignore[arg-type]
+    return new_tokens, prefix_offset, read_offset
+
+
+# Based on
+# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
+# under Apache 2.0 license
+def detokenize_incrementally(
+    tokenizer: AnyTokenizer,
+    all_input_ids: List[int],
+    prev_tokens: Optional[List[str]],
+    prefix_offset: int,
+    read_offset: int,
+    skip_special_tokens: bool = False,
+    spaces_between_special_tokens: bool = True,
+) -> Tuple[List[str], str, int, int]:
+    """Detokenizes the input ids incrementally and returns the new tokens
+    and the new text.
+
+    If `prev_tokens` is None, this function will convert the input ids to
+    tokens and return the tokens and the new text. Otherwise, it will return the
+    new tokens and the new text.
+
+    This function will also return the new prefix offset and the new read
+    offset to be used in the next iteration.
+
+    The offsets are necessary to defeat cleanup algorithms in the decode which
+    decide to add a space or not depending on the surrounding ids.
+
+    Args:
+        tokenizer: The tokenizer to use.
+        all_input_ids: The input ids. The last id is the new token id.
+        prev_tokens: The previous tokens. If None, this function will convert
+            the input ids to tokens and return the tokens and the new text.
+        prefix_offset: The prefix offset.
+        read_offset: The read offset.
+        skip_special_tokens: Whether to skip special tokens.
+        spaces_between_special_tokens: Whether to add spaces between special
+            tokens.
+    """
+    new_token_id = all_input_ids[-1]
+    # This is the first iteration for this sequence
+    is_first_iter = prev_tokens is None
+    if is_first_iter:
+        (prev_tokens, prefix_offset,
+         read_offset) = convert_prompt_ids_to_tokens(
+             tokenizer,
+             all_input_ids[:-1],
+             skip_special_tokens=skip_special_tokens)
+    assert prev_tokens is not None
+
+    # If the new token id is out of bounds, return an empty string.
+    if 0 <= new_token_id < len(tokenizer):
+        # Put new_token_id in a list so skip_special_tokens is respected
+        new_tokens = tokenizer.convert_ids_to_tokens(
+            [new_token_id], skip_special_tokens=skip_special_tokens)
+        if isinstance(new_tokens, str):
+            new_tokens = [new_tokens]
+    else:
+        new_tokens = [""]
+    output_tokens = prev_tokens + new_tokens
+
+    # If this is the first iteration, return all tokens.
+    if is_first_iter:
+        new_tokens = output_tokens
+
+    # The prefix text is necessary only to defeat cleanup algorithms in
+    # the decode which decide to add a space or not depending on the
+    # surrounding ids.
+    if tokenizer.is_fast or not tokenizer.get_added_vocab():
+        prefix_text = tokenizer.convert_tokens_to_string(
+            output_tokens[prefix_offset:read_offset])
+        new_text = tokenizer.convert_tokens_to_string(
+            output_tokens[prefix_offset:])
+    else:
+        prefix_text = _convert_tokens_to_string_with_added_encoders(
+            tokenizer,
+            output_tokens[prefix_offset:read_offset],
+            skip_special_tokens=skip_special_tokens,
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
+        new_text = _convert_tokens_to_string_with_added_encoders(
+            tokenizer,
+            output_tokens[prefix_offset:],
+            skip_special_tokens=skip_special_tokens,
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
+
+    if len(new_text) <= len(prefix_text) or new_text.endswith("ļæ½"):
+        # utf-8 char at the end means it's a potential unfinished byte sequence
+        # from byte fallback tokenization.
+        # If it's in the middle, it's probably a real invalid id generated
+        # by the model
+        return new_tokens, "", prefix_offset, read_offset
+
+    new_text = new_text[len(prefix_text):]
+    return new_tokens, new_text, read_offset, len(output_tokens)
diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py
index 98663f7f0bd0..f1523667b046 100644
--- a/vllm/transformers_utils/processor.py
+++ b/vllm/transformers_utils/processor.py
@@ -1,3 +1,4 @@
+from functools import lru_cache
 from typing import Any, cast
 
 
@@ -37,6 +38,9 @@ def get_processor(
     return cast(ProcessorMixin, processor)
 
 
+cached_get_processor = lru_cache(get_processor)
+
+
 def get_image_processor(
     processor_name: str,
     *args: Any,
diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py
index aae10d3ee25f..a12d3136e8d6 100644
--- a/vllm/transformers_utils/tokenizers/mistral.py
+++ b/vllm/transformers_utils/tokenizers/mistral.py
@@ -2,11 +2,13 @@
 import re
 from dataclasses import dataclass
 from pathlib import Path
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
 
+import huggingface_hub
 from huggingface_hub import HfApi, hf_hub_download
+from mistral_common.protocol.instruct.request import ChatCompletionRequest
+from mistral_common.tokens.instruct.request import FIMRequest
 # yapf: disable
-from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
 from mistral_common.tokens.tokenizers.mistral import (
     MistralTokenizer as PublicMistralTokenizer)
 # yapf: enable
@@ -15,15 +17,39 @@
 from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
                                                      Tekkenizer)
 
+from vllm.logger import init_logger
+
 if TYPE_CHECKING:
     from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
 
+logger = init_logger(__name__)
+
 
 @dataclass
 class Encoding:
     input_ids: List[int]
 
 
+def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
+    repo_cache = os.path.join(
+        huggingface_hub.constants.HF_HUB_CACHE,
+        huggingface_hub.constants.REPO_ID_SEPARATOR.join(
+            ["models", *repo_id.split("/")]))
+
+    if revision is None:
+        revision_file = os.path.join(repo_cache, "refs", "main")
+        if os.path.isfile(revision_file):
+            with open(revision_file) as file:
+                revision = file.read()
+
+    if revision:
+        revision_dir = os.path.join(repo_cache, "snapshots", revision)
+        if os.path.isdir(revision_dir):
+            return os.listdir(revision_dir)
+
+    return []
+
+
 def find_tokenizer_file(files: List[str]):
     file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
 
@@ -51,18 +77,12 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
             # Make sure special tokens will not raise
             tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
 
-            self._vocab = {
-                token: idx
-                for idx, token in enumerate(tokenizer_.vocab())
-            }
         elif isinstance(tokenizer_, SentencePieceTokenizer):
-            self._vocab = {
-                token: idx
-                for idx, token in enumerate(tokenizer_.vocab())
-            }
+            pass
         else:
             raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
 
+        self._vocab = tokenizer_.vocab()
         self.tokenizer = tokenizer_
 
     @classmethod
@@ -90,9 +110,16 @@ def from_pretrained(cls,
     @staticmethod
     def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
                                             revision: Optional[str]) -> str:
-        api = HfApi()
-        repo_info = api.model_info(tokenizer_name)
-        files = [s.rfilename for s in repo_info.siblings]
+        try:
+            hf_api = HfApi()
+            files = hf_api.list_repo_files(repo_id=tokenizer_name,
+                                           revision=revision)
+        except ConnectionError as exc:
+            files = list_local_repo_files(repo_id=tokenizer_name,
+                                          revision=revision)
+
+            if len(files) == 0:
+                raise exc
 
         filename = find_tokenizer_file(files)
 
@@ -149,7 +176,10 @@ def __call__(
         return Encoding(input_ids=input_ids)
 
     def get_vocab(self) -> Dict[str, int]:
-        return self._vocab
+        # Convert to a Dict[str, int] to match protocol, but this is a lossy
+        # conversion. There may be multiple token ids that decode to the same
+        # string due to partial UTF-8 byte sequences being converted to ļæ½
+        return {token: idx for idx, token in enumerate(self._vocab)}
 
     def get_added_vocab(self) -> Dict[str, int]:
         # Mistral tokenizers have no added vocabulary
@@ -161,11 +191,19 @@ def encode(self, prompt: str) -> List[int]:
         # For chat completion use `apply_chat_template`
         return self.tokenizer.encode(prompt, bos=True, eos=False)
 
+    def encode_with_suffix(self, prefix: str, suffix: str) -> List[int]:
+        fim = FIMRequest(prompt=prefix, suffix=suffix)
+        return self.mistral.encode_fim(fim).tokens
+
     def apply_chat_template(self,
                             messages: List["ChatCompletionMessageParam"],
                             tools: Optional[Dict[str, Any]] = None,
                             **kwargs) -> List[int]:
 
+        last_message = cast(Dict[str, Any], messages[-1])
+        if last_message["role"] == "assistant":
+            last_message["prefix"] = True
+
         request = ChatCompletionRequest(messages=messages,
                                         tools=tools)  # type: ignore[type-var]
         encoded = self.mistral.encode_chat_completion(request)
@@ -183,14 +221,20 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
             if any(isinstance(t, bytes) for t in tokens):
                 # we need to encode and decode all tokens again
                 shift = self.tokenizer.num_special_tokens
-                byte_tokens = [
-                    t.encode("utf-8") if not isinstance(t, bytes) else t
-                    for t in tokens
-                ]
-                ids = [
-                    self.tokenizer._tekken_token2id_nospecial[t] + shift
-                    for t in byte_tokens
-                ]
+
+                def _token_to_id(t: str):
+                    t_bytes = t.encode("utf-8") \
+                        if not isinstance(t, bytes) else t
+                    try:
+                        return shift + \
+                            self.tokenizer._tekken_token2id_nospecial[t_bytes]
+                    except KeyError:
+                        logger.warning(
+                            "Failed to convert token %s to id,"
+                            " replacing with ", t_bytes)
+                        return self.tokenizer.unk_id
+
+                ids = [_token_to_id(t) for t in tokens]
                 decoded = self.tokenizer.decode(ids)
             else:
                 decoded = "".join(tokens)
@@ -220,9 +264,10 @@ def convert_ids_to_tokens(
 
         tokens = [self.tokenizer.id_to_piece(id) for id in ids]
 
-        if any(t.strip() == "ļæ½" for t in tokens):
-            # if any stripped decoded token is undefined
-            # because it's invalid unicode then pass bytes
+        if any("ļæ½" in t for t in tokens):
+            # if a decoded token contains the replacement character, then the
+            # token has an incomplete UTF-8 character so we must use a byte
+            # string to avoid losing information
             # See: https://github.com/vllm-project/vllm/pull/8640
             tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
 
diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py
index ce4608224763..36315abcdfcd 100644
--- a/vllm/triton_utils/importing.py
+++ b/vllm/triton_utils/importing.py
@@ -1,11 +1,16 @@
 from importlib.util import find_spec
 
 from vllm.logger import init_logger
+from vllm.platforms import current_platform
 
 logger = init_logger(__name__)
 
-HAS_TRITON = find_spec("triton") is not None
+HAS_TRITON = (
+    find_spec("triton") is not None
+    and not current_platform.is_xpu()  # Not compatible
+    and not current_platform.is_neuron()  # neuron has too old torch
+)
 
 if not HAS_TRITON:
-    logger.info("Triton not installed; certain GPU-related functions"
-                " will not be available.")
+    logger.info("Triton not installed or not compatible; certain GPU-related"
+                " functions will not be available.")
diff --git a/vllm/utils.py b/vllm/utils.py
index 8debae52b288..0e9b241b6f9f 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -13,10 +13,12 @@
 import sys
 import tempfile
 import threading
+import time
 import uuid
 import warnings
 import weakref
-from asyncio import FIRST_COMPLETED, ensure_future
+from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
+from collections.abc import Mapping
 from functools import lru_cache, partial, wraps
 from platform import uname
 from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
@@ -316,15 +318,6 @@ def is_hip() -> bool:
     return torch.version.hip is not None
 
 
-@lru_cache(maxsize=None)
-def is_cpu() -> bool:
-    from importlib.metadata import PackageNotFoundError, version
-    try:
-        return "cpu" in version("vllm")
-    except PackageNotFoundError:
-        return False
-
-
 @lru_cache(maxsize=None)
 def is_openvino() -> bool:
     from importlib.metadata import PackageNotFoundError, version
@@ -334,38 +327,6 @@ def is_openvino() -> bool:
         return False
 
 
-@lru_cache(maxsize=None)
-def is_neuron() -> bool:
-    try:
-        import transformers_neuronx
-    except ImportError:
-        transformers_neuronx = None
-    return transformers_neuronx is not None
-
-
-@lru_cache(maxsize=None)
-def is_xpu() -> bool:
-    from importlib.metadata import PackageNotFoundError, version
-    try:
-        is_xpu_flag = "xpu" in version("vllm")
-    except PackageNotFoundError:
-        return False
-    # vllm is not build with xpu
-    if not is_xpu_flag:
-        return False
-    try:
-        import intel_extension_for_pytorch as ipex  # noqa: F401
-        _import_ipex = True
-    except ImportError as e:
-        logger.warning("Import Error for IPEX: %s", e.msg)
-        _import_ipex = False
-    # ipex dependency is not ready
-    if not _import_ipex:
-        logger.warning("not found ipex lib")
-        return False
-    return hasattr(torch, "xpu") and torch.xpu.is_available()
-
-
 @lru_cache(maxsize=None)
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
@@ -395,7 +356,7 @@ def seed_everything(seed: int) -> None:
     if current_platform.is_cuda_alike():
         torch.cuda.manual_seed_all(seed)
 
-    if is_xpu():
+    if current_platform.is_xpu():
         torch.xpu.manual_seed_all(seed)
 
 
@@ -436,6 +397,12 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
     return _async_wrapper
 
 
+def _next_task(iterator: AsyncGenerator[T, None],
+               loop: AbstractEventLoop) -> Task:
+    # Can use anext() in python >= 3.10
+    return loop.create_task(iterator.__anext__())  # type: ignore[arg-type]
+
+
 async def iterate_with_cancellation(
     iterator: AsyncGenerator[T, None],
     is_cancelled: Callable[[], Awaitable[bool]],
@@ -444,19 +411,27 @@ async def iterate_with_cancellation(
     at least once per second to check for client cancellation.
     """
 
-    # Can use anext() in python >= 3.10
-    awaits = [ensure_future(iterator.__anext__())]
+    loop = asyncio.get_running_loop()
+
+    awaits: List[Future[T]] = [_next_task(iterator, loop)]
+    next_cancel_check: float = 0
     while True:
-        done, pending = await asyncio.wait(awaits, timeout=1)
-        if await is_cancelled():
-            with contextlib.suppress(BaseException):
-                awaits[0].cancel()
-                await iterator.aclose()
-            raise asyncio.CancelledError("client cancelled")
+        done, pending = await asyncio.wait(awaits, timeout=1.5)
+
+        # Check for cancellation at most once per second
+        time_now = time.time()
+        if time_now >= next_cancel_check:
+            if await is_cancelled():
+                with contextlib.suppress(BaseException):
+                    awaits[0].cancel()
+                    await iterator.aclose()
+                raise asyncio.CancelledError("client cancelled")
+            next_cancel_check = time_now + 1
+
         if done:
             try:
                 item = await awaits[0]
-                awaits[0] = ensure_future(iterator.__anext__())
+                awaits[0] = _next_task(iterator, loop)
                 yield item
             except StopAsyncIteration:
                 # we are done
@@ -477,25 +452,29 @@ async def merge_async_iterators(
     to check for client cancellation.
     """
 
-    # Can use anext() in python >= 3.10
-    awaits = {
-        ensure_future(pair[1].__anext__()): pair
-        for pair in enumerate(iterators)
-    }
-    timeout = None if is_cancelled is None else 1
+    loop = asyncio.get_running_loop()
+
+    awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
+    timeout = None if is_cancelled is None else 1.5
+    next_cancel_check: float = 0
     try:
         while awaits:
             done, pending = await asyncio.wait(awaits.keys(),
                                                return_when=FIRST_COMPLETED,
                                                timeout=timeout)
-            if is_cancelled is not None and await is_cancelled():
-                raise asyncio.CancelledError("client cancelled")
+            if is_cancelled is not None:
+                # Check for cancellation at most once per second
+                time_now = time.time()
+                if time_now >= next_cancel_check:
+                    if await is_cancelled():
+                        raise asyncio.CancelledError("client cancelled")
+                    next_cancel_check = time_now + 1
             for d in done:
                 pair = awaits.pop(d)
                 try:
                     item = await d
                     i, it = pair
-                    awaits[ensure_future(it.__anext__())] = pair
+                    awaits[_next_task(it, loop)] = pair
                     yield i, item
                 except StopAsyncIteration:
                     pass
@@ -772,13 +751,13 @@ def is_pin_memory_available() -> bool:
         print_warning_once("Using 'pin_memory=False' as WSL is detected. "
                            "This may slow down the performance.")
         return False
-    elif is_xpu():
+    elif current_platform.is_xpu():
         print_warning_once("Pin memory is not supported on XPU.")
         return False
-    elif is_neuron():
+    elif current_platform.is_neuron():
         print_warning_once("Pin memory is not supported on Neuron.")
         return False
-    elif is_cpu() or is_openvino():
+    elif current_platform.is_cpu() or is_openvino():
         return False
     return True
 
@@ -793,7 +772,7 @@ def current_memory_usage(self) -> float:
         if current_platform.is_cuda_alike():
             torch.cuda.reset_peak_memory_stats(self.device)
             mem = torch.cuda.max_memory_allocated(self.device)
-        elif is_xpu():
+        elif current_platform.is_xpu():
             torch.xpu.reset_peak_memory_stats(self.device)  # type: ignore
             mem = torch.xpu.max_memory_allocated(self.device)  # type: ignore
         return mem
@@ -948,6 +927,8 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
     return [item for sublist in lists for item in sublist]
 
 
+# TODO: This function can be removed if transformer_modules classes are
+# serialized by value when communicating between processes
 def init_cached_hf_modules() -> None:
     """
     Lazy initialization of the Hugging Face modules.
@@ -1033,10 +1014,54 @@ def identity(value: T) -> T:
 F = TypeVar('F', bound=Callable[..., Any])
 
 
+def deprecate_args(
+    start_index: int,
+    is_deprecated: Union[bool, Callable[[], bool]] = True,
+    additional_message: Optional[str] = None,
+) -> Callable[[F], F]:
+
+    if not callable(is_deprecated):
+        is_deprecated = partial(identity, is_deprecated)
+
+    def wrapper(fn: F) -> F:
+
+        params = inspect.signature(fn).parameters
+        pos_types = (
+            inspect.Parameter.POSITIONAL_ONLY,
+            inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        )
+        pos_kws = [
+            kw for kw, param in params.items() if param.kind in pos_types
+        ]
+
+        @wraps(fn)
+        def inner(*args, **kwargs):
+            if is_deprecated():
+                deprecated_args = pos_kws[start_index:len(args)]
+                if deprecated_args:
+                    msg = (
+                        f"The positional arguments {deprecated_args} are "
+                        "deprecated and will be removed in a future update.")
+                    if additional_message is not None:
+                        msg += f" {additional_message}"
+
+                    warnings.warn(
+                        DeprecationWarning(msg),
+                        stacklevel=3,  # The inner function takes up one level
+                    )
+
+            return fn(*args, **kwargs)
+
+        return inner  # type: ignore
+
+    return wrapper
+
+
 def deprecate_kwargs(
-        *kws: str,
-        is_deprecated: Union[bool, Callable[[], bool]] = True,
-        additional_message: Optional[str] = None) -> Callable[[F], F]:
+    *kws: str,
+    is_deprecated: Union[bool, Callable[[], bool]] = True,
+    additional_message: Optional[str] = None,
+) -> Callable[[F], F]:
     deprecated_kws = set(kws)
 
     if not callable(is_deprecated):
@@ -1442,3 +1467,24 @@ def dec(self, num=1):
     @property
     def value(self):
         return self._value
+
+
+# Adapted from: https://stackoverflow.com/a/47212782/5082708
+class LazyDict(Mapping, Generic[T]):
+
+    def __init__(self, factory: Dict[str, Callable[[], T]]):
+        self._factory = factory
+        self._dict: Dict[str, T] = {}
+
+    def __getitem__(self, key) -> T:
+        if key not in self._dict:
+            if key not in self._factory:
+                raise KeyError(key)
+            self._dict[key] = self._factory[key]()
+        return self._dict[key]
+
+    def __iter__(self):
+        return iter(self._factory)
+
+    def __len__(self):
+        return len(self._factory)
diff --git a/vllm/v1/attention/__init__.py b/vllm/v1/attention/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/attention/backends/__init__.py b/vllm/v1/attention/backends/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py
new file mode 100644
index 000000000000..0530b1a6762c
--- /dev/null
+++ b/vllm/v1/attention/backends/flash_attn.py
@@ -0,0 +1,241 @@
+"""Attention layer with FlashAttention."""
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+
+from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
+                                              AttentionMetadata, AttentionType)
+from vllm.forward_context import get_forward_context
+from vllm.vllm_flash_attn import flash_attn_varlen_func
+
+
+class FlashAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [32, 64, 96, 128, 160, 192, 224, 256]
+
+    @staticmethod
+    def get_name() -> str:
+        return "flash-attn-vllm-v1"
+
+    @staticmethod
+    def get_impl_cls() -> Type["FlashAttentionImpl"]:
+        return FlashAttentionImpl
+
+    @staticmethod
+    def get_metadata_cls() -> Type["AttentionMetadata"]:
+        return FlashAttentionMetadata
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        if block_size % 16 != 0:
+            raise ValueError("Block size must be a multiple of 16.")
+        return (2, num_blocks, block_size, num_kv_heads, head_size)
+
+
+@dataclass
+class FlashAttentionMetadata:
+    # NOTE(sang): Definition of context_len, query_len, and seq_len.
+    # |---------- N-1 iteration --------|
+    # |---------------- N iteration ---------------------|
+    # |- tokenA -|......................|-- newTokens ---|
+    # |---------- context_len ----------|
+    # |-------------------- seq_len ---------------------|
+    #                                   |-- query_len ---|
+
+    max_query_len: int
+    query_start_loc: torch.Tensor
+    max_seq_len: int
+    seq_start_loc: torch.Tensor
+    block_table: torch.Tensor
+    slot_mapping: torch.Tensor
+
+
+class FlashAttentionImpl(AttentionImpl):
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+        logits_soft_cap: Optional[float] = None,
+    ) -> None:
+        if blocksparse_params is not None:
+            raise ValueError(
+                "FlashAttention does not support block-sparse attention.")
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_kv_heads
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+        self.sliding_window = ((sliding_window, sliding_window)
+                               if sliding_window is not None else (-1, -1))
+        self.kv_cache_dtype = kv_cache_dtype
+        if logits_soft_cap is None:
+            # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
+            logits_soft_cap = 0
+        self.logits_soft_cap = logits_soft_cap
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+        if sliding_window is not None:
+            # NOTE(woosuk): flash-attn's sliding window does not work with
+            # paged KV cache.
+            raise ValueError(
+                "Sliding window is not supported in FlashAttention.")
+
+        support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
+        if head_size not in support_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by FlashAttention. "
+                f"Supported head sizes are: {support_head_sizes}.")
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: FlashAttentionMetadata,
+        k_scale: float = 1.0,
+        v_scale: float = 1.0,
+        attn_type: AttentionType = AttentionType.DECODER,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention.
+
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        if attn_type != AttentionType.DECODER:
+            raise NotImplementedError("Encoder self-attention and "
+                                      "encoder/decoder cross-attention "
+                                      "are not implemented for "
+                                      "FlashAttentionImpl")
+
+        # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
+        assert k_scale == 1.0 and v_scale == 1.0, (
+            "key/v_scale is not supported in FlashAttention.")
+
+        output = torch.ops.vllm.unified_flash_attention(
+            query,
+            key,
+            value,
+            self.num_heads,
+            self.head_size,
+            self.num_kv_heads,
+            kv_cache,
+            self.kv_cache_dtype,
+            k_scale,
+            v_scale,
+            self.scale,
+            self.sliding_window,
+            self.alibi_slopes,
+            self.logits_soft_cap,
+        )
+        return output
+
+
+@torch.library.custom_op("vllm::unified_flash_attention",
+                         mutates_args=["kv_cache"])
+def unified_flash_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    num_heads: int,
+    head_size: int,
+    num_kv_heads: int,
+    kv_cache: torch.Tensor,
+    kv_cache_dtype: str,
+    k_scale: float,
+    v_scale: float,
+    softmax_scale: float,
+    window_size: Optional[List[int]] = None,
+    alibi_slopes: Optional[torch.Tensor] = None,
+    logits_soft_cap: Optional[float] = None,
+) -> torch.Tensor:
+    current_metadata = get_forward_context()
+    if current_metadata is None:
+        # Profiling run.
+        return torch.empty_like(query)
+
+    assert current_metadata is not None
+    assert isinstance(current_metadata, FlashAttentionMetadata)
+    attn_metadata: FlashAttentionMetadata = current_metadata
+
+    num_tokens, hidden_size = query.shape
+    # Reshape the query, key, and value tensors.
+    query = query.view(-1, num_heads, head_size)
+    key = key.view(-1, num_kv_heads, head_size)
+    value = value.view(-1, num_kv_heads, head_size)
+
+    # Reshape the input keys and values and store them in the cache.
+    key_cache = kv_cache[0]
+    value_cache = kv_cache[1]
+    torch.ops._C_cache_ops.reshape_and_cache_flash(
+        key,
+        value,
+        kv_cache[0],
+        kv_cache[1],
+        attn_metadata.slot_mapping,
+        kv_cache_dtype,
+        k_scale,
+        v_scale,
+    )
+
+    output = flash_attn_varlen_func(
+        q=query,
+        k=key_cache,
+        v=value_cache,
+        cu_seqlens_q=attn_metadata.query_start_loc,
+        max_seqlen_q=attn_metadata.max_query_len,
+        cu_seqlens_k=attn_metadata.seq_start_loc,
+        max_seqlen_k=attn_metadata.max_seq_len,
+        softmax_scale=softmax_scale,
+        causal=True,
+        alibi_slopes=alibi_slopes,
+        window_size=window_size,
+        block_table=attn_metadata.block_table,
+        softcap=logits_soft_cap,
+    )
+    return output.view(num_tokens, hidden_size)
+
+
+@unified_flash_attention.register_fake
+def _(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    num_heads: int,
+    head_size: int,
+    num_kv_heads: int,
+    kv_cache: torch.Tensor,
+    kv_cache_dtype: str,
+    k_scale: float,
+    v_scale: float,
+    softmax_scale: float,
+    window_size: Optional[List[int]] = None,
+    alibi_slopes: Optional[torch.Tensor] = None,
+    logits_soft_cap: Optional[float] = None,
+) -> torch.Tensor:
+    return torch.empty_like(query)
diff --git a/vllm/v1/core/__init__.py b/vllm/v1/core/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
new file mode 100644
index 000000000000..9b735a8be10d
--- /dev/null
+++ b/vllm/v1/core/kv_cache_manager.py
@@ -0,0 +1,108 @@
+from typing import Dict, List, Optional
+
+import numpy as np
+
+from vllm.logger import init_logger
+from vllm.utils import cdiv
+from vllm.v1.request import Request
+
+logger = init_logger(__name__)
+
+
+class KVCacheManager:
+
+    def __init__(
+        self,
+        block_size: int,
+        num_gpu_blocks: int,
+        sliding_window: Optional[int] = None,
+        enable_caching: bool = True,
+        num_preallocate_tokens: int = 64,
+    ) -> None:
+        self.block_size = block_size
+        self.num_gpu_blocks = num_gpu_blocks
+        self.sliding_window = sliding_window
+        self.enable_caching = enable_caching
+        # NOTE(woosuk): To avoid frequent block allocation, we preallocate some
+        # blocks for each request. For example, when a request reaches the end
+        # of its block table, we preallocate N blocks in advance. This way, we
+        # reduce the overhead of updating free_block_ids and ref_cnts for each
+        # request every step (at the cost of some memory waste).
+        # NOTE(woosuk): This is different from the "lookahead" slots since this
+        # does not guarantee that the request always has N empty blocks. After
+        # the request gets N empty blocks, it starts to use the blocks without
+        # further allocation. When it uses up all the N empty blocks, it gets
+        # N new empty blocks.
+        self.num_preallocate_tokens = num_preallocate_tokens
+        self.num_preallocate_blocks = cdiv(num_preallocate_tokens, block_size)
+
+        self.free_block_ids = list(range(num_gpu_blocks))
+        self.req_to_block_ids: Dict[str, List[int]] = {}
+        self.ref_cnts = np.zeros(num_gpu_blocks, dtype=np.int32)
+
+    def get_computed_blocks(self, request: Request) -> List[int]:
+        if not self.enable_caching:
+            # No prefix caching.
+            return []
+        # TODO(woosuk): Implement hash-based caching.
+        return []
+
+    def append_slots(
+        self,
+        request: Request,
+        num_tokens: int,
+    ) -> Optional[List[int]]:
+        num_required_blocks = cdiv(request.num_computed_tokens + num_tokens,
+                                   self.block_size)
+        req_block_ids = self.req_to_block_ids[request.request_id]
+        if num_required_blocks <= len(req_block_ids):
+            # No new block is needed.
+            return []
+
+        num_new_blocks = num_required_blocks - len(req_block_ids)
+        num_free_blocks = len(self.free_block_ids)
+        if num_new_blocks > num_free_blocks:
+            # Cannot allocate new blocks.
+            return None
+
+        # Allocate new blocks.
+        num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks,
+                             num_free_blocks)
+        new_block_ids = self._get_new_blocks(num_new_blocks)
+        req_block_ids.extend(new_block_ids)
+        self.ref_cnts[new_block_ids] += 1
+        return new_block_ids
+
+    def allocate_slots(
+        self,
+        request: Request,
+        num_tokens: int,
+        computed_block_ids: List[int],
+    ) -> Optional[List[int]]:
+        num_required_blocks = cdiv(num_tokens, self.block_size)
+        num_free_blocks = len(self.free_block_ids)
+        if num_required_blocks > num_free_blocks:
+            # Cannot allocate new blocks.
+            return None
+
+        num_new_blocks = min(num_required_blocks + self.num_preallocate_blocks,
+                             num_free_blocks)
+        new_block_ids = self._get_new_blocks(num_new_blocks)
+        block_ids = computed_block_ids + new_block_ids
+        self.req_to_block_ids[request.request_id] = block_ids
+        self.ref_cnts[block_ids] += 1
+        return new_block_ids
+
+    def free(self, request: Request) -> None:
+        block_ids = self.req_to_block_ids.pop(request.request_id)
+        self.ref_cnts[block_ids] -= 1
+        for block_id in block_ids:
+            ref_cnt = self.ref_cnts[block_id]
+            if ref_cnt == 0:
+                self.free_block_ids.append(block_id)
+
+    def _get_new_blocks(self, num_blocks: int) -> List[int]:
+        assert num_blocks <= len(self.free_block_ids)
+        new_block_ids = self.free_block_ids[-num_blocks:]
+        self.free_block_ids = self.free_block_ids[:-num_blocks]
+        return new_block_ids
diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py
new file mode 100644
index 000000000000..41659ff62747
--- /dev/null
+++ b/vllm/v1/core/scheduler.py
@@ -0,0 +1,412 @@
+from collections import deque
+from dataclasses import dataclass
+from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
+
+from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
+from vllm.logger import init_logger
+from vllm.multimodal import MultiModalDataDict
+from vllm.sampling_params import SamplingParams
+from vllm.v1.core.kv_cache_manager import KVCacheManager
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.request import Request, RequestStatus
+
+logger = init_logger(__name__)
+
+
+class Scheduler:
+
+    def __init__(
+        self,
+        scheduler_config: SchedulerConfig,
+        cache_config: CacheConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        self.scheduler_config = scheduler_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        # TODO: Support LoRA.
+        assert lora_config is None, "V1 does not support LoRA yet."
+
+        num_gpu_blocks = cache_config.num_gpu_blocks
+        assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
+        # Create the block space manager.
+        self.kv_cache_manager = KVCacheManager(
+            block_size=self.cache_config.block_size,
+            num_gpu_blocks=num_gpu_blocks,
+            sliding_window=self.cache_config.sliding_window,
+            enable_caching=True)
+        self.block_size = self.cache_config.block_size
+
+        # Scheduling constraints.
+        self.max_num_running_reqs = self.scheduler_config.max_num_seqs
+        self.max_num_scheduled_tokens = \
+            self.scheduler_config.max_num_batched_tokens
+        self.max_model_len = self.scheduler_config.max_model_len
+
+        # req_id -> Request
+        self.requests: Dict[str, Request] = {}
+        # Priority queues for requests.
+        self.waiting: Deque[Request] = deque()
+        self.running: List[Request] = []
+
+        # The request IDs that are finished in between the previous and the
+        # current steps. This is used to notify the workers about the finished
+        # requests so that they can free the cached states for those requests.
+        # This is flushed at the end of each scheduling step.
+        self.finished_req_ids: Set[str] = set()
+
+        # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
+        # them at each scheduling step.
+        # Request id -> RunningRequestData
+        self.running_reqs_data: Dict[str, RunningRequestData] = {}
+
+    def schedule(self) -> "SchedulerOutput":
+        scheduled_new_reqs: List[Request] = []
+        scheduled_resumed_reqs: List[Request] = []
+        scheduled_running_reqs: List[Request] = []
+        preempted_reqs: List[Request] = []
+
+        # NOTE(woosuk) on the scheduling algorithm:
+        # There's no "decoding phase" nor "prefill phase" in the scheduler.
+        # Each request just has the num_computed_tokens and num_tokens,
+        # which is equal to len(prompt_token_ids) + len(output_token_ids).
+        # At each step, the scheduler tries to assign tokens to the requests
+        # so that each request's num_computed_tokens can catch up its
+        # num_tokens. This is general enough to cover chunked prefills,
+        # prefix caching, and the "jump forward" optimization in the future.
+
+        req_to_new_block_ids: Dict[str, List[int]] = {}
+        num_scheduled_tokens: Dict[str, int] = {}
+        token_budget = self.max_num_scheduled_tokens
+
+        # First, schedule the RUNNING requests.
+        req_index = 0
+        while req_index < len(self.running):
+            if token_budget == 0:
+                break
+
+            request = self.running[req_index]
+            num_new_tokens = request.num_tokens - request.num_computed_tokens
+            num_new_tokens = min(num_new_tokens, token_budget)
+            assert num_new_tokens > 0
+
+            while True:
+                new_block_ids = self.kv_cache_manager.append_slots(
+                    request, num_new_tokens)
+                if new_block_ids is None:
+                    # The request cannot be scheduled.
+                    # Preempt the lowest-priority request.
+                    preempted_req = self.running.pop()
+                    self.kv_cache_manager.free(preempted_req)
+                    preempted_req.status = RequestStatus.PREEMPTED
+                    preempted_req.num_computed_tokens = 0
+
+                    self.waiting.appendleft(preempted_req)
+                    preempted_reqs.append(preempted_req)
+                    if preempted_req == request:
+                        # No more request to preempt.
+                        break
+                else:
+                    # The request can be scheduled.
+                    scheduled_running_reqs.append(request)
+
+                    req_to_new_block_ids[request.request_id] = new_block_ids
+                    num_scheduled_tokens[request.request_id] = num_new_tokens
+                    token_budget -= num_new_tokens
+                    req_index += 1
+                    break
+
+        # Next, schedule the WAITING requests.
+        if not preempted_reqs:
+            while self.waiting:
+                if len(self.running) == self.max_num_running_reqs:
+                    break
+                if token_budget == 0:
+                    break
+
+                request = self.waiting[0]
+                # Get already-cached tokens.
+                computed_block_ids = self.kv_cache_manager.get_computed_blocks(
+                    request)
+                # NOTE(woosuk): Since incomplete blocks are not eligible for
+                # sharing, `num_computed_tokens` is always a multiple of
+                # `block_size`.
+                num_computed_tokens = len(computed_block_ids) * self.block_size
+                # Number of tokens to be scheduled.
+                # We use `request.num_tokens` instead of
+                # `request.num_prompt_tokens` to consider the resumed requests,
+                # which have output tokens.
+                num_new_tokens = request.num_tokens - num_computed_tokens
+                num_new_tokens = min(num_new_tokens, token_budget)
+                assert num_new_tokens > 0
+                new_block_ids = self.kv_cache_manager.allocate_slots(
+                    request, num_new_tokens, computed_block_ids)
+                if new_block_ids is None:
+                    # The request cannot be scheduled.
+                    break
+                request.num_computed_tokens = num_computed_tokens
+
+                self.waiting.popleft()
+                self.running.append(request)
+                if request.status == RequestStatus.WAITING:
+                    scheduled_new_reqs.append(request)
+                elif request.status == RequestStatus.PREEMPTED:
+                    scheduled_resumed_reqs.append(request)
+                else:
+                    raise RuntimeError(
+                        f"Invalid request status: {request.status}")
+
+                req_to_new_block_ids[request.request_id] = (
+                    computed_block_ids + new_block_ids)
+                num_scheduled_tokens[request.request_id] = num_new_tokens
+                token_budget -= num_new_tokens
+                request.status = RequestStatus.RUNNING
+
+        # Check if the scheduling constraints are satisfied.
+        total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
+        assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
+        assert token_budget >= 0
+        assert len(self.running) <= self.max_num_running_reqs
+        assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
+                len(scheduled_running_reqs) == len(self.running))
+
+        # Construct the scheduler output.
+        new_reqs_data = [
+            NewRequestData.from_request(req,
+                                        req_to_new_block_ids[req.request_id],
+                                        req.num_computed_tokens)
+            for req in scheduled_new_reqs
+        ]
+        resumed_reqs_data = [
+            ResumedRequestData.from_request(
+                req, req_to_new_block_ids[req.request_id],
+                req.num_computed_tokens) for req in scheduled_resumed_reqs
+        ]
+        running_reqs_data = [
+            self._make_running_request_data(
+                req, req_to_new_block_ids[req.request_id],
+                req.num_computed_tokens) for req in scheduled_running_reqs
+        ]
+        preempted_req_ids = {req.request_id for req in preempted_reqs}
+        scheduler_output = SchedulerOutput(
+            scheduled_new_reqs=new_reqs_data,
+            scheduled_resumed_reqs=resumed_reqs_data,
+            scheduled_running_reqs=running_reqs_data,
+            num_scheduled_tokens=num_scheduled_tokens,
+            total_num_scheduled_tokens=total_num_scheduled_tokens,
+            preempted_req_ids=preempted_req_ids,
+            # finished_req_ids is an existing state in the scheduler,
+            # instead of being newly scheduled in this step.
+            # It contains the request IDs that are finished in between
+            # the previous and the current steps.
+            finished_req_ids=self.finished_req_ids,
+        )
+
+        self.finished_req_ids = set()
+        return scheduler_output
+
+    def _make_running_request_data(
+        self,
+        request: Request,
+        new_block_ids: List[int],
+        num_computed_tokens: int,
+    ) -> "RunningRequestData":
+        # OPTIMIZATION: Cache the RunningRequestData objects to avoid creating
+        # them at each scheduling step.
+        if request.request_id in self.running_reqs_data:
+            req_data = self.running_reqs_data[request.request_id]
+            req_data.new_block_ids = new_block_ids
+            req_data.num_computed_tokens = num_computed_tokens
+        else:
+            req_data = RunningRequestData.from_request(request, new_block_ids,
+                                                       num_computed_tokens)
+            self.running_reqs_data[request.request_id] = req_data
+        return req_data
+
+    def update_from_output(
+        self,
+        scheduler_output: "SchedulerOutput",
+        model_runner_output: "ModelRunnerOutput",
+    ) -> List[Tuple[Request, int]]:
+        # NOTE(woosuk): This method doesn't consider speculative decoding.
+        sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
+        num_scheduled_tokens = scheduler_output.num_scheduled_tokens
+        new_running: List[Request] = []
+        # (request, num_sampled_tokens)
+        sampled: List[Tuple[Request, int]] = []
+        for request in self.running:
+            req_id = request.request_id
+            request.num_computed_tokens += num_scheduled_tokens[req_id]
+            # When the request's num_computed_tokens catches up its num_tokens,
+            # the request generates output tokens. Otherwise, we ignore the
+            # sampler output for the request.
+            assert request.num_computed_tokens <= request.num_tokens
+            if request.num_computed_tokens == request.num_tokens:
+                req_index = model_runner_output.req_id_to_index[req_id]
+                # NOTE(woosuk): Currently, we assume that each request
+                # generates at most one token at each step.
+                token_id = sampled_token_ids[req_index]
+                request.output_token_ids.append(token_id)
+                sampled.append((request, 1))
+                # TODO: Update the KV cache manager for prefix caching.
+
+                # Check if the request is finished.
+                stopped = self._check_stop(request)
+                if stopped:
+                    continue
+
+            new_running.append(request)
+        self.running = new_running
+        return sampled
+
+    def _check_stop(self, request: Request) -> bool:
+        if (request.num_tokens >= self.max_model_len
+                or request.num_output_tokens >= request.max_tokens):
+            request.status = RequestStatus.FINISHED_LENGTH_CAPPED
+            self._free_request(request)
+            return True
+
+        sampling_params = request.sampling_params
+        last_token_id = request.output_token_ids[-1]
+        if (not sampling_params.ignore_eos
+                and last_token_id == request.eos_token_id):
+            request.status = RequestStatus.FINISHED_STOPPED
+            self._free_request(request)
+            return True
+
+        if last_token_id in (sampling_params.stop_token_ids or ()):
+            request.status = RequestStatus.FINISHED_STOPPED
+            request.stop_reason = last_token_id
+            self._free_request(request)
+            return True
+        return False
+
+    def add_request(self, request: Request) -> None:
+        self.waiting.append(request)
+        self.requests[request.request_id] = request
+
+    def finish_requests(
+        self,
+        request_ids: Union[str, Iterable[str]],
+        finished_status: RequestStatus,
+    ) -> None:
+        """Handles the finish signal from outside the scheduler.
+
+        For example, the API server can abort a request when the client
+        disconnects.
+        """
+        assert RequestStatus.is_finished(finished_status)
+        if isinstance(request_ids, str):
+            request_ids = (request_ids, )
+        request_ids = set(request_ids)
+
+        for req_id in request_ids:
+            request = self.requests.get(req_id)
+            if request is None:
+                # Invalid request ID.
+                continue
+
+            if request.status == RequestStatus.RUNNING:
+                self.running.remove(request)
+            else:
+                self.waiting.remove(request)
+            request.status = finished_status
+            self._free_request(request)
+
+    def _free_request(self, request: Request) -> None:
+        assert request.is_finished()
+        self.kv_cache_manager.free(request)
+        self.running_reqs_data.pop(request.request_id, None)
+        del self.requests[request.request_id]
+        self.finished_req_ids.add(request.request_id)
+
+    def get_num_unfinished_requests(self) -> int:
+        return len(self.waiting) + len(self.running)
+
+    def has_unfinished_requests(self) -> bool:
+        return self.get_num_unfinished_requests() > 0
+
+
+@dataclass
+class NewRequestData:
+
+    req_id: str
+    prompt_token_ids: List[int]
+    prompt: Optional[str]
+    multi_modal_data: Optional[MultiModalDataDict]
+    sampling_params: SamplingParams
+    block_ids: List[int]
+    num_computed_tokens: int
+
+    @classmethod
+    def from_request(
+        cls,
+        request: Request,
+        block_ids: List[int],
+        num_computed_tokens: int,
+    ) -> "NewRequestData":
+        return cls(
+            req_id=request.request_id,
+            prompt_token_ids=request.inputs["prompt_token_ids"],
+            prompt=request.inputs.get("prompt"),
+            multi_modal_data=request.inputs.get("multi_modal_data"),
+            sampling_params=request.sampling_params,
+            block_ids=block_ids,
+            num_computed_tokens=num_computed_tokens,
+        )
+
+
+@dataclass
+class ResumedRequestData:
+
+    req_id: str
+    block_ids: List[int]
+    num_computed_tokens: int
+
+    @classmethod
+    def from_request(
+        cls,
+        request: Request,
+        block_ids: List[int],
+        num_computed_tokens: int,
+    ) -> "ResumedRequestData":
+        return cls(
+            req_id=request.request_id,
+            block_ids=block_ids,
+            num_computed_tokens=num_computed_tokens,
+        )
+
+
+@dataclass
+class RunningRequestData:
+
+    req_id: str
+    new_block_ids: List[int]
+    num_computed_tokens: int
+
+    @classmethod
+    def from_request(
+        cls,
+        request: Request,
+        new_block_ids: List[int],
+        num_computed_tokens: int,
+    ) -> "RunningRequestData":
+        return cls(
+            req_id=request.request_id,
+            new_block_ids=new_block_ids,
+            num_computed_tokens=num_computed_tokens,
+        )
+
+
+@dataclass
+class SchedulerOutput:
+
+    scheduled_new_reqs: List[NewRequestData]
+    scheduled_resumed_reqs: List[ResumedRequestData]
+    scheduled_running_reqs: List[RunningRequestData]
+
+    num_scheduled_tokens: Dict[str, int]
+    total_num_scheduled_tokens: int
+
+    preempted_req_ids: Set[str]
+    finished_req_ids: Set[str]
diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py
new file mode 100644
index 000000000000..072e52bcd686
--- /dev/null
+++ b/vllm/v1/engine/llm_engine.py
@@ -0,0 +1,532 @@
+import time
+from typing import (Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type,
+                    Union)
+
+from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
+                         EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
+                         ObservabilityConfig, ParallelConfig,
+                         PromptAdapterConfig, SchedulerConfig,
+                         SpeculativeConfig)
+from vllm.engine.arg_utils import EngineArgs
+from vllm.engine.metrics_types import StatLoggerBase
+from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
+                         EncoderDecoderLLMInputs, InputRegistry, PromptType)
+from vllm.inputs.preprocess import InputPreprocessor
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.outputs import CompletionOutput, RequestOutput
+from vllm.pooling_params import PoolingParams
+from vllm.prompt_adapter.request import PromptAdapterRequest
+from vllm.sampling_params import RequestOutputKind, SamplingParams
+from vllm.transformers_utils.config import try_get_generation_config
+from vllm.transformers_utils.tokenizer_group import (
+    BaseTokenizerGroup, init_tokenizer_from_configs)
+from vllm.usage.usage_lib import UsageContext
+from vllm.v1.core.scheduler import Scheduler
+from vllm.v1.executor.gpu_executor import GPUExecutor
+from vllm.v1.request import Request, RequestStatus
+from vllm.v1.tokenizer.detokenizer import Detokenizer, DetokenizerInputs
+from vllm.version import __version__ as VLLM_VERSION
+
+logger = init_logger(__name__)
+
+
+class LLMEngine:
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        load_config: LoadConfig,
+        lora_config: Optional[LoRAConfig],
+        speculative_config: Optional[SpeculativeConfig],
+        decoding_config: Optional[DecodingConfig],
+        observability_config: Optional[ObservabilityConfig],
+        prompt_adapter_config: Optional[PromptAdapterConfig],
+        executor_class: Type[GPUExecutor],
+        log_stats: bool,
+        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
+        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
+        input_registry: InputRegistry = INPUT_REGISTRY,
+        use_cached_outputs: bool = False,
+    ) -> None:
+        # Override the configs for V1.
+        # FIXME
+        if usage_context == UsageContext.LLM_CLASS:
+            scheduler_config.max_num_seqs = 1024
+            scheduler_config.max_num_batched_tokens = 8192
+        elif usage_context == UsageContext.OPENAI_API_SERVER:
+            scheduler_config.max_num_seqs = 1024
+            scheduler_config.max_num_batched_tokens = 2048
+
+        logger.info(
+            "Initializing an LLM engine (v%s) with config: "
+            "model=%r, speculative_config=%r, tokenizer=%r, "
+            "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
+            "override_neuron_config=%s, "
+            "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
+            "trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
+            "download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
+            "pipeline_parallel_size=%d, "
+            "disable_custom_all_reduce=%s, quantization=%s, "
+            "enforce_eager=%s, kv_cache_dtype=%s, "
+            "quantization_param_path=%s, device_config=%s, "
+            "decoding_config=%r, observability_config=%r, "
+            "seed=%d, served_model_name=%s, "
+            "num_scheduler_steps=%d, enable_prefix_caching=%s, "
+            "use_async_output_proc=%s, mm_processor_kwargs=%s)",
+            VLLM_VERSION,
+            model_config.model,
+            speculative_config,
+            model_config.tokenizer,
+            model_config.skip_tokenizer_init,
+            model_config.tokenizer_mode,
+            model_config.revision,
+            model_config.override_neuron_config,
+            model_config.rope_scaling,
+            model_config.rope_theta,
+            model_config.tokenizer_revision,
+            model_config.trust_remote_code,
+            model_config.dtype,
+            model_config.max_model_len,
+            load_config.download_dir,
+            load_config.load_format,
+            parallel_config.tensor_parallel_size,
+            parallel_config.pipeline_parallel_size,
+            parallel_config.disable_custom_all_reduce,
+            model_config.quantization,
+            model_config.enforce_eager,
+            cache_config.cache_dtype,
+            model_config.quantization_param_path,
+            device_config.device,
+            decoding_config,
+            observability_config,
+            model_config.seed,
+            model_config.served_model_name,
+            scheduler_config.num_scheduler_steps,
+            cache_config.enable_prefix_caching,
+            model_config.use_async_output_proc,
+            model_config.mm_processor_kwargs,
+        )
+
+        self.model_config = model_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.speculative_config = speculative_config
+        self.load_config = load_config
+        self.decoding_config = decoding_config or DecodingConfig()
+        self.prompt_adapter_config = prompt_adapter_config
+        self.observability_config = observability_config or ObservabilityConfig(
+        )
+        self.log_stats = log_stats
+
+        assert not self.model_config.skip_tokenizer_init
+        self.tokenizer = self._init_tokenizer()
+        if self.tokenizer:
+            # Ping the tokenizer to ensure liveness if it runs in a
+            # different process.
+            self.tokenizer.ping()
+        self.detokenizer = Detokenizer(self.model_config.tokenizer)
+
+        self.generation_config_fields = _load_generation_config_dict(
+            model_config)
+        self.input_preprocessor = InputPreprocessor(model_config,
+                                                    self.tokenizer)
+        self.input_registry = input_registry
+        self.input_processor = input_registry.create_input_processor(
+            model_config)
+
+        # Request id -> Request
+        self.requests: Dict[str, Request] = {}
+        # NOTE(woosuk): Now that the detokenizer works asynchronously, we need
+        # to keep track of how many steps each request has been lagged behind
+        # in terms of detokenization.
+        # Request id -> how many detokenizer steps the request should wait for.
+        self.num_lagged_steps: Dict[str, int] = {}
+        # OPTIMIZATION: Cache the request output and update it incrementally.
+        # This is used to avoid creating a new RequestOutput object every step.
+        # Request id -> RequestOutput
+        self.request_outputs: Dict[str, RequestOutput] = {}
+
+        self.model_executor = executor_class(
+            model_config=model_config,
+            cache_config=cache_config,
+            parallel_config=parallel_config,
+            scheduler_config=scheduler_config,
+            device_config=device_config,
+            lora_config=lora_config,
+            speculative_config=speculative_config,
+            load_config=load_config,
+            prompt_adapter_config=prompt_adapter_config,
+            observability_config=self.observability_config,
+        )
+        assert self.model_config.task != "embedding"
+        self._initialize_kv_caches()
+
+        # Create the scheduler.
+        # NOTE: the cache_config here have been updated with the numbers of
+        # GPU and CPU blocks, which are profiled in the distributed executor.
+        self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
+
+    def _initialize_kv_caches(self) -> None:
+        num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
+        )
+
+        if self.cache_config.num_gpu_blocks_override is not None:
+            num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override
+            logger.info(
+                "Overriding num_gpu_blocks=%d with "
+                "num_gpu_blocks_override=%d", num_gpu_blocks,
+                num_gpu_blocks_override)
+            num_gpu_blocks = num_gpu_blocks_override
+
+        self.cache_config.num_gpu_blocks = num_gpu_blocks
+        self.cache_config.num_cpu_blocks = 0
+        self.model_executor.initialize_cache(num_gpu_blocks)
+
+    @classmethod
+    def from_engine_args(
+        cls,
+        engine_args: EngineArgs,
+        usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
+        stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
+    ) -> "LLMEngine":
+        """Creates an LLM engine from the engine arguments."""
+        # Create the engine configs.
+        engine_config = engine_args.create_engine_config()
+        executor_class = cls._get_executor_cls(engine_config)
+        # Create the LLM engine.
+        engine = cls(
+            **engine_config.to_dict(),
+            executor_class=executor_class,
+            log_stats=not engine_args.disable_log_stats,
+            usage_context=usage_context,
+            stat_loggers=stat_loggers,
+        )
+        return engine
+
+    def _init_tokenizer(self) -> BaseTokenizerGroup:
+        return init_tokenizer_from_configs(
+            model_config=self.model_config,
+            scheduler_config=self.scheduler_config,
+            parallel_config=self.parallel_config,
+            enable_lora=bool(self.lora_config))
+
+    def _verify_args(self) -> None:
+        self.model_config.verify_with_parallel_config(self.parallel_config)
+        self.cache_config.verify_with_parallel_config(self.parallel_config)
+        if self.lora_config:
+            self.lora_config.verify_with_model_config(self.model_config)
+            self.lora_config.verify_with_scheduler_config(
+                self.scheduler_config)
+        if self.prompt_adapter_config:
+            self.prompt_adapter_config.verify_with_model_config(
+                self.model_config)
+
+    def _add_processed_request(
+        self,
+        request_id: str,
+        processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderLLMInputs],
+        params: Union[SamplingParams, PoolingParams],
+        arrival_time: float,
+        lora_request: Optional[LoRARequest],
+        prompt_adapter_request: Optional[PromptAdapterRequest],
+        trace_headers: Optional[Mapping[str, str]] = None,
+    ) -> None:
+        assert prompt_adapter_request is None
+        assert trace_headers is None
+        self._validate_model_inputs(processed_inputs)
+        eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
+
+        # TODO(woosuk): Support embedding mode.
+        assert isinstance(params, SamplingParams)
+        sampling_params = params.clone()
+        sampling_params.update_from_generation_config(
+            self.generation_config_fields, eos_token_id)
+
+        # TODO(woosuk): Check max_logprobs
+        # TODO(woosuk): Support encoder-decoder models.
+        req = Request(request_id, processed_inputs, params, eos_token_id,
+                      arrival_time)
+        self.requests[request_id] = req
+        self.num_lagged_steps[request_id] = 0
+        self.scheduler.add_request(req)
+
+    def stop_remote_worker_execution_loop(self) -> None:
+        raise NotImplementedError("TP not implemented yet.")
+
+    def add_request(
+        self,
+        request_id: str,
+        prompt: PromptType,
+        params: Union[SamplingParams, PoolingParams],
+        arrival_time: Optional[float] = None,
+        lora_request: Optional[LoRARequest] = None,
+        trace_headers: Optional[Mapping[str, str]] = None,
+        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
+        priority: int = 0,
+    ) -> None:
+        if lora_request is not None and not self.lora_config:
+            raise ValueError(f"Got lora_request {lora_request} but LoRA is "
+                             "not enabled!")
+        if arrival_time is None:
+            arrival_time = time.time()
+        assert priority == 0, "vLLM V1 does not support priority at the moment."
+
+        preprocessed_inputs = self.input_preprocessor.preprocess(
+            prompt,
+            request_id=request_id,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request,
+        )
+        processed_inputs = self.input_processor(preprocessed_inputs)
+
+        self._add_processed_request(
+            request_id=request_id,
+            processed_inputs=processed_inputs,
+            params=params,
+            arrival_time=arrival_time,
+            lora_request=lora_request,
+            prompt_adapter_request=prompt_adapter_request,
+            trace_headers=trace_headers,
+        )
+
+    def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
+        self.scheduler.finish_requests(request_id,
+                                       RequestStatus.FINISHED_ABORTED)
+        self._free_request(request_id)
+
+    def get_num_unfinished_requests(self) -> int:
+        """Gets the number of unfinished requests."""
+        return len(self.requests)
+
+    def has_unfinished_requests(self) -> bool:
+        """Returns True if there are unfinished requests."""
+        return len(self.requests) > 0
+
+    def step(self) -> List[RequestOutput]:
+        # NOTE(woosuk): This method may return an empty list when the
+        # detokenizer is still processing the outputs. This should not be
+        # considered as the end of the generation process.
+        # FIXME(woosuk): Currently, the step method is inefficient because it
+        # creates RequestOutput objects for all running requests, while they
+        # may not be needed unless the output is streamed to the client.
+        if self.scheduler.has_unfinished_requests():
+            scheduler_output = self.scheduler.schedule()
+            output = self.model_executor.execute_model(scheduler_output)
+            sampled = self.scheduler.update_from_output(
+                scheduler_output, output)
+            self.send_to_detokenizer(sampled)
+        req_outputs = self.recv_from_detokenizer()
+        return req_outputs
+
+    def send_to_detokenizer(self, sampled: List[Tuple[Request, int]]) -> None:
+        inputs = DetokenizerInputs(
+            req_ids=[],
+            prompt_token_ids=[],
+            new_token_ids=[],
+            skip_special_tokens=[],
+            spaces_between_special_tokens=[],
+            free_req_ids=[],  # TODO(woosuk): Implement freeing.
+        )
+        for req, num_tokens in sampled:
+            inputs.req_ids.append(req.request_id)
+            if len(req.output_token_ids) == num_tokens:
+                # The request is first detokenized.
+                inputs.prompt_token_ids.append(req.prompt_token_ids)
+            else:
+                # The prompt token ids are already cached in the detokenizer.
+                inputs.prompt_token_ids.append([])
+            inputs.new_token_ids.append(req.output_token_ids[-num_tokens:])
+            inputs.skip_special_tokens.append(
+                req.sampling_params.skip_special_tokens)
+            inputs.spaces_between_special_tokens.append(
+                req.sampling_params.spaces_between_special_tokens)
+
+            # Update the number of lagged steps.
+            self.num_lagged_steps[req.request_id] += 1
+        self.detokenizer.send(inputs)
+
+    def recv_from_detokenizer(self) -> List[RequestOutput]:
+        detokenizer_output = self.detokenizer.recv()
+        if detokenizer_output is None:
+            return []
+
+        req_outputs: List[RequestOutput] = []
+        num_reqs = len(detokenizer_output.req_ids)
+        for i in range(num_reqs):
+            req_id = detokenizer_output.req_ids[i]
+            if req_id not in self.requests:
+                # The request has been aborted while the detokenizer was
+                # processing the outputs.
+                continue
+
+            req = self.requests[req_id]
+            req.output_text += detokenizer_output.detokenized_texts[i]
+
+            self.num_lagged_steps[req_id] -= 1
+            finished = (self.num_lagged_steps[req_id] == 0
+                        and req.is_finished())
+            req_output = self._make_request_output(
+                req, detokenizer_output.num_output_token_ids[i],
+                detokenizer_output.detokenized_texts[i], finished)
+            req_outputs.append(req_output)
+
+            if finished:
+                self._free_request(req_id)
+        return req_outputs
+
+    def terminate_detokenizer(self) -> None:
+        self.detokenizer.terminate()
+
+    def _make_request_output(
+        self,
+        request: Request,
+        num_output_tokens: int,
+        new_output_text: str,
+        finished: bool,
+    ) -> RequestOutput:
+        req_output = self.request_outputs.get(request.request_id)
+        if req_output is None:
+            # TODO: Support `n` > 1.
+            completion_output = CompletionOutput(
+                index=0,
+                text="",
+                token_ids=[],
+                cumulative_logprob=None,
+                logprobs=None,  # TODO
+                finish_reason=None,
+                stop_reason=None,
+                lora_request=None,
+            )
+            req_output = RequestOutput(
+                request_id=request.request_id,
+                prompt=request.prompt,
+                prompt_token_ids=request.prompt_token_ids,
+                prompt_logprobs=None,  # TODO
+                outputs=[completion_output],
+                finished=False,
+                metrics=None,
+                lora_request=None,
+                encoder_prompt=None,
+                encoder_prompt_token_ids=None,
+            )
+            self.request_outputs[request.request_id] = req_output
+
+        completion_output = req_output.outputs[0]
+        if request.sampling_params.output_kind == RequestOutputKind.CUMULATIVE:
+            completion_output.text += new_output_text
+            completion_output.token_ids = (
+                request.output_token_ids[:num_output_tokens])
+        elif request.sampling_params.output_kind == RequestOutputKind.DELTA:
+            completion_output.text = new_output_text
+            num_prev_tokens = len(completion_output.token_ids)
+            completion_output.token_ids = request.output_token_ids[
+                num_prev_tokens:num_output_tokens]
+        elif (request.sampling_params.output_kind ==
+              RequestOutputKind.FINAL_ONLY):
+            if finished:
+                completion_output.text = request.output_text
+                completion_output.token_ids = request.output_token_ids
+            else:
+                completion_output.text = ""
+                completion_output.token_ids = []
+
+        if finished:
+            completion_output.finish_reason = request.get_finished_reason()
+            completion_output.stop_reason = request.stop_reason
+            req_output.finished = finished
+        return req_output
+
+    def _free_request(self, request_id: str) -> None:
+        self.requests.pop(request_id, None)
+        self.num_lagged_steps.pop(request_id, None)
+        self.request_outputs.pop(request_id, None)
+
+    def check_health(self) -> None:
+        if self.tokenizer:
+            self.tokenizer.check_health()
+        self.model_executor.check_health()
+
+    def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
+                                                   EncoderDecoderLLMInputs]):
+        prompt_ids = inputs.get("prompt_token_ids")
+        if prompt_ids is None or len(prompt_ids) == 0:
+            raise ValueError("Prompt cannot be empty")
+
+        if self.model_config.is_multimodal_model:
+            max_prompt_len = self.model_config.max_model_len
+
+            if len(prompt_ids) > max_prompt_len:
+                raise ValueError(
+                    f"The prompt (total length {len(prompt_ids)}) is too long "
+                    f"to fit into the model (context length {max_prompt_len}). "
+                    "Make sure that `max_model_len` is no smaller than the "
+                    "number of text tokens plus multimodal tokens. For image "
+                    "inputs, the number of image tokens depends on the number "
+                    "of images, and possibly their aspect ratios as well.")
+
+    @classmethod
+    def validate_outputs(cls, outputs, output_type):
+        return outputs
+
+    def get_model_config(self) -> ModelConfig:
+        """Gets the model configuration."""
+        return self.model_config
+
+    def get_parallel_config(self) -> ParallelConfig:
+        """Gets the parallel configuration."""
+        return self.parallel_config
+
+    def get_decoding_config(self) -> DecodingConfig:
+        """Gets the decoding configuration."""
+        return self.decoding_config
+
+    def get_scheduler_config(self) -> SchedulerConfig:
+        """Gets the scheduler configuration."""
+        return self.scheduler_config
+
+    def get_lora_config(self) -> LoRAConfig:
+        """Gets the LoRA configuration."""
+        return self.lora_config
+
+    @classmethod
+    def _get_executor_cls(cls, engine_config: EngineConfig):
+        return GPUExecutor
+
+    def is_tracing_enabled(self) -> bool:
+        return False
+
+    def do_log_stats(self, *args, **kwargs) -> None:
+        pass
+
+    def is_encoder_decoder_model(self) -> bool:
+        return False
+
+    def start_profile(self) -> None:
+        pass
+
+    def stop_profile(self) -> None:
+        pass
+
+    def get_tokenizer_group(self, *args, **kwargs):
+        return self.tokenizer
+
+
+def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
+    config = try_get_generation_config(
+        model_config.model,
+        trust_remote_code=model_config.trust_remote_code,
+        revision=model_config.revision,
+    )
+
+    if config is None:
+        return {}
+
+    return config.to_diff_dict()
diff --git a/vllm/v1/executor/__init__.py b/vllm/v1/executor/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/executor/gpu_executor.py b/vllm/v1/executor/gpu_executor.py
new file mode 100644
index 000000000000..c780c7031c3d
--- /dev/null
+++ b/vllm/v1/executor/gpu_executor.py
@@ -0,0 +1,100 @@
+import os
+from typing import Optional, Tuple
+
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+                         ModelConfig, ObservabilityConfig, ParallelConfig,
+                         PromptAdapterConfig, SchedulerConfig,
+                         SpeculativeConfig)
+from vllm.logger import init_logger
+from vllm.utils import get_distributed_init_method, get_ip, get_open_port
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.worker.gpu_worker import Worker
+
+logger = init_logger(__name__)
+
+
+class GPUExecutor:
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        load_config: LoadConfig,
+        lora_config: Optional[LoRAConfig],
+        speculative_config: Optional[SpeculativeConfig],
+        prompt_adapter_config: Optional[PromptAdapterConfig],
+        observability_config: Optional[ObservabilityConfig],
+    ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.load_config = load_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.speculative_config = speculative_config
+        self.prompt_adapter_config = prompt_adapter_config
+        self.observability_config = observability_config
+
+        self.worker = self._create_worker()
+        self.worker.initialize()
+        self.worker.load_model()
+
+    def _create_worker(
+            self,
+            local_rank: int = 0,
+            rank: int = 0,
+            distributed_init_method: Optional[str] = None) -> Worker:
+        """Return worker init args for a given rank."""
+        # see https://github.com/NVIDIA/nccl/issues/1234
+        os.environ['NCCL_CUMEM_ENABLE'] = '0'
+
+        if distributed_init_method is None:
+            distributed_init_method = get_distributed_init_method(
+                get_ip(), get_open_port())
+        return Worker(
+            model_config=self.model_config,
+            parallel_config=self.parallel_config,
+            scheduler_config=self.scheduler_config,
+            device_config=self.device_config,
+            cache_config=self.cache_config,
+            load_config=self.load_config,
+            local_rank=local_rank,
+            rank=rank,
+            distributed_init_method=distributed_init_method,
+            lora_config=self.lora_config,
+            speculative_config=self.speculative_config,
+            prompt_adapter_config=self.prompt_adapter_config,
+            observability_config=self.observability_config,
+        )
+
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
+        """Determine the number of available KV blocks by invoking the
+        underlying worker.
+        """
+        return self.worker.determine_num_available_blocks()
+
+    def initialize_cache(self, num_gpu_blocks: int) -> None:
+        """Initialize the KV cache by invoking the underlying worker.
+        """
+        # NOTE: This is logged in the executor because there can be >1 worker
+        # with other executors. We could log in the engine level, but work
+        # remains to abstract away the device for non-GPU configurations.
+        logger.info("# GPU blocks: %d", num_gpu_blocks)
+        self.worker.initialize_cache(num_gpu_blocks)
+        self.worker.compile_or_warm_up_model()
+
+    def execute_model(
+        self,
+        scheduler_output,
+    ) -> ModelRunnerOutput:
+        output = self.worker.execute_model(scheduler_output)
+        return output
+
+    def check_health(self) -> None:
+        # GPUExecutor will always be healthy as long as
+        # it's running.
+        return
diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py
new file mode 100644
index 000000000000..857498772884
--- /dev/null
+++ b/vllm/v1/outputs.py
@@ -0,0 +1,37 @@
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import torch
+
+
+@dataclass
+class SamplerOutput:
+
+    # [num_reqs]
+    sampled_token_ids: torch.Tensor
+
+    # [num_reqs, max_num_logprobs + 1]
+    logprob_token_ids: Optional[torch.Tensor]
+    # [num_reqs, max_num_logprobs + 1]
+    logprobs: Optional[torch.Tensor]
+
+    # TODO: Support prompt logprobs.
+    prompt_logprob_token_ids: Optional[torch.Tensor]
+    prompt_logprobs: Optional[torch.Tensor]
+
+
+@dataclass
+class ModelRunnerOutput:
+
+    # [num_reqs]
+    req_ids: List[str]
+    # req_id -> index
+    req_id_to_index: Dict[str, int]
+
+    # [num_reqs]
+    sampled_token_ids_cpu: torch.Tensor
+
+    # [num_reqs, max_num_logprobs + 1]
+    logprob_token_ids_cpu: Optional[torch.Tensor]
+    # [num_reqs, max_num_logprobs + 1]
+    logprobs_cpu: Optional[torch.Tensor]
diff --git a/vllm/v1/request.py b/vllm/v1/request.py
new file mode 100644
index 000000000000..be7d4d165d28
--- /dev/null
+++ b/vllm/v1/request.py
@@ -0,0 +1,92 @@
+import enum
+from typing import TYPE_CHECKING, List, Optional, Union
+
+from vllm.lora.request import LoRARequest
+from vllm.sampling_params import SamplingParams
+from vllm.sequence import RequestMetrics
+
+if TYPE_CHECKING:
+    from vllm.inputs import DecoderOnlyInputs
+
+
+class Request:
+
+    def __init__(
+        self,
+        request_id: str,
+        inputs: "DecoderOnlyInputs",
+        sampling_params: SamplingParams,
+        eos_token_id: Optional[int],
+        arrival_time: float,
+        lora_request: Optional[LoRARequest] = None,
+    ) -> None:
+        self.request_id = request_id
+        self.inputs = inputs
+        self.sampling_params = sampling_params
+        # Because of LoRA, the eos token id can be different for each request.
+        self.eos_token_id = eos_token_id
+        self.metrics = RequestMetrics(arrival_time=arrival_time,
+                                      last_token_time=arrival_time,
+                                      first_scheduled_time=None,
+                                      first_token_time=None,
+                                      time_in_queue=None)
+        self.lora_request = lora_request
+
+        self.status = RequestStatus.WAITING
+        self.stop_reason: Union[int, str, None] = None
+        assert sampling_params.max_tokens is not None
+        self.max_tokens = sampling_params.max_tokens
+
+        self.prompt = inputs.get("prompt")
+        self.prompt_token_ids = inputs["prompt_token_ids"]
+        self.num_prompt_tokens = len(self.prompt_token_ids)
+        self.output_token_ids: List[int] = []
+        self.output_text = ""
+        self.num_computed_tokens = 0
+
+    @property
+    def num_tokens(self) -> int:
+        return self.num_prompt_tokens + len(self.output_token_ids)
+
+    @property
+    def num_output_tokens(self) -> int:
+        return len(self.output_token_ids)
+
+    def is_finished(self) -> bool:
+        return RequestStatus.is_finished(self.status)
+
+    def get_finished_reason(self) -> Union[str, None]:
+        return RequestStatus.get_finished_reason(self.status)
+
+
+class RequestStatus(enum.IntEnum):
+    """Status of a sequence."""
+    WAITING = 0
+    RUNNING = 1
+    PREEMPTED = 2
+    # Note: anything after PREEMPTED (2) will be considered
+    # as a finished status.
+    FINISHED_STOPPED = 3
+    FINISHED_LENGTH_CAPPED = 4
+    FINISHED_ABORTED = 5
+    FINISHED_IGNORED = 6
+
+    @staticmethod
+    def is_finished(status: "RequestStatus") -> bool:
+        return status > RequestStatus.PREEMPTED
+
+    @staticmethod
+    def get_finished_reason(status: "RequestStatus") -> Union[str, None]:
+        return _FINISHED_REASON_MAP.get(status)
+
+
+# Mapping of finished statuses to their finish reasons.
+# NOTE: The ignored sequences are the sequences whose prompt lengths
+# are longer than the model's length cap. Therefore, the stop
+# reason should also be "length" as in OpenAI API.
+_FINISHED_REASON_MAP = {
+    RequestStatus.FINISHED_STOPPED: "stop",
+    RequestStatus.FINISHED_LENGTH_CAPPED: "length",
+    RequestStatus.FINISHED_ABORTED: "abort",
+    RequestStatus.FINISHED_IGNORED: "length",
+}
diff --git a/vllm/v1/sample/__init__.py b/vllm/v1/sample/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py
new file mode 100644
index 000000000000..28614377b27b
--- /dev/null
+++ b/vllm/v1/sample/metadata.py
@@ -0,0 +1,22 @@
+from dataclasses import dataclass
+from typing import List, Optional
+
+import torch
+
+
+@dataclass
+class SamplingMetadata:
+
+    temperature: torch.Tensor
+    all_greedy: bool
+    all_random: bool
+
+    top_p: torch.Tensor
+    top_k: torch.Tensor
+    no_top_p: bool
+    no_top_k: bool
+
+    generators: List[Optional[torch.Generator]]
+    no_generator: bool
+
+    max_num_logprobs: int
diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py
new file mode 100644
index 000000000000..157c4dd6d771
--- /dev/null
+++ b/vllm/v1/sample/sampler.py
@@ -0,0 +1,161 @@
+"""A layer that samples the next tokens from the model's outputs."""
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+
+from vllm.v1.outputs import SamplerOutput
+from vllm.v1.sample.metadata import SamplingMetadata
+
+_SAMPLING_EPS = 1e-5
+
+
+class Sampler(nn.Module):
+
+    def forward(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> SamplerOutput:
+        logits = self.apply_temperature(logits, sampling_metadata.temperature)
+        logits = self.apply_top_k_top_p(logits, sampling_metadata)
+
+        probs = self.get_probs(logits)
+        sampled = self.sample(probs, sampling_metadata)
+        # Use int32 to reduce the tensor size.
+        sampled = sampled.to(torch.int32)
+
+        if sampling_metadata.max_num_logprobs > 0:
+            logprobs = self.get_logprobs(logits)
+            # FIXME: Mask the sampled token_id, get topk logprobs,
+            # and concatenate the topk with the sampled token_id.
+            topk_logprobs, topk_indices = torch.topk(
+                logprobs, sampling_metadata.max_num_logprobs, dim=-1)
+            # Use int32 to reduce the tensor size.
+            topk_indices = topk_indices.to(torch.int32)
+        else:
+            topk_logprobs = None
+            topk_indices = None
+
+        sampler_output = SamplerOutput(
+            sampled_token_ids=sampled,
+            logprob_token_ids=topk_indices,
+            logprobs=topk_logprobs,
+            prompt_logprob_token_ids=None,
+            prompt_logprobs=None,
+        )
+        return sampler_output
+
+    def apply_temperature(
+        self,
+        logits: torch.Tensor,
+        temp: torch.Tensor,
+    ) -> torch.Tensor:
+        # Use float32 to apply temperature scaling.
+        logits = logits.to(torch.float32)
+        # Avoid division by zero.
+        temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
+        # Use in-place division to avoid creating a new tensor.
+        logits.div_(temp.unsqueeze(dim=1))
+        return logits
+
+    def apply_top_k_top_p(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> torch.Tensor:
+        return _apply_top_k_top_p(
+            logits,
+            sampling_metadata.no_top_k,
+            sampling_metadata.top_k,
+            sampling_metadata.no_top_p,
+            sampling_metadata.top_p,
+        )
+
+    def get_probs(self, logits: torch.Tensor) -> torch.Tensor:
+        return torch.softmax(logits, dim=-1, dtype=torch.float32)
+
+    def get_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
+        return torch.log_softmax(logits, dim=-1, dtype=torch.float32)
+
+    def greedy_sample(self, probs: torch.Tensor) -> torch.Tensor:
+        return probs.argmax(dim=-1).view(-1)
+
+    def random_sample(
+        self,
+        probs: torch.Tensor,
+        generators: List[Optional[torch.Generator]],
+        no_generator: bool,
+    ) -> torch.Tensor:
+        q = torch.empty_like(probs)
+        # NOTE(woosuk): To batch-process the requests without their own seeds,
+        # which is the common case, we first assume that every request does
+        # not have its own seed. Then, we overwrite the values for the requests
+        # that have their own seeds.
+        q.exponential_()
+        if not no_generator:
+            assert len(generators) == probs.shape[0]
+            # TODO(woosuk): This can be slow because we handle each request
+            # one by one. Optimize this.
+            for i, generator in enumerate(generators):
+                if generator is not None:
+                    q[i].exponential_(generator=generator)
+        return probs.div_(q).argmax(dim=-1).view(-1)
+
+    def sample(
+        self,
+        probs: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> torch.Tensor:
+        assert not (sampling_metadata.all_greedy
+                    and sampling_metadata.all_random)
+        if sampling_metadata.all_greedy:
+            return self.greedy_sample(probs)
+        if sampling_metadata.all_random:
+            return self.random_sample(probs, sampling_metadata.generators,
+                                      sampling_metadata.no_generator)
+
+        greedy_sampled = self.greedy_sample(probs)
+        random_sampled = self.random_sample(probs,
+                                            sampling_metadata.generators,
+                                            sampling_metadata.no_generator)
+        sampled = torch.where(
+            sampling_metadata.temperature < _SAMPLING_EPS,
+            greedy_sampled,
+            random_sampled,
+        )
+        return sampled
+
+
+# TODO(woosuk): Optimize this with a custom kernel.
+def _apply_top_k_top_p(
+    logits: torch.Tensor,
+    no_top_k: bool,
+    k: torch.Tensor,
+    no_top_p: bool,
+    p: torch.Tensor,
+) -> torch.Tensor:
+    if no_top_k and no_top_p:
+        return logits
+    logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
+
+    if not no_top_k:
+        # Apply top-k.
+        top_k_mask = logits_sort.size(1) - k.to(torch.long)
+        # Get all the top_k values.
+        top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
+        top_k_mask = logits_sort < top_k_mask
+        logits_sort.masked_fill_(top_k_mask, -float("inf"))
+
+    if not no_top_p:
+        # Apply top-p.
+        probs_sort = logits_sort.softmax(dim=-1)
+        probs_sum = probs_sort.cumsum(dim=-1)
+        top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
+        # at least one
+        top_p_mask[:, -1] = False
+        logits_sort.masked_fill_(top_p_mask, -float("inf"))
+
+    # Re-sort the probabilities.
+    logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
+    return logits
diff --git a/vllm/v1/tokenizer/__init__.py b/vllm/v1/tokenizer/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/tokenizer/detokenizer.py b/vllm/v1/tokenizer/detokenizer.py
new file mode 100644
index 000000000000..4bbcf4717981
--- /dev/null
+++ b/vllm/v1/tokenizer/detokenizer.py
@@ -0,0 +1,215 @@
+import multiprocessing
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+import msgspec
+import zmq
+from msgspec import msgpack
+
+from vllm.transformers_utils.detokenizer_utils import (
+    convert_prompt_ids_to_tokens, detokenize_incrementally)
+from vllm.transformers_utils.tokenizer import get_tokenizer
+from vllm.utils import get_open_port
+
+
+class DetokenizerInputs(msgspec.Struct):
+
+    # [num_reqs]
+    req_ids: List[str]
+    # A request's prompt token ids is sent to the detokenizer only when
+    # the request is first detokenized. Otherwise, an empty list is sent.
+    prompt_token_ids: List[List[int]]
+    new_token_ids: List[List[int]]
+    skip_special_tokens: List[bool]
+    spaces_between_special_tokens: List[bool]
+
+    # [num_free_reqs]
+    free_req_ids: List[str]
+
+
+class DetokenizerOutputs(msgspec.Struct):
+
+    # [num_reqs]
+    req_ids: List[str]
+    detokenized_texts: List[str]
+    # NOTE(woosuk): The number of the output token ids of each request
+    # at the time of detokenization. The detokenizer returns this to the engine
+    # because the request state (including the output token ids) is
+    # asynchronously updated in the engine, while RequestOutput requires the
+    # output token ids to be consistent with the detokenized text.
+    num_output_token_ids: List[int]
+
+
+class Detokenizer:
+
+    def __init__(self, tokenizer_name: str):
+        # FIXME(woosuk): Currently, the detokenizer is just a hacky prototype.
+        # For example, it does not terminate properly. We need to improve this.
+        self.push_port = get_open_port()
+        self.pull_port = get_open_port()
+        self.detokenizer = DetokenizerProc(tokenizer_name, self.push_port,
+                                           self.pull_port)
+        self.detokenizer.start()
+
+        self.zmq_context = zmq.Context()
+        self.push_socket = self.zmq_context.socket(zmq.PUSH)
+        self.push_socket.connect(f"tcp://localhost:{self.push_port}")
+        self.pull_socket = self.zmq_context.socket(zmq.PULL)
+        self.pull_socket.connect(f"tcp://localhost:{self.pull_port}")
+        self.poller = zmq.Poller()
+        self.poller.register(self.pull_socket, zmq.POLLIN)
+        self.msgpack_encoder = msgpack.Encoder()
+        self.msgpack_decoder = msgpack.Decoder(DetokenizerOutputs)
+
+    def send(self, inputs: DetokenizerInputs) -> None:
+        self.push_socket.send(self.msgpack_encoder.encode(inputs),
+                              flags=zmq.NOBLOCK)
+
+    def recv(self) -> Optional[DetokenizerOutputs]:
+        socks = dict(self.poller.poll(timeout=0))
+        if self.pull_socket in socks and socks[self.pull_socket] == zmq.POLLIN:
+            msg = self.pull_socket.recv()
+            return self.msgpack_decoder.decode(msg)
+        return None
+
+    def terminate(self) -> None:
+        self.push_socket.send(b"", flags=zmq.NOBLOCK)
+        self.detokenizer.join()
+
+
+class DetokenizerProc(multiprocessing.Process):
+
+    def __init__(
+        self,
+        tokenizer_name: str,
+        pull_port: int,
+        push_port: int,
+    ):
+        super().__init__()
+        self.tokenizer_name = tokenizer_name
+        # NOTE: The pull_port of the detokenizer should be the same as the
+        # push_port of the engine. Vice versa.
+        self.pull_port = pull_port
+        self.push_port = push_port
+
+    def run(self):
+        # Initialize these objects after the process is forked since they are
+        # not picklable.
+        self.msgpack_encoder = msgpack.Encoder()
+        self.msgpack_decoder = msgpack.Decoder(DetokenizerInputs)
+        self.tokenizer = get_tokenizer(self.tokenizer_name)
+        # req_id -> RequestState
+        self.request_states: Dict[str, RequestState] = {}
+
+        self.zmq_context = zmq.Context()
+        self.pull_socket = self.zmq_context.socket(zmq.PULL)
+        self.pull_socket.bind(f"tcp://*:{self.pull_port}")
+        self.push_socket = self.zmq_context.socket(zmq.PUSH)
+        self.push_socket.bind(f"tcp://*:{self.push_port}")
+
+        while True:
+            message = self.pull_socket.recv()
+            if message == b"":
+                # Terminate signal.
+                break
+            inputs = self.msgpack_decoder.decode(message)
+
+            for req_id in inputs.free_req_ids:
+                self.free(req_id)
+
+            detokenized_texts: List[str] = []
+            num_output_token_ids: List[int] = []
+            num_reqs = len(inputs.req_ids)
+            for i in range(num_reqs):
+                req_id = inputs.req_ids[i]
+                if req_id not in self.request_states:
+                    self.add_request(
+                        request_id=req_id,
+                        prompt_token_ids=inputs.prompt_token_ids[i],
+                        skip_special_tokens=inputs.skip_special_tokens[i],
+                        spaces_between_special_tokens=inputs.
+                        spaces_between_special_tokens[i],
+                    )
+                new_str = self.detokenize(req_id, inputs.new_token_ids[i])
+                detokenized_texts.append(new_str)
+                req_state = self.request_states[req_id]
+                num_output_token_ids.append(
+                    len(req_state.token_ids) - req_state.num_prompt_tokens)
+
+            detokenized = DetokenizerOutputs(
+                req_ids=inputs.req_ids,
+                detokenized_texts=detokenized_texts,
+                num_output_token_ids=num_output_token_ids,
+            )
+            self.push_socket.send(self.msgpack_encoder.encode(detokenized),
+                                  flags=zmq.NOBLOCK)
+
+    def add_request(
+        self,
+        request_id: str,
+        prompt_token_ids: List[int],
+        skip_special_tokens: bool,
+        spaces_between_special_tokens: bool,
+    ) -> None:
+        tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
+            tokenizer=self.tokenizer,
+            prompt_ids=prompt_token_ids,
+            skip_special_tokens=skip_special_tokens,
+        )
+        self.request_states[request_id] = RequestState(
+            req_id=request_id,
+            token_ids=prompt_token_ids,
+            tokens=tokens,
+            num_prompt_tokens=len(prompt_token_ids),
+            prefix_offset=prefix_offset,
+            read_offset=read_offset,
+            skip_special_tokens=skip_special_tokens,
+            spaces_between_special_tokens=spaces_between_special_tokens,
+        )
+
+    def free(self, request_id: str) -> None:
+        del self.request_states[request_id]
+
+    def detokenize(self, request_id: str, new_token_ids: List[int]) -> str:
+        # TODO(woosuk): This method becomes very inefficient when the number of
+        # new_token_ids is more than 1. We need to optimize this.
+        req_state = self.request_states[request_id]
+        decoded_text = ""
+        for new_token_id in new_token_ids:
+            req_state.token_ids.append(new_token_id)
+            (new_tokens, new_decoded_token_text, prefix_offset,
+             read_offset) = detokenize_incrementally(
+                 tokenizer=self.tokenizer,
+                 all_input_ids=req_state.token_ids,
+                 prev_tokens=req_state.tokens,
+                 prefix_offset=req_state.prefix_offset,
+                 read_offset=req_state.read_offset,
+                 skip_special_tokens=req_state.skip_special_tokens,
+                 spaces_between_special_tokens=req_state.
+                 spaces_between_special_tokens,
+             )
+
+            req_state.tokens.extend(new_tokens)
+            req_state.prefix_offset = prefix_offset
+            req_state.read_offset = read_offset
+            req_state.output_text += new_decoded_token_text
+            decoded_text += new_decoded_token_text
+        return decoded_text
+
+
+@dataclass
+class RequestState:
+
+    req_id: str
+
+    token_ids: List[int]
+    tokens: List[str]
+    num_prompt_tokens: int
+
+    prefix_offset: int
+    read_offset: int
+
+    skip_special_tokens: bool
+    spaces_between_special_tokens: bool
+
+    output_text: str = ""
diff --git a/vllm/v1/worker/__init__.py b/vllm/v1/worker/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
new file mode 100644
index 000000000000..e84645ac7a4a
--- /dev/null
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -0,0 +1,690 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Optional, Set
+from unittest.mock import patch
+
+import numpy as np
+import torch
+import torch.distributed
+import torch.nn as nn
+
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+                         ModelConfig, ObservabilityConfig, ParallelConfig,
+                         PromptAdapterConfig, SchedulerConfig)
+from vllm.forward_context import set_forward_context
+from vllm.logger import init_logger
+from vllm.model_executor.model_loader import get_model
+from vllm.multimodal import MultiModalDataDict
+from vllm.sampling_params import SamplingParams, SamplingType
+from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
+                        is_pin_memory_available)
+from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
+                                                   FlashAttentionMetadata)
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.sample.metadata import SamplingMetadata
+from vllm.v1.sample.sampler import Sampler
+
+if TYPE_CHECKING:
+    from vllm.v1.core.scheduler import SchedulerOutput
+
+logger = init_logger(__name__)
+
+
+class GPUModelRunner:
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        cache_config: CacheConfig,
+        load_config: LoadConfig,
+        lora_config: Optional[LoRAConfig] = None,
+        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
+        observability_config: Optional[ObservabilityConfig] = None,
+    ):
+        self.model_config = model_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.load_config = load_config
+        self.prompt_adapter_config = prompt_adapter_config
+        self.observability_config = observability_config
+
+        self.device = self.device_config.device
+        self.pin_memory = is_pin_memory_available()
+        self.dtype = self.model_config.dtype
+        if cache_config.cache_dtype == "auto":
+            self.kv_cache_dtype = self.dtype
+        else:
+            self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
+                cache_config.cache_dtype]
+
+        self.sliding_window = model_config.get_sliding_window()
+        self.block_size = cache_config.block_size
+        self.max_model_len = model_config.max_model_len
+        self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
+        self.max_num_tokens = scheduler_config.max_num_batched_tokens
+
+        # Model-related.
+        self.num_attn_layers = model_config.get_num_attention_layers(
+            parallel_config)
+        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
+        self.head_size = model_config.get_head_size()
+
+        # Lazy initialization
+        # self.model: nn.Module  # Set after load_model
+        self.kv_caches: List[torch.Tensor] = []
+
+        # Request states.
+        self.requests: Dict[str, CachedRequestState] = {}
+        # Persistent batch.
+        self.input_batch = InputBatch(
+            max_num_reqs=self.scheduler_config.max_num_seqs,
+            max_model_len=self.max_model_len,
+            max_num_blocks_per_req=self.max_num_blocks_per_req,
+            device=self.device,
+            pin_memory=self.pin_memory,
+        )
+
+    def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
+        # Remove stopped requests from the cached states.
+        # Keep the states of the pre-empted requests.
+        for req_id in scheduler_output.finished_req_ids:
+            self.requests.pop(req_id, None)
+
+        # Remove the requests from the persistent batch.
+        stopped_req_ids = set().union(
+            scheduler_output.preempted_req_ids,
+            scheduler_output.finished_req_ids,
+        )
+        removed_req_indices: List[int] = []
+        for req_id in stopped_req_ids:
+            req_index = self.input_batch.remove_request(req_id)
+            if req_index is not None:
+                removed_req_indices.append(req_index)
+
+        # Update the states of the running requests.
+        for req_data in scheduler_output.scheduled_running_reqs:
+            req_id = req_data.req_id
+            req_state = self.requests[req_id]
+            req_index = self.input_batch.req_id_to_index[req_id]
+
+            # Update the num_computed_tokens.
+            req_state.num_computed_tokens = req_data.num_computed_tokens
+            self.input_batch.num_computed_tokens_cpu[req_index] = (
+                req_data.num_computed_tokens)
+
+            # Update the block table.
+            num_new_blocks = len(req_data.new_block_ids)
+            if num_new_blocks == 0:
+                continue
+            start_index = len(req_state.block_ids)
+            end_index = start_index + num_new_blocks
+            req_state.block_ids.extend(req_data.new_block_ids)
+            self.input_batch.block_table_cpu[
+                req_index, start_index:end_index] = req_data.new_block_ids
+
+        req_ids_to_add: List[str] = []
+        # Add new requests to the cached states.
+        for req_data in scheduler_output.scheduled_new_reqs:
+            req_id = req_data.req_id
+            self.requests[req_id] = CachedRequestState(
+                req_id=req_id,
+                prompt_token_ids=req_data.prompt_token_ids,
+                prompt=req_data.prompt,
+                multi_modal_data=req_data.multi_modal_data,
+                sampling_params=req_data.sampling_params,
+                generator=None,  # TODO
+                block_ids=req_data.block_ids,
+                num_computed_tokens=req_data.num_computed_tokens,
+                output_token_ids=[],
+            )
+            req_ids_to_add.append(req_id)
+
+        # Update the cached states of the resumed requests.
+        for req_data in scheduler_output.scheduled_resumed_reqs:
+            req_id = req_data.req_id
+            req_state = self.requests[req_id]
+
+            req_state.block_ids = req_data.block_ids
+            req_state.num_computed_tokens = req_data.num_computed_tokens
+            req_ids_to_add.append(req_id)
+
+        # Add the new or resumed requests to the persistent batch.
+        # The smaller empty indices are filled first.
+        removed_req_indices = sorted(removed_req_indices, reverse=True)
+        for req_id in req_ids_to_add:
+            req_state = self.requests[req_id]
+            if removed_req_indices:
+                # Fill the empty index.
+                req_index = removed_req_indices.pop()
+            else:
+                # Append to the end.
+                req_index = None
+            self.input_batch.add_request(req_state, req_index)
+
+        # Condense the batched states if there are empty indices.
+        if removed_req_indices:
+            self.input_batch.condense(removed_req_indices)
+
+    def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
+        total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
+        assert total_num_scheduled_tokens > 0
+        num_reqs = self.input_batch.num_reqs
+        assert num_reqs > 0
+
+        # OPTIMIZATION: Start copying the block table first.
+        # This way, we can overlap the copy with the following CPU operations.
+        self.input_batch.block_table[:num_reqs].copy_(
+            self.input_batch.block_table_cpu_tensor[:num_reqs],
+            non_blocking=True)
+
+        # Get the number of scheduled tokens for each request.
+        # TODO: The Python loop can be slow. Optimize.
+        num_scheduled_tokens = []
+        max_num_scheduled_tokens = 0
+        for req_id in self.input_batch.req_ids[:num_reqs]:
+            num_tokens = scheduler_output.num_scheduled_tokens[req_id]
+            num_scheduled_tokens.append(num_tokens)
+            max_num_scheduled_tokens = max(max_num_scheduled_tokens,
+                                           num_tokens)
+        num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
+        assert max_num_scheduled_tokens > 0
+
+        # Get request indices.
+        # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
+        indices = np.arange(num_reqs)
+        req_indices = np.repeat(indices, num_scheduled_tokens)
+
+        # Get batched arange.
+        # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
+        arange_matrix = np.tile(np.arange(max_num_scheduled_tokens),
+                                (num_reqs, 1))
+        mask = arange_matrix < num_scheduled_tokens[:, np.newaxis]
+        arange = arange_matrix[mask]
+
+        # Get positions.
+        positions = torch.empty((total_num_scheduled_tokens, ),
+                                dtype=torch.int32,
+                                device="cpu",
+                                pin_memory=self.pin_memory)
+        positions_np = positions.numpy()
+        np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
+               arange,
+               out=positions_np)
+
+        # Get token indices.
+        # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
+        # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
+        # where M is the max_model_len.
+        token_indices = positions_np + req_indices * self.max_model_len
+        token_indices = torch.from_numpy(token_indices)
+        input_ids = torch.empty((total_num_scheduled_tokens, ),
+                                dtype=torch.int32,
+                                device="cpu",
+                                pin_memory=self.pin_memory)
+        torch.index_select(torch.from_numpy(
+            self.input_batch.token_ids_cpu).flatten(),
+                           0,
+                           token_indices,
+                           out=input_ids)
+
+        # Calculate the slot mapping.
+        block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[
+            token_indices // self.block_size]
+        block_offsets = token_indices % self.block_size
+        slot_mapping = torch.empty((total_num_scheduled_tokens, ),
+                                   dtype=torch.int32,
+                                   device="cpu",
+                                   pin_memory=self.pin_memory)
+        torch.add(block_numbers * self.block_size,
+                  block_offsets,
+                  out=slot_mapping)
+
+        # Prepare the attention metadata.
+        query_start_loc = torch.empty((num_reqs + 1, ),
+                                      dtype=torch.int32,
+                                      device="cpu",
+                                      pin_memory=self.pin_memory)
+        query_start_loc_np = query_start_loc.numpy()
+        query_start_loc_np[0] = 0
+        np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
+
+        seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
+                    num_scheduled_tokens)
+        max_seq_len = seq_lens.max()
+        seq_start_loc = torch.empty((num_reqs + 1, ),
+                                    dtype=torch.int32,
+                                    device="cpu",
+                                    pin_memory=self.pin_memory)
+        seq_start_loc_np = seq_start_loc.numpy()
+        seq_start_loc_np[0] = 0
+        np.cumsum(seq_lens, out=seq_start_loc_np[1:])
+
+        input_ids = input_ids.to(self.device, non_blocking=True)
+        positions = positions.to(self.device, non_blocking=True).long()
+        query_start_loc = query_start_loc.to(self.device, non_blocking=True)
+        seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
+        slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
+        attn_metadata = FlashAttentionMetadata(
+            max_query_len=max_num_scheduled_tokens,
+            query_start_loc=query_start_loc,
+            max_seq_len=max_seq_len,
+            seq_start_loc=seq_start_loc,
+            block_table=self.input_batch.block_table[:num_reqs],
+            slot_mapping=slot_mapping,
+        )
+        # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
+        # request in the batch. While we should not sample any token from this
+        # partial request, we do so for simplicity. We will ignore the sampled
+        # token from the partial request.
+        # TODO: Support prompt logprobs.
+        logits_indices = query_start_loc[1:] - 1
+        return input_ids, positions, attn_metadata, logits_indices
+
+    def _prepare_sampling(
+        self,
+        scheduler_output: "SchedulerOutput",
+    ) -> SamplingMetadata:
+        skip_copy = True
+        if (scheduler_output.finished_req_ids
+                or scheduler_output.preempted_req_ids):
+            skip_copy = False
+        if (scheduler_output.scheduled_new_reqs
+                or scheduler_output.scheduled_resumed_reqs):
+            skip_copy = False
+        # Create the sampling metadata.
+        sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
+        return sampling_metadata
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        scheduler_output: "SchedulerOutput",
+    ) -> ModelRunnerOutput:
+        self._update_states(scheduler_output)
+        inputs = self._prepare_inputs(scheduler_output)
+        input_ids, positions, attn_metadata, logits_indices = inputs
+
+        with set_forward_context(attn_metadata):
+            hidden_states = self.model(
+                input_ids=input_ids,
+                positions=positions,
+                kv_caches=self.kv_caches,
+                attn_metadata=attn_metadata,
+            )
+        hidden_states = hidden_states[logits_indices]
+        logits = self.model.compute_logits(hidden_states, None)
+
+        # Sample the next token and get logprobs if needed.
+        sampling_metadata = self._prepare_sampling(scheduler_output)
+        sampler_output = self.model.sample(
+            logits=logits,
+            sampling_metadata=sampling_metadata,
+        )
+
+        # NOTE: CPU-GPU synchronization happens here.
+        sampled_token_ids = sampler_output.sampled_token_ids.cpu()
+        sampled_token_ids_list = sampled_token_ids.tolist()
+        # TODO(woosuk): The following loop can be slow since it iterates over
+        # the requests one by one. Optimize.
+        num_reqs = self.input_batch.num_reqs
+        for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
+            req_state = self.requests[req_id]
+            seq_len = (req_state.num_computed_tokens +
+                       scheduler_output.num_scheduled_tokens[req_id])
+            assert seq_len <= req_state.num_tokens
+            if seq_len == req_state.num_tokens:
+                # Append the sampled token to the output token ids.
+                token_id = sampled_token_ids_list[i]
+                self.input_batch.token_ids_cpu[i, seq_len] = token_id
+                req_state.output_token_ids.append(token_id)
+            else:
+                # Ignore the sampled token from the partial request.
+                # Rewind the generator state as if the token was not sampled.
+                generator = self.input_batch.generators[i]
+                if generator is not None:
+                    offset = generator.get_offset()
+                    generator = generator.set_offset(offset - 1)
+                    self.input_batch.generators[i] = generator
+
+        if sampler_output.logprob_token_ids is None:
+            logprob_token_ids = None
+        else:
+            logprob_token_ids = sampler_output.logprob_token_ids.cpu()
+        if sampler_output.logprobs is None:
+            logprobs = None
+        else:
+            logprobs = sampler_output.logprobs.cpu()
+        model_runner_output = ModelRunnerOutput(
+            req_ids=self.input_batch.req_ids[:num_reqs],
+            req_id_to_index=self.input_batch.req_id_to_index,
+            sampled_token_ids_cpu=sampled_token_ids,
+            logprob_token_ids_cpu=logprob_token_ids,
+            logprobs_cpu=logprobs,
+        )
+        return model_runner_output
+
+    def load_model(self) -> None:
+        logger.info("Starting to load model %s...", self.model_config.model)
+        with DeviceMemoryProfiler() as m:  # noqa: SIM117
+            with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
+                self.model = get_model(model_config=self.model_config,
+                                       device_config=self.device_config,
+                                       load_config=self.load_config,
+                                       lora_config=self.lora_config,
+                                       parallel_config=self.parallel_config,
+                                       scheduler_config=self.scheduler_config,
+                                       cache_config=self.cache_config)
+
+        self.model_memory_usage = m.consumed_memory
+        logger.info("Loading model weights took %.4f GB",
+                    self.model_memory_usage / float(2**30))
+
+    def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
+        input_ids = torch.zeros(num_tokens,
+                                dtype=torch.int32,
+                                device=self.device)
+        positions = torch.zeros(num_tokens,
+                                dtype=torch.long,
+                                device=self.device)
+        kv_caches = [None for _ in range(self.num_attn_layers)]
+        model(input_ids, positions, kv_caches, attn_metadata=None)
+        return
+
+    @torch.inference_mode()
+    def profile_run(self) -> None:
+        self._dummy_run(self.model, self.max_num_tokens)
+        torch.cuda.synchronize()
+        return
+
+    @torch.inference_mode()
+    def capture_model(self) -> None:
+        # TODO: Implement CUDA graph support.
+        return
+
+    def initialize_kv_cache(self, num_blocks: int) -> None:
+        assert len(self.kv_caches) == 0
+        kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
+            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
+        for _ in range(self.num_attn_layers):
+            self.kv_caches.append(
+                torch.zeros(kv_cache_shape,
+                            dtype=self.kv_cache_dtype,
+                            device=self.device))
+
+
+@dataclass
+class CachedRequestState:
+
+    req_id: str
+    prompt_token_ids: List[int]
+    prompt: Optional[str]
+    multi_modal_data: Optional["MultiModalDataDict"]
+    sampling_params: SamplingParams
+    generator: Optional[torch.Generator]
+
+    block_ids: List[int]
+    num_computed_tokens: int
+    output_token_ids: List[int]
+
+    @property
+    def num_tokens(self) -> int:
+        return len(self.prompt_token_ids) + len(self.output_token_ids)
+
+
+class InputBatch:
+
+    def __init__(
+        self,
+        max_num_reqs: int,
+        max_model_len: int,
+        max_num_blocks_per_req: int,
+        device: torch.device,
+        pin_memory: bool,
+    ):
+        self.max_num_reqs = max_num_reqs
+        self.max_model_len = max_model_len
+        self.max_num_blocks_per_req = max_num_blocks_per_req
+        self.device = device
+        self.pin_memory = pin_memory
+
+        self.req_ids: List[Optional[str]] = [None] * max_num_reqs
+        self.req_id_to_index: Dict[str, int] = {}
+
+        self.token_ids_cpu = np.empty((max_num_reqs, max_model_len),
+                                      dtype=np.int32)
+        self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32)
+
+        # Attention-related.
+        self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req),
+                                       device=self.device,
+                                       dtype=torch.int32)
+        self.block_table_cpu_tensor = torch.zeros(
+            (max_num_reqs, max_num_blocks_per_req),
+            device="cpu",
+            dtype=torch.int32,
+            pin_memory=pin_memory,
+        )
+        self.block_table_cpu = self.block_table_cpu_tensor.numpy()
+
+        # Sampling-related.
+        self.temperature = torch.empty((max_num_reqs, ),
+                                       dtype=torch.float32,
+                                       device=device)
+        self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
+                                                  dtype=torch.float32,
+                                                  device="cpu",
+                                                  pin_memory=pin_memory)
+        self.temperature_cpu = self.temperature_cpu_tensor.numpy()
+        self.greedy_reqs: Set[str] = set()
+        self.random_reqs: Set[str] = set()
+
+        self.top_p = torch.empty((max_num_reqs, ),
+                                 dtype=torch.float32,
+                                 device=device)
+        self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
+                                            dtype=torch.float32,
+                                            device="cpu",
+                                            pin_memory=pin_memory)
+        self.top_p_cpu = self.top_p_cpu_tensor.numpy()
+        self.top_p_reqs: Set[str] = set()
+
+        self.top_k = torch.empty((max_num_reqs, ),
+                                 dtype=torch.int32,
+                                 device=device)
+        self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
+                                            dtype=torch.int32,
+                                            device="cpu",
+                                            pin_memory=pin_memory)
+        self.top_k_cpu = self.top_k_cpu_tensor.numpy()
+        self.top_k_reqs: Set[str] = set()
+
+        self.generators: List[Optional[torch.Generator]] = [None
+                                                            ] * max_num_reqs
+
+        self.num_logprobs: Dict[str, int] = {}
+        self.prompt_logprob_reqs: Set[str] = set()
+
+    def add_request(
+        self,
+        request: "CachedRequestState",
+        req_index: Optional[int] = None,
+    ) -> None:
+        if req_index is None:
+            req_index = self.num_reqs
+        assert req_index < self.max_num_reqs
+
+        self.req_ids[req_index] = request.req_id
+        self.req_id_to_index[request.req_id] = req_index
+
+        # Copy the prompt token ids and output token ids.
+        num_prompt_tokens = len(request.prompt_token_ids)
+        self.token_ids_cpu[
+            req_index, :num_prompt_tokens] = request.prompt_token_ids
+        start_idx = num_prompt_tokens
+        end_idx = start_idx + len(request.output_token_ids)
+        self.token_ids_cpu[req_index,
+                           start_idx:end_idx] = request.output_token_ids
+
+        self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
+        num_blocks = len(request.block_ids)
+        self.block_table_cpu[req_index, :num_blocks] = request.block_ids
+
+        sampling_params = request.sampling_params
+        self.temperature_cpu[req_index] = sampling_params.temperature
+        if sampling_params.sampling_type == SamplingType.GREEDY:
+            self.greedy_reqs.add(req_index)
+        elif sampling_params.sampling_type == SamplingType.RANDOM:
+            self.random_reqs.add(req_index)
+        elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
+            # TODO(woosuk): Support per-request random seed.
+            raise NotImplementedError("Per-request seed is not supported yet.")
+
+        self.top_p_cpu[req_index] = sampling_params.top_p
+        if sampling_params.top_p < 1:
+            self.top_p_reqs.add(req_index)
+        self.top_k_cpu[req_index] = sampling_params.top_k
+        if sampling_params.top_k > 0:
+            self.top_k_reqs.add(req_index)
+
+        self.generators[req_index] = request.generator
+
+        num_logprobs = sampling_params.logprobs
+        if num_logprobs is not None and num_logprobs > 0:
+            self.num_logprobs[request.req_id] = num_logprobs
+        if sampling_params.prompt_logprobs:
+            self.prompt_logprob_reqs.add(req_index)
+
+    def remove_request(self, req_id: str) -> Optional[int]:
+        req_index = self.req_id_to_index.pop(req_id, None)
+        if req_index is None:
+            return None
+        self.req_ids[req_index] = None
+
+        self.greedy_reqs.discard(req_id)
+        self.random_reqs.discard(req_id)
+        self.top_p_reqs.discard(req_id)
+        self.top_k_reqs.discard(req_id)
+        self.generators[req_index] = None
+        self.num_logprobs.pop(req_id, None)
+        self.prompt_logprob_reqs.discard(req_id)
+        return req_index
+
+    def clear(self) -> None:
+        self.req_ids = [None] * self.max_num_reqs
+        self.req_id_to_index.clear()
+        self.greedy_reqs.clear()
+        self.random_reqs.clear()
+        self.top_p_reqs.clear()
+        self.top_k_reqs.clear()
+        self.generators.clear()
+        self.num_logprobs.clear()
+        self.prompt_logprob_reqs.clear()
+
+    def condense(self, empty_req_indices: List[int]) -> None:
+        if self.num_reqs == 0:
+            # The batched states are empty.
+            return
+
+        # NOTE(woosuk): This function assumes that the empty_req_indices
+        # is sorted in descending order.
+        last_req_index = self.num_reqs + len(empty_req_indices) - 1
+        while empty_req_indices:
+            # Find the largest non-empty index.
+            while last_req_index in empty_req_indices:
+                last_req_index -= 1
+
+            # Find the smallest empty index.
+            empty_index = empty_req_indices.pop()
+            if empty_index >= last_req_index:
+                break
+
+            # Swap the states.
+            req_id = self.req_ids[last_req_index]
+            self.req_ids[empty_index] = req_id
+            self.req_ids[last_req_index] = None
+            self.req_id_to_index[req_id] = empty_index
+
+            # TODO(woosuk): Optimize the copy of token_ids_cpu and
+            # block_table_cpu.
+            self.token_ids_cpu[empty_index] = self.token_ids_cpu[
+                last_req_index]
+            self.num_computed_tokens_cpu[
+                empty_index] = self.num_computed_tokens_cpu[last_req_index]
+            self.block_table_cpu[empty_index] = self.block_table_cpu[
+                last_req_index]
+            self.temperature_cpu[empty_index] = self.temperature_cpu[
+                last_req_index]
+            self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
+            self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
+            self.generators[empty_index] = self.generators[last_req_index]
+
+            # Decrement last_req_index since it is now empty.
+            last_req_index -= 1
+
+    def make_sampling_metadata(
+        self,
+        skip_copy: bool = False,
+    ) -> SamplingMetadata:
+        if not skip_copy:
+            self.temperature[:self.num_reqs].copy_(
+                self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True)
+            self.top_p[:self.num_reqs].copy_(
+                self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
+            self.top_k[:self.num_reqs].copy_(
+                self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
+        return SamplingMetadata(
+            temperature=self.temperature[:self.num_reqs],
+            all_greedy=self.all_greedy,
+            all_random=self.all_random,
+            top_p=self.top_p[:self.num_reqs],
+            top_k=self.top_k[:self.num_reqs],
+            no_top_p=self.no_top_p,
+            no_top_k=self.no_top_k,
+            generators=self.generators[:self.num_reqs],
+            no_generator=self.no_generator,
+            max_num_logprobs=self.max_num_logprobs,
+        )
+
+    @property
+    def num_reqs(self) -> int:
+        return len(self.req_id_to_index)
+
+    @property
+    def all_greedy(self) -> bool:
+        return len(self.random_reqs) == 0
+
+    @property
+    def all_random(self) -> bool:
+        return len(self.greedy_reqs) == 0
+
+    @property
+    def no_top_p(self) -> bool:
+        return len(self.top_p_reqs) == 0
+
+    @property
+    def no_top_k(self) -> bool:
+        return len(self.top_k_reqs) == 0
+
+    @property
+    def no_generator(self) -> bool:
+        return len(self.generators) == 0
+
+    @property
+    def max_num_logprobs(self) -> int:
+        if self.num_logprobs:
+            return max(self.num_logprobs.values())
+        else:
+            return 0
+
+    @property
+    def no_logprob(self) -> bool:
+        return len(self.num_logprobs) == 0
+
+    @property
+    def no_prompt_logprob(self) -> bool:
+        return len(self.prompt_logprob_reqs) == 0
diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py
new file mode 100644
index 000000000000..8c5ca2ec3566
--- /dev/null
+++ b/vllm/v1/worker/gpu_worker.py
@@ -0,0 +1,245 @@
+"""A GPU worker class."""
+import gc
+import os
+from typing import TYPE_CHECKING, Optional, Tuple
+
+import torch
+import torch.distributed
+
+from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
+                         ModelConfig, ObservabilityConfig, ParallelConfig,
+                         PromptAdapterConfig, SchedulerConfig,
+                         SpeculativeConfig)
+from vllm.distributed import (ensure_model_parallel_initialized,
+                              init_distributed_environment,
+                              set_custom_all_reduce)
+from vllm.logger import init_logger
+from vllm.model_executor import set_random_seed
+from vllm.platforms import current_platform
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
+from vllm.v1.outputs import ModelRunnerOutput
+from vllm.v1.worker.gpu_model_runner import GPUModelRunner
+
+logger = init_logger(__name__)
+
+if TYPE_CHECKING:
+    from vllm.v1.core.scheduler import SchedulerOutput
+
+
+class Worker:
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        cache_config: CacheConfig,
+        load_config: LoadConfig,
+        local_rank: int,
+        rank: int,
+        distributed_init_method: str,
+        speculative_config: Optional[SpeculativeConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+        prompt_adapter_config: Optional[PromptAdapterConfig] = None,
+        observability_config: Optional[ObservabilityConfig] = None,
+    ):
+        self.model_config = model_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.cache_config = cache_config
+        self.load_config = load_config
+        self.local_rank = local_rank
+        self.rank = rank
+        self.distributed_init_method = distributed_init_method
+        self.lora_config = lora_config
+        self.speculative_config = speculative_config
+        self.prompt_adapter_config = prompt_adapter_config
+        self.observability_config = observability_config
+
+        if self.model_config.trust_remote_code:
+            # note: lazy import to avoid importing torch before initializing
+            from vllm.utils import init_cached_hf_modules
+            init_cached_hf_modules()
+
+        self.model_runner = GPUModelRunner(
+            model_config,
+            parallel_config,
+            scheduler_config,
+            device_config,
+            cache_config,
+            load_config,
+            lora_config=lora_config,
+        )
+
+    def initialize(self):
+        if self.device_config.device.type == "cuda":
+            # torch.distributed.all_reduce does not free the input tensor until
+            # the synchronization point. This causes the memory usage to grow
+            # as the number of all_reduce calls increases. This env var disables
+            # this behavior.
+            # Related issue:
+            # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
+            os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
+
+            # This env var set by Ray causes exceptions with graph building.
+            os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
+            self.device = torch.device(f"cuda:{self.local_rank}")
+            torch.cuda.set_device(self.device)
+
+            _check_if_gpu_supports_dtype(self.model_config.dtype)
+            gc.collect()
+            torch.cuda.empty_cache()
+            self.init_gpu_memory = torch.cuda.mem_get_info()[0]
+        else:
+            raise RuntimeError(
+                f"Not support device type: {self.device_config.device}")
+        # Initialize the distributed environment.
+        init_worker_distributed_environment(self.parallel_config, self.rank,
+                                            self.distributed_init_method,
+                                            self.local_rank)
+        # Set random seed.
+        set_random_seed(self.model_config.seed)
+
+    def load_model(self) -> None:
+        self.model_runner.load_model()
+
+    @torch.inference_mode()
+    def determine_num_available_blocks(self) -> Tuple[int, int]:
+        """Profiles the peak memory usage of the model to determine how many
+        KV blocks may be allocated without OOMs.
+
+        The engine will first conduct a profiling of the existing memory usage.
+        Then, it calculate the maximum possible number of GPU and CPU blocks
+        that can be allocated with the remaining free memory.
+
+        .. tip::
+            You may limit the usage of GPU memory
+            by adjusting the `gpu_memory_utilization` parameter.
+        """
+        # Profile the memory usage of the model and get the maximum number of
+        # cache blocks that can be allocated with the remaining free memory.
+        torch.cuda.empty_cache()
+
+        # Execute a forward pass with dummy inputs to profile the memory usage
+        # of the model.
+        self.model_runner.profile_run()
+
+        # Calculate the number of blocks that can be allocated with the
+        # profiled peak memory.
+        torch.cuda.synchronize()
+        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
+        # NOTE(woosuk): Here we assume that the other processes using the same
+        # GPU did not change their memory usage during the profiling.
+        peak_memory = self.init_gpu_memory - free_gpu_memory
+        assert peak_memory > 0, (
+            "Error in memory profiling. "
+            f"Initial free memory {self.init_gpu_memory}, current free memory"
+            f" {free_gpu_memory}. This happens when the GPU memory was "
+            "not properly cleaned up before initializing the vLLM instance.")
+
+        cache_block_size = _get_cache_block_size(self.cache_config,
+                                                 self.model_config,
+                                                 self.parallel_config)
+        num_gpu_blocks = int(
+            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
+             peak_memory) // cache_block_size)
+        num_gpu_blocks = max(num_gpu_blocks, 0)
+        # if self.model_runner.lora_manager:
+        #     self.model_runner.remove_all_loras()
+        gc.collect()
+        torch.cuda.empty_cache()
+        return num_gpu_blocks, 0
+
+    def initialize_cache(self, num_gpu_blocks: int) -> None:
+        """Allocate GPU and CPU KV cache with the specified number of blocks."""
+        if num_gpu_blocks <= 0:
+            raise ValueError("No available memory for the cache blocks. "
+                             "Try increasing `gpu_memory_utilization` when "
+                             "initializing the engine.")
+
+        max_seq_len = self.cache_config.block_size * num_gpu_blocks
+        max_model_len = self.model_config.max_model_len
+        if max_model_len > max_seq_len:
+            raise ValueError(
+                f"The model's max seq len ({max_model_len}) "
+                "is larger than the maximum number of tokens that can be "
+                f"stored in KV cache ({max_seq_len}). Try increasing "
+                "`gpu_memory_utilization` or decreasing `max_model_len` when "
+                "initializing the engine.")
+
+        self.model_runner.initialize_kv_cache(num_gpu_blocks)
+
+    def compile_or_warm_up_model(self) -> None:
+        if not self.model_config.enforce_eager:
+            self.model_runner.capture_model()
+        # Reset the seed to ensure that the random state is not affected by
+        # the model initialization and profiling.
+        set_random_seed(self.model_config.seed)
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        scheduler_output: "SchedulerOutput",
+    ) -> ModelRunnerOutput:
+        output = self.model_runner.execute_model(scheduler_output)
+        # TODO(woosuk): Send the output to the engine process.
+        return output
+
+
+def init_worker_distributed_environment(
+    parallel_config: ParallelConfig,
+    rank: int,
+    distributed_init_method: Optional[str] = None,
+    local_rank: int = -1,
+) -> None:
+    """Initialize the distributed environment."""
+    set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
+
+    init_distributed_environment(parallel_config.world_size, rank,
+                                 distributed_init_method, local_rank)
+
+    ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
+                                      parallel_config.pipeline_parallel_size)
+
+
+def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
+    # Check if the GPU supports the dtype.
+    if torch_dtype == torch.bfloat16:  # noqa: SIM102
+        if not current_platform.has_device_capability(80):
+            capability = current_platform.get_device_capability()
+            gpu_name = current_platform.get_device_name()
+
+            if capability is None:
+                compute_str = "does not have a compute capability"
+            else:
+                version_str = capability.as_version_str()
+                compute_str = f"has compute capability {version_str}"
+
+            raise ValueError(
+                "Bfloat16 is only supported on GPUs with compute capability "
+                f"of at least 8.0. Your {gpu_name} GPU {compute_str}. "
+                "You can use float16 instead by explicitly setting the"
+                "`dtype` flag in CLI, for example: --dtype=half.")
+
+
+def _get_cache_block_size(
+    cache_config: CacheConfig,
+    model_config: ModelConfig,
+    parallel_config: ParallelConfig,
+) -> int:
+    head_size = model_config.get_head_size()
+    num_heads = model_config.get_num_kv_heads(parallel_config)
+    num_attention_layers = model_config.get_num_attention_layers(
+        parallel_config)
+
+    key_cache_block = cache_config.block_size * num_heads * head_size
+    value_cache_block = key_cache_block
+    total = num_attention_layers * (key_cache_block + value_cache_block)
+    if cache_config.cache_dtype == "auto":
+        dtype = model_config.dtype
+    else:
+        dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
+    dtype_size = get_dtype_size(dtype)
+    return dtype_size * total
diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py
index 090f95e6e892..ac3270d1c990 100644
--- a/vllm/worker/cache_engine.py
+++ b/vllm/worker/cache_engine.py
@@ -53,7 +53,6 @@ def __init__(
 
         # Get attention backend.
         self.attn_backend = get_attn_backend(self.head_size,
-                                             model_config.get_sliding_window(),
                                              model_config.dtype,
                                              cache_config.cache_dtype,
                                              self.block_size,
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index 795511aea675..5032896600b3 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -19,6 +19,7 @@
                              MultiModalInputs)
 from vllm.sequence import (IntermediateTensors, SequenceData,
                            SequenceGroupMetadata)
+from vllm.transformers_utils.config import uses_mrope
 from vllm.utils import make_tensor_with_pad
 from vllm.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
@@ -419,7 +420,6 @@ def __init__(
         self.block_size = cache_config.block_size
         self.attn_backend = get_attn_backend(
             self.model_config.get_head_size(),
-            self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
             self.block_size,
@@ -439,10 +439,7 @@ def __init__(
     def model_is_mrope(self) -> bool:
         """Detect if the model has "mrope" rope_scaling type.
         mrope requires keep "rope_deltas" between prompt and decoding phases."""
-        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
-        if rope_scaling is None:
-            return False
-        return rope_scaling.get("type", None) == "mrope"
+        return uses_mrope(self.model_config.hf_config)
 
     def load_model(self) -> None:
         self.model = get_model(model_config=self.model_config,
diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py
index b84562851f0f..ab93471b5af7 100644
--- a/vllm/worker/cpu_worker.py
+++ b/vllm/worker/cpu_worker.py
@@ -57,7 +57,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
         # Get attention backend.
         self.attn_backend = get_attn_backend(
             self.model_config.get_head_size(),
-            self.model_config.get_sliding_window(),
             self.model_config.dtype,
             cache_config.cache_dtype,
             self.block_size,
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index f88b1d84fbcd..8b74f06e77be 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -47,6 +47,7 @@
     LRUCacheWorkerPromptAdapterManager)
 from vllm.sampling_params import SamplingParams
 from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
+from vllm.transformers_utils.config import uses_mrope
 from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d,
                         flatten_2d_lists, is_hip, is_pin_memory_available,
                         supports_dynamo)
@@ -573,17 +574,12 @@ def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
             # paged attn. We can remove it if we make paged attn kernel
             # to properly handle slinding window attn.
             curr_sliding_window_block = self.sliding_window_blocks
-            if self.scheduler_config.use_v2_block_manager:
-                # number of elements in last block
-                suff_len = inter_data.seq_lens[seq_idx] % self.block_size
-                sliding_seq_len = min(
-                    inter_data.seq_lens[seq_idx],
-                    self.block_aligned_sliding_window + suff_len)
-                if suff_len > 0:
-                    curr_sliding_window_block += 1
-            else:
-                sliding_seq_len = min(inter_data.seq_lens[seq_idx],
-                                      self.sliding_window)
+            # number of elements in last block
+            suff_len = inter_data.seq_lens[seq_idx] % self.block_size
+            sliding_seq_len = min(inter_data.seq_lens[seq_idx],
+                                  self.block_aligned_sliding_window + suff_len)
+            if suff_len > 0:
+                curr_sliding_window_block += 1
 
         inter_data.curr_sliding_window_blocks[
             seq_idx] = curr_sliding_window_block
@@ -832,7 +828,7 @@ def build(self) -> ModelInputForGPU:
 
         cuda_graph_pad_size = self._get_cuda_graph_pad_size(
             num_seqs=len(seq_lens),
-            max_decode_seq_len=max_encoder_seq_len,
+            max_decode_seq_len=max_decode_seq_len,
             max_encoder_seq_len=max_encoder_seq_len)
 
         batch_size = len(input_tokens)
@@ -1015,7 +1011,6 @@ def __init__(
 
         self.attn_backend = get_attn_backend(
             self.model_config.get_head_size(),
-            self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
             self.block_size,
@@ -1379,10 +1374,7 @@ def list_prompt_adapters(self) -> Set[int]:
     def model_is_mrope(self) -> bool:
         """Detect if the model has "mrope" rope_scaling type.
         mrope requires keep "rope_deltas" between prompt and decoding phases."""
-        rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {})
-        if rope_scaling is None:
-            return False
-        return rope_scaling.get("type", None) == "mrope"
+        return uses_mrope(self.model_config.hf_config)
 
     @torch.inference_mode()
     def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
@@ -1744,10 +1736,13 @@ def execute_model(
         return [output]
 
 
-class CUDAGraphRunner:
+# NOTE: this is nn.Module so the profiler can properly capture/group
+#  kernels calls made within the graph
+class CUDAGraphRunner(nn.Module):
 
     def __init__(self, model: nn.Module, backend_name: str,
                  attn_state: AttentionState, is_encoder_decoder_model: bool):
+        super().__init__()
         self.model = model
         self.backend_name = backend_name
         self.attn_state = attn_state
@@ -1860,7 +1855,7 @@ def forward(
         self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
         self.input_buffers["positions"].copy_(positions, non_blocking=True)
 
-        if self.backend_name != "placeholder-attn":
+        if self.backend_name != "NO_ATTENTION":
             self.input_buffers["slot_mapping"].copy_(
                 attn_metadata.slot_mapping, non_blocking=True)
 
@@ -1894,9 +1889,6 @@ def forward(
 
         return self.output_buffers
 
-    def __call__(self, *args, **kwargs):
-        return self.forward(*args, **kwargs)
-
 
 def _get_graph_batch_size(batch_size: int) -> int:
     """Returns the padded batch size given actual batch size.
diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py
index 0cd0047bebf2..be2f0d79154d 100644
--- a/vllm/worker/multi_step_model_runner.py
+++ b/vllm/worker/multi_step_model_runner.py
@@ -29,8 +29,8 @@
 
 logger = init_logger(__name__)
 
-MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]
-MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"]
+MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
+MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]
 
 def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
     -> List[str]:
diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py
index 760b18427e22..a164fbe3393c 100644
--- a/vllm/worker/openvino_model_runner.py
+++ b/vllm/worker/openvino_model_runner.py
@@ -75,7 +75,6 @@ def __init__(
 
         self.attn_backend = get_attn_backend(
             self.model_config.get_head_size(),
-            self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
             self.block_size,
diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py
index 24425fece850..bc245d19663d 100644
--- a/vllm/worker/openvino_worker.py
+++ b/vllm/worker/openvino_worker.py
@@ -71,7 +71,6 @@ def __init__(
         # Get attention backend.
         self.attn_backend = get_attn_backend(
             self.head_size,
-            self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.cache_config.cache_dtype,
             self.block_size,
diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py
index c13e95f60af5..87ced7818a67 100644
--- a/vllm/worker/tpu_model_runner.py
+++ b/vllm/worker/tpu_model_runner.py
@@ -114,7 +114,6 @@ def __init__(
             dtype=np.int32)
         self.attn_backend = get_attn_backend(
             self.model_config.get_head_size(),
-            self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.cache_config.cache_dtype,
             self.block_size,
@@ -123,6 +122,15 @@ def __init__(
         )
         self.cached_step_outputs: List[torch.Tensor] = []
 
+        smem_size = 512 * 1024
+        block_table_size = 4 * self.block_tables.size
+        if block_table_size >= smem_size:
+            logger.warning(
+                "The max_model_len (%d) is too large. This may degrade the "
+                "performance due to the insufficient smem size. Consider "
+                "setting --max-model-len to a smaller value.",
+                self.model_config.max_model_len)
+
     def load_model(self) -> None:
         self.device = self.device_config.device
 
diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py
index ab61e4377f90..fd30962e5d6b 100644
--- a/vllm/worker/worker.py
+++ b/vllm/worker/worker.py
@@ -92,7 +92,7 @@ def __init__(
         ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
         if model_runner_cls is not None:
             ModelRunnerClass = model_runner_cls
-        elif self._is_embedding_model():
+        elif model_config.task == "embedding":
             ModelRunnerClass = EmbeddingModelRunner
         elif self._is_encoder_decoder_model():
             ModelRunnerClass = EncoderDecoderModelRunner
@@ -147,9 +147,6 @@ def stop_profile(self):
     def _is_encoder_decoder_model(self):
         return self.model_config.is_encoder_decoder_model
 
-    def _is_embedding_model(self):
-        return self.model_config.is_embedding_model
-
     def init_device(self) -> None:
         if self.device_config.device.type == "cuda":
             # torch.distributed.all_reduce does not free the input tensor until
@@ -217,42 +214,79 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
         # Profile the memory usage of the model and get the maximum number of
         # cache blocks that can be allocated with the remaining free memory.
         torch.cuda.empty_cache()
+        torch.cuda.reset_peak_memory_stats()
+
+        free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
 
         # Execute a forward pass with dummy inputs to profile the memory usage
         # of the model.
         self.model_runner.profile_run()
+        torch.cuda.synchronize()
+
+        self._assert_memory_footprint_increased_during_profiling()
+
+        # Get the peak memory allocation recorded by torch
+        peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
+
+        # Check for any memory left around that may have been allocated on the
+        # gpu outside of `torch`. NCCL operations, for example, can use a few
+        # GB during a forward pass
+        torch.cuda.empty_cache()
+        torch_allocated_bytes = torch.cuda.memory_stats(
+        )["allocated_bytes.all.current"]
+        total_allocated_bytes = torch.cuda.mem_get_info(
+        )[1] - torch.cuda.mem_get_info()[0]
+        non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
+        if non_torch_allocations > 0:
+            peak_memory += non_torch_allocations
+
+        available_kv_cache_memory = (
+            total_gpu_memory * self.cache_config.gpu_memory_utilization -
+            peak_memory)
 
         # Calculate the number of blocks that can be allocated with the
         # profiled peak memory.
-        torch.cuda.synchronize()
-        free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
-        # NOTE(woosuk): Here we assume that the other processes using the same
-        # GPU did not change their memory usage during the profiling.
-        peak_memory = self.init_gpu_memory - free_gpu_memory
-        assert peak_memory > 0, (
-            "Error in memory profiling. "
-            f"Initial free memory {self.init_gpu_memory}, current free memory"
-            f" {free_gpu_memory}. This happens when the GPU memory was "
-            "not properly cleaned up before initializing the vLLM instance.")
-
         cache_block_size = self.get_cache_block_size_bytes()
         if cache_block_size == 0:
             num_gpu_blocks = 0
             num_cpu_blocks = 0
         else:
-            num_gpu_blocks = int(
-                (total_gpu_memory * self.cache_config.gpu_memory_utilization -
-                 peak_memory) // cache_block_size)
+            num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
             num_cpu_blocks = int(self.cache_config.swap_space_bytes //
                                  cache_block_size)
         num_gpu_blocks = max(num_gpu_blocks, 0)
         num_cpu_blocks = max(num_cpu_blocks, 0)
+
+        logger.info(
+            "Memory profiling results: total_gpu_memory=%.2fGiB"
+            " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
+            " memory_usage_post_profile=%.2fGib"
+            " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
+            " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
+            (total_gpu_memory - free_memory_pre_profile) / (1024**3),
+            (peak_memory - non_torch_allocations) / (1024**3),
+            total_allocated_bytes / (1024**3),
+            non_torch_allocations / (1024**3),
+            available_kv_cache_memory / (1024**3),
+            self.cache_config.gpu_memory_utilization)
+
+        # Final cleanup
         if self.model_runner.lora_manager:
             self.model_runner.remove_all_loras()
         gc.collect()
-        torch.cuda.empty_cache()
+
         return num_gpu_blocks, num_cpu_blocks
 
+    def _assert_memory_footprint_increased_during_profiling(self):
+        # NOTE(woosuk): Here we assume that the other processes using the same
+        # GPU did not change their memory usage during the profiling.
+        free_gpu_memory, _ = torch.cuda.mem_get_info()
+        assert self.init_gpu_memory - free_gpu_memory > 0, (
+            "Error in memory profiling. "
+            f"Initial free memory {self.init_gpu_memory}, current free memory"
+            f" {free_gpu_memory}. This happens when the GPU memory was "
+            "not properly cleaned up before initializing the vLLM instance.")
+
     def initialize_cache(self, num_gpu_blocks: int,
                          num_cpu_blocks: int) -> None:
         """Allocate GPU and CPU KV cache with the specified number of blocks.
diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py
index 5ff4626c060b..75a6de3b24ba 100644
--- a/vllm/worker/xpu_model_runner.py
+++ b/vllm/worker/xpu_model_runner.py
@@ -374,7 +374,6 @@ def __init__(
 
         self.attn_backend = get_attn_backend(
             self.model_config.get_head_size(),
-            self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
             self.block_size,
diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py
index 9ad070d042a3..917866f2d985 100644
--- a/vllm/worker/xpu_worker.py
+++ b/vllm/worker/xpu_worker.py
@@ -17,7 +17,7 @@
 from vllm.distributed.parallel_state import get_pp_group
 from vllm.logger import init_logger
 from vllm.model_executor import set_random_seed
-from vllm.utils import is_xpu
+from vllm.platforms import current_platform
 from vllm.worker.cache_engine import CacheEngine
 from vllm.worker.worker import Worker
 from vllm.worker.worker_base import LoraNotSupportedWorkerBase
@@ -53,7 +53,7 @@ def __init__(
         observability_config: Optional[ObservabilityConfig] = None,
     ) -> None:
         assert device_config.device_type == "xpu"
-        assert is_xpu()
+        assert current_platform.is_xpu()
 
         self.model_config = model_config
         self.parallel_config = parallel_config
@@ -91,7 +91,8 @@ def __init__(
         self.gpu_cache: Optional[List[List[torch.Tensor]]]
 
     def init_device(self) -> None:
-        if self.device_config.device.type == "xpu" and is_xpu():
+        if self.device_config.device.type == "xpu" and current_platform.is_xpu(
+        ):
             self.device = torch.device(f"xpu:{self.local_rank}")
             torch.xpu.set_device(self.device)
             torch.xpu.empty_cache()