diff --git a/container/Dockerfile.vllm b/container/Dockerfile.vllm index eb983377f1..25623ad1f2 100644 --- a/container/Dockerfile.vllm +++ b/container/Dockerfile.vllm @@ -69,7 +69,8 @@ RUN apt-get update -y && \ tmux \ vim \ autoconf \ - libtool + libtool \ + net-tools # These headers are missing with the hpcx installer, required # by UCX to find RDMA devices @@ -120,12 +121,21 @@ WORKDIR /workspace # Copy nixl source, and use commit hash as cache hint COPY --from=nixl_base /opt/nixl /opt/nixl COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt -RUN cd /opt/nixl && \ - mkdir build && \ - meson setup build/ --buildtype=release --prefix=/usr/local/nixl && \ - cd build/ && \ - ninja && \ - ninja install +RUN if [ "$ARCH" = "arm64" ]; then \ + cd /opt/nixl && \ + mkdir build && \ + meson setup build/ --buildtype=release --prefix=/usr/local/nixl -Dgds_path=/usr/local/cuda/targets/sbsa-linux && \ + cd build/ && \ + ninja && \ + ninja install; \ + else \ + cd /opt/nixl && \ + mkdir build && \ + meson setup build/ --buildtype=release --prefix=/usr/local/nixl && \ + cd build/ && \ + ninja && \ + ninja install; \ + fi ### NATS & ETCD SETUP ### # nats @@ -152,65 +162,37 @@ ENV VIRTUAL_ENV=/opt/dynamo/venv ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Install NIXL Python module -RUN cd /opt/nixl && uv build . --out-dir /workspace/wheels/nixl +# TODO: Move gds_path selection based on arch into NIXL build +RUN if [ "$ARCH" = "arm64" ]; then \ + cd /opt/nixl && uv build . --out-dir /workspace/wheels/nixl \ + --config-settings=setup-args="-Dgds_path=/usr/local/cuda/targets/sbsa-linux"; \ + else \ + cd /opt/nixl && uv build . --out-dir /workspace/wheels/nixl; \ + fi # Install the wheel # TODO: Move NIXL wheel install to the wheel_builder stage RUN uv pip install /workspace/wheels/nixl/*.whl -# Install patched vllm - keep this early in Dockerfile to avoid +# Install vllm - keep this early in Dockerfile to avoid # rebuilds from unrelated source code changes -ARG VLLM_REF="0.8.4" -ARG VLLM_PATCH="vllm_v${VLLM_REF}-dynamo-kv-disagg-patch.patch" -ARG VLLM_PATCHED_PACKAGE_NAME="ai_dynamo_vllm" -ARG VLLM_PATCHED_PACKAGE_VERSION="0.8.4.post4" -ARG VLLM_MAX_JOBS=4 +ARG VLLM_REF="059d4cd" +ENV CUDA_HOME=/usr/local/cuda RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \ --mount=type=cache,target=/root/.cache/uv \ - mkdir /tmp/vllm && \ - uv pip install pip wheel && \ - # NOTE: vLLM build from source on ARM can take several hours, see VLLM_MAX_JOBS details. - if [ "$ARCH" = "arm64" ]; then \ - # PyTorch 2.7 supports CUDA 12.8 and aarch64 installs - # NIXL has a torch dependency, so need to force-reinstall to install the correct version - uv pip install torch==2.7.0 torchvision torchaudio --force-reinstall --index-url https://download.pytorch.org/whl/cu128 && \ - # Download vLLM source with version matching patch - git clone --branch v${VLLM_REF} --depth 1 https://github.com/vllm-project/vllm.git /tmp/vllm/vllm-${VLLM_REF} && \ - cd /tmp/vllm/vllm-${VLLM_REF}/ && \ - # Patch vLLM source with dynamo additions - patch -p1 < /tmp/deps/vllm/${VLLM_PATCH} && \ - # WAR: Set package version check to 'vllm' instead of 'ai_dynamo_vllm' to avoid - # platform detection issues on ARM install. - # TODO: Rename package from vllm to ai_dynamo_vllm like x86 path below to remove this WAR. - sed -i 's/version("ai_dynamo_vllm")/version("vllm")/g' vllm/platforms/__init__.py && \ - # Remove pytorch from vllm install dependencies - python use_existing_torch.py && \ - # Build/install vllm from source - uv pip install -r requirements/build.txt && \ - # MAX_JOBS set to avoid running OOM on vllm-flash-attn build, this can - # significantly impact the overall build time. Each job can take up - # to -16GB RAM each, so tune according to available system memory. - MAX_JOBS=${VLLM_MAX_JOBS} uv pip install -vv . --no-build-isolation ; \ - # Handle x86_64: Download wheel, unpack, setup for later steps - else \ - python -m pip download --only-binary=:all: --no-deps --dest /tmp/vllm vllm==v${VLLM_REF} && \ - # Patch vLLM pre-built download with dynamo additions - cd /tmp/vllm && \ - wheel unpack *.whl && \ - cd vllm-${VLLM_REF}/ && \ - patch -p1 < /tmp/deps/vllm/${VLLM_PATCH} && \ - # Rename the package from vllm to ai_dynamo_vllm - mv vllm-${VLLM_REF}.dist-info ${VLLM_PATCHED_PACKAGE_NAME}-${VLLM_PATCHED_PACKAGE_VERSION}.dist-info && \ - sed -i "s/^Name: vllm/Name: ${VLLM_PATCHED_PACKAGE_NAME}/g" ${VLLM_PATCHED_PACKAGE_NAME}-${VLLM_PATCHED_PACKAGE_VERSION}.dist-info/METADATA && \ - sed -i "s/^Version: ${VLLM_REF}/Version: ${VLLM_PATCHED_PACKAGE_VERSION}/g" ${VLLM_PATCHED_PACKAGE_NAME}-${VLLM_PATCHED_PACKAGE_VERSION}.dist-info/METADATA && \ - # Update wheel tag from linux_${ARCH_ALT} to manylinux1_${ARCH_ALT} in WHEEL file - sed -i "s/Tag: cp38-abi3-linux_${ARCH_ALT}/Tag: cp38-abi3-manylinux1_${ARCH_ALT}/g" ${VLLM_PATCHED_PACKAGE_NAME}-${VLLM_PATCHED_PACKAGE_VERSION}.dist-info/WHEEL && \ - # Also update the tag in RECORD file to match - sed -i "s/-cp38-abi3-linux_${ARCH_ALT}.whl/-cp38-abi3-manylinux1_${ARCH_ALT}.whl/g" ${VLLM_PATCHED_PACKAGE_NAME}-${VLLM_PATCHED_PACKAGE_VERSION}.dist-info/RECORD && \ - mkdir -p /workspace/dist && \ - wheel pack . --dest-dir /workspace/dist && \ - uv pip install /workspace/dist/${VLLM_PATCHED_PACKAGE_NAME}-*.whl ; \ - fi + uv pip install pip cuda-python && \ + mkdir /opt/vllm && \ + cd /opt/vllm && \ + git clone https://github.com/vllm-project/vllm.git && \ + cd vllm && \ + git checkout $VLLM_REF && \ + VLLM_USE_PRECOMPILED=1 uv pip install -e . && \ + cd tools/ep_kernels && \ + bash install_python_libraries.sh && \ + cd ep_kernels_workspace && \ + git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && \ + cd DeepGEMM && \ + python setup.py install # Common dependencies RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \ @@ -324,8 +306,6 @@ RUN SNIPPET="export PROMPT_COMMAND='history -a' && export HISTFILE=$HOME/.comman RUN mkdir -p /home/$USERNAME/.cache/ -ENV VLLM_KV_CAPI_PATH=$HOME/dynamo/.build/target/debug/libdynamo_llm_capi.so - ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"] ################################## @@ -443,12 +423,7 @@ RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/la sed '/^#\s/d' /workspace/launch_message.txt > ~/.launch_screen && \ echo "cat ~/.launch_screen" >> ~/.bashrc -# Tell vllm to use the Dynamo LLM C API for KV Cache Routing -ENV VLLM_KV_CAPI_PATH=/opt/dynamo/bindings/lib/libdynamo_llm_capi.so - -ARG ARCH_ALT -ENV NIXL_PLUGIN_DIR=/usr/local/nixl/lib/${ARCH_ALT}-linux-gnu/plugins -ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/${ARCH_ALT}-linux-gnu:/usr/local/nixl/lib/${ARCH_ALT}-linux-gnu/plugins:/usr/local/ucx/lib:$LD_LIBRARY_PATH +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/nvidia/nvda_nixl/lib/x86_64-linux-gnu/ ######################################## ########## Development Image ########### @@ -519,16 +494,13 @@ COPY --from=base /workspace/wheels/nixl/*.whl wheelhouse/ COPY --from=wheel_builder /workspace/dist/*.whl wheelhouse/ RUN uv pip install ai-dynamo[vllm] --find-links wheelhouse && \ uv pip install nixl --find-links wheelhouse && \ - ln -sf $VIRTUAL_ENV/bin/* /usr/local/bin/ - -# Tell vllm to use the Dynamo LLM C API for KV Cache Routing -ENV VLLM_KV_CAPI_PATH="/opt/dynamo/bindings/lib/libdynamo_llm_capi.so" + ln -sf $VIRTUAL_ENV/bin/* /usr/local/bin/ && \ + rm -r wheelhouse # Copy launch banner RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/launch_message.txt \ sed '/^#\s/d' /workspace/launch_message.txt > ~/.launch_screen && \ echo "cat ~/.launch_screen" >> ~/.bashrc - ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"] CMD [] diff --git a/container/Dockerfile.vllm_v1 b/container/Dockerfile.vllm_v1 deleted file mode 100644 index 0ad6785813..0000000000 --- a/container/Dockerfile.vllm_v1 +++ /dev/null @@ -1,499 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -ARG BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" -# FIXME: NCCL will hang with 25.03, so use 25.01 for now -# Please check https://github.com/ai-dynamo/dynamo/pull/1065 -# for details and reproducer to manually test if the image -# can be updated to later versions. -ARG BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" -ARG RELEASE_BUILD -ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda" -ARG RUNTIME_IMAGE_TAG="12.8.1-runtime-ubuntu24.04" - -# Define general architecture ARGs for supporting both x86 and aarch64 builds. -# ARCH: Used for package suffixes (e.g., amd64, arm64) -# ARCH_ALT: Used for Rust targets, manylinux suffix (e.g., x86_64, aarch64) -# -# Default values are for x86/amd64: -# --build-arg ARCH=amd64 --build-arg ARCH_ALT=x86_64 -# -# For arm64/aarch64, build with: -# --build-arg ARCH=arm64 --build-arg ARCH_ALT=aarch64 -# -# NOTE: There isn't an easy way to define one of these values based on the other value -# without adding if statements everywhere, so just define both as ARGs for now. -ARG ARCH=amd64 -ARG ARCH_ALT=x86_64 - -FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS nixl_base - -# Redeclare ARCH and ARCH_ALT so they're available in this stage -ARG ARCH -ARG ARCH_ALT - -WORKDIR /opt/nixl -# Add a cache hint that only changes when the nixl commit changes -ARG NIXL_COMMIT -# This line acts as a cache key - it only changes when NIXL_COMMIT changes -RUN echo "NIXL commit: ${NIXL_COMMIT}" > /opt/nixl/commit.txt -# Copy the nixl source -COPY --from=nixl . . - -################################## -########## Base Image ############ -################################## - -FROM ${BASE_IMAGE}:${BASE_IMAGE_TAG} AS base - -# Redeclare ARCH and ARCH_ALT so they're available in this stage -ARG ARCH -ARG ARCH_ALT - -USER root -ARG PYTHON_VERSION=3.12 - -RUN apt-get update -y && \ - apt-get install -y \ - # NIXL build dependencies - cmake \ - meson \ - ninja-build \ - pybind11-dev \ - # Rust build dependencies - clang \ - libclang-dev \ - git \ - # Install utilities - nvtop \ - tmux \ - vim \ - autoconf \ - libtool \ - net-tools - -# These headers are missing with the hpcx installer, required -# by UCX to find RDMA devices -RUN apt-get update -y && \ - apt-get install -y --no-install-recommends \ - --reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev \ - libnuma-dev librdmacm-dev ibverbs-providers - -ARG NIXL_UCX_REF=v1.19.x - -WORKDIR /workspace - -### UCX EFA Setup ### -RUN rm -rf /opt/hpcx/ucx -RUN rm -rf /usr/local/ucx -RUN echo "Building UCX with reference $NIXL_UCX_REF" -RUN cd /usr/local/src && \ - git clone https://github.com/openucx/ucx.git && \ - cd ucx && \ - git checkout $NIXL_UCX_REF && \ - ./autogen.sh && ./configure \ - --prefix=/usr/local/ucx \ - --enable-shared \ - --disable-static \ - --disable-doxygen-doc \ - --enable-optimizations \ - --enable-cma \ - --enable-devel-headers \ - --with-cuda=/usr/local/cuda \ - --with-verbs \ - --with-efa \ - --with-dm \ - --with-gdrcopy=/usr/local \ - --enable-mt && \ - make -j && \ - make -j install-strip && \ - ldconfig - -ENV LD_LIBRARY_PATH=/usr/lib:/usr/local/ucx/lib:$LD_LIBRARY_PATH -ENV CPATH=/usr/include -ENV PATH=/usr/bin:$PATH -ENV PKG_CONFIG_PATH=/usr/lib/pkgconfig -SHELL ["/bin/bash", "-c"] - -WORKDIR /workspace - -### NIXL SETUP ### -# Copy nixl source, and use commit hash as cache hint -COPY --from=nixl_base /opt/nixl /opt/nixl -COPY --from=nixl_base /opt/nixl/commit.txt /opt/nixl/commit.txt -RUN if [ "$ARCH" = "arm64" ]; then \ - cd /opt/nixl && \ - mkdir build && \ - meson setup build/ --buildtype=release --prefix=/usr/local/nixl -Dgds_path=/usr/local/cuda/targets/sbsa-linux && \ - cd build/ && \ - ninja && \ - ninja install; \ - else \ - cd /opt/nixl && \ - mkdir build && \ - meson setup build/ --buildtype=release --prefix=/usr/local/nixl && \ - cd build/ && \ - ninja && \ - ninja install; \ - fi - -### NATS & ETCD SETUP ### -# nats -RUN wget --tries=3 --waitretry=5 https://github.com/nats-io/nats-server/releases/download/v2.10.28/nats-server-v2.10.28-${ARCH}.deb && \ - dpkg -i nats-server-v2.10.28-${ARCH}.deb && rm nats-server-v2.10.28-${ARCH}.deb -# etcd -ENV ETCD_VERSION="v3.5.21" -RUN wget --tries=3 --waitretry=5 https://github.com/etcd-io/etcd/releases/download/$ETCD_VERSION/etcd-$ETCD_VERSION-linux-${ARCH}.tar.gz -O /tmp/etcd.tar.gz && \ - mkdir -p /usr/local/bin/etcd && \ - tar -xvf /tmp/etcd.tar.gz -C /usr/local/bin/etcd --strip-components=1 && \ - rm /tmp/etcd.tar.gz -ENV PATH=/usr/local/bin/etcd/:$PATH - - -### VIRTUAL ENVIRONMENT SETUP ### - -# Install uv and create virtualenv -COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ -RUN mkdir /opt/dynamo && \ - uv venv /opt/dynamo/venv --python 3.12 - -# Activate virtual environment -ENV VIRTUAL_ENV=/opt/dynamo/venv -ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" - -# Install NIXL Python module -# TODO: Move gds_path selection based on arch into NIXL build -RUN if [ "$ARCH" = "arm64" ]; then \ - cd /opt/nixl && uv build . --out-dir /workspace/wheels/nixl \ - --config-settings=setup-args="-Dgds_path=/usr/local/cuda/targets/sbsa-linux"; \ - else \ - cd /opt/nixl && uv build . --out-dir /workspace/wheels/nixl; \ - fi - -# Install the wheel -# TODO: Move NIXL wheel install to the wheel_builder stage -RUN uv pip install /workspace/wheels/nixl/*.whl - -# Install vllm - keep this early in Dockerfile to avoid -# rebuilds from unrelated source code changes -ARG VLLM_REF="059d4cd" -ENV CUDA_HOME=/usr/local/cuda -RUN --mount=type=bind,source=./container/deps/,target=/tmp/deps \ - --mount=type=cache,target=/root/.cache/uv \ - uv pip install pip cuda-python && \ - mkdir /opt/vllm && \ - cd /opt/vllm && \ - git clone https://github.com/vllm-project/vllm.git && \ - cd vllm && \ - git checkout $VLLM_REF && \ - VLLM_USE_PRECOMPILED=1 uv pip install -e . && \ - cd tools/ep_kernels && \ - bash install_python_libraries.sh && \ - cd ep_kernels_workspace && \ - git clone --recursive https://github.com/deepseek-ai/DeepGEMM.git && \ - cd DeepGEMM && \ - python setup.py install - -# Common dependencies -RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \ - uv pip install --requirement /tmp/requirements.txt - -# Install test dependencies -RUN --mount=type=bind,source=./container/deps/requirements.test.txt,target=/tmp/requirements.txt \ - uv pip install --requirement /tmp/requirements.txt - -# ### MISC UTILITY SETUP ### - -# Finish pyright install -RUN pyright --help > /dev/null 2>&1 - -# Enable Git operations in the /workspace directory -RUN printf "[safe]\n directory=/workspace\n" > /root/.gitconfig - -# Install prometheus -ARG PROM_VERSION=3.4.1 -RUN apt-get update && apt-get install -y --no-install-recommends \ - curl tar ca-certificates && \ - rm -rf /var/lib/apt/lists/* -RUN ARCH=$(dpkg --print-architecture) && \ - case "$ARCH" in \ - amd64) PLATFORM=linux-amd64 ;; \ - arm64) PLATFORM=linux-arm64 ;; \ - *) echo "Unsupported architecture: $ARCH" && exit 1 ;; \ - esac && \ - curl -fsSL https://github.com/prometheus/prometheus/releases/download/v${PROM_VERSION}/prometheus-${PROM_VERSION}.${PLATFORM}.tar.gz \ - | tar -xz -C /tmp && \ - mv /tmp/prometheus-${PROM_VERSION}.${PLATFORM}/prometheus /usr/local/bin/ && \ - chmod +x /usr/local/bin/prometheus && \ - rm -rf /tmp/prometheus-${PROM_VERSION}.${PLATFORM} - -### BUILDS ### - -# Rust build/dev dependencies -RUN apt update -y && \ - apt install --no-install-recommends -y \ - build-essential \ - protobuf-compiler \ - cmake \ - libssl-dev \ - pkg-config - -ENV RUSTUP_HOME=/usr/local/rustup \ - CARGO_HOME=/usr/local/cargo \ - PATH=/usr/local/cargo/bin:$PATH \ - RUST_VERSION=1.87.0 - -# Define Rust target based on ARCH_ALT ARG -ARG RUSTARCH=${ARCH_ALT}-unknown-linux-gnu - -# Install Rust using RUSTARCH derived from ARCH_ALT -RUN wget --tries=3 --waitretry=5 "https://static.rust-lang.org/rustup/archive/1.28.1/${RUSTARCH}/rustup-init" && \ - # TODO: Add SHA check back based on RUSTARCH - chmod +x rustup-init && \ - ./rustup-init -y --no-modify-path --profile default --default-toolchain $RUST_VERSION --default-host ${RUSTARCH} && \ - rm rustup-init && \ - chmod -R a+w $RUSTUP_HOME $CARGO_HOME - -ARG CARGO_BUILD_JOBS -# Set CARGO_BUILD_JOBS to 16 if not provided -# This is to prevent cargo from building $(nproc) jobs in parallel, -# which might exceed the number of opened files limit. -ENV CARGO_BUILD_JOBS=${CARGO_BUILD_JOBS:-16} - -####################################### -########## Local Development ########## -####################################### - -FROM base AS local-dev - -# https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user -# Will use the default ubuntu user, but give sudo access -# Needed so files permissions aren't set to root ownership when writing from inside container - -# Don't want ubuntu to be editable, just change uid and gid. User ubuntu is hardcoded in .devcontainer -ENV USERNAME=ubuntu -ARG USER_UID=1000 -ARG USER_GID=1000 - -RUN apt-get update && apt-get install -y sudo gnupg2 gnupg1 \ - && echo "$USERNAME ALL=(root) NOPASSWD:ALL" > /etc/sudoers.d/$USERNAME \ - && chmod 0440 /etc/sudoers.d/$USERNAME \ - && mkdir -p /home/$USERNAME \ - && chown -R $USERNAME:$USERNAME /home/$USERNAME \ - && rm -rf /var/lib/apt/lists/* \ - && chsh -s /bin/bash $USERNAME - -# This is a slow operation (~40s on my cpu) -# Much better than chown -R $USERNAME:$USERNAME /opt/dynamo/venv (~10min on my cpu) -COPY --from=base --chown=$USER_UID:$USER_GID /opt/dynamo/venv/ /opt/dynamo/venv/ -RUN chown $USERNAME:$USERNAME /opt/dynamo/venv -COPY --from=base --chown=$USERNAME:$USERNAME /usr/local/bin /usr/local/bin - -# so we can use maturin develop -RUN uv pip install maturin[patchelf] - -USER $USERNAME -ENV HOME=/home/$USERNAME -ENV PYTHONPATH=$HOME/dynamo/deploy/sdk/src:$PYTHONPATH:$HOME/dynamo/components/planner/src:$PYTHONPATH -ENV CARGO_TARGET_DIR=$HOME/dynamo/.build/target -WORKDIR $HOME - -# https://code.visualstudio.com/remote/advancedcontainers/persist-bash-history -RUN SNIPPET="export PROMPT_COMMAND='history -a' && export HISTFILE=$HOME/.commandhistory/.bash_history" \ - && mkdir -p $HOME/.commandhistory \ - && touch $HOME/.commandhistory/.bash_history \ - && echo "$SNIPPET" >> "$HOME/.bashrc" - -RUN mkdir -p /home/$USERNAME/.cache/ - -ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"] - -################################## -##### Wheel Build Image ########## -################################## - -# Redeclare ARCH_ALT ARG so it's available for interpolation in the FROM instruction -ARG ARCH_ALT - -FROM quay.io/pypa/manylinux_2_28_${ARCH_ALT} AS wheel_builder - -ARG CARGO_BUILD_JOBS -# Set CARGO_BUILD_JOBS to 16 if not provided -# This is to prevent cargo from building $(nproc) jobs in parallel, -# which might exceed the number of opened files limit. -ENV CARGO_BUILD_JOBS=${CARGO_BUILD_JOBS:-16} -# Use build arg RELEASE_BUILD = true to generate wheels for Python 3.10, 3.11 and 3.12. -ARG RELEASE_BUILD - -WORKDIR /workspace - -RUN yum update -y \ - && yum install -y llvm-toolset \ - && yum install -y python3.12-devel \ - && yum install -y protobuf-compiler \ - && yum clean all \ - && rm -rf /var/cache/yum - -ENV RUSTUP_HOME=/usr/local/rustup \ - CARGO_HOME=/usr/local/cargo \ - CARGO_TARGET_DIR=/workspace/target \ - VIRTUAL_ENV=/opt/dynamo/venv - -COPY --from=base $RUSTUP_HOME $RUSTUP_HOME -COPY --from=base $CARGO_HOME $CARGO_HOME -COPY --from=base /usr/local/nixl /opt/nvidia/nvda_nixl -COPY --from=base /workspace /workspace -COPY --from=base $VIRTUAL_ENV $VIRTUAL_ENV -ENV PATH=$CARGO_HOME/bin:$VIRTUAL_ENV/bin:$PATH - -# Copy configuration files -COPY pyproject.toml /workspace/ -COPY README.md /workspace/ -COPY LICENSE /workspace/ -COPY Cargo.toml /workspace/ -COPY Cargo.lock /workspace/ -COPY rust-toolchain.toml /workspace/ -COPY hatch_build.py /workspace/ - -# Copy source code -COPY lib/ /workspace/lib/ -COPY components /workspace/components -COPY launch /workspace/launch -COPY deploy/sdk /workspace/deploy/sdk - -RUN cargo build \ - --release \ - --locked \ - --features dynamo-llm/block-manager \ - --workspace - -# Build dynamo wheel -RUN uv build --wheel --out-dir /workspace/dist && \ - cd /workspace/lib/bindings/python && \ - uv pip install maturin[patchelf] && \ - maturin build --release --features block-manager --out /workspace/dist && \ - if [ "$RELEASE_BUILD" = "true" ]; then \ - # do not enable KVBM feature, ensure compatibility with lower glibc - uv run --python 3.11 maturin build --release --out /workspace/dist && \ - uv run --python 3.10 maturin build --release --out /workspace/dist; \ - fi - -####################################### -########## CI Minimum Image ########### -####################################### -FROM base AS ci_minimum - -ENV DYNAMO_HOME=/workspace -ENV CARGO_TARGET_DIR=/workspace/target - -WORKDIR /workspace - -COPY --from=wheel_builder /workspace /workspace -COPY --from=wheel_builder /opt/nvidia/nvda_nixl /opt/nvidia/nvda_nixl -# Copy Cargo cache to avoid re-downloading dependencies -COPY --from=wheel_builder $CARGO_HOME $CARGO_HOME - -# Copy rest of the code -COPY . /workspace - -# Build C bindings, creates lib/bindings/c/include -# -# TODO: In theory the 'cargo build' in earlier stage covers this, we "just" need to copy the -# `lib/bindings/c/include` folder that build.rs generated across. -# I couldn't get that to work, hence TODO. -RUN cd /workspace/lib/bindings/c && cargo build --release --locked - -# Package the bindings -RUN mkdir -p /opt/dynamo/bindings/wheels && \ - mkdir /opt/dynamo/bindings/lib && \ - cp dist/ai_dynamo*cp312*.whl /opt/dynamo/bindings/wheels/. && \ - cp target/release/libdynamo_llm_capi.so /opt/dynamo/bindings/lib/. && \ - cp -r lib/bindings/c/include /opt/dynamo/bindings/. && \ - cp target/release/dynamo-run /usr/local/bin && \ - cp target/release/metrics /usr/local/bin && \ - cp target/release/mock_worker /usr/local/bin - -RUN uv pip install /workspace/dist/ai_dynamo_runtime*cp312*.whl && \ - uv pip install /workspace/dist/ai_dynamo*any.whl - -RUN uv pip install /workspace/benchmarks - -# Copy launch banner -RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/launch_message.txt \ - sed '/^#\s/d' /workspace/launch_message.txt > ~/.launch_screen && \ - echo "cat ~/.launch_screen" >> ~/.bashrc - -ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/nvidia/nvda_nixl/lib/x86_64-linux-gnu/ - -######################################## -########## Development Image ########### -######################################## -FROM ci_minimum AS dev - -ENTRYPOINT ["/opt/nvidia/nvidia_entrypoint.sh"] - -CMD [] - -#################################### -########## Runtime Image ########### -#################################### - -FROM ${RUNTIME_IMAGE}:${RUNTIME_IMAGE_TAG} AS runtime - -WORKDIR /workspace -ENV DYNAMO_HOME=/workspace -ENV VIRTUAL_ENV=/opt/dynamo/venv -ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" - -# Install build-essential and python3-dev as apt dependencies -RUN apt-get update && \ - DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - build-essential \ - python3-dev && \ - rm -rf /var/lib/apt/lists/* - -### COPY BINDINGS ### -# Copy all bindings (wheels, lib, include) from ci_minimum -COPY --from=ci_minimum /opt/dynamo/bindings /opt/dynamo/bindings -### COPY NATS & ETCD ### -# Copy nats and etcd from base image -COPY --from=base /usr/bin/nats-server /usr/bin/nats-server -COPY --from=base /usr/local/bin/etcd/ /usr/local/bin/etcd/ - -# Copy UCX from base image as plugin for NIXL -# Copy NIXL source from base image (required for NIXL plugins) -COPY --from=base /usr/local/ucx /usr/local/ucx -COPY --from=base /usr/local/nixl /usr/local/nixl -ARG ARCH_ALT -ENV NIXL_PLUGIN_DIR=/usr/local/nixl/lib/${ARCH_ALT}-linux-gnu/plugins -ENV LD_LIBRARY_PATH=/usr/local/nixl/lib/${ARCH_ALT}-linux-gnu:/usr/local/nixl/lib/${ARCH_ALT}-linux-gnu/plugins:/usr/local/ucx/lib:$LD_LIBRARY_PATH - -# Setup the python environment -COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ -RUN uv venv $VIRTUAL_ENV --python 3.12 && \ - echo "source $VIRTUAL_ENV/bin/activate" >> ~/.bashrc - -# Common dependencies -RUN --mount=type=bind,source=./container/deps/requirements.txt,target=/tmp/requirements.txt \ - uv pip install --requirement /tmp/requirements.txt - -# Install the wheels and symlink executables to /usr/local/bin so dynamo components can use them -# Dynamo components currently do not have the VIRTUAL_ENV in their PATH, so we need to symlink the executables -#Copy NIXL and Dynamo wheels into wheelhouse -COPY --from=base /workspace/wheels/nixl/*.whl wheelhouse/ -COPY --from=wheel_builder /workspace/dist/*.whl wheelhouse/ -RUN uv pip install ai-dynamo --find-links wheelhouse && \ - uv pip install nixl --find-links wheelhouse && \ - ln -sf $VIRTUAL_ENV/bin/* /usr/local/bin/ && \ - rm -r wheelhouse - -# Copy launch banner -RUN --mount=type=bind,source=./container/launch_message.txt,target=/workspace/launch_message.txt \ - sed '/^#\s/d' /workspace/launch_message.txt > ~/.launch_screen && \ - echo "cat ~/.launch_screen" >> ~/.bashrc - -# Copy examples -COPY ./examples examples/ - -ENTRYPOINT [ "/usr/bin/bash" ] -CMD [] diff --git a/container/build.sh b/container/build.sh index eda7ac43bd..29578f4ad2 100755 --- a/container/build.sh +++ b/container/build.sh @@ -49,7 +49,7 @@ PYTHON_PACKAGE_VERSION=${current_tag:-$latest_tag.dev+$commit_id} # dependencies are specified in the /container/deps folder and # installed within framework specific sections of the Dockerfile. -declare -A FRAMEWORKS=(["VLLM"]=1 ["TENSORRTLLM"]=2 ["NONE"]=3 ["SGLANG"]=4 ["VLLM_V1"]=5) +declare -A FRAMEWORKS=(["VLLM"]=1 ["TENSORRTLLM"]=2 ["NONE"]=3 ["SGLANG"]=4) DEFAULT_FRAMEWORK=VLLM SOURCE_DIR=$(dirname "$(readlink -f "$0")") @@ -111,9 +111,6 @@ NONE_BASE_IMAGE_TAG="24.04" SGLANG_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" SGLANG_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" -VLLM_V1_BASE_IMAGE="nvcr.io/nvidia/cuda-dl-base" -VLLM_V1_BASE_IMAGE_TAG="25.01-cuda12.8-devel-ubuntu24.04" - NIXL_COMMIT=3c47a48955e6f96bd5d4fb43a9d80bb64722f8e4 NIXL_REPO=ai-dynamo/nixl.git @@ -403,8 +400,6 @@ elif [[ $FRAMEWORK == "NONE" ]]; then DOCKERFILE=${SOURCE_DIR}/Dockerfile.none elif [[ $FRAMEWORK == "SGLANG" ]]; then DOCKERFILE=${SOURCE_DIR}/Dockerfile.sglang -elif [[ $FRAMEWORK == "VLLM_V1" ]]; then - DOCKERFILE=${SOURCE_DIR}/Dockerfile.vllm_v1 fi NIXL_DIR="/tmp/nixl/nixl_src" diff --git a/container/deps/vllm/README.md b/container/deps/vllm/README.md deleted file mode 100644 index 0b417eca9a..0000000000 --- a/container/deps/vllm/README.md +++ /dev/null @@ -1,18 +0,0 @@ - - -Apply this patch to Python source code from vLLM release [v0.7.2](https://github.com/vllm-project/vllm/releases/tag/v0.7.2). \ No newline at end of file diff --git a/container/deps/vllm/prepare_patch.sh b/container/deps/vllm/prepare_patch.sh deleted file mode 100755 index 3ff5205aa6..0000000000 --- a/container/deps/vllm/prepare_patch.sh +++ /dev/null @@ -1,83 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -e - -# Function to print usage -print_usage() { - echo "Usage: $0 --original-ref --fork-repo --fork-ref --output " - echo - echo "Arguments:" - echo " --original-ref The tag or branch name from the original vllm-project/vllm repo" - echo " --fork-repo The URL of the forked repository" - echo " --fork-ref The tag or branch name from the forked repository" - echo " --output Path where the generated patch file should be saved" - echo - echo "Example:" - echo " $0 --original-ref v0.2.0 --fork-repo https://github.com/user/vllm.git --fork-ref feature-branch --output ./my-patch.diff" - exit 1 -} - -# Parse named arguments -while [[ $# -gt 0 ]]; do - case $1 in - --original-ref) - ORIGINAL_REF="$2" - shift 2 - ;; - --fork-repo) - FORK_REPO="$2" - shift 2 - ;; - --fork-ref) - FORK_REF="$2" - shift 2 - ;; - --output) - PATCH_OUTPUT="$2" - shift 2 - ;; - *) - print_usage - ;; - esac -done - -# Check if all required arguments are provided -if [ -z "$ORIGINAL_REF" ] || [ -z "$FORK_REPO" ] || [ -z "$FORK_REF" ] || [ -z "$PATCH_OUTPUT" ]; then - print_usage -fi - -# Convert patch output path to absolute path if it's relative -if [[ ! "$PATCH_OUTPUT" = /* ]]; then - PATCH_OUTPUT="$(pwd)/${PATCH_OUTPUT}" -fi - -TEMP_DIR=$(mktemp -d) - -# Clean up temp directory on script exit -trap 'rm -rf "$TEMP_DIR"' EXIT - -# Clone original vLLM to a temp directory -git clone https://github.com/vllm-project/vllm.git "$TEMP_DIR/original_vllm" - -cd "$TEMP_DIR/original_vllm" - -git remote add fork "$FORK_REPO" -git fetch fork "$FORK_REF" -git diff "$ORIGINAL_REF" fork/"$FORK_REF" > "$PATCH_OUTPUT" - -echo "Patch created successfully: $PATCH_OUTPUT" \ No newline at end of file diff --git a/container/deps/vllm/tests/test_patch_install.py b/container/deps/vllm/tests/test_patch_install.py deleted file mode 100644 index 045c80e4fd..0000000000 --- a/container/deps/vllm/tests/test_patch_install.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest - -try: - import vllm -except ImportError: - vllm = None # type: ignore - -pytestmark = pytest.mark.pre_merge - - -# TODO: Consider `pytest.mark.vllm` and running tests based on environment -@pytest.mark.skipif(vllm is None, reason="Skipping vllm tests, vllm not installed") -def test_version(): - # Verify that the image has the patched version of vllm - assert vllm.__version__.endswith("0.8.4") # type: ignore diff --git a/container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch b/container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch deleted file mode 100644 index 559ef28ccd..0000000000 --- a/container/deps/vllm/vllm_v0.7.2-dynamo-kv-disagg-patch.patch +++ /dev/null @@ -1,4778 +0,0 @@ -diff --git a/vllm/config.py b/vllm/config.py -index 9ba497576..db2dc002f 100644 ---- a/vllm/config.py -+++ b/vllm/config.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import ast - import copy -@@ -2620,6 +2633,9 @@ class KVTransferConfig(BaseModel): - # The KV connector for vLLM to transmit KV caches between vLLM instances. - kv_connector: Optional[str] = None - -+ # Whether to use NIXL prepped xfer for KV cache transfer. -+ use_prepped_xfer: bool = True -+ - # The device used by kv connector to buffer the KV cache. - # Currently only support 'cuda'. - kv_buffer_device: Optional[str] = "cuda" -@@ -2629,7 +2645,7 @@ class KVTransferConfig(BaseModel): - kv_buffer_size: float = 1e9 - - # Whether this vLLM instance produces, consumes KV cache, or both. Choices -- # are 'kv_producer', 'kv_consumer', and 'both'. -+ # are 'kv_producer', 'kv_consumer', and 'kv_both'. - kv_role: Optional[str] = None - - # The rank of this vLLM instance in the KV cache transfer. Typical value: -@@ -2647,6 +2663,14 @@ class KVTransferConfig(BaseModel): - # The KV connector port, used to build distributed connection - kv_port: int = 14579 - -+ -+ # This does not need to be set by the user. It is set by the connector. -+ kv_producers_parallel_size: Optional[int] = None -+ kv_producers_tensor_parallel_size: Optional[int] = None -+ kv_producers_pipeline_parallel_size: Optional[int] = None -+ kv_consumers_tensor_parallel_size: Optional[int] = None -+ kv_consumers_pipeline_parallel_size: Optional[int] = None -+ - def compute_hash(self) -> str: - """ - WARNING: Whenever a new field is added to this config, -@@ -2680,11 +2704,16 @@ class KVTransferConfig(BaseModel): - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") - -- if self.kv_connector is not None and self.kv_role is None: -+ if self.kv_connector is not None and self.kv_connector != "DynamoNixlConnector" and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - "is set, supported roles are `kv_producer`, " - "`kv_consumer`, and `kv_both`") - -+ if self.use_prepped_xfer is False: -+ logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.") -+ self.use_prepped_xfer = True -+ -+ - @property - def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and \ -@@ -2694,6 +2723,8 @@ class KVTransferConfig(BaseModel): - def need_kv_parallel_group(self) -> bool: - # for those database-based connector, vLLM does not need to create - # parallel group, and in that case the kv parallel size will be 1. -+ if self.kv_connector == "DynamoNixlConnector": -+ return False - return self.kv_connector is not None and self.kv_parallel_size > 1 - - @property -@@ -2706,6 +2737,18 @@ class KVTransferConfig(BaseModel): - return self.kv_connector is not None and \ - self.kv_role in ["kv_consumer", "kv_both"] - -+ @property -+ def tensor_parallel_multiplier(self) -> int: -+ return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size -+ -+ @property -+ def kv_consumers_parallel_size(self) -> int: -+ return self.kv_parallel_size - self.kv_producers_parallel_size -+ -+ @property -+ def kv_world_size(self) -> int: -+ return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier -+ - - class CompilationLevel: - # constants for the levels of the compilation process -diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py -index 359b5b263..7bac45ff0 100644 ---- a/vllm/core/block/cpu_gpu_block_allocator.py -+++ b/vllm/core/block/cpu_gpu_block_allocator.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - from typing import Dict, FrozenSet, List, Optional, Tuple - -@@ -6,6 +19,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, - DeviceAwareBlockAllocator) - from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator - from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator -+from vllm.core.event_manager import KVCacheEventManager - from vllm.platforms import current_platform - from vllm.utils import Device - -@@ -28,6 +42,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - num_gpu_blocks: int, - num_cpu_blocks: int, - block_size: int, -+ event_manager: Optional[KVCacheEventManager] = None, - ) -> DeviceAwareBlockAllocator: - """Creates a CpuGpuBlockAllocator instance with the specified - configuration. -@@ -64,6 +79,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - cpu_block_ids = block_ids[num_gpu_blocks:] - - if allocator_type == "naive": -+ assert event_manager is None, "Event API not supported with naive allocator." - gpu_allocator: BlockAllocator = NaiveBlockAllocator( - create_block=NaiveBlock, # type: ignore - num_blocks=num_gpu_blocks, -@@ -82,12 +98,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - num_blocks=num_gpu_blocks, - block_size=block_size, - block_ids=gpu_block_ids, -+ event_manager=event_manager, - ) - - cpu_allocator = PrefixCachingBlockAllocator( - num_blocks=num_cpu_blocks, - block_size=block_size, - block_ids=cpu_block_ids, -+ event_manager=event_manager, - ) - else: - raise ValueError(f"Unknown allocator type {allocator_type=}") -@@ -95,10 +113,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - return CpuGpuBlockAllocator( - cpu_block_allocator=cpu_allocator, - gpu_block_allocator=gpu_allocator, -+ event_manager=event_manager, - ) - - def __init__(self, cpu_block_allocator: BlockAllocator, -- gpu_block_allocator: BlockAllocator): -+ gpu_block_allocator: BlockAllocator, -+ event_manager: Optional[KVCacheEventManager] = None,): - assert not ( - cpu_block_allocator.all_block_ids - & gpu_block_allocator.all_block_ids -@@ -108,6 +128,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - Device.CPU: cpu_block_allocator, - Device.GPU: gpu_block_allocator, - } -+ self.event_manager = event_manager - - self._swap_mapping: Dict[int, int] = {} - self._null_block: Optional[Block] = None -diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py -index c388366b8..3c223b519 100644 ---- a/vllm/core/block/naive_block.py -+++ b/vllm/core/block/naive_block.py -@@ -1,8 +1,21 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - from collections import deque - from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union -- -+import heapq - from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, - get_all_blocks_recursively) - from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -@@ -38,7 +51,7 @@ class NaiveBlockAllocator(BlockAllocator): - if block_ids is None: - block_ids = range(num_blocks) - -- self._free_block_indices: Deque[BlockId] = deque(block_ids) -+ self._free_block_indices: List[BlockId] = list(block_ids) - self._all_block_indices = frozenset(block_ids) - assert len(self._all_block_indices) == num_blocks - -@@ -134,7 +147,8 @@ class NaiveBlockAllocator(BlockAllocator): - if not self._free_block_indices: - raise BlockAllocator.NoFreeBlocksError() - -- block_id = self._free_block_indices.popleft() -+ block_id = heapq.heappop(self._free_block_indices) -+ # TODO: figure out why sometime block_id is None - self._refcounter.incr(block_id) - return block_id - -@@ -148,7 +162,7 @@ class NaiveBlockAllocator(BlockAllocator): - - refcount = self._refcounter.decr(block_id) - if refcount == 0: -- self._free_block_indices.appendleft(block_id) -+ heapq.heappush(self._free_block_indices, block_id) - - def free(self, block: Block, keep_block_object: bool = False) -> None: - # Release the physical block id -diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py -index 1ca9e49da..26fabb243 100644 ---- a/vllm/core/block/prefix_caching_block.py -+++ b/vllm/core/block/prefix_caching_block.py -@@ -1,10 +1,23 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """Token blocks.""" - import sys - from bisect import bisect_left - from os.path import commonprefix - from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, -- Tuple) -+ Tuple, TYPE_CHECKING) - - from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, - get_all_blocks_recursively) -@@ -23,6 +36,9 @@ PrefixHash = int - # then we know this block hasn't been accessed yet. - _DEFAULT_LAST_ACCESSED_TIME = -1 - -+if TYPE_CHECKING: -+ from vllm.core.event_manager import KVCacheEventManager -+ - logger = init_logger(__name__) - - -@@ -80,6 +96,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): - block_size: int, - block_ids: Optional[Iterable[int]] = None, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU, -+ event_manager: Optional["KVCacheEventManager"] = None, - ): - if block_ids is None: - block_ids = range(num_blocks) -@@ -131,6 +148,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): - - self.metric_data = CacheMetricData() - -+ self.event_manager = event_manager -+ -+ # Implements Block.Factory. - def _create_block( - self, - prev_block: Optional[Block], -@@ -337,6 +357,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): - assert self._refcounter.get(_block_id) == 0 - assert _block_id == block_id - -+ if self.event_manager: -+ self.event_manager.enqueue_removed_event(content_hash_to_evict) -+ - self._cached_blocks.pop(content_hash_to_evict) - - self._refcounter.incr(block_id) -@@ -513,6 +536,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): - # Mark this block as touched so that it can be marked as - # computed after the entire batch of sequences are scheduled. - self._touched_blocks.add(block.block_id) -+ -+ if self.event_manager: -+ self.event_manager.enqueue_stored_event(block.prev_block, block) -+ - return block.block_id - - # Reuse the cached content hash -@@ -579,9 +606,11 @@ class PrefixCachingBlockAllocator(BlockAllocator): - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - # Mark all touched blocks as computed. -- for block_id in self._touched_blocks: -- self._block_tracker[block_id].computed = True -- self._touched_blocks.clear() -+ for block_id in block_ids: -+ if block_id in self._touched_blocks: -+ logger.debug("Mark block as computed: %s", block_id) -+ self._block_tracker[block_id].computed = True -+ self._touched_blocks.remove(block_id) - - def _track_block_id(self, block_id: Optional[BlockId], - computed: bool) -> None: -diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py -index c5b3b04f3..d3a4b77f8 100644 ---- a/vllm/core/block_manager.py -+++ b/vllm/core/block_manager.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """A block manager that manages token blocks.""" - from typing import Dict, List, Optional - from typing import Sequence as GenericSequence -@@ -10,7 +23,10 @@ from vllm.core.block.interfaces import Block - from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - LastAccessBlocksTracker) - from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -+from vllm.core.event_manager import KVCacheEventManager - from vllm.core.interfaces import AllocStatus, BlockSpaceManager -+from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE, -+ VLLM_WORKER_ID) - from vllm.sequence import Sequence, SequenceGroup, SequenceStatus - from vllm.utils import Device - -@@ -60,6 +76,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - - def __init__( - self, -+ model_name: str, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, -@@ -91,11 +108,29 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - - self.watermark_blocks = int(watermark * num_gpu_blocks) - -+ kv_event_manager_params = [ -+ VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE, -+ VLLM_KV_COMPONENT -+ ] -+ set_kv_event_manager_params = len( -+ [param for param in kv_event_manager_params if param is not None]) -+ -+ if set_kv_event_manager_params == len(kv_event_manager_params): -+ self.event_manager = KVCacheEventManager( -+ namespace=VLLM_KV_NAMESPACE, -+ component=VLLM_KV_COMPONENT, -+ worker_id=VLLM_WORKER_ID, -+ lib_path=VLLM_KV_CAPI_PATH, -+ kv_block_size=block_size) -+ else: -+ self.event_manager = None -+ - self.block_allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching" if enable_caching else "naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, -+ event_manager=self.event_manager, - ) - - self.block_tables: Dict[SeqId, BlockTable] = {} -@@ -108,7 +143,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - - def can_allocate(self, - seq_group: SequenceGroup, -- num_lookahead_slots: int = 0) -> AllocStatus: -+ num_lookahead_slots: int = 0, -+ is_remote_decode: bool = False) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - -@@ -121,6 +157,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - num_lookahead_slots=num_lookahead_slots, - ) - -+ # if remote decode, we need to allocate twice as many blocks for staging -+ if is_remote_decode: -+ num_required_blocks *= 2 -+ - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None -diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py -new file mode 100644 -index 000000000..79eb8db67 ---- /dev/null -+++ b/vllm/core/event_manager.py -@@ -0,0 +1,121 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+import ctypes -+import logging -+import uuid -+from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64 -+from typing import Optional -+ -+from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash -+ -+logger = logging.getLogger(__name__) -+ -+ -+class DynamoResult: -+ OK = 0 -+ ERR = 1 -+ -+ -+class KVCacheEventManager: -+ -+ def __init__(self, namespace: str, component: str, worker_id: int, -+ lib_path: str, kv_block_size: int): -+ self.lib = None -+ -+ try: -+ self.lib = ctypes.CDLL(lib_path) -+ self.lib.dynamo_llm_init.argtypes = [ -+ c_char_p, -+ c_char_p, -+ c_int64, -+ c_uint32, -+ ] -+ self.lib.dynamo_llm_init.restype = c_uint32 -+ -+ result = self.lib.dynamo_llm_init( -+ namespace.encode(), component.encode(), worker_id, kv_block_size -+ ) -+ if result == DynamoResult.OK: -+ logger.info( -+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events" -+ ) -+ else: -+ logger.info("KVCacheEventManager initialization failed!") -+ -+ except Exception as e: -+ print(f"Failed to load {lib_path}") -+ raise e -+ -+ self.lib.dynamo_kv_event_publish_stored.argtypes = [ -+ ctypes.c_uint64, # event_id -+ ctypes.POINTER(ctypes.c_uint32), # token_ids -+ ctypes.POINTER(ctypes.c_size_t), # num_block_tokens -+ ctypes.POINTER(ctypes.c_uint64), # block_ids -+ ctypes.c_size_t, # num_blocks -+ ctypes.POINTER(ctypes.c_uint64), # parent_hash -+ ctypes.c_uint64, # lora_id -+ ] -+ self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t -+ -+ self.lib.dynamo_kv_event_publish_removed.argtypes = [ -+ ctypes.c_uint64, # event_id -+ ctypes.POINTER(ctypes.c_uint64), # block_ids -+ ctypes.c_size_t, # num_blocks -+ ] -+ self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t -+ -+ self.event_id_counter = 0 -+ -+ def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock], -+ block: PrefixCachingBlock): -+ token_ids_arr = (ctypes.c_uint32 * -+ len(block.token_ids))(*block.token_ids) -+ num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids)) -+ block_hash = (ctypes.c_uint64 * 1)(block.content_hash) -+ parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash) -+ if parent is not None else None) -+ -+ # Publish the event -+ result = self.lib.dynamo_kv_event_publish_stored( -+ self.event_id_counter, # uint64_t event_id -+ token_ids_arr, # const uint32_t *token_ids -+ num_block_tokens, # const uintptr_t *num_block_tokens -+ block_hash, # const uint64_t *block_ids -+ 1, # uintptr_t num_blocks -+ parent_hash, # const uint64_t *parent_hash -+ 0, # uint64_t lora_id -+ ) -+ -+ if result == DynamoResult.OK: -+ logger.debug(f"Store - Published KV Event: {block.content_hash}") -+ else: -+ logger.debug( -+ f"Store - Failed to Publish KV Event: {block.content_hash}") -+ -+ self.event_id_counter += 1 -+ -+ def enqueue_removed_event(self, block_hash: PrefixHash): -+ result = self.lib.dynamo_kv_event_publish_removed( -+ self.event_id_counter, -+ (ctypes.c_uint64 * 1)(block_hash), -+ 1, -+ ) -+ -+ if result == DynamoResult.OK: -+ logger.debug(f"Remove - Published KV Event: {block_hash}") -+ else: -+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}") -+ -+ self.event_id_counter += 1 -diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py -index f507847ad..3f3cba766 100644 ---- a/vllm/core/scheduler.py -+++ b/vllm/core/scheduler.py -@@ -1,25 +1,38 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import enum - import os - import random - import time -+import copy - from collections import deque - from dataclasses import dataclass, field - from typing import Callable, Deque, Dict, Iterable, List, Optional - from typing import Sequence as GenericSequence --from typing import Set, Tuple, Union -+from typing import Set, Tuple, Union, Any - --from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -+from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig - from vllm.core.interfaces import AllocStatus, BlockSpaceManager - from vllm.logger import init_logger - from vllm.lora.request import LoRARequest - from vllm.prompt_adapter.request import PromptAdapterRequest - from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceGroupMetadataDelta, -- SequenceStatus) -+ SequenceStatus, SequenceStage) - from vllm.utils import Device, PyObjectCache -- - logger = init_logger(__name__) - - # Test-only. If configured, decode is preempted with -@@ -285,6 +298,7 @@ class SchedulerPrefillOutputs: - # Ignored sequence groups. - ignored_seq_groups: List[SequenceGroup] - num_lookahead_slots: int -+ num_remote_prefill_groups: int - - @classmethod - def create_empty(cls) -> "SchedulerPrefillOutputs": -@@ -292,6 +306,7 @@ class SchedulerPrefillOutputs: - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=0, -+ num_remote_prefill_groups=0, - ) - - -@@ -325,12 +340,14 @@ class Scheduler: - - def __init__( - self, -+ model_config: ModelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], - pipeline_parallel_size: int = 1, - output_proc_callback: Optional[Callable] = None, - ) -> None: -+ self.model_config = model_config - self.scheduler_config = scheduler_config - self.cache_config = cache_config - # Note for LoRA scheduling: the current policy is extremely -@@ -356,6 +373,7 @@ class Scheduler: - - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( -+ model_name=self.model_config.served_model_name, - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, -@@ -371,6 +389,16 @@ class Scheduler: - # Sequence groups in the SWAPPED state. - # Contain decode requests that are swapped out. - self.swapped: Deque[SequenceGroup] = deque() -+ -+ # Sequence groups in the REMOTE_PREFILLING state. -+ # Contain requests that are being prefilled by a remote worker. -+ self.remote_prefilling: Deque[SequenceGroup] = deque() -+ # Contain requests that are being prefilled by a local worker. -+ self.prefill_sending: Deque[SequenceGroup] = deque() -+ -+ self._remote_prefill_outputs: Dict[str, int] = {} -+ -+ - # Sequence groups finished requests ids since last step iteration. - # It lets the model know that any state associated with these requests - # can and must be released after the current step. -@@ -501,7 +529,7 @@ class Scheduler: - - def has_unfinished_seqs(self) -> bool: - return len(self.waiting) != 0 or len(self.running) != 0 or len( -- self.swapped) != 0 -+ self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0 - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_manager.get_prefix_cache_hit_rate(device) -@@ -523,6 +551,8 @@ class Scheduler: - budget: SchedulingBudget, - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, -+ finished_prefills: Optional[Set[str]] = None, -+ finished_transfers: Optional[Set[str]] = None - ) -> SchedulerRunningOutputs: - """Schedule sequence groups that are running. - -@@ -537,6 +567,8 @@ class Scheduler: - chunked number of tokens are scheduled if - `budget.num_batched_tokens` has not enough capacity to schedule - all tokens. -+ finished_remote_prefill_request_ids: Set of request ids of remote -+ prefills that have finished. - - Returns: - SchedulerRunningOutputs. -@@ -566,6 +598,38 @@ class Scheduler: - preempted: List[SequenceGroup] = ret.preempted - swapped_out: List[SequenceGroup] = ret.swapped_out - -+ remote_prefilling_queue = self.remote_prefilling -+ leftover_remote_prefilling_sequences: Deque[SequenceGroup] = deque() -+ while remote_prefilling_queue: -+ seq_group = remote_prefilling_queue.popleft() -+ if seq_group.request_id not in finished_prefills: -+ leftover_remote_prefilling_sequences.append(seq_group) -+ continue -+ -+ else: -+ finished_prefills.remove(seq_group.request_id) -+ assert len(seq_group.seqs) == 1 -+ seq = seq_group.seqs[0] -+ # we computed all but the last token in prefill, we need to decode the first token on decode -+ seq_group.update_num_computed_tokens(seq.get_len() - 1) -+ seq.status = SequenceStatus.RUNNING -+ seq.data._stage = SequenceStage.DECODE -+ self.running.appendleft(seq_group) -+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences) -+ -+ remote_transfers_queue = self.prefill_sending -+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque() -+ while remote_transfers_queue: -+ seq_group = remote_transfers_queue.popleft() -+ if seq_group.request_id not in finished_transfers: -+ leftover_remote_transfers_sequences.append(seq_group) -+ else: -+ finished_transfers.remove(seq_group.request_id) -+ assert len(seq_group.seqs) == 1 -+ seq = seq_group.seqs[0] -+ self.free_seq(seq) -+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences) -+ - running_queue = self.running - assert len(self._async_stopped) == 0 - while running_queue: -@@ -925,6 +989,7 @@ class Scheduler: - seq_groups: List[ScheduledSequenceGroup] = [] - - waiting_queue = self.waiting -+ num_remote_prefill_groups = 0 - - leftover_waiting_sequences: Deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: -@@ -961,8 +1026,10 @@ class Scheduler: - True, enable_chunking) - - # If the sequence group cannot be allocated, stop. -+ is_remote_decode = seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode - can_allocate = self.block_manager.can_allocate( -- seq_group, num_lookahead_slots=num_lookahead_slots) -+ seq_group, num_lookahead_slots=num_lookahead_slots, -+ is_remote_decode=is_remote_decode) - if can_allocate == AllocStatus.LATER: - break - elif can_allocate == AllocStatus.NEVER: -@@ -1008,7 +1075,18 @@ class Scheduler: - if curr_loras is not None and lora_int_id > 0: - curr_loras.add(lora_int_id) - waiting_queue.popleft() -- self._allocate_and_set_running(seq_group) -+ -+ seq_group_copy = copy.deepcopy(seq_group) -+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1 -+ -+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id) -+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id) -+ is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group) -+ num_remote_prefill_groups += is_remote_prefill -+ if is_remote_decode: -+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id) -+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy) -+ self.prefill_sending.append(seq_group_copy) - - if enable_chunking and self.scheduler_config.is_multi_step: - blocks_to_copy: List[Tuple[int, int]] = [] -@@ -1046,9 +1124,11 @@ class Scheduler: - seq_groups=seq_groups, - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots( -- is_prefill=True, enable_chunking=enable_chunking)) -+ is_prefill=True, enable_chunking=enable_chunking), -+ num_remote_prefill_groups=num_remote_prefill_groups -+ ) - -- def _schedule_default(self) -> SchedulerOutputs: -+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: - """Schedule queued requests. - - The current policy is designed to optimize the throughput. First, -@@ -1066,9 +1146,13 @@ class Scheduler: - for seq_group in self.running: - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) -- curr_loras = set( -+ for seq_group in self.remote_prefilling: -+ budget.add_num_seqs(seq_group.request_id, -+ seq_group.get_max_num_running_seqs()) -+ -+ curr_loras = (set( - seq_group.lora_int_id for seq_group in self.running -- if seq_group.lora_int_id > 0) if self.lora_enabled else None -+ if seq_group.lora_int_id > 0) if self.lora_enabled else None) - - prefills = SchedulerPrefillOutputs.create_empty() - running_scheduled = SchedulerRunningOutputs.create_empty() -@@ -1090,7 +1174,9 @@ class Scheduler: - if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running(budget, - curr_loras, -- enable_chunking=False) -+ enable_chunking=False, -+ finished_prefills=finished_prefills, -+ finished_transfers=finished_transfers) - - # If any sequence group is preempted, do not swap in any sequence - # group. because it means there's no slot for new running requests. -@@ -1106,7 +1192,12 @@ class Scheduler: - self.waiting.extendleft(running_scheduled.preempted) - # Update new running requests. - if len(prefills.seq_groups) > 0: -- self.running.extend([s.seq_group for s in prefills.seq_groups]) -+ for s in prefills.seq_groups: -+ seq_group = s.seq_group -+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: -+ self.remote_prefilling.append(seq_group) -+ else: -+ self.running.append(seq_group) - - self.running.extend(running_scheduled.decode_seq_groups_list) - -@@ -1248,12 +1339,14 @@ class Scheduler: - len(running_scheduled.swapped_out)), - ) - -- def _schedule(self) -> SchedulerOutputs: -+ def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: - """Schedule queued requests.""" - if self.scheduler_config.chunked_prefill_enabled: -+ if finished_prefills or finished_transfers: -+ raise ValueError("Chunked prefill does not support remote prefills") - return self._schedule_chunked_prefill() - else: -- return self._schedule_default() -+ return self._schedule_default(finished_prefills, finished_transfers) - - def _can_append_slots(self, seq_group: SequenceGroup, - enable_chunking: bool) -> bool: -@@ -1287,14 +1380,16 @@ class Scheduler: - return no_single_seq - - def schedule( -- self -+ self, -+ finished_prefills: Optional[Set[str]] = None, -+ finished_transfers: Optional[Set[str]] = None - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: - # Schedule sequence groups. - # This function call changes the internal states of the scheduler - # such as self.running, self.swapped, and self.waiting. -- scheduler_start_time = time.perf_counter() - -- scheduler_outputs: SchedulerOutputs = self._schedule() -+ scheduler_start_time = time.perf_counter() -+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers) - now = time.time() - - if not self.cache_config.enable_prefix_caching: -@@ -1333,7 +1428,8 @@ class Scheduler: - encoder_seq_data = None - cross_block_table = None - -- for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): -+ running_or_remote_prefilling_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + seq_group.get_seqs(status=SequenceStatus.REMOTE_PREFILLING) -+ for seq in running_or_remote_prefilling_seqs: - seq_id = seq.seq_id - seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) -@@ -1342,7 +1438,9 @@ class Scheduler: - if self.cache_config.enable_prefix_caching: - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( -- seq_group.get_seqs(status=SequenceStatus.RUNNING))) -+ running_or_remote_prefilling_seqs -+ ) -+ ) - - do_sample = True - is_prompt = seq_group.is_prefill() -@@ -1364,9 +1462,30 @@ class Scheduler: - < seqs[0].data.get_len()): - do_sample = False - -+ is_remote_prefill = False -+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: -+ is_remote_prefill = True -+ logger.debug("Remote prefill, computed block nums: %s", common_computed_block_nums) -+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode: -+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids -+ -+ # Since we know that prefill is scheduled we can -+ # assume that the blocks computed on decode -+ # will be fetched by the time we run prefill -+ logger.debug("Computed decode blocks: %s", seq_group.remote_prefill_params.decode_computed_block_ids) -+ if seq_group.remote_prefill_params.decode_computed_block_ids: -+ computed_block_ids = set(seq_group.remote_prefill_params.decode_computed_block_ids) -+ prefill_block_ids = block_tables[seq_group.seqs[0].seq_id] -+ prefill_fetched_block_ids = [prefill_block_ids[i] for i, block_id in enumerate(seq_group.remote_prefill_params.decode_block_ids) if block_id in computed_block_ids and i < len(prefill_block_ids)] -+ -+ assert len(common_computed_block_nums) == 0, "common_computed_block_nums should be empty for remote prefill as it doesn't suport prefix caching" -+ common_computed_block_nums = prefill_fetched_block_ids -+ -+ - # It assumes the scheduled_seq_groups is ordered by - # prefill < decoding. - if is_first_prefill or not self.scheduler_config.send_delta_data: -+ logger.debug("Assinged blocks: %s", block_tables) - seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group.request_id, - is_prompt=is_prompt, -@@ -1392,6 +1511,7 @@ class Scheduler: - if scheduler_outputs.num_prefill_groups > 0 else None, - mm_processor_kwargs=seq_group.mm_processor_kwargs, - prompt_adapter_request=seq_group.prompt_adapter_request, -+ do_remote_prefill=is_remote_prefill, - ) - else: - # When SPMD mode is enabled, we only send delta data except for -@@ -1490,11 +1610,17 @@ class Scheduler: - - self._async_stopped.clear() - -- def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: -+ def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> bool: - self.block_manager.allocate(seq_group) -+ is_remote_prefill = False - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): -- seq.status = SequenceStatus.RUNNING -- -+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: -+ seq.status = SequenceStatus.REMOTE_PREFILLING -+ is_remote_prefill = True -+ else: -+ seq.status = SequenceStatus.RUNNING -+ return is_remote_prefill -+ - def _append_slots(self, - seq_group: SequenceGroup, - blocks_to_copy: List[Tuple[int, int]], -diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py -new file mode 100644 -index 000000000..a2f9ce99e ---- /dev/null -+++ b/vllm/distributed/device_communicators/kv_rearrange.py -@@ -0,0 +1,125 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+import torch -+import triton -+import triton.language as tl -+ -+@triton.jit -+def rearrange_kernel_read( -+ t1_ptr, -+ t2_ptr, -+ N, -+ B, -+ H, -+ C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE: tl.constexpr, -+): -+ pid = tl.program_id(0) -+ -+ block_start = pid * BLOCK_SIZE -+ offsets = block_start + tl.arange(0, BLOCK_SIZE) -+ -+ curr_n = offsets // block_size -+ curr_b = offsets // token_size % B -+ curr_h = offsets // C % H -+ curr_c = offsets % C -+ -+ src_pos = offsets -+ -+ tp_group = curr_h * d // H -+ dst_h = curr_h % (H // d) -+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c -+ -+ dst_pos = tensor_subset_size * tp_group + tp_group_offset -+ -+ tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos)) -+ -+@triton.jit -+def rearrange_kernel_write( -+ t1_ptr, -+ t2_ptr, -+ N, -+ B, -+ H, -+ C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE: tl.constexpr, -+): -+ pid = tl.program_id(0) -+ -+ block_start = pid * BLOCK_SIZE -+ offsets = block_start + tl.arange(0, BLOCK_SIZE) -+ -+ curr_n = offsets // block_size -+ curr_b = offsets // token_size % B -+ curr_h = offsets // C % H -+ curr_c = offsets % C -+ -+ src_pos = offsets -+ -+ tp_group = curr_h * d // H -+ dst_h = curr_h % (H // d) -+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c -+ -+ dst_pos = tensor_subset_size * tp_group + tp_group_offset -+ -+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos)) -+ -+ -+ -+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str): -+ N, B, H, C = t1.shape -+ -+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source" -+ assert H % d == 0, "H must be divisible by d" -+ -+ block_size = B * H * C -+ token_size = H * C -+ tensor_size = N * block_size -+ tensor_subset_size = tensor_size // d -+ -+ BLOCK_SIZE = 1024 -+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,) -+ -+ if direction == "read": -+ rearrange_kernel_read[grid]( -+ t1, t2, -+ N, B, H, C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE=BLOCK_SIZE -+ ) -+ elif direction == "write": -+ rearrange_kernel_write[grid]( -+ t1, t2, -+ N, B, H, C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE=BLOCK_SIZE -+ ) -+ else: -+ raise ValueError(f"Invalid direction: {direction}") -\ No newline at end of file -diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py -new file mode 100644 -index 000000000..136a0bd37 ---- /dev/null -+++ b/vllm/distributed/device_communicators/nixl.py -@@ -0,0 +1,394 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+import torch -+from typing import List, Tuple -+from vllm.config import VllmConfig -+from vllm.logger import init_logger -+import msgspec -+import time -+import uuid -+from collections import defaultdict -+from .kv_rearrange import rearrange_tensors -+ -+logger = init_logger(__name__) -+ -+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used -+try: -+ from nixl._api import nixl_agent as NixlWrapper -+ logger.info("NIXL is available") -+except ImportError: -+ logger.warning("NIXL is not available") -+ NixlWrapper = None -+ -+class NixlMetadata( -+ msgspec.Struct, -+ omit_defaults=True, # type: ignore[call-arg] -+ # required for @cached_property. -+ dict=True): -+ engine_id: str -+ agent_metadata: List[bytes] -+ kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values -+ num_blocks: int -+ -+ -+class DynamoNixlConnector: -+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): -+ self.vllm_config = vllm_config -+ if NixlWrapper is None: -+ logger.error("NIXL is not available") -+ raise RuntimeError("NIXL is not available") -+ logger.info("Initializing NIXL wrapper") -+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) -+ -+ self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer -+ -+ self.num_layers = None -+ self.num_blocks = None -+ self.num_heads = None -+ self.block_len = None -+ self.kv_caches = None -+ self.kv_caches_base_addr = {} -+ self.kv_cache_shape = {} -+ -+ self._registered_descs = [] -+ self._remote_agents = {} -+ self.engine_id = engine_id -+ self.rank = rank -+ self._tp_size = {} -+ self.src_xfer_side_handles = {} -+ self.dst_xfer_side_handles = defaultdict(dict) -+ self.dst_num_blocks = {} -+ -+ self._transfers = defaultdict(list) -+ -+ -+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size -+ -+ -+ @property -+ def agent_name(self): -+ return self.nixl_wrapper.name -+ -+ def register_kv_caches(self, kv_caches: List[torch.Tensor]): -+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape -+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() -+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) -+ self.num_layers = len(kv_caches) -+ self.num_blocks = num_blocks -+ self.num_heads = num_heads -+ self.kv_caches = kv_caches -+ kv_caches_base_addr = [] -+ caches_data = [] -+ for key_cache, value_cache in kv_caches: -+ base_addr = key_cache.data_ptr() -+ region_len = 2 * num_blocks * self.block_len -+ caches_data.append((base_addr, region_len, self.rank, "")) -+ kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr())) -+ -+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr -+ -+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") -+ logger.debug("Registering descs: %s", caches_data) -+ self.nixl_wrapper.register_memory(descs) -+ self._registered_descs.append(descs) -+ -+ def get_agent_metadata(self): -+ return self.nixl_wrapper.get_agent_metadata() -+ -+ def shutdown(self): -+ for descs_list in self._registered_descs: -+ self.nixl_wrapper.deregister_memory(descs_list) -+ for agent_names in self._remote_agents.values(): -+ for agent_name in agent_names: -+ self.nixl_wrapper.remove_remote_agent(agent_name) -+ for src_xfer_side_handle in self.src_xfer_side_handles.values(): -+ self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) -+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): -+ for dst_xfer_side_handle in dst_xfer_side_handles.values(): -+ self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle) -+ -+ def _get_ranges(self, block_ids): -+ # This function should return a list of ranges of block ids that are contiguous -+ # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] -+ # The ranges are sorted by the starting block id -+ # The function should also make sure that the block ids are contiguous -+ # If the block ids are not contiguous, the function should raise an error -+ ranges = [] -+ for i in range(len(block_ids)): -+ if i == 0 or block_ids[i] != block_ids[i-1] + 1: -+ ranges.append([block_ids[i], block_ids[i]]) -+ else: -+ ranges[-1][1] = block_ids[i] -+ return ranges -+ -+ def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None): -+ -+ if layer_ids == "all": -+ layer_ids = list(range(self.num_layers)) -+ if block_ids == "all": -+ block_ids = list(range(self.num_blocks)) -+ -+ descs_ids = [] -+ -+ -+ if i is not None: -+ num_blocks = self.num_blocks -+ for layer_id in layer_ids: -+ for is_value in [0, 1]: -+ staging_range_idx = 0 -+ for block_id in block_ids: -+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]: -+ staging_range_idx += 1 -+ start_offset = staging_ranges[staging_range_idx][0] -+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1) -+ descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset)) -+ else: -+ num_blocks = self.dst_num_blocks[engine_id] -+ for layer_id in layer_ids: -+ for is_value in [0, 1]: -+ for block_id in block_ids: -+ descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id) -+ return descs_ids -+ -+ def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False): -+ # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length -+ # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]] -+ # The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]]) -+ src_overlapping_ranges, dst_overlapping_ranges = [], [] -+ -+ original_src_ranges = [] -+ org_src_range = tuple(src_ranges[0]) -+ -+ src_idx, dst_idx = 0, 0 -+ while src_idx < len(src_ranges) and dst_idx < len(dst_ranges): -+ src_range = src_ranges[src_idx] -+ dst_range = dst_ranges[dst_idx] -+ -+ # Calculate the length of each range -+ src_len = src_range[-1] - src_range[0] + 1 -+ dst_len = dst_range[-1] - dst_range[0] + 1 -+ -+ # If ranges have the same length, add them directly -+ if src_len == dst_len: -+ src_overlapping_ranges.append([src_range[0], src_range[-1]]) -+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) -+ original_src_ranges.append(org_src_range) -+ src_idx += 1 -+ dst_idx += 1 -+ if src_idx < len(src_ranges): -+ org_src_range = tuple(src_ranges[src_idx]) -+ # If source range is longer, split it -+ elif src_len > dst_len: -+ src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1]) -+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) -+ original_src_ranges.append(org_src_range) -+ # Update source range for next iteration -+ src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]] -+ dst_idx += 1 -+ # If destination range is longer, split it -+ else: # src_len < dst_len -+ src_overlapping_ranges.append([src_range[0], src_range[-1]]) -+ dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1]) -+ original_src_ranges.append(org_src_range) -+ # Update destination range for next iteration -+ dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]] -+ src_idx += 1 -+ if src_idx < len(src_ranges): -+ org_src_range = tuple(src_ranges[src_idx]) -+ if return_original_src_ranges: -+ return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges -+ return src_overlapping_ranges, dst_overlapping_ranges -+ -+ def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id): -+ logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id) -+ -+ assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids) -+ -+ if len(local_block_ids) == 0: -+ logger.debug("No blocks to read") -+ return -+ -+ start_time = time.perf_counter() -+ -+ local_ranges = self._get_ranges(local_block_ids) -+ staging_ranges = self._get_ranges(staging_block_ids) -+ -+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) -+ -+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] -+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) -+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] -+ handles = [] -+ -+ logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000) -+ create_xfer_start_time = time.perf_counter() -+ -+ for i in range(tp_multiplier): -+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) -+ assert len(staging_block_descs_ids) == len(remote_block_descs_ids) -+ remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] -+ handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids, -+ remote_xfer_side_handle, remote_block_descs_ids, -+ "") -+ handles.append(handle) -+ status = self.nixl_wrapper.transfer(handle) -+ -+ logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) -+ -+ transfer_start_time = time.perf_counter() -+ -+ for handle in handles: -+ while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE": -+ if status == "PROC": -+ time.sleep(0.001) -+ else: -+ raise RuntimeError("Read transfer failed with state %s", status) -+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors? -+ -+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000) -+ -+ rearrange_start_time = time.perf_counter() -+ -+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): -+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) -+ for kv_cache in self.kv_caches: -+ for cache in kv_cache: -+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read") -+ -+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000) -+ logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000) -+ -+ def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg): -+ logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg) -+ -+ # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last -+ # isl[-1] token is calculated in the first iteration in decode. -+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ -+ # one less block due to the missing last token. -+ remote_block_ids = remote_block_ids[:len(local_block_ids)] -+ -+ assert len(staging_block_ids) == len(local_block_ids) -+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] -+ -+ if len(local_block_ids) == 0: -+ logger.debug("No blocks to write") -+ for i in range(tp_multiplier): -+ self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg) -+ return -+ -+ start_time = time.perf_counter() -+ -+ local_ranges = self._get_ranges(local_block_ids) -+ staging_ranges = self._get_ranges(staging_block_ids) -+ -+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) -+ -+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): -+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) -+ for kv_cache in self.kv_caches: -+ for cache in kv_cache: -+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write") -+ -+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000) -+ -+ create_xfer_start_time = time.perf_counter() -+ -+ # getting block descs ids -+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) -+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] -+ -+ for i in range(tp_multiplier): -+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) -+ assert len(staging_block_descs_ids) == len(remote_block_descs_ids) -+ remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] -+ handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids, -+ remote_xfer_side_handle, remote_block_descs_ids, -+ notify_msg) -+ self._transfers[notify_msg].append(handle) -+ status = self.nixl_wrapper.transfer(handle) -+ -+ logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) -+ -+ transfer_start_time = time.perf_counter() -+ logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000) -+ -+ def get_notifs(self): -+ return self.nixl_wrapper.update_notifs() -+ -+ def get_new_notifs(self): -+ return self.nixl_wrapper.get_new_notifs() -+ -+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks): -+ self._tp_size[engine_id] = agent_tp -+ agent_names = [] -+ for agent_meta in agent_metadata: -+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta) -+ agent_names.append(agent_name) -+ self._remote_agents[engine_id] = agent_names -+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr -+ -+ tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id] -+ assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}" -+ -+ logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier) -+ dst_block_len = self.block_len // tp_multiplier -+ if tp_multiplier not in self.src_xfer_side_handles: -+ # create descs and xfer side handles -+ blocks_data = [] -+ for layer_id in range(self.num_layers): -+ for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]: -+ for block_id in range(self.num_blocks): -+ block_offset = block_id * self.block_len -+ for i in range(tp_multiplier): -+ tp_multiplier_offset = i * dst_block_len -+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) -+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i) -+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") -+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs) -+ -+ # create dst xfer side handles -+ self.dst_num_blocks[engine_id] = num_blocks -+ for i in range(tp_multiplier): -+ blocks_data = [] -+ for layer_id in range(self.num_layers): -+ for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]: -+ for block_id in range(num_blocks): -+ block_offset = block_id * dst_block_len -+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i)) -+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i) -+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") -+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs) -+ -+ return agent_names -+ -+ def get_done_tranfers(self) -> List[str]: -+ done_req_ids = [] -+ for req_id, handles in self._transfers.items(): -+ running_reqs = [] -+ for handle in handles: -+ xfer_state = self.nixl_wrapper.check_xfer_state(handle) -+ if xfer_state == "DONE": -+ # self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors? -+ continue -+ if xfer_state == "PROC": -+ running_reqs.append(handle) -+ else: -+ raise RuntimeError("Transfer failed with state %s", xfer_state) -+ if len(running_reqs) == 0: -+ done_req_ids.append(req_id) -+ else: -+ self._transfers[req_id] = running_reqs -+ return done_req_ids -diff --git a/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py -new file mode 100644 -index 000000000..418fc7154 ---- /dev/null -+++ b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py -@@ -0,0 +1,363 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+""" -+Simple KV Cache Connector for Distributed Machine Learning Inference -+ -+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache -+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or -+MooncakePipe. -+ -+But the logic can be extended to support other pipe and lookup buffer. -+""" -+import re -+from typing import TYPE_CHECKING, List, Optional, Tuple, Union -+ -+import torch -+ -+from vllm import _custom_ops as ops -+from vllm.config import VllmConfig, KVTransferConfig -+from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -+from vllm.distributed.utils import StatelessProcessGroup -+from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( -+ SimpleBuffer) -+from vllm.logger import init_logger -+from vllm.sequence import IntermediateTensors -+ -+if TYPE_CHECKING: -+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata -+ -+logger = init_logger(__name__) -+ -+ -+class DynamoConnector(KVConnectorBase): -+ -+ def __init__( -+ self, -+ rank: int, -+ local_rank: int, -+ config: VllmConfig, -+ world_group, -+ ): -+ -+ self.config = config.kv_transfer_config -+ self.tp_size = config.parallel_config.tensor_parallel_size -+ self.rank = rank -+ -+ if self.config.kv_connector != "DynamoNcclConnector": -+ raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class") -+ -+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( -+ PyNcclPipe) -+ from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import ( -+ DynamoNcclDataPlane) -+ -+ logger.info( -+ "Initializing DynamoNcclConnector under kv_transfer_config %s", -+ self.config) -+ -+ self.lookup_buffer_size = self.config.kv_buffer_size -+ -+ self.producer_data_pipe: PyNcclPipe -+ self.consumer_data_pipe: PyNcclPipe -+ self.producer_signal_pipe: PyNcclPipe -+ self.consumer_signal_pipe: PyNcclPipe -+ -+ self._broadcast_and_enhance_kv_config(rank, config, world_group) -+ -+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) -+ self.tp_size = config.parallel_config.tensor_parallel_size -+ -+ # 2 pipes for every rank in the world -+ if self.config.is_kv_producer: -+ port_offset_base = rank + 1 -+ else: -+ port_offset_base = rank // self.config.tensor_parallel_multiplier + 1 -+ -+ -+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier -+ self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config) -+ -+ self.data_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, -+ local_rank=local_rank, -+ config=self.config, -+ port_offset=port_offset_base, -+ ) -+ -+ self.data_plane = DynamoNcclDataPlane( -+ data_pipe=self.data_pipe, -+ port=self._get_data_plane_port(self.global_kv_rank), -+ ) -+ -+ def send_kv_caches_and_hidden_states( -+ self, -+ model_executable: torch.nn.Module, -+ model_input: "ModelInputForGPUWithSamplingMetadata", -+ kv_caches: List[torch.Tensor], -+ hidden_or_intermediate_states: Union[torch.Tensor, -+ IntermediateTensors], -+ ) -> None: -+ -+ input_tokens_tensor = model_input.input_tokens -+ seq_lens = model_input.attn_metadata.seq_lens -+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() -+ start_layer = model_executable.model.start_layer -+ end_layer = model_executable.model.end_layer -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) -+ -+ model_config = model_executable.model.config -+ is_deepseek = "deepseek" in model_config.architectures[0].lower() -+ if not is_deepseek: -+ num_heads = int(model_config.num_key_value_heads / self.tp_size) -+ hidden_size = model_config.hidden_size -+ num_attention_heads = model_config.num_attention_heads -+ head_size = int(hidden_size / num_attention_heads) -+ else: -+ num_heads = int(model_config.num_key_value_heads / self.tp_size) -+ hidden_size = model_config.hidden_size -+ num_attention_heads = model_config.num_attention_heads -+ head_size = int(4.5 * hidden_size / num_attention_heads) -+ -+ # query_lens contains new KV caches that are added to vLLM. -+ # so we will send them to decode instance -+ # FIXME(Kuntai): This assume that all requests are prefill. -+ for idx, slen in enumerate(seq_lens): -+ start_pos = sum(seq_lens[:idx]) -+ end_pos = start_pos + slen -+ current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id) -+ decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config) -+ -+ for target_rank in range(self.config.tensor_parallel_multiplier): -+ -+ keys, values = [], [] -+ -+ for layer_id in range(start_layer, end_layer): -+ kv_cache = kv_caches[layer_id - start_layer] -+ -+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] -+ -+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier -+ head_start = target_rank * num_heads_per_rank -+ head_end = head_start + num_heads_per_rank -+ -+ if not is_deepseek: -+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) -+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) -+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ else: -+ key_cache = kv_cache -+ keys.append(key_cache[current_slot_mapping].unsqueeze(0)) -+ values.append(torch.empty(0)) -+ -+ keys = torch.cat(keys, dim=0) -+ values = torch.cat(values, dim=0) -+ -+ decode_global_rank = decode_first_global_rank + target_rank -+ decode_port = self._get_data_plane_port(decode_global_rank) -+ partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos] -+ self._send(decode_hostname, decode_port, current_request_id, keys, values, -+ partial_hidden_or_intermediate_states) -+ -+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) -+ -+ def recv_kv_caches_and_hidden_states( -+ self, model_executable: torch.nn.Module, -+ model_input: "ModelInputForGPUWithSamplingMetadata", -+ kv_caches: List[torch.Tensor] -+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, -+ "ModelInputForGPUWithSamplingMetadata"]: -+ -+ # When bypass_model_exec is set to False, it means that at least for one -+ # request its corresponding KV cache or hidden state is missing. -+ # In this case we need to do prefilling to recompute missing KV cache -+ # and hidden states. -+ bypass_model_exec = True -+ -+ input_tokens_tensor = model_input.input_tokens -+ seq_lens = model_input.attn_metadata.seq_lens -+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten() -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) -+ -+ hidden_or_intermediate_states_for_one_req = [] -+ -+ input_tokens_list = [] -+ start_pos_list = [] -+ -+ model_config = model_executable.model.config -+ is_deepseek = "deepseek" in model_config.architectures[0].lower() -+ -+ # enumerate different requests -+ # FIXME(Kuntai): This impl assumes that all requests are prefill. -+ for idx, slen in enumerate(seq_lens): -+ -+ start_pos = sum(seq_lens[:idx]) -+ end_pos = start_pos + slen -+ current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ num_tokens = slen -+ -+ # collecting data for rebuilding the input -+ input_tokens_list.append(current_tokens) -+ start_pos_list.append(start_pos) -+ -+ ret = self._recv(current_request_id) -+ keys: torch.Tensor = ret[0] -+ values: torch.Tensor = ret[1] -+ hidden: torch.Tensor = ret[2] -+ -+ # put received KV caches into paged memory -+ for i in range(model_executable.model.start_layer, -+ model_executable.model.end_layer): -+ -+ kv_cache = kv_caches[i - model_executable.model.start_layer] -+ layer = model_executable.model.layers[i] -+ -+ if not is_deepseek: -+ key_cache, value_cache = kv_cache[0], kv_cache[1] -+ ops.reshape_and_cache_flash( -+ keys[i - model_executable.model.start_layer].to( -+ key_cache.device), -+ values[i - model_executable.model.start_layer].to( -+ value_cache.device), -+ key_cache, -+ value_cache, -+ slot_mapping[start_pos:end_pos], -+ layer.self_attn.attn.kv_cache_dtype, -+ layer.self_attn.attn._k_scale, -+ layer.self_attn.attn._v_scale, -+ ) -+ else: -+ key_cache = kv_cache -+ copy_from =keys[i - model_executable.model.start_layer].to( -+ key_cache.device) -+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from -+ -+ hidden_or_intermediate_states_for_one_req.append(hidden) -+ -+ if not bypass_model_exec: -+ # Some of the KV cache is not retrieved -+ # Here we will fall back to normal model forwarding -+ # But optionally you can adjust model_input so that you only do -+ # prefilling on those tokens that are missing KV caches. -+ logger.debug( -+ "[rank%d]: Failed to receive all KVs and hidden " -+ "states, redo model forwarding.", torch.distributed.get_rank()) -+ hidden_or_intermediate_states = None -+ -+ else: -+ logger.debug( -+ "[rank%d]: Successfully received all KVs and hidden " -+ "states, skip model forwarding.", torch.distributed.get_rank()) -+ hidden_or_intermediate_states = torch.cat( -+ hidden_or_intermediate_states_for_one_req, dim=0) -+ -+ return hidden_or_intermediate_states, bypass_model_exec, model_input -+ -+ def close(self): -+ self.data_pipe.close() -+ # self.data_plane.close() -+ -+ @staticmethod -+ def parse_request_id(request_id: str) -> Tuple[str, int]: -+ # Regular expression to match the string hostname and integer decode_kv_rank -+ pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)" -+ -+ # Use re.search to find the pattern in the request_id -+ match = re.search(pattern, request_id) -+ if match: -+ # Extract the ranks -+ decode_hostname = match.group(1) -+ decode_rank = int(match.group(2)) -+ -+ return decode_hostname, decode_rank -+ raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank") -+ -+ def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor): -+ remote_address = f"{hostname}:{port}" -+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address) -+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address) -+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address) -+ -+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: -+ keys = self.data_plane.recv_tensor(f"{request_id}_keys") -+ values = self.data_plane.recv_tensor(f"{request_id}_values") -+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden") -+ return keys, values, hidden -+ -+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: -+ if kv_rank < config.kv_producers_parallel_size: -+ return kv_rank -+ -+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size -+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier -+ -+ -+ def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: -+ if kv_rank <= config.kv_producers_parallel_size: -+ return kv_rank * config.kv_producers_tensor_parallel_size + rank -+ -+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size -+ return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank -+ -+ -+ def _get_data_plane_port(self, global_kv_rank: int) -> int: -+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank -+ -+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): -+ if rank == 0: -+ config_group = StatelessProcessGroup.create( -+ host=self.config.kv_ip, -+ port=self.config.kv_port, -+ rank=self.config.kv_rank, -+ world_size=self.config.kv_parallel_size, -+ ) -+ parallel_configs = config_group.all_gather_obj({ -+ "kv_role": self.config.kv_role, -+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size, -+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, -+ }) -+ logger.debug("parallel_configs: %s", parallel_configs) -+ kv_config_enhanced = { -+ "kv_producers_tensor_parallel_size": None, -+ "kv_consumers_tensor_parallel_size": None, -+ "kv_producers_pipeline_parallel_size": None, -+ "kv_consumers_pipeline_parallel_size": None, -+ "kv_producers_parallel_size": 0, -+ } -+ for parallel_config in parallel_configs: -+ kv_role = parallel_config["kv_role"] -+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" -+ -+ if kv_role == "kv_producer": -+ kv_config_enhanced["kv_producers_parallel_size"] += 1 -+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: -+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] -+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] -+ else: -+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" -+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" -+ world_group.broadcast_object(kv_config_enhanced) -+ else: -+ kv_config_enhanced = world_group.broadcast_object() -+ logger.info("kv_config_enhanced: %s", kv_config_enhanced) -+ -+ self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] -+ self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] -+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] -+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] -+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] -diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py -index fe4805334..0e16f0b31 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/factory.py -+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import importlib - from typing import TYPE_CHECKING, Callable, Dict, Type -@@ -27,13 +40,13 @@ class KVConnectorFactory: - - @classmethod - def create_connector(cls, rank: int, local_rank: int, -- config: "VllmConfig") -> KVConnectorBase: -+ config: "VllmConfig", world_group) -> KVConnectorBase: - connector_name = config.kv_transfer_config.kv_connector - if connector_name not in cls._registry: - raise ValueError(f"Unsupported connector type: {connector_name}") - - connector_cls = cls._registry[connector_name]() -- return connector_cls(rank, local_rank, config) -+ return connector_cls(rank, local_rank, config, world_group) - - - # Register various connectors here. -@@ -48,3 +61,8 @@ KVConnectorFactory.register_connector( - "MooncakeConnector", - "vllm.distributed.kv_transfer.kv_connector.simple_connector", - "SimpleConnector") -+ -+KVConnectorFactory.register_connector( -+ "DynamoNcclConnector", -+ "vllm.distributed.kv_transfer.kv_connector.dynamo_connector", -+ "DynamoConnector") -diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py -index 2033e9762..983bc69a3 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py -+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - Simple KV Cache Connector for Distributed Machine Learning Inference - -@@ -8,13 +21,15 @@ MooncakePipe. - - But the logic can be extended to support other pipe and lookup buffer. - """ -+import re - from typing import TYPE_CHECKING, List, Optional, Tuple, Union - - import torch - - from vllm import _custom_ops as ops --from vllm.config import VllmConfig -+from vllm.config import VllmConfig, KVTransferConfig - from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -+from vllm.distributed.utils import StatelessProcessGroup - from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) - from vllm.logger import init_logger -@@ -33,6 +48,7 @@ class SimpleConnector(KVConnectorBase): - rank: int, - local_rank: int, - config: VllmConfig, -+ world_group, - ): - - self.config = config.kv_transfer_config -@@ -71,20 +87,31 @@ class SimpleConnector(KVConnectorBase): - self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - -+ self._broadcast_and_enhance_kv_config(rank, config, world_group) -+ -+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) -+ self.tp_size = config.parallel_config.tensor_parallel_size -+ - # 2 pipes for every rank in the world -- port_offset_base = 2 * rank -+ if self.config.is_kv_producer: -+ port_offset_base = 2 * rank + 1 -+ else: -+ port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1 - -+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier - # In disaggregated prefill, the prefill vLLM only uses send pipe - # and the decode vLLM only uses recv pipe - if self.config.is_kv_producer: - - if self.config.kv_connector == "PyNcclConnector": - self.producer_data_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.producer_signal_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, -@@ -108,11 +135,13 @@ class SimpleConnector(KVConnectorBase): - # its recv pipe to the send pipe of KV producder - if self.config.kv_connector == "PyNcclConnector": - self.consumer_data_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.consumer_signal_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, -@@ -131,21 +160,25 @@ class SimpleConnector(KVConnectorBase): - self.config.kv_buffer_size, - ) - -- def select(self, input_tokens: Optional[torch.Tensor], -+ def select(self, source_rank: int, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - -+ logger.info("Selecting KV caches and hidden states for source rank %d", source_rank) -+ - assert self.consumer_buffer is not None, "Please initialize the "\ - "consumer buffer before calling select." -- return self.consumer_buffer.drop_select(input_tokens, roi) -+ return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi) - -- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, -+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - -+ logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank) -+ - assert self.producer_buffer is not None, "Please initialize the "\ - "producer buffer before calling insert." - -- self.producer_buffer.insert(input_tokens, roi, key, value, hidden) -+ self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden) - - def send_kv_caches_and_hidden_states( - self, -@@ -161,12 +194,20 @@ class SimpleConnector(KVConnectorBase): - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) - - model_config = model_executable.model.config -- num_heads = int(model_config.num_key_value_heads / self.tp_size) -- hidden_size = model_config.hidden_size -- num_attention_heads = model_config.num_attention_heads -- head_size = int(hidden_size / num_attention_heads) -+ is_deepseek = "deepseek" in model_config.architectures[0].lower() -+ if not is_deepseek: -+ num_heads = int(model_config.num_key_value_heads / self.tp_size) -+ hidden_size = model_config.hidden_size -+ num_attention_heads = model_config.num_attention_heads -+ head_size = int(hidden_size / num_attention_heads) -+ else: -+ num_heads = int(model_config.num_key_value_heads / self.tp_size) -+ hidden_size = model_config.hidden_size -+ num_attention_heads = model_config.num_attention_heads -+ head_size = int(4.5 * hidden_size / num_attention_heads) - - # query_lens contains new KV caches that are added to vLLM. - # so we will send them to decode instance -@@ -175,27 +216,40 @@ class SimpleConnector(KVConnectorBase): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ _, decode_kv_rank = self.parse_request_id(current_request_id) -+ starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config) -+ -+ for target_rank in range(self.config.tensor_parallel_multiplier): - -- keys, values = [], [] -+ keys, values = [], [] - -- for layer_id in range(start_layer, end_layer): -- kv_cache = kv_caches[layer_id - start_layer] -+ for layer_id in range(start_layer, end_layer): -+ kv_cache = kv_caches[layer_id - start_layer] - -- key_cache = kv_cache[0].reshape(-1, num_heads, head_size) -- value_cache = kv_cache[1].reshape(-1, num_heads, head_size) -+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - -- current_slot_mapping = slot_mapping_flat[start_pos:end_pos] -+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier -+ head_start = target_rank * num_heads_per_rank -+ head_end = head_start + num_heads_per_rank - -- keys.append(key_cache[current_slot_mapping].unsqueeze(0)) -- values.append(value_cache[current_slot_mapping].unsqueeze(0)) -+ if not is_deepseek: -+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) -+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) -+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ else: -+ key_cache = kv_cache -+ keys.append(key_cache[current_slot_mapping].unsqueeze(0)) -+ values.append(torch.empty(0)) - -- keys = torch.cat(keys, dim=0) -- values = torch.cat(values, dim=0) -+ keys = torch.cat(keys, dim=0) -+ values = torch.cat(values, dim=0) - -- self.insert(current_tokens, -- torch.ones_like(current_tokens, -- dtype=bool), keys, values, -- hidden_or_intermediate_states[start_pos:end_pos]) -+ self.insert(starting_kv_group_rank, target_rank, current_tokens, -+ torch.ones_like(current_tokens, -+ dtype=bool), keys, values, -+ hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - -@@ -215,6 +269,7 @@ class SimpleConnector(KVConnectorBase): - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) - - hidden_or_intermediate_states_for_one_req = [] - -@@ -222,6 +277,9 @@ class SimpleConnector(KVConnectorBase): - num_computed_tokens_list = [] - start_pos_list = [] - -+ model_config = model_executable.model.config -+ is_deepseek = "deepseek" in model_config.architectures[0].lower() -+ - # enumerate different requests - # FIXME(Kuntai): This impl assumes that all requests are prefill. - for idx, slen in enumerate(seq_lens): -@@ -229,13 +287,15 @@ class SimpleConnector(KVConnectorBase): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ prefill_rank, _ = self.parse_request_id(current_request_id) - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - -- ret = self.select(current_tokens, -+ ret = self.select(prefill_rank, current_tokens, - torch.ones_like(current_tokens, dtype=bool)) - if ret[0] is None: - # didn't find any match. -@@ -267,19 +327,25 @@ class SimpleConnector(KVConnectorBase): - kv_cache = kv_caches[i - model_executable.model.start_layer] - layer = model_executable.model.layers[i] - -- key_cache, value_cache = kv_cache[0], kv_cache[1] -- ops.reshape_and_cache_flash( -- keys[i - model_executable.model.start_layer].to( -- key_cache.device), -- values[i - model_executable.model.start_layer].to( -- value_cache.device), -- key_cache, -- value_cache, -- slot_mapping[start_pos:end_pos], -- layer.self_attn.attn.kv_cache_dtype, -- layer.self_attn.attn._k_scale, -- layer.self_attn.attn._v_scale, -- ) -+ if not is_deepseek: -+ key_cache, value_cache = kv_cache[0], kv_cache[1] -+ ops.reshape_and_cache_flash( -+ keys[i - model_executable.model.start_layer].to( -+ key_cache.device), -+ values[i - model_executable.model.start_layer].to( -+ value_cache.device), -+ key_cache, -+ value_cache, -+ slot_mapping[start_pos:end_pos], -+ layer.self_attn.attn.kv_cache_dtype, -+ layer.self_attn.attn._k_scale, -+ layer.self_attn.attn._v_scale, -+ ) -+ else: -+ key_cache = kv_cache -+ copy_from =keys[i - model_executable.model.start_layer].to( -+ key_cache.device) -+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from - - hidden_or_intermediate_states_for_one_req.append(hidden) - -@@ -312,3 +378,77 @@ class SimpleConnector(KVConnectorBase): - # MooncakePipe reuses data_pipe for signal_pipe, so we only have to - # close the data_pipe. - pass -+ -+ @staticmethod -+ def parse_request_id(request_id): -+ # Regular expression to match the ranks -+ pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)" -+ -+ # Use re.search to find the pattern in the request_id -+ match = re.search(pattern, request_id) -+ -+ if match: -+ # Extract the ranks -+ prefill_rank = int(match.group(1)) -+ decode_rank = int(match.group(2)) -+ -+ return prefill_rank, decode_rank -+ else: -+ return None, None -+ -+ -+ -+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: -+ if kv_rank < config.kv_producers_parallel_size: -+ return kv_rank -+ -+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size -+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier -+ -+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): -+ if rank == 0: -+ if self.config.kv_connector == "PyNcclConnector": -+ config_group = StatelessProcessGroup.create( -+ host=self.config.kv_ip, -+ port=self.config.kv_port, -+ rank=self.config.kv_rank, -+ world_size=self.config.kv_parallel_size, -+ ) -+ parallel_configs = config_group.all_gather_obj({ -+ "kv_role": self.config.kv_role, -+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size, -+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, -+ }) -+ logger.debug("parallel_configs: %s", parallel_configs) -+ kv_config_enhanced = { -+ "kv_producers_tensor_parallel_size": None, -+ "kv_consumers_tensor_parallel_size": None, -+ "kv_producers_pipeline_parallel_size": None, -+ "kv_consumers_pipeline_parallel_size": None, -+ "kv_producers_parallel_size": 0, -+ } -+ for parallel_config in parallel_configs: -+ kv_role = parallel_config["kv_role"] -+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" -+ -+ if kv_role == "kv_producer": -+ kv_config_enhanced["kv_producers_parallel_size"] += 1 -+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: -+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] -+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] -+ else: -+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" -+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" -+ world_group.broadcast_object(kv_config_enhanced) -+ -+ else: -+ raise NotImplementedError("MooncakeConnector is not supported in Dynamo patch") -+ else: -+ kv_config_enhanced = world_group.broadcast_object() -+ logger.info("kv_config_enhanced: %s", kv_config_enhanced) -+ -+ self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] -+ self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] -+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] -+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] -+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] -diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py -index 5e1b62352..7b4cb406e 100644 ---- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py -+++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - Implements a distributed key-value (KV) cache transfer mechanism. - -@@ -12,7 +25,8 @@ - import threading - import time - from collections import deque --from typing import Deque, List, Optional, Union -+from concurrent.futures import ThreadPoolExecutor -+from typing import Deque, List, Optional, Union, Dict - - import torch - -@@ -46,7 +60,7 @@ class SimpleBuffer(KVLookupBufferBase): - self.buffer_lock = threading.Lock() - self.signal_pipe = signal_pipe - self.data_pipe = data_pipe -- self.request_handling_thread: Optional[threading.Thread] = None -+ self.request_handling_thread: Optional[ThreadPoolExecutor] = None - - self.normal_signal = torch.tensor([0], device="cpu") - self.end_signal = None -@@ -57,10 +71,16 @@ class SimpleBuffer(KVLookupBufferBase): - # tokens_roi_sender: tokens and roi of the producer (in the buffer) - # tokens_roi_recver: tokens and roi of the consumer (query) - -- tokens_sender = tokens_roi_sender[0] -- tokens_recver = tokens_roi_recver[0] -- roi_sender = tokens_roi_sender[1] -- roi_recver = tokens_roi_recver[1] -+ target_rank_sender = tokens_roi_sender[0] -+ target_rank_recver = tokens_roi_recver[0] -+ -+ if target_rank_sender.item() != target_rank_recver.item(): -+ return 0 -+ -+ tokens_sender = tokens_roi_sender[1] -+ tokens_recver = tokens_roi_recver[1] -+ roi_sender = tokens_roi_sender[2] -+ roi_recver = tokens_roi_recver[2] - - if tokens_recver is None: - # consumer sends an empty request -@@ -80,14 +100,14 @@ class SimpleBuffer(KVLookupBufferBase): - - return 0 - -- def _send_tensor_and_dec_size(self, -- tensor: Optional[torch.Tensor]) -> None: -+ def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor], -+ target_rank: int) -> None: - - assert tensor is not None, "Use self.data_pipe.send(None) instead" - self.buffer_size -= tensor.element_size() * tensor.numel() - if tensor.dtype == torch.bool: - tensor = tensor.float() -- self.data_pipe.send_tensor(tensor) -+ self.data_pipe.send_tensor(tensor, target_rank) - - def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - -@@ -100,7 +120,7 @@ class SimpleBuffer(KVLookupBufferBase): - - raise AssertionError(f"Unknown data type {type(data)}") - -- def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, -+ def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor): - -@@ -115,7 +135,7 @@ class SimpleBuffer(KVLookupBufferBase): - if isinstance(hidden, torch.Tensor): - hidden = hidden.clone() - -- buffer_item = [input_tokens, roi, key, value, hidden] -+ buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden] - - with self.buffer_lock: - for data in buffer_item: -@@ -125,53 +145,54 @@ class SimpleBuffer(KVLookupBufferBase): - def _is_end_signal(self, signal): - return signal is None - -- def drop_select_handler(self): -+ def drop_select_handler(self, rank: int): - - try: - -- while True: -- signal = self.signal_pipe.recv_tensor() -- if self._is_end_signal(signal): -- logger.info("Received end signal!") -- break -- -- input_tokens = self.data_pipe.recv_tensor() -- -- roi = self.data_pipe.recv_tensor() -- assert roi is not None, "Please provide the roi when sending "\ -- "drop-select request" -- roi = (roi > 0.5) -- tokens_roi_recver = [input_tokens, roi] -- -- matched_length = 0 -- -- # perform input tokens and roi matching -- # FIXME: this matching is O(n), ideally it should be O(1) -- # but this buffer size won't (and shouldn't) be too large so -- # the fix is not urgent. -- with self.buffer_lock: -- -- for _ in range(len(self.buffer)): -- -- temp_length = self._matches(self.buffer[0], -- tokens_roi_recver) -- if temp_length > 0: -- matched_length = temp_length -- break -- # rotate the element we just accessed to the end -- self.buffer.rotate(-1) -- -- if matched_length > 0: -- # need to clone the tensor -- # in case the tensor is freed before sending finishes -- matched_item = self.buffer.popleft() -- for tensor in matched_item: -- self._send_tensor_and_dec_size(tensor) -- -- else: -- # no match, just send None -- for _ in range(5): -- self.data_pipe.send_tensor(None) -+ signal = self.signal_pipe.recv_tensor(rank) -+ if self._is_end_signal(signal): -+ logger.info("Received end signal!") -+ return -+ target_kv_rank = self.data_pipe.recv_tensor(rank) -+ # assert target_rank.item() == rank, "Target rank does not match"\ -+ # "the rank of the drop-select handler" -+ input_tokens = self.data_pipe.recv_tensor(rank) -+ roi = self.data_pipe.recv_tensor(rank) -+ assert roi is not None, "Please provide the roi when sending "\ -+ "drop-select request" -+ roi = (roi > 0.5) -+ tokens_roi_recver = [target_kv_rank, input_tokens, roi] -+ -+ matched_length = 0 -+ -+ # perform input tokens and roi matching -+ # FIXME: this matching is O(n), ideally it should be O(1) -+ # but this buffer size won't (and shouldn't) be too large so -+ # the fix is not urgent. -+ with self.buffer_lock: -+ -+ for _ in range(len(self.buffer)): -+ -+ temp_length = self._matches(self.buffer[0], -+ tokens_roi_recver) -+ if temp_length > 0: -+ matched_length = temp_length -+ break -+ # rotate the element we just accessed to the end -+ self.buffer.rotate(-1) -+ -+ if matched_length > 0: -+ # need to clone the tensor -+ # in case the tensor is freed before sending finishes -+ matched_item = self.buffer.popleft() -+ target_rank = matched_item[0].item() -+ for tensor in matched_item[1:]: -+ self._send_tensor_and_dec_size(tensor, rank) -+ -+ else: -+ # no match, just send None -+ for _ in range(5): -+ self.data_pipe.send_tensor(None, rank) - - except RuntimeError as e: - if 'Connection closed by peer' not in str(e): -@@ -180,10 +201,10 @@ class SimpleBuffer(KVLookupBufferBase): - logger.debug("Closing drop_select_handler") - - def drop_select( -- self, input_tokens: Optional[torch.Tensor], -+ self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - -- assert self.request_handling_thread is None, \ -+ assert not self.request_handling_thread, \ - "drop_select should be called by the KV cache consumer "\ - "(e.g. the decode vLLM instance)" - -@@ -192,26 +213,28 @@ class SimpleBuffer(KVLookupBufferBase): - if isinstance(roi, torch.Tensor): - roi = roi.clone().float() - -- self.signal_pipe.send_tensor(self.normal_signal) -- self.data_pipe.send_tensor(input_tokens) -- self.data_pipe.send_tensor(roi) -+ self.signal_pipe.send_tensor(self.normal_signal, rank) -+ -+ self.data_pipe.send_tensor(torch.tensor(kv_rank), rank) -+ self.data_pipe.send_tensor(input_tokens, rank) -+ self.data_pipe.send_tensor(roi, rank) - -- input_tokens = self.data_pipe.recv_tensor() -- roi = self.data_pipe.recv_tensor() -+ input_tokens = self.data_pipe.recv_tensor(rank) -+ roi = self.data_pipe.recv_tensor(rank) - if roi is not None: - # convert from float tensor to bool tensor - # as PyNccl does not support sending bool tensor - roi = (roi > 0.5) -- key = self.data_pipe.recv_tensor() -- value = self.data_pipe.recv_tensor() -- hidden = self.data_pipe.recv_tensor() -+ key = self.data_pipe.recv_tensor(rank) -+ value = self.data_pipe.recv_tensor(rank) -+ hidden = self.data_pipe.recv_tensor(rank) - - return [input_tokens, roi, key, value, hidden] - - def full_handler(self): - time.sleep(0.001) - -- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, -+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - -@@ -222,20 +245,19 @@ class SimpleBuffer(KVLookupBufferBase): - while self.buffer_size > self.buffer_size_threshold: - self.full_handler() - -- self._add_to_buffer(input_tokens, roi, key, value, hidden) -+ self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden) - - # when calling the insert, the current process is a sender - # need to launch the request handler and start listening to request. -+ target_rank_global = target_rank + kv_group_rank - if self.request_handling_thread is None: -- self.request_handling_thread = threading.Thread( -- target=self.drop_select_handler) -- self.request_handling_thread.start() -+ self.request_handling_thread = ThreadPoolExecutor(max_workers=1) -+ self.request_handling_thread.submit(self.drop_select_handler, target_rank_global) - - def close(self): - -- if hasattr(self, "request_handling_thread" -- ) and self.request_handling_thread is not None: -- self.request_handling_thread.join() -+ if hasattr(self, "request_handling_thread") and self.request_handling_thread: -+ self.request_handling_thread.shutdown() - - else: - # TODO: have a explicit close signal and have a explicit way to -diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py -index 40589fb3e..a3991c39d 100644 ---- a/vllm/distributed/kv_transfer/kv_pipe/base.py -+++ b/vllm/distributed/kv_transfer/kv_pipe/base.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - This file defines an interface `KVPipeBase` - that provides an abstraction for sending and receiving tensors, or None, via -@@ -23,7 +36,7 @@ class KVPipeBase(ABC): - """ - - @abstractmethod -- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: -+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None: - """Send a tensor, or None, via the pipe. - - Need to support sending None -- important for error handling. -@@ -41,7 +54,7 @@ class KVPipeBase(ABC): - raise NotImplementedError - - @abstractmethod -- def recv_tensor(self) -> Optional[torch.Tensor]: -+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: - """Receive a tensor (can be None) from the pipeline. - - Returns: -diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py -new file mode 100644 -index 000000000..ca5345359 ---- /dev/null -+++ b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py -@@ -0,0 +1,139 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+import logging -+import threading -+import typing -+import zmq -+import socket -+import time -+import torch -+ -+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe -+ -+ -+logger = logging.getLogger(__name__) -+ -+ -+class DynamoNcclDataPlane: -+ def __init__( -+ self, -+ data_pipe: PyNcclPipe, -+ hostname: str = "", -+ port: int = 0, -+ ) -> None: -+ -+ self.data_pipe = data_pipe -+ if not hostname: -+ hostname = socket.gethostname() -+ if port == 0: -+ raise ValueError("Port cannot be 0") -+ self._hostname = hostname -+ self._port = port -+ self.store = {} -+ self.context = zmq.Context() -+ self.rep_socket = self.context.socket(zmq.REP) -+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}") -+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}") -+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True) -+ self._listener_thread.start() -+ self.req_sockets = {} -+ logger.info(f"Rank {self.rank} connected to the server") -+ -+ @property -+ def rank(self): -+ return self.data_pipe.kv_group_rank -+ -+ def send_tensor( -+ self, -+ tensor: torch.Tensor, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ): -+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}") -+ return self._send_tensor(tensor, tensor_id, remote_address) -+ -+ def recv_tensor( -+ self, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ) -> torch.Tensor: -+ ret = self._recv_tensor(tensor_id, remote_address) -+ return ret -+ -+ def _send_tensor( -+ self, -+ tensor: torch.Tensor, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ): -+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}") -+ if remote_address is None: -+ self.store[tensor_id] = tensor -+ else: -+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape) -+ # tensor_dtype = str(tensor.dtype) -+ if remote_address not in self.req_sockets: -+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ) -+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}") -+ -+ req_socket = self.req_sockets[remote_address] -+ # req_socket.connect(f"tcp://{remote_address}") -+ req_socket.send_string(f"PUT {self.rank} {tensor_id}") -+ dst_rank = req_socket.recv_string() -+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}") -+ self.data_pipe.send_tensor(tensor, int(dst_rank)) -+ -+ def _recv_tensor( -+ self, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ) -> torch.Tensor: -+ logger.debug(f"Rank {self.rank} receiving tensor") -+ if remote_address is not None: -+ raise NotImplementedError("Getting tensor from remote rank not implemented") -+ if tensor_id in self.store: -+ logger.debug(f"Popping tensor {tensor_id} from store") -+ future = self.store.pop(tensor_id) -+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait -+ logger.debug(f"Rank {self.rank} received tensor") -+ return tensor -+ -+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}") -+ time.sleep(0.001) -+ return self._recv_tensor(tensor_id, remote_address) -+ # raise NotImplementedError("Tensor not found in store") -+ -+ def _receive_tensor( -+ self, -+ tensor_id: str, -+ rank: int, -+ ): -+ future = self.data_pipe.recv_tensor(rank) -+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store") -+ self.store[tensor_id] = future -+ -+ def listen_for_requests(self): -+ while True: -+ cmd, rank, tensor_id = self.rep_socket.recv_string().split() -+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}") -+ self.rep_socket.send_string(f"{self.rank}") -+ if cmd == "GET": -+ raise NotImplementedError("Getting tensor from remote rank not implemented") -+ elif cmd == "PUT": -+ rank = int(rank) -+ # shape = [int(dim) for dim in shape.split("_")] -+ # dtype = getattr(torch, dtype) -+ self._receive_tensor(tensor_id, rank) -diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py -index 7aa53d07a..8fb256aff 100644 ---- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py -+++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - This module implements a PyNccl pipe for sending and receiving - Optional[torch.Tensor] between distributed ranks with advanced -@@ -45,33 +58,33 @@ class PyNcclPipe(KVPipeBase): - METADATA_DTYPE = torch.int64 - - def __init__(self, -+ kv_group_rank: int, - local_rank: int, - config: KVTransferConfig, - device: Optional[str] = None, - port_offset: int = 0): - self.config = config - self.local_rank = local_rank -- self.kv_rank = self.config.kv_rank -+ self.kv_group_rank = kv_group_rank - self.kv_parallel_size = self.config.kv_parallel_size -+ self.kv_world_size = self.config.kv_world_size - if device is None: - self.device = self._select_device(self.config.kv_buffer_device) - else: - self.device = self._select_device(device) - - # build distributed connection and send/recv implementation -+ logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset) - self.group = StatelessProcessGroup.create( - host=self.config.kv_ip, - port=self.config.kv_port + port_offset, -- rank=self.kv_rank, -- world_size=self.kv_parallel_size, -+ rank=self.kv_group_rank, -+ world_size=self.kv_world_size, - ) - # add a barrier to make sure the connection is initiated properly - self.group.barrier() - impl = self._get_device_send_recv_impl(self.group) - self.device_send_func, self.device_recv_func = impl -- # set target rank -- self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size -- self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size - - # transportation-related variables - self.transport_thread: Optional[ThreadPoolExecutor] = None -@@ -145,16 +158,16 @@ class PyNcclPipe(KVPipeBase): - dtype=metadata["dtype"], - device=self.device) - -- def _send_metadata(self, metadata: Metadata): -+ def _send_metadata(self, metadata: Metadata, target_rank: int): - """ - Send the metadata dictionary to the target rank. - - Parameters: - - metadata: A dictionary with keys "dtype" and "shape". - """ -- self.group.send_obj(metadata, self.target_rank_for_send) -+ self.group.send_obj(metadata, target_rank) - -- def _recv_metadata(self) -> Metadata: -+ def _recv_metadata(self, src_rank: int) -> Metadata: - """ - Receive the metadata dictionary from the target rank. - -@@ -162,9 +175,9 @@ class PyNcclPipe(KVPipeBase): - - metadata: A dictionary with keys "dtype" and "shape" describing - the tensor. - """ -- return self.group.recv_obj(self.target_rank_for_recv) -+ return self.group.recv_obj(src_rank) - -- def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: -+ def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: - """ - The actual implementation of sending the tensor and its metadata to the - target rank. -@@ -174,12 +187,12 @@ class PyNcclPipe(KVPipeBase): - being sent. - """ - metadata = self._make_metadata(tensor) -- self._send_metadata(metadata) -+ self._send_metadata(metadata, target_rank) - if tensor is not None: - self.device_send_func(tensor.to(self.device), -- self.target_rank_for_send) -+ target_rank) - -- def _recv_impl(self) -> Optional[torch.Tensor]: -+ def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]: - """ - The actual implementation of receiving a tensor and its metadata from - the target rank. -@@ -187,21 +200,22 @@ class PyNcclPipe(KVPipeBase): - Returns: - - buffer: The received tensor, or None if no tensor is received. - """ -- metadata = self._recv_metadata() -+ metadata = self._recv_metadata(src_rank) - if metadata["dtype"] is None: - return None - buffer = self._prepare_recv_buffer(metadata) -- self.device_recv_func(buffer, self.target_rank_for_recv) -+ self.device_recv_func(buffer, src_rank) - - return buffer - - def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], -- tensor_size: int) -> None: -+ tensor_size: int, -+ target_rank: int) -> None: - """ - Wrapper for _send_impl to handle exceptions and update buffer size. - """ - try: -- self._send_impl(tensor) -+ self._send_impl(tensor, target_rank) - - with self.buffer_size_lock: - self.buffer_size -= tensor_size -@@ -220,7 +234,7 @@ class PyNcclPipe(KVPipeBase): - logger.debug("KV cache transfer pipe is full. Waiting...") - time.sleep(0.05) - -- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: -+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: - """ - Sends a tensor and its metadata to the destination rank in a - non-blocking way. -@@ -228,6 +242,7 @@ class PyNcclPipe(KVPipeBase): - Parameters: - - tensor: The tensor to send, or None if no tensor is being sent. - """ -+ logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank) - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - -@@ -241,32 +256,39 @@ class PyNcclPipe(KVPipeBase): - with self.buffer_size_lock: - self.buffer_size += tensor_size - -- self.transport_thread.submit(self.send_tensor_wrapper, tensor, -- tensor_size) -+ future = self.transport_thread.submit(self.send_tensor_wrapper, tensor, -+ tensor_size, -+ target_rank) -+ return future - -- def recv_tensor(self) -> Optional[torch.Tensor]: -+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: - """ - Receives a tensor and its metadata from the source rank. Blocking call. - - Returns: - - tensor: The received tensor, or None if no tensor is received. - """ -+ -+ logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank) -+ - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - -- future = self.transport_thread.submit(self._recv_impl) -+ future = self.transport_thread.submit(self._recv_impl, src_rank) - -- try: -- tensor = future.result() -- except Exception as e: -- logger.error("Encountering exception in KV receiving thread") -- logger.error("%s", e) -- logger.error("My device: %s", self.device) -- import traceback -- traceback.print_exc() -- raise e -+ return future -+ -+ # try: -+ # tensor = future.result() -+ # except Exception as e: -+ # logger.error("Encountering exception in KV receiving thread") -+ # logger.error("%s", e) -+ # logger.error("My device: %s", self.device) -+ # import traceback -+ # traceback.print_exc() -+ # raise e - -- return tensor -+ # return tensor - - def close(self): - """ -diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py -index 1e80e0bd7..f06c7a5f6 100644 ---- a/vllm/distributed/kv_transfer/kv_transfer_agent.py -+++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """A centralized entrypoint to perform distributed KV cache transfer. - - This implementation is a shim wrapper on two APIs exposed by `kv_connector`: -@@ -35,6 +48,7 @@ class KVTransferAgent: - rank: int, - local_rank: int, - config: "VllmConfig", -+ world_group, - ): - - self.config = config -@@ -47,7 +61,7 @@ class KVTransferAgent: - "TransferAgent should only be used when kv_connector is set." - - self.connector = KVConnectorFactory.create_connector( -- rank, local_rank, config) -+ rank, local_rank, config, world_group) - - def send_kv_caches_and_hidden_states( - self, -diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py -index 321902d11..03409899e 100644 ---- a/vllm/distributed/parallel_state.py -+++ b/vllm/distributed/parallel_state.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - # Copyright 2023 The vLLM team. - # Adapted from -@@ -1085,7 +1098,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: - _KV_TRANSFER = kv_transfer.KVTransferAgent( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, -- config=vllm_config) -+ config=vllm_config, -+ world_group=get_world_group()) - - - def ensure_model_parallel_initialized( -diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py -index d82d9ad9d..61c1e429d 100644 ---- a/vllm/engine/llm_engine.py -+++ b/vllm/engine/llm_engine.py -@@ -1,14 +1,31 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import copy - import time -+import pickle -+import uuid - from collections import Counter as collectionsCounter - from collections import deque -+from collections import defaultdict - from contextlib import contextmanager - from dataclasses import dataclass -+from concurrent.futures import ThreadPoolExecutor - from functools import partial - from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable, -- List, Mapping, NamedTuple, Optional) -+ List, Mapping, NamedTuple, Optional, Tuple) - from typing import Sequence as GenericSequence - from typing import Set, Type, Union, cast, overload - -@@ -60,6 +77,9 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, - usage_message) - from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind - from vllm.version import __version__ as VLLM_VERSION -+from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest, MemoryOpType -+from vllm.distributed.device_communicators.nixl import NixlMetadata -+ - - logger = init_logger(__name__) - _LOCAL_LOGGING_INTERVAL_SEC = 5 -@@ -90,7 +110,7 @@ class OutputData(NamedTuple): - # outputs from multiple steps. - is_first_step_output: Optional[bool] - skip: List[int] -- -+ remote_prefill_requests: Optional[List[RemotePrefillRequest]] - - class SchedulerContext: - -@@ -104,11 +124,14 @@ class SchedulerContext: - - self.multi_step_stream_outputs: bool = multi_step_stream_outputs - -+ self.remote_prefill_requests: List[RemotePrefillRequest] = [] -+ - def append_output(self, outputs: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool, -- is_first_step_output: Optional[bool]): -+ is_first_step_output: Optional[bool], -+ remote_prefill_requests: Optional[List[RemotePrefillRequest]] = None): - self.output_queue.append( - OutputData(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, -@@ -116,7 +139,9 @@ class SchedulerContext: - is_async=is_async, - is_last_step=is_last_step, - is_first_step_output=is_first_step_output, -- skip=[])) -+ skip=[], -+ remote_prefill_requests=remote_prefill_requests)) -+ - - - class LLMEngine: -@@ -348,7 +373,7 @@ class LLMEngine: - # GPU and CPU blocks, which are profiled in the distributed executor. - self.scheduler = [ - Scheduler( -- self.scheduler_config, self.cache_config, self.lora_config, -+ self.model_config, self.scheduler_config, self.cache_config, self.lora_config, - self.parallel_config.pipeline_parallel_size, - self.async_callbacks[v_id] - if self.model_config.use_async_output_proc else None) -@@ -405,6 +430,40 @@ class LLMEngine: - - self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {} - -+ self.engine_id = str(uuid.uuid4()) -+ self._nixl_agents_names: Optional[List[str]] = None -+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": -+ self._nixl_agents_names = self._initialize_nixl() -+ -+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) -+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) -+ self._finished_prefills = set() -+ self._finished_transfers = set() -+ -+ @property -+ def is_nixl_initialized(self) -> bool: -+ return getattr(self, "_nixl_agents_names", None) is not None -+ -+ def get_nixl_metadata(self) -> NixlMetadata: -+ if not self.is_nixl_initialized: -+ raise RuntimeError("Nixl is not initialized") -+ agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata") -+ kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr") -+ return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.cache_config.num_gpu_blocks) -+ -+ def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]: -+ if not self.is_nixl_initialized: -+ raise RuntimeError("Nixl is not initialized") -+ engine_id = nixl_metadata.engine_id -+ agents_metadata = nixl_metadata.agent_metadata -+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr -+ num_blocks = nixl_metadata.num_blocks -+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks)) -+ -+ def _initialize_nixl(self) -> List[bytes]: -+ agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,)) -+ return agents_names -+ - def _initialize_kv_caches(self) -> None: - """Initialize the KV cache in the worker(s). - -@@ -500,6 +559,8 @@ class LLMEngine: - # Shutdown model executor when engine is garbage collected - # Use getattr since __init__ can fail before the field is set - if model_executor := getattr(self, "model_executor", None): -+ if self.is_nixl_initialized: -+ model_executor.collective_rpc("shutdown_nixl") - model_executor.shutdown() - - def get_tokenizer_group( -@@ -552,11 +613,14 @@ class LLMEngine: - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> Optional[SequenceGroup]: - """Add a processed request to the engine's request pool. - return the created sequence group. - """ - if isinstance(params, SamplingParams) and params.n > 1: -+ if remote_prefill_params is not None: -+ raise ValueError("Remote prefill params are not supported for multi-step sampling") - ParallelSampleSequenceGroup.add_request( - request_id, - self, -@@ -574,6 +638,8 @@ class LLMEngine: - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) -+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: -+ next(self.seq_counter) # empty sequence for staging - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - - if is_encoder_decoder_inputs(processed_inputs): -@@ -584,7 +650,7 @@ class LLMEngine: - encoder_inputs = None - - seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, -- lora_request, prompt_adapter_request) -+ lora_request, prompt_adapter_request, remote_prefill_params) - - encoder_seq = (None if encoder_inputs is None else Sequence( - seq_id, encoder_inputs, block_size, eos_token_id, lora_request, -@@ -601,8 +667,12 @@ class LLMEngine: - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, -- priority=priority) -+ priority=priority, -+ remote_prefill_params=remote_prefill_params, -+ ) - elif isinstance(params, PoolingParams): -+ if remote_prefill_params is not None: -+ raise ValueError("Remote prefill params are not supported for pooling") - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, -@@ -673,6 +743,7 @@ class LLMEngine: - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: -@@ -765,6 +836,7 @@ class LLMEngine: - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - priority=priority, -+ remote_prefill_params=remote_prefill_params, - ) - - def _validate_token_prompt(self, prompt: PromptType, -@@ -799,6 +871,7 @@ class LLMEngine: - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - encoder_seq: Optional[Sequence] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> SequenceGroup: - """Creates a SequenceGroup with SamplingParams.""" - max_logprobs = self.get_model_config().max_logprobs -@@ -829,7 +902,9 @@ class LLMEngine: - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, -- priority=priority) -+ priority=priority, -+ remote_prefill_params=remote_prefill_params -+ ) - - return seq_group - -@@ -995,11 +1070,11 @@ class LLMEngine: - # When we process only one request, no pop is required - # (since later we will process all of the rest) - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, -- is_last_step, is_first_step_output, skip) = ctx.output_queue[0] -+ is_last_step, is_first_step_output, skip, remote_prefill_requests) = ctx.output_queue[0] - else: - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, -- skip) = ctx.output_queue.popleft() -+ skip, remote_prefill_requests) = ctx.output_queue.popleft() - - # Sanity check - assert len(seq_group_metadata_list) == len( -@@ -1325,15 +1400,55 @@ class LLMEngine: - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() -+ ctx.remote_prefill_requests.clear() - - # Skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current - # batch has completed. -+ remote_prefill_seq_group_metadata_list: List[SequenceGroupMetadata] = [] -+ running_seq_group_metadata_list: List[SequenceGroupMetadata] = [] -+ remote_prefill_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] -+ running_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] -+ - if not self._has_remaining_steps(seq_group_metadata_list): -- # Schedule iteration -+ - (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc -- ) = self.scheduler[virtual_engine].schedule() -+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers) -+ -+ -+ # Separate remote prefill and running seq groups -+ for seq_group_metadata, scheduled_seq_group in zip(seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups): -+ if seq_group_metadata.do_remote_prefill: -+ remote_prefill_seq_group_metadata_list.append(seq_group_metadata) -+ remote_prefill_scheduled_seq_groups.append(scheduled_seq_group) -+ else: -+ running_seq_group_metadata_list.append(seq_group_metadata) -+ running_scheduled_seq_groups.append(scheduled_seq_group) -+ -+ seq_group_metadata_list = running_seq_group_metadata_list -+ scheduler_outputs.scheduled_seq_groups = running_scheduled_seq_groups -+ -+ # Send remote prefill requests before model execution -+ for seq_group_metadata, scheduled_seq_group in zip(remote_prefill_seq_group_metadata_list, remote_prefill_scheduled_seq_groups): -+ assert len(scheduled_seq_group.seq_group.seqs) == 1 -+ assert self._nixl_agents_names -+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id -+ block_table = seq_group_metadata.block_tables[seq_id] -+ if len(block_table) == len(seq_group_metadata.computed_block_nums): -+ logger.debug("No blocks to prefill") -+ self._finished_prefills.add(seq_group_metadata.request_id) -+ continue -+ remote_prefill_request = RemotePrefillRequest( -+ request_id=seq_group_metadata.request_id, -+ # prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway -+ prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids, # TODO ptarasiewicz do not send the last token when NIXL fixes send notif (needed for writing 0 blocks) -+ sampling_params=scheduled_seq_group.seq_group.sampling_params, -+ block_ids=block_table, -+ engine_id=self.engine_id, -+ computed_block_ids=seq_group_metadata.computed_block_nums, -+ ) -+ scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request) - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs -@@ -1383,9 +1498,46 @@ class LLMEngine: - execute_model_req.async_callback = self.async_callbacks[ - virtual_engine] - -- outputs = self.model_executor.execute_model( -+ # After model execution, we need to transfer the memory from the prefill to the decode -+ memory_transfer_reqs = [] -+ for scheduled_seq_group, seq_group_metadata in zip(scheduler_outputs.scheduled_seq_groups, seq_group_metadata_list): -+ remote_prefill_params = scheduled_seq_group.seq_group.remote_prefill_params -+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: -+ assert len(scheduled_seq_group.seq_group.seqs) == 1 -+ req_id = scheduled_seq_group.seq_group.request_id -+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id -+ block_table = seq_group_metadata.block_tables[seq_id] -+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1] -+ -+ num_computed_blocks = len(seq_group_metadata.computed_block_nums) -+ computed_decode_block_ids = remote_prefill_params.decode_block_ids[:num_computed_blocks] -+ -+ if computed_decode_block_ids: -+ kv_recv_req = MemoryTransferRequest( -+ request_id=req_id, -+ local_block_ids=block_table[:num_computed_blocks], -+ staging_block_ids=staging_block_ids[:num_computed_blocks], -+ remote_block_ids=computed_decode_block_ids, -+ remote_engine_id=remote_prefill_params.decode_engine_id, -+ notify_msg=req_id, -+ op_type=MemoryOpType.READ -+ ) -+ memory_transfer_reqs.append(kv_recv_req) -+ -+ kv_send_req = MemoryTransferRequest( -+ request_id=req_id, -+ local_block_ids=block_table[num_computed_blocks:], -+ staging_block_ids=staging_block_ids[num_computed_blocks:], -+ remote_block_ids=remote_prefill_params.decode_block_ids[num_computed_blocks:], -+ remote_engine_id=remote_prefill_params.decode_engine_id, -+ notify_msg=req_id, -+ op_type=MemoryOpType.WRITE -+ ) -+ memory_transfer_reqs.append(kv_send_req) -+ execute_model_req.memory_transfer_requests = memory_transfer_reqs -+ -+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( - execute_model_req=execute_model_req) -- - # We need to do this here so that last step's sampled_token_ids can - # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: -@@ -1396,7 +1548,26 @@ class LLMEngine: - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - # No outputs in this case -- outputs = [] -+ execute_model_req = ExecuteModelRequest( -+ seq_group_metadata_list=[], -+ blocks_to_swap_in=[], -+ blocks_to_swap_out=[], -+ blocks_to_copy=[]) -+ -+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( -+ execute_model_req=execute_model_req) -+ -+ for req_id, notif_count in request_notif_counter.items(): -+ self._request_notif_counter[req_id] += notif_count -+ if self._request_notif_counter[req_id] > -1: -+ self._finished_prefills.add(req_id) -+ del self._request_notif_counter[req_id] -+ -+ for req_id, done_count in request_done_counter.items(): -+ self._request_done_counter[req_id] += done_count -+ if self._request_done_counter[req_id] > -1: -+ self._finished_transfers.add(req_id) -+ del self._request_done_counter[req_id] - - # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: -@@ -1456,7 +1627,7 @@ class LLMEngine: - # queued control plane messages, such as add/remove lora adapters. - logger.debug("Stopping remote worker execution loop.") - self.model_executor.stop_remote_worker_execution_loop() -- -+ - return ctx.request_outputs - - def _has_remaining_steps( -diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py -index 3cf1850ee..d20a5f20b 100644 ---- a/vllm/engine/multiprocessing/__init__.py -+++ b/vllm/engine/multiprocessing/__init__.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import uuid - from dataclasses import dataclass, field -@@ -14,13 +27,17 @@ from vllm.outputs import RequestOutput - from vllm.prompt_adapter.request import PromptAdapterRequest - from vllm.sampling_params import SamplingParams - from vllm.utils import deprecate_kwargs -- -+from vllm.remote_prefill import RemotePrefillParams -+from vllm.distributed.device_communicators.nixl import NixlMetadata - VLLM_RPC_SUCCESS_STR = "SUCCESS" - - IPC_INPUT_EXT = "_input_socket" - IPC_OUTPUT_EXT = "_output_socket" - IPC_HEALTH_EXT = "_health_socket" - IPC_DATA_EXT = "_data_socket" -+IPC_REMOTE_PREFILL_REQUEST_EXT = "_remote_prefill_request_socket" -+IPC_REMOTE_NIXL_METADATA_EXT = "_remote_nixl_metadata_socket" -+IPC_METRICS_EXT = "_metrics_socket" - - - class MQEngineDeadError(RuntimeError): -@@ -36,6 +53,7 @@ class RPCProcessRequest: - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - priority: int = 0 -+ remote_prefill_params: Optional[RemotePrefillParams] = None - - @overload - def __init__( -@@ -78,6 +96,7 @@ class RPCProcessRequest: - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: -@@ -95,7 +114,7 @@ class RPCProcessRequest: - self.trace_headers = trace_headers - self.prompt_adapter_request = prompt_adapter_request - self.priority = priority -- -+ self.remote_prefill_params = remote_prefill_params - - @dataclass - class RPCError: -@@ -116,7 +135,7 @@ class RPCStartupRequest(Enum): - @dataclass - class RPCStartupResponse: - tracing_enabled: bool -- -+ nixl_metadata: Optional[bytes] = None - - class RPCUProfileRequest(Enum): - START_PROFILE = 1 -@@ -157,3 +176,13 @@ def ENGINE_DEAD_ERROR( - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - f"find the original error: {repr(error)}.") -+ -+@dataclass -+class KvMetrics: -+ request_active_slots: int -+ request_total_slots: int -+ kv_active_blocks: int -+ kv_total_blocks: int -+ num_requests_waiting: int -+ gpu_cache_usage_perc: float -+ gpu_prefix_cache_hit_rate: float -diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py -index 85b5f31e3..c53b9eced 100644 ---- a/vllm/engine/multiprocessing/client.py -+++ b/vllm/engine/multiprocessing/client.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import asyncio - import copy -@@ -8,6 +21,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union, cast, overload) - - import cloudpickle -+import msgspec - import psutil - import zmq - import zmq.asyncio -@@ -19,20 +33,23 @@ from vllm import PoolingParams - from vllm.config import DecodingConfig, ModelConfig, VllmConfig - from vllm.core.scheduler import SchedulerOutputs - from vllm.engine.arg_utils import AsyncEngineArgs -+from vllm.engine.metrics import Stats - # yapf conflicts with isort for this block - # yapf: disable - from vllm.engine.async_llm_engine import ( - build_guided_decoding_logits_processor_async) - from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, -- IPC_OUTPUT_EXT, RPC_REQUEST_T, -- VLLM_RPC_SUCCESS_STR, RPCAbortRequest, -+ IPC_OUTPUT_EXT, IPC_REMOTE_PREFILL_REQUEST_EXT, -+ RPC_REQUEST_T, -+ VLLM_RPC_SUCCESS_STR, IPC_REMOTE_NIXL_METADATA_EXT, RPCAbortRequest, -+ IPC_METRICS_EXT, - RPCAdapterLoadedResponse, RPCError, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetPrefixCacheRequest, - RPCStartupRequest, RPCStartupResponse, -- RPCUProfileRequest) -+ RPCUProfileRequest, KvMetrics) - from vllm.engine.protocol import EngineClient - # yapf: enable - from vllm.envs import VLLM_RPC_TIMEOUT -@@ -46,6 +63,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest - from vllm.sampling_params import SamplingParams - from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - from vllm.utils import deprecate_kwargs -+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback -+from vllm.distributed.device_communicators.nixl import NixlMetadata - - logger = init_logger(__name__) - -@@ -91,6 +110,7 @@ class MQLLMEngineClient(EngineClient): - self._errored_with: Optional[BaseException] = None - - # Get the configs. -+ self.vllm_config = engine_config - self.model_config = engine_config.model_config - self.decoding_config = engine_config.decoding_config - -@@ -115,6 +135,10 @@ class MQLLMEngineClient(EngineClient): - self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) - self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") - -+ # Metrics. -+ self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL) -+ self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}") -+ - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - -@@ -129,8 +153,27 @@ class MQLLMEngineClient(EngineClient): - # Loop to check health of the LLMEngine periodically. - # Started after the MQLLMEngine is ready. - self.health_loop: Optional[asyncio.Task] = None -+ -+ # Loop to check metrics of the LLMEngine periodically. -+ # Started after the MQLLMEngine is ready. -+ self.metrics_loop: Optional[asyncio.Task] = None -+ self.metrics_publisher = None -+ - self._engine_process = psutil.Process(engine_pid) - -+ self.nixl_metadata: Optional[NixlMetadata] = None -+ self.remote_prefill_request_socket: Socket = self.context.socket(zmq.constants.PULL) -+ self.remote_nixl_metadata_socket: Socket = self.context.socket(zmq.constants.PUSH) -+ self.remote_prefill_requests_callback: Dict[str, RemotePrefillRequestCallback] = {} -+ if self.using_nixl_connector: -+ self.remote_prefill_request_socket.connect(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") -+ self.remote_nixl_metadata_socket.connect(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") -+ -+ -+ @property -+ def using_nixl_connector(self) -> bool: -+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector" -+ - @staticmethod - def is_unsupported_config(engine_args: AsyncEngineArgs): - # Pipeline parallel not yet supported -@@ -180,6 +223,61 @@ class MQLLMEngineClient(EngineClient): - except Exception as e: - self._set_errored(e) - -+ async def run_remote_prefill_request_handler_loop(self): -+ try: -+ while True: -+ if await self.remote_prefill_request_socket.poll(timeout=VLLM_RPC_TIMEOUT): -+ frames = await self.remote_prefill_request_socket.recv(copy=False) -+ remote_prefill_request = msgspec.msgpack.decode(frames.buffer, type=RemotePrefillRequest) -+ await self.remote_prefill_requests_callback[remote_prefill_request.request_id](remote_prefill_request) -+ except asyncio.CancelledError: -+ logger.debug("Shutting down MQLLMEngineClient remote prefill request handler loop.") -+ -+ async def run_metrics_loop(self, timeout: int): -+ """Background loop that continually checks to ensure the engine process -+ is still alive. -+ """ -+ try: -+ while True: -+ # Check if the engine process is running: -+ if not self._engine_process.is_running() or ( -+ self._engine_process.status() == psutil.STATUS_ZOMBIE): -+ # NB: is_running() returns True for zombies -+ self._set_errored( -+ RuntimeError( -+ f"Engine process (pid {self._engine_process.pid}) " -+ "died.")) -+ break -+ -+ if await self.metrics_socket.poll(timeout=timeout): -+ # Metrics received- check the message -+ message: Frame = await self.metrics_socket.recv(copy=False) -+ metrics = pickle.loads(message.buffer) -+ if self.metrics_publisher is not None and isinstance( -+ metrics, KvMetrics -+ ): -+ self.metrics_publisher.publish(metrics.request_active_slots, -+ metrics.request_total_slots, -+ metrics.kv_active_blocks, -+ metrics.kv_total_blocks, -+ metrics.num_requests_waiting, -+ metrics.gpu_cache_usage_perc, -+ metrics.gpu_prefix_cache_hit_rate) -+ logger.debug("Metrics successful.") -+ -+ # TODO: Investigate sending whole stats object -+ -+ except asyncio.CancelledError: -+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.") -+ -+ except psutil.NoSuchProcess: -+ self._set_errored( -+ RuntimeError( -+ f"Engine process (pid {self._engine_process.pid}) died.")) -+ -+ except Exception as e: -+ self._set_errored(e) -+ - async def run_output_handler_loop(self): - """Get RequestOutputs from Engine and stream to Request Queues""" - -@@ -278,12 +376,26 @@ class MQLLMEngineClient(EngineClient): - # Wait until server is ready. - response = await self._wait_for_server_rpc(socket) - -+ if response.nixl_metadata is not None: -+ assert self.using_nixl_connector -+ self.nixl_metadata = msgspec.msgpack.decode(response.nixl_metadata, type=NixlMetadata) -+ - self.tracing_flag = response.tracing_enabled - - # Start health_loop. - if self.health_loop is None: - self.health_loop = asyncio.create_task( - self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) -+ -+ if self.using_nixl_connector: -+ self.remote_prefill_loop = asyncio.create_task( -+ self.run_remote_prefill_request_handler_loop()) -+ -+ # Start metrics_loop. -+ if self.metrics_loop is None: -+ self.metrics_loop = asyncio.create_task( -+ self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT)) -+ - - def close(self): - """Destroy the ZeroMQ Context.""" -@@ -293,6 +405,8 @@ class MQLLMEngineClient(EngineClient): - # Cancel background tasks. - if self.health_loop is not None: - self.health_loop.cancel() -+ if self.metrics_loop is not None: -+ self.metrics_loop.cancel() - if self.output_loop is not None: - self.output_loop.cancel() - -@@ -415,6 +529,9 @@ class MQLLMEngineClient(EngineClient): - """ - if self._errored_with is not None: - raise self._errored_with -+ -+ async def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata): -+ await self.remote_nixl_metadata_socket.send(msgspec.msgpack.encode(nixl_metadata), copy=False) - - @property - def is_running(self) -> bool: -@@ -473,6 +590,7 @@ class MQLLMEngineClient(EngineClient): - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - *, - inputs: Optional[PromptType] = None # DEPRECATED - ) -> AsyncGenerator[RequestOutput, None]: -@@ -502,7 +620,8 @@ class MQLLMEngineClient(EngineClient): - - return self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, -- prompt_adapter_request, priority) -+ prompt_adapter_request, priority, -+ remote_prefill_params) - - @overload - def encode( -@@ -586,6 +705,7 @@ class MQLLMEngineClient(EngineClient): - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - PoolingRequestOutput, None]]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" -@@ -630,6 +750,12 @@ class MQLLMEngineClient(EngineClient): - else: - lp_bytes = None - -+ if remote_prefill_params is not None: -+ self.remote_prefill_requests_callback[request_id] = remote_prefill_params.remote_prefill_request_callback -+ remote_prefill_params.remote_prefill_request_callback = None -+ else: -+ remote_prefill_request_callback = None -+ - request_bytes = pickle.dumps( - RPCProcessRequest( - prompt=prompt, -@@ -639,11 +765,11 @@ class MQLLMEngineClient(EngineClient): - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, -+ remote_prefill_params=remote_prefill_params, - )) - - # 3) Send the RPCGenerateRequest to the MQLLMEngine. -- parts = (request_bytes, -- lp_bytes) if lp_bytes else (request_bytes, ) -+ parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes,) - await self.input_socket.send_multipart(parts, copy=False) - - # 4) Stream the RequestOutputs from the output queue. Note -@@ -705,3 +831,6 @@ class MQLLMEngineClient(EngineClient): - # Raise on error, otherwise happily return None - if isinstance(request_output, BaseException): - raise request_output -+ -+ def set_metrics_publisher(self, metrics_publisher): -+ self.metrics_publisher = metrics_publisher -diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py -index a0dd79586..ea0d2cd68 100644 ---- a/vllm/engine/multiprocessing/engine.py -+++ b/vllm/engine/multiprocessing/engine.py -@@ -1,37 +1,130 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import pickle - import signal - from contextlib import contextmanager --from typing import Iterator, List, Optional, Union -+from typing import Iterator, List, Optional, Union, Dict - - import cloudpickle -+import time - import zmq -- -+import msgspec - from vllm import AsyncEngineArgs, SamplingParams - from vllm.engine.llm_engine import LLMEngine - # yapf conflicts with isort for this block - # yapf: disable - from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, -- IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, -- VLLM_RPC_SUCCESS_STR, RPCAbortRequest, -+ REQUEST_OUTPUTS_T, -+ VLLM_RPC_SUCCESS_STR, IPC_REMOTE_PREFILL_REQUEST_EXT, -+ RPCAbortRequest, -+ IPC_OUTPUT_EXT, IPC_METRICS_EXT, - RPCAdapterLoadedResponse, RPCError, - RPCLoadAdapterRequest, - RPCProcessRequest, - RPCResetPrefixCacheRequest, - RPCStartupRequest, RPCStartupResponse, -- RPCUProfileRequest) -+ RPCUProfileRequest, IPC_REMOTE_NIXL_METADATA_EXT, -+ KvMetrics) - # yapf: enable - from vllm.logger import init_logger - from vllm.outputs import RequestOutput - from vllm.usage.usage_lib import UsageContext -+from vllm.remote_prefill import RemotePrefillRequest -+from vllm.distributed.device_communicators.nixl import NixlMetadata -+ -+from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo -+from dataclasses import dataclass, field - - logger = init_logger(__name__) - - POLLING_TIMEOUT_MS = 10000 - HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) - -+class KvStatLogger(StatLoggerBase): -+ def __init__( -+ self, -+ max_num_seqs: int, -+ num_total_gpu_blocks: int, -+ metrics_socket -+ ): -+ # Must query initialized scheduler for max infos -+ self.request_total_slots = max_num_seqs -+ self.kv_total_blocks = num_total_gpu_blocks -+ self.metrics_socket = metrics_socket -+ -+ # KV metrics -+ self._send_kv_metrics(0, 0, 0, 0.0, 0.0) -+ -+ def log(self, stats: Stats) -> None: -+ self._send_kv_metrics( -+ stats.num_running_sys, -+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks), -+ stats.num_waiting_sys, -+ stats.gpu_cache_usage_sys, -+ stats.gpu_prefix_cache_hit_rate -+ ) -+ -+ def info(self, type: str, obj: SupportsMetricsInfo) -> None: -+ pass -+ -+ def _send_kv_metrics( -+ self, -+ active_slots, -+ active_kv_blocks, -+ num_requests_waiting, -+ gpu_cache_usage_perc, -+ gpu_prefix_cache_hit_rate, -+ ): -+ if not self.metrics_socket.closed: -+ metrics_bytes = pickle.dumps( -+ KvMetrics( -+ active_slots, -+ self.request_total_slots, -+ active_kv_blocks, -+ self.kv_total_blocks, -+ num_requests_waiting, -+ gpu_cache_usage_perc, -+ gpu_prefix_cache_hit_rate, -+ ) -+ ) -+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) -+ -+# TODO: Send entire stats object to the client -+# class StatLogger(StatLoggerBase): -+# def __init__( -+# self, -+# metrics_socket -+# ): -+# self.metrics_socket = metrics_socket -+ -+# def log(self, stats: Stats) -> None: -+# self._send_metrics(stats) -+ -+# def info(self, type: str, obj: SupportsMetricsInfo) -> None: -+# pass -+ -+# def _send_metrics(self, stats: Stats): -+# if not self.metrics_socket.closed: -+# metrics_bytes = pickle.dumps(stats) -+# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) -+ -+ -+ -+ - - class MQLLMEngine: - """A multiprocessing wrapper for :class:`LLMEngine`. -@@ -94,12 +187,37 @@ class MQLLMEngine: - self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) - self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") - -+ # Send metrics back to client. -+ self.metrics_socket = self.ctx.socket(zmq.constants.PUSH) -+ self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}") -+ - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Error state. - self._errored_with: Optional[BaseException] = None - -+ self.remote_prefill_request_socket = self.ctx.socket(zmq.constants.PUSH) -+ self.remote_nixl_metadata_socket = self.ctx.socket(zmq.constants.PULL) -+ if self.engine.is_nixl_initialized: -+ self.remote_prefill_request_socket.bind(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") -+ self.remote_nixl_metadata_socket.bind(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") -+ -+ -+ # Attach logger for continuous metrics publishing -+ self.kv_stat_logger = KvStatLogger( -+ self.engine.scheduler_config.max_num_seqs, -+ self.engine.cache_config.num_gpu_blocks, -+ self.metrics_socket -+ ) -+ self.engine.add_logger("kv_metrics", self.kv_stat_logger) -+ -+ # TODO investigate sending whole stats object -+ # self.general_stat_logger = StatLogger( -+ # self.metrics_socket -+ # ) -+ # self.engine.add_logger("general_metrics", self.general_stat_logger) -+ - @property - def dead_error(self) -> BaseException: - if self._errored_with is not None: -@@ -171,8 +289,17 @@ class MQLLMEngine: - # Handle the query from the Client. - if request == RPCStartupRequest.IS_SERVER_READY: - tracing_enabled = self.engine.is_tracing_enabled() -- response = RPCStartupResponse( -- tracing_enabled=tracing_enabled) -+ -+ # Send nixl metadata to the client -+ if self.engine.is_nixl_initialized: -+ nixl_metadata = self.engine.get_nixl_metadata() -+ encoded_nixl_metadata = msgspec.msgpack.encode(nixl_metadata) -+ response = RPCStartupResponse( -+ tracing_enabled=tracing_enabled, -+ nixl_metadata=encoded_nixl_metadata) -+ else: -+ response = RPCStartupResponse( -+ tracing_enabled=tracing_enabled) - - except Exception as e: - response = e -@@ -185,6 +312,7 @@ class MQLLMEngine: - - while True: - if not self.engine.has_unfinished_requests(): -+ logger.debug("No unfinished requests") - # Poll until there is work to do. - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - # When there's no work, check on engine health and send -@@ -220,6 +348,13 @@ class MQLLMEngine: - def handle_new_input(self): - """Handle new input from the socket""" - try: -+ if self.engine.is_nixl_initialized: -+ while self.remote_nixl_metadata_socket.poll(timeout=0) != 0: -+ frames = self.remote_nixl_metadata_socket.recv(copy=False) -+ nixl_metadata = msgspec.msgpack.decode(frames.buffer, type=NixlMetadata) -+ logger.debug("Adding remote nixl metadata for engine: %s", nixl_metadata.engine_id) -+ self.engine.add_remote_nixl_metadata(nixl_metadata) -+ - while self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - request = pickle.loads(frames[0].buffer) -@@ -262,6 +397,11 @@ class MQLLMEngine: - self._send_outputs(rpc_err) - - try: -+ if request.remote_prefill_params is not None and request.remote_prefill_params.is_remote_prefill: -+ def remote_prefill_request_callback(request: RemotePrefillRequest): -+ logger.debug("Sending remote prefill request: %s", request.request_id) -+ self.remote_prefill_request_socket.send(msgspec.msgpack.encode(request), copy=False) -+ request.remote_prefill_params.remote_prefill_request_callback = remote_prefill_request_callback - self.engine.add_request( - request_id=request_id, - prompt=request.prompt, -@@ -269,7 +409,9 @@ class MQLLMEngine: - lora_request=request.lora_request, - trace_headers=request.trace_headers, - prompt_adapter_request=request.prompt_adapter_request, -- priority=request.priority) -+ priority=request.priority, -+ remote_prefill_params=request.remote_prefill_params, -+ ) - - if self.log_requests: - logger.info("Added request %s.", request.request_id) -diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py -index 107220d54..e0e0590b6 100644 ---- a/vllm/entrypoints/openai/serving_chat.py -+++ b/vllm/entrypoints/openai/serving_chat.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import asyncio - import json -@@ -34,6 +47,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams - from vllm.sequence import Logprob - from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer - from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls -+from vllm.remote_prefill import RemotePrefillParams - - logger = init_logger(__name__) - -@@ -112,6 +126,7 @@ class OpenAIServingChat(OpenAIServing): - self, - request: ChatCompletionRequest, - raw_request: Optional[Request] = None, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, - ErrorResponse]: - """ -@@ -243,6 +258,7 @@ class OpenAIServingChat(OpenAIServing): - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=request.priority, -+ remote_prefill_params=remote_prefill_params, - ) - - generators.append(generator) -diff --git a/vllm/envs.py b/vllm/envs.py -index 745b068b7..0f1a022fb 100644 ---- a/vllm/envs.py -+++ b/vllm/envs.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import os - import tempfile -@@ -87,6 +100,10 @@ if TYPE_CHECKING: - VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False - VLLM_RAY_PER_WORKER_GPUS: float = 1.0 - VLLM_RAY_BUNDLE_INDICES: str = "" -+ VLLM_KV_CAPI_PATH: Optional[str] = None -+ VLLM_KV_NAMESPACE: Optional[str] = None -+ VLLM_KV_COMPONENT: Optional[str] = None -+ VLLM_WORKER_ID: Optional[int] = None - - - def get_default_cache_root(): -@@ -572,6 +589,21 @@ environment_variables: Dict[str, Callable[[], Any]] = { - # models the alignment is already naturally aligned to 256 bytes. - "VLLM_CUDA_MEM_ALIGN_KV_CACHE": - lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))), -+ -+ # Path to the C API Library -+ "VLLM_KV_CAPI_PATH": -+ lambda: os.environ.get("VLLM_KV_CAPI_PATH", None), -+ -+ # Identifiers to publish KV related information -+ "VLLM_KV_NAMESPACE": -+ lambda: os.environ.get("VLLM_KV_NAMESPACE", None), -+ "VLLM_KV_COMPONENT": -+ lambda: os.environ.get("VLLM_KV_COMPONENT", None), -+ -+ # Worker ID used for identifying workers in distributed settings -+ "VLLM_WORKER_ID": -+ lambda: int(os.getenv("VLLM_WORKER_ID", "0")) -+ if "VLLM_WORKER_ID" in os.environ else None, - } - - # end-env-vars-definition -diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py -index 773f5abe7..365685e13 100644 ---- a/vllm/model_executor/models/deepseek_v2.py -+++ b/vllm/model_executor/models/deepseek_v2.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - # Adapted from - # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -@@ -585,6 +598,8 @@ class DeepseekV2Model(nn.Module): - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - -+ self.config = config -+ - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - -diff --git a/vllm/outputs.py b/vllm/outputs.py -index 786380c37..e9c3a5e16 100644 ---- a/vllm/outputs.py -+++ b/vllm/outputs.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import time - from dataclasses import dataclass -@@ -6,16 +19,16 @@ from typing import Dict, Generic, List, MutableSequence, Optional - from typing import Sequence as GenericSequence - from typing import Union - -+import msgspec - import torch - from typing_extensions import TypeVar, deprecated - - from vllm.lora.request import LoRARequest - from vllm.multimodal.inputs import MultiModalPlaceholderDict --from vllm.sampling_params import RequestOutputKind -+from vllm.sampling_params import RequestOutputKind, SamplingParams - from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceGroupBase, SequenceStatus) - -- - @dataclass - class CompletionOutput: - """The output data of one completion output of a request. -diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py -new file mode 100644 -index 000000000..83f6cd575 ---- /dev/null -+++ b/vllm/remote_prefill.py -@@ -0,0 +1,82 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+from dataclasses import dataclass -+from typing import Callable, Optional, List -+from enum import Enum -+ -+import msgspec -+ -+from vllm.sampling_params import SamplingParams -+ -+ -+class RemotePrefillRequest( -+ msgspec.Struct, -+ omit_defaults=True, # type: ignore[call-arg] -+ # required for @cached_property. -+ dict=True): -+ """The request data of one remote prefill output of a request. -+ -+ Args: -+ engine_id: The unique ID of the engine. -+ request_id: The unique ID of the request. -+ prompt_token_ids: The token IDs of the prompt. -+ sampling_params: The sampling parameters. -+ block_ids: The block IDs of the request. -+ computed_block_ids: The computed block IDs of the request. -+ """ -+ engine_id: str -+ request_id: str -+ prompt_token_ids: List[int] -+ sampling_params: SamplingParams -+ block_ids: List[int] -+ computed_block_ids: List[int] -+ -+ -+class MemoryOpType(str, Enum): -+ WRITE = "WRITE" -+ READ = "READ" -+ -+ -+class MemoryTransferRequest( -+ msgspec.Struct, -+ array_like=True, # type: ignore[call-arg] -+ omit_defaults=True): # type: ignore[call-arg] -+ """The request data of one memory transfer output of a request. -+ -+ Args: -+ request_id: The unique ID of the request. -+ """ -+ request_id: str -+ local_block_ids: List[int] -+ staging_block_ids: List[int] -+ remote_block_ids: List[int] -+ remote_engine_id: str -+ notify_msg: str -+ op_type: MemoryOpType -+ -+ -+RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None] -+ -+ -+@dataclass -+class RemotePrefillParams: -+ """Remote prefill parameters for text generation.""" -+ is_remote_prefill: bool = False -+ is_remote_decode: bool = False -+ decode_block_ids: Optional[List[int]] = None -+ decode_computed_block_ids: Optional[List[int]] = None -+ decode_engine_id: Optional[str] = None -+ remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None -\ No newline at end of file -diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py -index 97f9e2129..5849befba 100644 ---- a/vllm/sampling_params.py -+++ b/vllm/sampling_params.py -@@ -1,4 +1,18 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - """Sampling parameters for text generation.""" - import copy - from dataclasses import dataclass -@@ -83,7 +97,7 @@ class RequestOutputKind(Enum): - DELTA = 1 - # Do not return intermediate RequestOuputs - FINAL_ONLY = 2 -- -+ - - class SamplingParams( - msgspec.Struct, -diff --git a/vllm/sequence.py b/vllm/sequence.py -index 534b9e606..c33bbde1c 100644 ---- a/vllm/sequence.py -+++ b/vllm/sequence.py -@@ -1,4 +1,18 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - """Sequence and its related classes.""" - import copy - import enum -@@ -20,6 +34,7 @@ from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict - from vllm.pooling_params import PoolingParams - from vllm.prompt_adapter.request import PromptAdapterRequest - from vllm.sampling_params import RequestOutputKind, SamplingParams -+from vllm.remote_prefill import RemotePrefillParams, MemoryTransferRequest - - VLLM_TOKEN_ID_ARRAY_TYPE = "l" - -@@ -59,13 +74,14 @@ class SequenceStatus(enum.IntEnum): - """Status of a sequence.""" - WAITING = 0 - RUNNING = 1 -- SWAPPED = 2 -- # Note: anything after SWAPPED (2) will be considered -+ REMOTE_PREFILLING = 2 -+ SWAPPED = 3 -+ # Note: anything after SWAPPED (3) will be considered - # as a finished status. -- FINISHED_STOPPED = 3 -- FINISHED_LENGTH_CAPPED = 4 -- FINISHED_ABORTED = 5 -- FINISHED_IGNORED = 6 -+ FINISHED_STOPPED = 4 -+ FINISHED_LENGTH_CAPPED = 5 -+ FINISHED_ABORTED = 6 -+ FINISHED_IGNORED = 7 - - @staticmethod - def is_finished(status: "SequenceStatus") -> bool: -@@ -409,6 +425,7 @@ class Sequence: - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> None: - self.seq_id = seq_id - self.inputs = SingletonInputsAdapter(inputs) -@@ -416,7 +433,7 @@ class Sequence: - self.eos_token_id = eos_token_id - self.lora_request = lora_request - self.prompt_adapter_request = prompt_adapter_request -- -+ self.remote_prefill_params = remote_prefill_params - self.data = SequenceData.from_seqs(self.prompt_token_ids) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" -@@ -639,6 +656,7 @@ class SequenceGroup: - trace_headers: OpenTelemetry trace headers. - prompt_adapter_request: Prompt Adapter request. - priority: User-defined priority of the request. -+ remote_prefill_params: Remote prefill parameters. - """ - - def __init__( -@@ -654,6 +672,7 @@ class SequenceGroup: - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> None: - self.request_id = request_id - self.seqs = seqs -@@ -678,7 +697,7 @@ class SequenceGroup: - self.encoder_seq = encoder_seq - self.trace_headers = trace_headers - self.priority = priority -- -+ self.remote_prefill_params = remote_prefill_params - self.cached_request_output = None - - @property -@@ -927,6 +946,9 @@ class SequenceGroupMetadata( - query tokens for prefill, we don't need sampling. - token_chunk_size: The number of tokens to be processed (per sequence). - None if chunking is not required. -+ do_remote_prefill: True if remote prefill is required. -+ do_remote_decode: True if remote decode is required. -+ decode_memory_desc: The memory descriptor for the decoder blocks. - lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, - used in prefix caching. -@@ -966,6 +988,9 @@ class SequenceGroupMetadata( - cross_block_table: Optional[List[int]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - token_chunk_size: Optional[int] = None -+ do_remote_prefill: bool = False -+ do_remote_decode: bool = False -+ decode_memory_desc: Optional[bytes] = None - - ### Stateful fields that are lazily defined. ### - # The number of speculative tokens adopted in this request. -@@ -1310,6 +1335,8 @@ class ExecuteModelRequest( - last_sampled_token_ids: Optional[torch.Tensor] = None - # Async callback - async_callback: Optional[Callable] = None -+ # The memory transfer requests. -+ memory_transfer_requests: Optional[List[MemoryTransferRequest]] = None - - @property - def is_first_multi_step(self) -> bool: -diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py -index 12baecde6..11034b391 100644 ---- a/vllm/worker/model_runner.py -+++ b/vllm/worker/model_runner.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import dataclasses - import gc -@@ -1824,6 +1837,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): - - if self.vllm_config.kv_transfer_config is None: - return False -+ -+ if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": -+ return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - -@@ -1849,6 +1865,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): - - if self.vllm_config.kv_transfer_config is None: - return False -+ -+ if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": -+ return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - -diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py -index 582aa460e..0be784a40 100644 ---- a/vllm/worker/worker.py -+++ b/vllm/worker/worker.py -@@ -1,8 +1,22 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - """A GPU worker class.""" - import gc - import os --from typing import Dict, List, Optional, Set, Tuple, Type, Union -+from typing import Dict, List, Optional, Set, Tuple, Type, Union, TYPE_CHECKING, Any - - import torch - import torch.distributed -@@ -31,6 +45,9 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner - from vllm.worker.pooling_model_runner import PoolingModelRunner - from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) -+from vllm.distributed.device_communicators.nixl import DynamoNixlConnector -+from vllm.remote_prefill import MemoryOpType -+ - - logger = init_logger(__name__) - -@@ -306,6 +323,46 @@ class Worker(LocalOrDistributedWorkerBase): - self._init_cache_engine() - self._warm_up_model() - -+ def initialize_nixl(self, engine_id: str) -> List[bytes]: -+ -+ # TODO ptarasiewicz nixl can also support DRAM -+ assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector" -+ -+ self.nixl_connector = DynamoNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank? -+ assert len(self.cache_engine) == 1, "Only one cache engine is supported for now" -+ self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache) -+ return self.nixl_connector.agent_name -+ -+ def get_nixl_agent_metadata(self) -> bytes: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ return self.nixl_connector.get_agent_metadata() -+ -+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]], num_blocks: int) -> str: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank? -+ return agent_name -+ -+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id] -+ -+ def _read_blocks(self, worker_input: WorkerInput) -> None: -+ for i, op_type in enumerate(worker_input.op_type): -+ if op_type == MemoryOpType.READ: -+ self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i]) -+ -+ def _write_blocks(self, worker_input: WorkerInput) -> None: -+ if not self.is_driver_worker: -+ torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize -+ -+ for i, op_type in enumerate(worker_input.op_type): -+ if op_type == MemoryOpType.WRITE: -+ self.nixl_connector.write_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i], worker_input.notify_msg[i]) -+ -+ def shutdown_nixl(self) -> None: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ self.nixl_connector.shutdown() -+ - def _init_cache_engine(self): - assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = [ -@@ -367,6 +424,8 @@ class Worker(LocalOrDistributedWorkerBase): - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, - dtype=torch.int64).view(-1, 2) -+ -+ mem_transfer_reqs = execute_model_req.memory_transfer_requests or [] - - return WorkerInput( - num_seq_groups=num_seq_groups, -@@ -375,6 +434,12 @@ class Worker(LocalOrDistributedWorkerBase): - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - num_steps=num_steps, -+ local_block_ids=[r.local_block_ids for r in mem_transfer_reqs], -+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs], -+ remote_block_ids=[r.remote_block_ids for r in mem_transfer_reqs], -+ remote_engine_id=[r.remote_engine_id for r in mem_transfer_reqs], -+ notify_msg=[r.notify_msg for r in mem_transfer_reqs], -+ op_type=[r.op_type for r in mem_transfer_reqs], - ) - - @torch.inference_mode() -diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py -index 819b81fbf..7d1b1836d 100644 ---- a/vllm/worker/worker_base.py -+++ b/vllm/worker/worker_base.py -@@ -1,4 +1,18 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - - import dataclasses - import os -@@ -9,6 +23,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union - import cloudpickle - import torch - import torch.nn as nn -+from collections import defaultdict - - from vllm.config import (ObservabilityConfig, VllmConfig, - set_current_vllm_config) -@@ -23,6 +38,9 @@ from vllm.utils import (enable_trace_function_call_for_thread, - from vllm.worker.model_runner_base import (BroadcastableModelInput, - ModelRunnerBase, - ModelRunnerInputBase) -+from vllm.distributed.device_communicators.nixl import DynamoNixlConnector -+from vllm.remote_prefill import MemoryOpType -+ - - logger = init_logger(__name__) - -@@ -53,6 +71,8 @@ class WorkerBase(ABC): - from vllm.platforms import current_platform - self.current_platform = current_platform - -+ self.nixl_connector: Optional[DynamoNixlConnector] = None -+ - @abstractmethod - def init_device(self) -> None: - """Initialize device state, such as loading the model or other on-device -@@ -216,6 +236,13 @@ class WorkerInput: - virtual_engine: int = 0 - num_steps: int = 1 - -+ local_block_ids: Optional[List[List[int]]] = None -+ staging_block_ids: Optional[List[List[int]]] = None -+ remote_block_ids: Optional[List[List[int]]] = None -+ remote_engine_id: Optional[List[str]] = None -+ notify_msg: Optional[List[str]] = None -+ op_type: Optional[List[MemoryOpType]] = None -+ - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["WorkerInput"], -@@ -232,6 +259,12 @@ class WorkerInput: - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), -+ local_block_ids=tensor_dict.pop("local_block_ids"), -+ staging_block_ids=tensor_dict.pop("staging_block_ids"), -+ remote_block_ids=tensor_dict.pop("remote_block_ids"), -+ remote_engine_id=tensor_dict.pop("remote_engine_id"), -+ notify_msg=tensor_dict.pop("notify_msg"), -+ op_type=tensor_dict.pop("op_type"), - ) - - def as_broadcastable_tensor_dict( -@@ -246,6 +279,12 @@ class WorkerInput: - "blocks_to_copy": self.blocks_to_copy, - "virtual_engine": self.virtual_engine, - "num_steps": self.num_steps, -+ "local_block_ids": self.local_block_ids, -+ "staging_block_ids": self.staging_block_ids, -+ "remote_block_ids": self.remote_block_ids, -+ "remote_engine_id": self.remote_engine_id, -+ "notify_msg": self.notify_msg, -+ "op_type": self.op_type, - } - - return tensor_dict -@@ -316,13 +355,16 @@ class LocalOrDistributedWorkerBase(WorkerBase): - return None - - worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) -- model_input = ( -- self.model_runner.make_model_input_from_broadcasted_tensor_dict( -- broadcast_data)) -+ if worker_input.num_seq_groups > 0: -+ model_input = ( -+ self.model_runner.make_model_input_from_broadcasted_tensor_dict( -+ broadcast_data)) - -- kwargs = extract_previous_hidden_states(broadcast_data) -+ kwargs = extract_previous_hidden_states(broadcast_data) - -- return model_input, worker_input, kwargs -+ return model_input, worker_input, kwargs -+ else: -+ return None, worker_input, {} - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest -@@ -396,49 +438,88 @@ class LocalOrDistributedWorkerBase(WorkerBase): - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. -- if worker_input.num_seq_groups == 0: -- return [] -- -- intermediate_tensors = None -- orig_model_execute_time = 0.0 -- if not get_pp_group().is_first_rank: -- intermediate_tensors = IntermediateTensors( -- get_pp_group().recv_tensor_dict( -- all_gather_group=get_tp_group())) -+ if worker_input.num_seq_groups > 0: -+ -+ self._read_blocks(worker_input) -+ -+ intermediate_tensors = None -+ orig_model_execute_time = 0.0 -+ if not get_pp_group().is_first_rank: -+ intermediate_tensors = IntermediateTensors( -+ get_pp_group().recv_tensor_dict( -+ all_gather_group=get_tp_group())) -+ if (self.observability_config is not None -+ and self.observability_config.collect_model_execute_time): -+ orig_model_execute_time = intermediate_tensors.tensors.get( -+ "model_execute_time", torch.tensor(0)).item() -+ -+ output = self.model_runner.execute_model( -+ model_input=model_input, -+ kv_caches=self.kv_cache[worker_input.virtual_engine] -+ if self.kv_cache is not None else None, -+ intermediate_tensors=intermediate_tensors, -+ num_steps=num_steps, -+ **kwargs, -+ ) -+ -+ model_execute_time = time.perf_counter() - start_time -+ if not get_pp_group().is_last_rank: -+ # output is IntermediateTensors -+ assert isinstance(output, IntermediateTensors) -+ if (self.observability_config is not None -+ and self.observability_config.collect_model_execute_time): -+ output.tensors["model_execute_time"] = torch.tensor( -+ model_execute_time + orig_model_execute_time) -+ get_pp_group().send_tensor_dict(output.tensors, -+ all_gather_group=get_tp_group()) -+ return [None] - if (self.observability_config is not None -- and self.observability_config.collect_model_execute_time): -- orig_model_execute_time = intermediate_tensors.tensors.get( -- "model_execute_time", torch.tensor(0)).item() -+ and self.observability_config.collect_model_execute_time -+ and output is not None): -+ for o in output: -+ o.model_execute_time = (orig_model_execute_time + -+ model_execute_time) - -- output = self.model_runner.execute_model( -- model_input=model_input, -- kv_caches=self.kv_cache[worker_input.virtual_engine] -- if self.kv_cache is not None else None, -- intermediate_tensors=intermediate_tensors, -- num_steps=num_steps, -- **kwargs, -- ) -- -- model_execute_time = time.perf_counter() - start_time -- if not get_pp_group().is_last_rank: -- # output is IntermediateTensors -- assert isinstance(output, IntermediateTensors) -- if (self.observability_config is not None -- and self.observability_config.collect_model_execute_time): -- output.tensors["model_execute_time"] = torch.tensor( -- model_execute_time + orig_model_execute_time) -- get_pp_group().send_tensor_dict(output.tensors, -- all_gather_group=get_tp_group()) -- return [None] -- if (self.observability_config is not None -- and self.observability_config.collect_model_execute_time -- and output is not None): -- for o in output: -- o.model_execute_time = (orig_model_execute_time + -- model_execute_time) -+ self._write_blocks(worker_input) - -+ else: -+ output = [] -+ -+ # collect kv transfer notifications from non driver workers -+ -+ if self.nixl_connector is not None: -+ new_notifs = self.nixl_connector.get_new_notifs() -+ rank = get_tp_group().rank -+ all_new_notifs = [new_notifs] -+ if rank > 0: -+ get_tp_group().send_object(new_notifs, dst=0) -+ else: -+ for i in range(1, get_tp_group().world_size): -+ all_new_notifs.append(get_tp_group().recv_object(src=i)) -+ -+ request_notif_counter = defaultdict(int) -+ for notifs in all_new_notifs: -+ for req_ids in notifs.values(): -+ for req_id in req_ids: -+ request_notif_counter[req_id.decode("utf-8")] += 1 -+ -+ if request_notif_counter: -+ logger.debug("Request notif counter: %s", request_notif_counter) -+ -+ request_done_counter = defaultdict(int) -+ for req_id in self.nixl_connector.get_done_tranfers(): -+ request_done_counter[req_id] += 1 -+ else: -+ request_notif_counter = {} -+ request_done_counter = {} - # output is List[SamplerOutput] -- return output -+ return output, request_notif_counter, request_done_counter -+ -+ def _read_blocks(self, worker_input: WorkerInput) -> None: -+ pass -+ -+ def _write_blocks(self, worker_input: WorkerInput) -> None: -+ pass - - def _execute_model_spmd( - self, diff --git a/container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch b/container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch deleted file mode 100644 index b4939df62a..0000000000 --- a/container/deps/vllm/vllm_v0.8.4-dynamo-kv-disagg-patch.patch +++ /dev/null @@ -1,4840 +0,0 @@ -diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py -index 54278f5f6..7eaf92feb 100644 ---- a/vllm/attention/backends/mla/common.py -+++ b/vllm/attention/backends/mla/common.py -@@ -300,7 +300,8 @@ class MLACommonState(AttentionState, Generic[T]): - cache_config = runner.cache_config - - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled -- self.enable_prefix_caching = cache_config.enable_prefix_caching -+ # TODO ptarasiewicz: we pretend that prefix caching is enabled to make fetching from Decode kv cache work -+ self.enable_prefix_caching = True # cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - self.context_chunk_workspace_size = min( -@@ -735,8 +736,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]): - self.block_size = input_builder.block_size - self.chunked_prefill_enabled = \ - self.runner.scheduler_config.chunked_prefill_enabled -- self.enable_prefix_caching = \ -- self.runner.cache_config.enable_prefix_caching -+ # TODO ptarasiewicz: we pretend that prefix caching is enabled to make fetching from Decode kv cache work -+ self.enable_prefix_caching = True # self.runner.cache_config.enable_prefix_caching - - if self.chunked_prefill_enabled or self.enable_prefix_caching: - attn_state = self.input_builder.runner.attn_state -diff --git a/vllm/config.py b/vllm/config.py -index 2912361ee..eea9cb65d 100644 ---- a/vllm/config.py -+++ b/vllm/config.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import ast - import copy -@@ -3091,6 +3104,9 @@ class KVTransferConfig(BaseModel): - # The KV connector for vLLM to transmit KV caches between vLLM instances. - kv_connector: Optional[str] = None - -+ # Whether to use NIXL prepped xfer for KV cache transfer. -+ use_prepped_xfer: bool = True -+ - # The device used by kv connector to buffer the KV cache. - # Currently only support 'cuda'. - kv_buffer_device: Optional[str] = "cuda" -@@ -3100,7 +3116,7 @@ class KVTransferConfig(BaseModel): - kv_buffer_size: float = 1e9 - - # Whether this vLLM instance produces, consumes KV cache, or both. Choices -- # are 'kv_producer', 'kv_consumer', and 'both'. -+ # are 'kv_producer', 'kv_consumer', and 'kv_both'. - kv_role: Optional[str] = None - - # The rank of this vLLM instance in the KV cache transfer. Typical value: -@@ -3155,11 +3171,16 @@ class KVTransferConfig(BaseModel): - f"Supported roles are `kv_producer`, `kv_consumer`, " - f"and `kv_both`") - -- if self.kv_connector is not None and self.kv_role is None: -+ if self.kv_connector is not None and self.kv_connector != "DynamoNixlConnector" and self.kv_role is None: - raise ValueError("Please specify kv_disagg_role when kv_connector " - "is set, supported roles are `kv_producer`, " - "`kv_consumer`, and `kv_both`") - -+ if self.use_prepped_xfer is False: -+ logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.") -+ self.use_prepped_xfer = True -+ -+ - @property - def is_kv_transfer_instance(self) -> bool: - return self.kv_connector is not None and \ -diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py -index d64142e77..6279767cb 100644 ---- a/vllm/core/block/cpu_gpu_block_allocator.py -+++ b/vllm/core/block/cpu_gpu_block_allocator.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - from typing import Dict, FrozenSet, List, Optional, Tuple - -@@ -6,6 +19,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, - DeviceAwareBlockAllocator) - from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator - from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator -+from vllm.core.event_manager import KVCacheEventManager - from vllm.platforms import current_platform - from vllm.utils import Device - -@@ -28,6 +42,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - num_gpu_blocks: int, - num_cpu_blocks: int, - block_size: int, -+ event_manager: Optional[KVCacheEventManager] = None, - ) -> DeviceAwareBlockAllocator: - """Creates a CpuGpuBlockAllocator instance with the specified - configuration. -@@ -64,6 +79,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - cpu_block_ids = block_ids[num_gpu_blocks:] - - if allocator_type == "naive": -+ assert event_manager is None, "Event API not supported with naive allocator." - gpu_allocator: BlockAllocator = NaiveBlockAllocator( - create_block=NaiveBlock, # type: ignore - num_blocks=num_gpu_blocks, -@@ -82,12 +98,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - num_blocks=num_gpu_blocks, - block_size=block_size, - block_ids=gpu_block_ids, -+ event_manager=event_manager, - ) - - cpu_allocator = PrefixCachingBlockAllocator( - num_blocks=num_cpu_blocks, - block_size=block_size, - block_ids=cpu_block_ids, -+ event_manager=event_manager, - ) - else: - raise ValueError(f"Unknown allocator type {allocator_type=}") -@@ -95,10 +113,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - return CpuGpuBlockAllocator( - cpu_block_allocator=cpu_allocator, - gpu_block_allocator=gpu_allocator, -+ event_manager=event_manager, - ) - - def __init__(self, cpu_block_allocator: BlockAllocator, -- gpu_block_allocator: BlockAllocator): -+ gpu_block_allocator: BlockAllocator, -+ event_manager: Optional[KVCacheEventManager] = None,): - assert not ( - cpu_block_allocator.all_block_ids - & gpu_block_allocator.all_block_ids -@@ -108,6 +128,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): - Device.CPU: cpu_block_allocator, - Device.GPU: gpu_block_allocator, - } -+ self.event_manager = event_manager - - self._swap_mapping: Dict[int, int] = {} - self._null_block: Optional[Block] = None -diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py -index c388366b8..3c223b519 100644 ---- a/vllm/core/block/naive_block.py -+++ b/vllm/core/block/naive_block.py -@@ -1,8 +1,21 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - from collections import deque - from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union -- -+import heapq - from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, - get_all_blocks_recursively) - from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device -@@ -38,7 +51,7 @@ class NaiveBlockAllocator(BlockAllocator): - if block_ids is None: - block_ids = range(num_blocks) - -- self._free_block_indices: Deque[BlockId] = deque(block_ids) -+ self._free_block_indices: List[BlockId] = list(block_ids) - self._all_block_indices = frozenset(block_ids) - assert len(self._all_block_indices) == num_blocks - -@@ -134,7 +147,8 @@ class NaiveBlockAllocator(BlockAllocator): - if not self._free_block_indices: - raise BlockAllocator.NoFreeBlocksError() - -- block_id = self._free_block_indices.popleft() -+ block_id = heapq.heappop(self._free_block_indices) -+ # TODO: figure out why sometime block_id is None - self._refcounter.incr(block_id) - return block_id - -@@ -148,7 +162,7 @@ class NaiveBlockAllocator(BlockAllocator): - - refcount = self._refcounter.decr(block_id) - if refcount == 0: -- self._free_block_indices.appendleft(block_id) -+ heapq.heappush(self._free_block_indices, block_id) - - def free(self, block: Block, keep_block_object: bool = False) -> None: - # Release the physical block id -diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py -index 1ca9e49da..26fabb243 100644 ---- a/vllm/core/block/prefix_caching_block.py -+++ b/vllm/core/block/prefix_caching_block.py -@@ -1,10 +1,23 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """Token blocks.""" - import sys - from bisect import bisect_left - from os.path import commonprefix - from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, -- Tuple) -+ Tuple, TYPE_CHECKING) - - from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, - get_all_blocks_recursively) -@@ -23,6 +36,9 @@ PrefixHash = int - # then we know this block hasn't been accessed yet. - _DEFAULT_LAST_ACCESSED_TIME = -1 - -+if TYPE_CHECKING: -+ from vllm.core.event_manager import KVCacheEventManager -+ - logger = init_logger(__name__) - - -@@ -80,6 +96,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): - block_size: int, - block_ids: Optional[Iterable[int]] = None, - eviction_policy: EvictionPolicy = EvictionPolicy.LRU, -+ event_manager: Optional["KVCacheEventManager"] = None, - ): - if block_ids is None: - block_ids = range(num_blocks) -@@ -131,6 +148,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): - - self.metric_data = CacheMetricData() - -+ self.event_manager = event_manager -+ -+ # Implements Block.Factory. - def _create_block( - self, - prev_block: Optional[Block], -@@ -337,6 +357,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): - assert self._refcounter.get(_block_id) == 0 - assert _block_id == block_id - -+ if self.event_manager: -+ self.event_manager.enqueue_removed_event(content_hash_to_evict) -+ - self._cached_blocks.pop(content_hash_to_evict) - - self._refcounter.incr(block_id) -@@ -513,6 +536,10 @@ class PrefixCachingBlockAllocator(BlockAllocator): - # Mark this block as touched so that it can be marked as - # computed after the entire batch of sequences are scheduled. - self._touched_blocks.add(block.block_id) -+ -+ if self.event_manager: -+ self.event_manager.enqueue_stored_event(block.prev_block, block) -+ - return block.block_id - - # Reuse the cached content hash -@@ -579,9 +606,11 @@ class PrefixCachingBlockAllocator(BlockAllocator): - - def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - # Mark all touched blocks as computed. -- for block_id in self._touched_blocks: -- self._block_tracker[block_id].computed = True -- self._touched_blocks.clear() -+ for block_id in block_ids: -+ if block_id in self._touched_blocks: -+ logger.debug("Mark block as computed: %s", block_id) -+ self._block_tracker[block_id].computed = True -+ self._touched_blocks.remove(block_id) - - def _track_block_id(self, block_id: Optional[BlockId], - computed: bool) -> None: -diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py -index c6bf6d163..c5514f935 100644 ---- a/vllm/core/block_manager.py -+++ b/vllm/core/block_manager.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """A block manager that manages token blocks.""" - from typing import Dict, List, Optional - from typing import Sequence as GenericSequence -@@ -10,7 +23,10 @@ from vllm.core.block.interfaces import Block - from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, - LastAccessBlocksTracker) - from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec -+from vllm.core.event_manager import KVCacheEventManager - from vllm.core.interfaces import AllocStatus, BlockSpaceManager -+from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE, -+ VLLM_WORKER_ID) - from vllm.sequence import Sequence, SequenceGroup, SequenceStatus - from vllm.utils import Device - -@@ -60,6 +76,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - - def __init__( - self, -+ model_name: str, - block_size: int, - num_gpu_blocks: int, - num_cpu_blocks: int, -@@ -91,11 +108,29 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - - self.watermark_blocks = int(watermark * num_gpu_blocks) - -+ kv_event_manager_params = [ -+ VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE, -+ VLLM_KV_COMPONENT -+ ] -+ set_kv_event_manager_params = len( -+ [param for param in kv_event_manager_params if param is not None]) -+ -+ if set_kv_event_manager_params == len(kv_event_manager_params): -+ self.event_manager = KVCacheEventManager( -+ namespace=VLLM_KV_NAMESPACE, -+ component=VLLM_KV_COMPONENT, -+ worker_id=VLLM_WORKER_ID, -+ lib_path=VLLM_KV_CAPI_PATH, -+ kv_block_size=block_size) -+ else: -+ self.event_manager = None -+ - self.block_allocator = CpuGpuBlockAllocator.create( - allocator_type="prefix_caching" if enable_caching else "naive", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, - block_size=block_size, -+ event_manager=self.event_manager, - ) - - self.block_tables: Dict[SeqId, BlockTable] = {} -@@ -108,7 +143,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - - def can_allocate(self, - seq_group: SequenceGroup, -- num_lookahead_slots: int = 0) -> AllocStatus: -+ num_lookahead_slots: int = 0, -+ is_remote_decode: bool = False) -> AllocStatus: - # FIXME(woosuk): Here we assume that all sequences in the group share - # the same prompt. This may not be true for preempted sequences. - -@@ -121,6 +157,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager): - num_lookahead_slots=num_lookahead_slots, - ) - -+ # if remote decode, we need to allocate twice as many blocks for staging -+ if is_remote_decode: -+ num_required_blocks *= 2 -+ - if seq_group.is_encoder_decoder(): - encoder_seq = seq_group.get_encoder_seq() - assert encoder_seq is not None -diff --git a/vllm/core/event_manager.py b/vllm/core/event_manager.py -new file mode 100644 -index 000000000..79eb8db67 ---- /dev/null -+++ b/vllm/core/event_manager.py -@@ -0,0 +1,121 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+import ctypes -+import logging -+import uuid -+from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64 -+from typing import Optional -+ -+from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash -+ -+logger = logging.getLogger(__name__) -+ -+ -+class DynamoResult: -+ OK = 0 -+ ERR = 1 -+ -+ -+class KVCacheEventManager: -+ -+ def __init__(self, namespace: str, component: str, worker_id: int, -+ lib_path: str, kv_block_size: int): -+ self.lib = None -+ -+ try: -+ self.lib = ctypes.CDLL(lib_path) -+ self.lib.dynamo_llm_init.argtypes = [ -+ c_char_p, -+ c_char_p, -+ c_int64, -+ c_uint32, -+ ] -+ self.lib.dynamo_llm_init.restype = c_uint32 -+ -+ result = self.lib.dynamo_llm_init( -+ namespace.encode(), component.encode(), worker_id, kv_block_size -+ ) -+ if result == DynamoResult.OK: -+ logger.info( -+ "KVCacheEventManager initialized successfully. Ready to publish KV Cache Events" -+ ) -+ else: -+ logger.info("KVCacheEventManager initialization failed!") -+ -+ except Exception as e: -+ print(f"Failed to load {lib_path}") -+ raise e -+ -+ self.lib.dynamo_kv_event_publish_stored.argtypes = [ -+ ctypes.c_uint64, # event_id -+ ctypes.POINTER(ctypes.c_uint32), # token_ids -+ ctypes.POINTER(ctypes.c_size_t), # num_block_tokens -+ ctypes.POINTER(ctypes.c_uint64), # block_ids -+ ctypes.c_size_t, # num_blocks -+ ctypes.POINTER(ctypes.c_uint64), # parent_hash -+ ctypes.c_uint64, # lora_id -+ ] -+ self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t -+ -+ self.lib.dynamo_kv_event_publish_removed.argtypes = [ -+ ctypes.c_uint64, # event_id -+ ctypes.POINTER(ctypes.c_uint64), # block_ids -+ ctypes.c_size_t, # num_blocks -+ ] -+ self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t -+ -+ self.event_id_counter = 0 -+ -+ def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock], -+ block: PrefixCachingBlock): -+ token_ids_arr = (ctypes.c_uint32 * -+ len(block.token_ids))(*block.token_ids) -+ num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids)) -+ block_hash = (ctypes.c_uint64 * 1)(block.content_hash) -+ parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash) -+ if parent is not None else None) -+ -+ # Publish the event -+ result = self.lib.dynamo_kv_event_publish_stored( -+ self.event_id_counter, # uint64_t event_id -+ token_ids_arr, # const uint32_t *token_ids -+ num_block_tokens, # const uintptr_t *num_block_tokens -+ block_hash, # const uint64_t *block_ids -+ 1, # uintptr_t num_blocks -+ parent_hash, # const uint64_t *parent_hash -+ 0, # uint64_t lora_id -+ ) -+ -+ if result == DynamoResult.OK: -+ logger.debug(f"Store - Published KV Event: {block.content_hash}") -+ else: -+ logger.debug( -+ f"Store - Failed to Publish KV Event: {block.content_hash}") -+ -+ self.event_id_counter += 1 -+ -+ def enqueue_removed_event(self, block_hash: PrefixHash): -+ result = self.lib.dynamo_kv_event_publish_removed( -+ self.event_id_counter, -+ (ctypes.c_uint64 * 1)(block_hash), -+ 1, -+ ) -+ -+ if result == DynamoResult.OK: -+ logger.debug(f"Remove - Published KV Event: {block_hash}") -+ else: -+ logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}") -+ -+ self.event_id_counter += 1 -diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py -index cf85a2135..f157aa231 100644 ---- a/vllm/core/scheduler.py -+++ b/vllm/core/scheduler.py -@@ -1,16 +1,30 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import enum - import os - import random - import time -+import copy - from collections import deque - from dataclasses import dataclass, field - from typing import Callable, Deque, Dict, Iterable, List, Optional - from typing import Sequence as GenericSequence --from typing import Set, Tuple, Union -+from typing import Set, Tuple, Union, Any - --from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig -+from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig - from vllm.core.interfaces import AllocStatus, BlockSpaceManager - from vllm.logger import init_logger - from vllm.lora.request import LoRARequest -@@ -20,7 +34,6 @@ from vllm.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadataDelta, SequenceStage, - SequenceStatus) - from vllm.utils import Device, PyObjectCache -- - logger = init_logger(__name__) - - # Test-only. If configured, decode is preempted with -@@ -292,6 +305,7 @@ class SchedulerPrefillOutputs: - # Ignored sequence groups. - ignored_seq_groups: List[SequenceGroup] - num_lookahead_slots: int -+ num_remote_prefill_groups: int - - @classmethod - def create_empty(cls) -> "SchedulerPrefillOutputs": -@@ -299,6 +313,7 @@ class SchedulerPrefillOutputs: - seq_groups=[], - ignored_seq_groups=[], - num_lookahead_slots=0, -+ num_remote_prefill_groups=0, - ) - - -@@ -426,12 +441,14 @@ class Scheduler: - - def __init__( - self, -+ model_config: ModelConfig, - scheduler_config: SchedulerConfig, - cache_config: CacheConfig, - lora_config: Optional[LoRAConfig], - pipeline_parallel_size: int = 1, - output_proc_callback: Optional[Callable] = None, - ) -> None: -+ self.model_config = model_config - self.scheduler_config = scheduler_config - self.cache_config = cache_config - # Note for LoRA scheduling: the current policy is extremely -@@ -457,6 +474,7 @@ class Scheduler: - - # Create the block space manager. - self.block_manager = BlockSpaceManagerImpl( -+ model_name=self.model_config.served_model_name, - block_size=self.cache_config.block_size, - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks, -@@ -473,6 +491,16 @@ class Scheduler: - # Sequence groups in the SWAPPED state. - # Contain decode requests that are swapped out. - self.swapped: Deque[SequenceGroup] = deque() -+ -+ # Sequence groups in the REMOTE_PREFILLING state. -+ # Contain requests that are being prefilled by a remote worker. -+ self.remote_prefilling: Deque[SequenceGroup] = deque() -+ # Contain requests that are being prefilled by a local worker. -+ self.prefill_sending: Deque[SequenceGroup] = deque() -+ -+ self._remote_prefill_outputs: Dict[str, int] = {} -+ -+ - # Sequence groups finished requests ids since last step iteration. - # It lets the model know that any state associated with these requests - # can and must be released after the current step. -@@ -628,8 +656,8 @@ class Scheduler: - self.block_manager.free_cross(seq_group) - - def has_unfinished_seqs(self) -> bool: -- return (len(self.waiting) != 0 or len(self.running) != 0 -- or len(self.swapped) != 0) -+ return len(self.waiting) != 0 or len(self.running) != 0 or len( -+ self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0 - - def get_prefix_cache_hit_rate(self, device: Device) -> float: - return self.block_manager.get_prefix_cache_hit_rate(device) -@@ -652,6 +680,8 @@ class Scheduler: - curr_loras: Optional[Set[int]], - enable_chunking: bool = False, - partial_prefill_metadata: Optional[PartialPrefillMetadata] = None, -+ finished_prefills: Optional[Set[str]] = None, -+ finished_transfers: Optional[Set[str]] = None, - ) -> SchedulerRunningOutputs: - """Schedule sequence groups that are running. - -@@ -669,6 +699,9 @@ class Scheduler: - partial_prefill_metadata: information about the partial prefills - that are currently running - -+ finished_remote_prefill_request_ids: Set of request ids of remote -+ prefills that have finished. -+ - Returns: - SchedulerRunningOutputs. - """ -@@ -697,6 +730,38 @@ class Scheduler: - preempted: List[SequenceGroup] = ret.preempted - swapped_out: List[SequenceGroup] = ret.swapped_out - -+ remote_prefilling_queue = self.remote_prefilling -+ leftover_remote_prefilling_sequences: Deque[SequenceGroup] = deque() -+ while remote_prefilling_queue: -+ seq_group = remote_prefilling_queue.popleft() -+ if seq_group.request_id not in finished_prefills: -+ leftover_remote_prefilling_sequences.append(seq_group) -+ continue -+ -+ else: -+ finished_prefills.remove(seq_group.request_id) -+ assert len(seq_group.seqs) == 1 -+ seq = seq_group.seqs[0] -+ # we computed all but the last token in prefill, we need to decode the first token on decode -+ seq_group.update_num_computed_tokens(seq.get_len() - 1) -+ seq.status = SequenceStatus.RUNNING -+ seq.data._stage = SequenceStage.DECODE -+ self.running.appendleft(seq_group) -+ remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences) -+ -+ remote_transfers_queue = self.prefill_sending -+ leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque() -+ while remote_transfers_queue: -+ seq_group = remote_transfers_queue.popleft() -+ if seq_group.request_id not in finished_transfers: -+ leftover_remote_transfers_sequences.append(seq_group) -+ else: -+ finished_transfers.remove(seq_group.request_id) -+ assert len(seq_group.seqs) == 1 -+ seq = seq_group.seqs[0] -+ self.free_seq(seq) -+ remote_transfers_queue.extendleft(leftover_remote_transfers_sequences) -+ - running_queue = self.running - assert len(self._async_stopped) == 0 - while running_queue: -@@ -1068,11 +1133,13 @@ class Scheduler: - ignored_seq_groups=[], - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), -+ num_remote_prefill_groups=0 - ) - ignored_seq_groups: List[SequenceGroup] = [] - seq_groups: List[ScheduledSequenceGroup] = [] - - waiting_queue = self.waiting -+ num_remote_prefill_groups = 0 - - leftover_waiting_sequences: Deque[SequenceGroup] = deque() - while self._passed_delay(time.time()) and waiting_queue: -@@ -1121,8 +1188,10 @@ class Scheduler: - True, enable_chunking) - - # If the sequence group cannot be allocated, stop. -+ is_remote_decode = seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode - can_allocate = self.block_manager.can_allocate( -- seq_group, num_lookahead_slots=num_lookahead_slots) -+ seq_group, num_lookahead_slots=num_lookahead_slots, -+ is_remote_decode=is_remote_decode) - if can_allocate == AllocStatus.LATER: - break - elif can_allocate == AllocStatus.NEVER: -@@ -1170,7 +1239,18 @@ class Scheduler: - if curr_loras is not None and lora_int_id > 0: - curr_loras.add(lora_int_id) - waiting_queue.popleft() -- self._allocate_and_set_running(seq_group) -+ -+ seq_group_copy = copy.deepcopy(seq_group) -+ seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1 -+ -+ logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id) -+ logger.debug("Seq id: %s", seq_group.seqs[0].seq_id) -+ is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group) -+ num_remote_prefill_groups += is_remote_prefill -+ if is_remote_decode: -+ logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id) -+ self._allocate_and_set_running_or_remote_prefill(seq_group_copy) -+ self.prefill_sending.append(seq_group_copy) - - if partial_prefill_metadata is not None: - partial_prefill_metadata.maybe_increment_partial_prefills( -@@ -1214,9 +1294,10 @@ class Scheduler: - ignored_seq_groups=ignored_seq_groups, - num_lookahead_slots=self._get_num_lookahead_slots( - is_prefill=True, enable_chunking=enable_chunking), -+ num_remote_prefill_groups=num_remote_prefill_groups - ) - -- def _schedule_default(self) -> SchedulerOutputs: -+ def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: - """Schedule queued requests. - - The current policy is designed to optimize the throughput. First, -@@ -1234,6 +1315,9 @@ class Scheduler: - for seq_group in self.running: - budget.add_num_seqs(seq_group.request_id, - seq_group.get_max_num_running_seqs()) -+ for seq_group in self.remote_prefilling: -+ budget.add_num_seqs(seq_group.request_id, -+ seq_group.get_max_num_running_seqs()) - curr_loras = (set( - seq_group.lora_int_id for seq_group in self.running - if seq_group.lora_int_id > 0) if self.lora_enabled else None) -@@ -1258,7 +1342,9 @@ class Scheduler: - if len(prefills.seq_groups) == 0: - running_scheduled = self._schedule_running(budget, - curr_loras, -- enable_chunking=False) -+ enable_chunking=False, -+ finished_prefills=finished_prefills, -+ finished_transfers=finished_transfers) - - # If any sequence group is preempted, do not swap in any sequence - # group. because it means there's no slot for new running requests. -@@ -1275,7 +1361,12 @@ class Scheduler: - self.waiting.extendleft(running_scheduled.preempted) - # Update new running requests. - if len(prefills.seq_groups) > 0: -- self.running.extend([s.seq_group for s in prefills.seq_groups]) -+ for s in prefills.seq_groups: -+ seq_group = s.seq_group -+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: -+ self.remote_prefilling.append(seq_group) -+ else: -+ self.running.append(seq_group) - - self.running.extend(running_scheduled.decode_seq_groups_list) - -@@ -1452,12 +1543,14 @@ class Scheduler: - ] - return finishing + not_finishing - -- def _schedule(self) -> SchedulerOutputs: -+ def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs: - """Schedule queued requests.""" - if self.scheduler_config.chunked_prefill_enabled: -+ if finished_prefills or finished_transfers: -+ raise ValueError("Chunked prefill does not support remote prefills") - return self._schedule_chunked_prefill() - else: -- return self._schedule_default() -+ return self._schedule_default(finished_prefills, finished_transfers) - - def _can_append_slots(self, seq_group: SequenceGroup, - enable_chunking: bool) -> bool: -@@ -1491,14 +1584,16 @@ class Scheduler: - return no_single_seq - - def schedule( -- self -+ self, -+ finished_prefills: Optional[Set[str]] = None, -+ finished_transfers: Optional[Set[str]] = None - ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: - # Schedule sequence groups. - # This function call changes the internal states of the scheduler - # such as self.running, self.swapped, and self.waiting. -- scheduler_start_time = time.perf_counter() - -- scheduler_outputs: SchedulerOutputs = self._schedule() -+ scheduler_start_time = time.perf_counter() -+ scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers) - now = time.time() - - if not self.cache_config.enable_prefix_caching: -@@ -1537,7 +1632,8 @@ class Scheduler: - encoder_seq_data = None - cross_block_table = None - -- for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): -+ running_or_remote_prefilling_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + seq_group.get_seqs(status=SequenceStatus.REMOTE_PREFILLING) -+ for seq in running_or_remote_prefilling_seqs: - seq_id = seq.seq_id - seq_data[seq_id] = seq.data - block_tables[seq_id] = self.block_manager.get_block_table(seq) -@@ -1546,7 +1642,9 @@ class Scheduler: - if self.cache_config.enable_prefix_caching: - common_computed_block_nums = ( - self.block_manager.get_common_computed_block_ids( -- seq_group.get_seqs(status=SequenceStatus.RUNNING))) -+ running_or_remote_prefilling_seqs -+ ) -+ ) - - do_sample = True - is_prompt = seq_group.is_prefill() -@@ -1568,9 +1666,30 @@ class Scheduler: - < seqs[0].data.get_len()): - do_sample = False - -+ is_remote_prefill = False -+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: -+ is_remote_prefill = True -+ logger.debug("Remote prefill, computed block nums: %s", common_computed_block_nums) -+ if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode: -+ block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids -+ -+ # Since we know that prefill is scheduled we can -+ # assume that the blocks computed on decode -+ # will be fetched by the time we run prefill -+ logger.debug("Computed decode blocks: %s", seq_group.remote_prefill_params.decode_computed_block_ids) -+ if seq_group.remote_prefill_params.decode_computed_block_ids: -+ computed_block_ids = set(seq_group.remote_prefill_params.decode_computed_block_ids) -+ prefill_block_ids = block_tables[seq_group.seqs[0].seq_id] -+ prefill_fetched_block_ids = [prefill_block_ids[i] for i, block_id in enumerate(seq_group.remote_prefill_params.decode_block_ids) if block_id in computed_block_ids and i < len(prefill_block_ids)] -+ -+ assert len(common_computed_block_nums) == 0, "common_computed_block_nums should be empty for remote prefill as it doesn't suport prefix caching" -+ common_computed_block_nums = prefill_fetched_block_ids -+ -+ - # It assumes the scheduled_seq_groups is ordered by - # prefill < decoding. - if is_first_prefill or not self.scheduler_config.send_delta_data: -+ logger.debug("Assinged blocks: %s", block_tables) - seq_group_metadata = SequenceGroupMetadata( - request_id=seq_group.request_id, - is_prompt=is_prompt, -@@ -1598,6 +1717,7 @@ class Scheduler: - if scheduler_outputs.num_prefill_groups > 0 else None), - mm_processor_kwargs=seq_group.mm_processor_kwargs, - prompt_adapter_request=seq_group.prompt_adapter_request, -+ do_remote_prefill=is_remote_prefill, - ) - else: - # When SPMD mode is enabled, we only send delta data except for -@@ -1696,10 +1816,16 @@ class Scheduler: - - self._async_stopped.clear() - -- def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: -+ def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> bool: - self.block_manager.allocate(seq_group) -+ is_remote_prefill = False - for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): -- seq.status = SequenceStatus.RUNNING -+ if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill: -+ seq.status = SequenceStatus.REMOTE_PREFILLING -+ is_remote_prefill = True -+ else: -+ seq.status = SequenceStatus.RUNNING -+ return is_remote_prefill - - def _append_slots( - self, -diff --git a/vllm/distributed/device_communicators/kv_rearrange.py b/vllm/distributed/device_communicators/kv_rearrange.py -new file mode 100644 -index 000000000..a2f9ce99e ---- /dev/null -+++ b/vllm/distributed/device_communicators/kv_rearrange.py -@@ -0,0 +1,125 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+import torch -+import triton -+import triton.language as tl -+ -+@triton.jit -+def rearrange_kernel_read( -+ t1_ptr, -+ t2_ptr, -+ N, -+ B, -+ H, -+ C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE: tl.constexpr, -+): -+ pid = tl.program_id(0) -+ -+ block_start = pid * BLOCK_SIZE -+ offsets = block_start + tl.arange(0, BLOCK_SIZE) -+ -+ curr_n = offsets // block_size -+ curr_b = offsets // token_size % B -+ curr_h = offsets // C % H -+ curr_c = offsets % C -+ -+ src_pos = offsets -+ -+ tp_group = curr_h * d // H -+ dst_h = curr_h % (H // d) -+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c -+ -+ dst_pos = tensor_subset_size * tp_group + tp_group_offset -+ -+ tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos)) -+ -+@triton.jit -+def rearrange_kernel_write( -+ t1_ptr, -+ t2_ptr, -+ N, -+ B, -+ H, -+ C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE: tl.constexpr, -+): -+ pid = tl.program_id(0) -+ -+ block_start = pid * BLOCK_SIZE -+ offsets = block_start + tl.arange(0, BLOCK_SIZE) -+ -+ curr_n = offsets // block_size -+ curr_b = offsets // token_size % B -+ curr_h = offsets // C % H -+ curr_c = offsets % C -+ -+ src_pos = offsets -+ -+ tp_group = curr_h * d // H -+ dst_h = curr_h % (H // d) -+ tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c -+ -+ dst_pos = tensor_subset_size * tp_group + tp_group_offset -+ -+ tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos)) -+ -+ -+ -+def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str): -+ N, B, H, C = t1.shape -+ -+ assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source" -+ assert H % d == 0, "H must be divisible by d" -+ -+ block_size = B * H * C -+ token_size = H * C -+ tensor_size = N * block_size -+ tensor_subset_size = tensor_size // d -+ -+ BLOCK_SIZE = 1024 -+ grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,) -+ -+ if direction == "read": -+ rearrange_kernel_read[grid]( -+ t1, t2, -+ N, B, H, C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE=BLOCK_SIZE -+ ) -+ elif direction == "write": -+ rearrange_kernel_write[grid]( -+ t1, t2, -+ N, B, H, C, -+ d, -+ tensor_subset_size, -+ block_size, -+ token_size, -+ BLOCK_SIZE=BLOCK_SIZE -+ ) -+ else: -+ raise ValueError(f"Invalid direction: {direction}") -\ No newline at end of file -diff --git a/vllm/distributed/device_communicators/nixl.py b/vllm/distributed/device_communicators/nixl.py -new file mode 100644 -index 000000000..bd4ac984e ---- /dev/null -+++ b/vllm/distributed/device_communicators/nixl.py -@@ -0,0 +1,445 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+import torch -+from typing import List, Tuple -+from vllm.config import VllmConfig -+from vllm.logger import init_logger -+import msgspec -+import time -+import uuid -+from collections import defaultdict -+from .kv_rearrange import rearrange_tensors -+ -+logger = init_logger(__name__) -+ -+# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used -+try: -+ from nixl._api import nixl_agent as NixlWrapper -+ logger.info("NIXL is available") -+except ImportError: -+ logger.warning("NIXL is not available") -+ NixlWrapper = None -+ -+class NixlMetadata( -+ msgspec.Struct, -+ omit_defaults=True, # type: ignore[call-arg] -+ # required for @cached_property. -+ dict=True): -+ engine_id: str -+ agent_metadata: List[bytes] -+ kv_caches_base_addr: List[List[List[int]]] # base address for each rank for each layer for keys and values -+ num_blocks: int -+ -+ -+class DynamoNixlConnector: -+ def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int): -+ self.vllm_config = vllm_config -+ if NixlWrapper is None: -+ logger.error("NIXL is not available") -+ raise RuntimeError("NIXL is not available") -+ logger.info("Initializing NIXL wrapper") -+ self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None) -+ -+ self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer -+ -+ self.num_layers = None -+ self.num_blocks = None -+ self.num_heads = None -+ self.block_len = None -+ self.kv_caches = None -+ self.kv_caches_base_addr = {} -+ self.kv_cache_shape = {} -+ -+ self._registered_descs = [] -+ self._remote_agents = {} -+ self.engine_id = engine_id -+ self.rank = rank -+ self._tp_size = {} -+ self.src_xfer_side_handles = {} -+ self.dst_xfer_side_handles = defaultdict(dict) -+ self.dst_num_blocks = {} -+ -+ self._transfers = defaultdict(list) -+ -+ -+ self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size -+ self._is_mla = "deepseek" in vllm_config.model_config.architectures[0].lower() -+ -+ -+ @property -+ def agent_name(self): -+ return self.nixl_wrapper.name -+ -+ def register_kv_caches(self, kv_caches: List[torch.Tensor]): -+ logger.debug("--------------------------------") -+ logger.debug("Registering kv caches for engine %s", self.engine_id) -+ logger.debug(f"Is deepseek: {self._is_mla}") -+ logger.debug(f"kv_cache shape: {kv_caches[0].shape}") -+ logger.debug("--------------------------------") -+ -+ if self._is_mla: -+ num_blocks, block_size, head_dim = kv_caches[0].shape -+ self.block_len = head_dim * block_size * kv_caches[0].element_size() -+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) -+ self.num_layers = len(kv_caches) -+ self.num_blocks = num_blocks -+ self.num_heads = 1 -+ self.kv_caches = kv_caches -+ self.num_cache_entries = 1 -+ -+ kv_caches_base_addr = [] -+ caches_data = [] -+ for kv_cache in kv_caches: -+ base_addr = kv_cache.data_ptr() -+ region_len = self.num_cache_entries * num_blocks * self.block_len -+ caches_data.append((base_addr, region_len, self.rank, "")) -+ kv_caches_base_addr.append([base_addr,]) -+ -+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr -+ -+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") -+ logger.debug("Registering descs: %s", caches_data) -+ self.nixl_wrapper.register_memory(descs) -+ self._registered_descs.append(descs) -+ else: -+ _, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape -+ self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size() -+ logger.debug("Per layer kv cache size: %s", kv_caches[0].shape) -+ self.num_layers = len(kv_caches) -+ self.num_blocks = num_blocks -+ self.num_heads = num_heads -+ self.kv_caches = kv_caches -+ self.num_cache_entries = 2 -+ kv_caches_base_addr = [] -+ caches_data = [] -+ for key_cache, value_cache in kv_caches: -+ base_addr = key_cache.data_ptr() -+ region_len = self.num_cache_entries * num_blocks * self.block_len -+ caches_data.append((base_addr, region_len, self.rank, "")) -+ kv_caches_base_addr.append([key_cache.data_ptr(), value_cache.data_ptr()]) -+ -+ self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr -+ -+ descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM") -+ logger.debug("Registering descs: %s", caches_data) -+ self.nixl_wrapper.register_memory(descs) -+ self._registered_descs.append(descs) -+ -+ def get_agent_metadata(self): -+ return self.nixl_wrapper.get_agent_metadata() -+ -+ def shutdown(self): -+ for descs_list in self._registered_descs: -+ self.nixl_wrapper.deregister_memory(descs_list) -+ for agent_names in self._remote_agents.values(): -+ for agent_name in agent_names: -+ self.nixl_wrapper.remove_remote_agent(agent_name) -+ for src_xfer_side_handle in self.src_xfer_side_handles.values(): -+ self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle) -+ for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): -+ for dst_xfer_side_handle in dst_xfer_side_handles.values(): -+ self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) -+ -+ def _get_ranges(self, block_ids): -+ # This function should return a list of ranges of block ids that are contiguous -+ # For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]] -+ # The ranges are sorted by the starting block id -+ # The function should also make sure that the block ids are contiguous -+ # If the block ids are not contiguous, the function should raise an error -+ ranges = [] -+ for i in range(len(block_ids)): -+ if i == 0 or block_ids[i] != block_ids[i-1] + 1: -+ ranges.append([block_ids[i], block_ids[i]]) -+ else: -+ ranges[-1][1] = block_ids[i] -+ return ranges -+ -+ def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None): -+ -+ if layer_ids == "all": -+ layer_ids = list(range(self.num_layers)) -+ if block_ids == "all": -+ block_ids = list(range(self.num_blocks)) -+ -+ descs_ids = [] -+ -+ -+ if i is not None: -+ num_blocks = self.num_blocks -+ for layer_id in layer_ids: -+ for entry_index in range(self.num_cache_entries): -+ staging_range_idx = 0 -+ for block_id in block_ids: -+ if staging_ranges is not None: -+ if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]: -+ staging_range_idx += 1 -+ start_offset = staging_ranges[staging_range_idx][0] -+ i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1) -+ descs_ids.append(layer_id * self.num_cache_entries * num_blocks * tp_multiplier + entry_index * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset)) -+ else: -+ descs_ids.append(layer_id * self.num_cache_entries * num_blocks + entry_index * num_blocks + block_id) -+ else: -+ num_blocks = self.dst_num_blocks[engine_id] -+ for layer_id in layer_ids: -+ for entry_index in range(self.num_cache_entries): -+ for block_id in block_ids: -+ descs_ids.append(layer_id * self.num_cache_entries * num_blocks + entry_index * num_blocks + block_id) -+ return descs_ids -+ -+ def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False): -+ # This function should return a list of ranges for both src and dst so that corresponding ranges are the same length -+ # For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]] -+ # The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]]) -+ src_overlapping_ranges, dst_overlapping_ranges = [], [] -+ -+ original_src_ranges = [] -+ org_src_range = tuple(src_ranges[0]) -+ -+ src_idx, dst_idx = 0, 0 -+ while src_idx < len(src_ranges) and dst_idx < len(dst_ranges): -+ src_range = src_ranges[src_idx] -+ dst_range = dst_ranges[dst_idx] -+ -+ # Calculate the length of each range -+ src_len = src_range[-1] - src_range[0] + 1 -+ dst_len = dst_range[-1] - dst_range[0] + 1 -+ -+ # If ranges have the same length, add them directly -+ if src_len == dst_len: -+ src_overlapping_ranges.append([src_range[0], src_range[-1]]) -+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) -+ original_src_ranges.append(org_src_range) -+ src_idx += 1 -+ dst_idx += 1 -+ if src_idx < len(src_ranges): -+ org_src_range = tuple(src_ranges[src_idx]) -+ # If source range is longer, split it -+ elif src_len > dst_len: -+ src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1]) -+ dst_overlapping_ranges.append([dst_range[0], dst_range[-1]]) -+ original_src_ranges.append(org_src_range) -+ # Update source range for next iteration -+ src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]] -+ dst_idx += 1 -+ # If destination range is longer, split it -+ else: # src_len < dst_len -+ src_overlapping_ranges.append([src_range[0], src_range[-1]]) -+ dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1]) -+ original_src_ranges.append(org_src_range) -+ # Update destination range for next iteration -+ dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]] -+ src_idx += 1 -+ if src_idx < len(src_ranges): -+ org_src_range = tuple(src_ranges[src_idx]) -+ if return_original_src_ranges: -+ return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges -+ return src_overlapping_ranges, dst_overlapping_ranges -+ -+ def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id): -+ logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id) -+ -+ assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids) -+ -+ if len(local_block_ids) == 0: -+ logger.debug("No blocks to read") -+ return -+ -+ start_time = time.perf_counter() -+ -+ if self._is_mla: -+ # TODO ptarasiewicz: we skip staging when is_mla is true, we shouldn't assign staging blocks at all -+ staging_rearranging_ranges = None -+ staging_block_ids = local_block_ids -+ else: -+ local_ranges = self._get_ranges(local_block_ids) -+ staging_ranges = self._get_ranges(staging_block_ids) -+ -+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) -+ -+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] -+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) -+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] -+ handles = [] -+ -+ logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000) -+ create_xfer_start_time = time.perf_counter() -+ -+ for i in range(tp_multiplier): -+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) -+ assert len(staging_block_descs_ids) == len(remote_block_descs_ids) -+ remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] -+ handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids, -+ remote_xfer_side_handle, remote_block_descs_ids, -+ "") -+ handles.append(handle) -+ status = self.nixl_wrapper.transfer(handle) -+ -+ logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) -+ -+ transfer_start_time = time.perf_counter() -+ -+ for handle in handles: -+ while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE": -+ if status == "PROC": -+ time.sleep(0.001) -+ else: -+ raise RuntimeError("Read transfer failed with state %s", status) -+ # self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors? -+ -+ logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000) -+ -+ rearrange_start_time = time.perf_counter() -+ -+ if not self._is_mla: -+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): -+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) -+ for kv_cache in self.kv_caches: -+ for cache in kv_cache: -+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read") -+ -+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000) -+ logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000) -+ -+ def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg): -+ logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg) -+ -+ # hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last -+ # isl[-1] token is calculated in the first iteration in decode. -+ # If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \ -+ # one less block due to the missing last token. -+ remote_block_ids = remote_block_ids[:len(local_block_ids)] -+ -+ assert len(staging_block_ids) == len(local_block_ids) -+ tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id] -+ -+ if len(local_block_ids) == 0: -+ logger.debug("No blocks to write") -+ for i in range(tp_multiplier): -+ self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg) -+ return -+ -+ start_time = time.perf_counter() -+ -+ if self._is_mla: -+ # TODO ptarasiewicz: we skip staging when is_mla is true, we shouldn't assign staging blocks at all -+ staging_rearranging_ranges = None -+ staging_block_ids = local_block_ids -+ else: -+ local_ranges = self._get_ranges(local_block_ids) -+ staging_ranges = self._get_ranges(staging_block_ids) -+ -+ local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges) -+ -+ for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges): -+ logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range) -+ for kv_cache in self.kv_caches: -+ for cache in kv_cache: -+ rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write") -+ -+ logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000) -+ -+ create_xfer_start_time = time.perf_counter() -+ -+ # getting block descs ids -+ remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids) -+ local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier] -+ -+ logger.debug("Creating xfer handles") -+ for i in range(tp_multiplier): -+ staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges) -+ assert len(staging_block_descs_ids) == len(remote_block_descs_ids) -+ remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i] -+ handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids, -+ remote_xfer_side_handle, remote_block_descs_ids, -+ notify_msg) -+ self._transfers[notify_msg].append(handle) -+ status = self.nixl_wrapper.transfer(handle) -+ -+ logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000) -+ -+ transfer_start_time = time.perf_counter() -+ logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000) -+ -+ def get_notifs(self): -+ return self.nixl_wrapper.update_notifs() -+ -+ def get_new_notifs(self): -+ return self.nixl_wrapper.get_new_notifs() -+ -+ def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks): -+ self._tp_size[engine_id] = agent_tp -+ agent_names = [] -+ for agent_meta in agent_metadata: -+ agent_name = self.nixl_wrapper.add_remote_agent(agent_meta) -+ agent_names.append(agent_name) -+ self._remote_agents[engine_id] = agent_names -+ self.kv_caches_base_addr[engine_id] = kv_caches_base_addr -+ -+ tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id] -+ assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}" -+ -+ logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier) -+ if self._is_mla: -+ dst_block_len = self.block_len -+ else: -+ dst_block_len = self.block_len // tp_multiplier -+ if tp_multiplier not in self.src_xfer_side_handles: -+ # create descs and xfer side handles -+ blocks_data = [] -+ for layer_id in range(self.num_layers): -+ for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]: -+ for block_id in range(self.num_blocks): -+ block_offset = block_id * self.block_len -+ for i in range(1 if self._is_mla else tp_multiplier): -+ tp_multiplier_offset = i * dst_block_len -+ blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank)) -+ logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i) -+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") -+ self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs) -+ -+ # create dst xfer side handles -+ self.dst_num_blocks[engine_id] = num_blocks -+ for i in range(tp_multiplier): -+ blocks_data = [] -+ for layer_id in range(self.num_layers): -+ for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]: -+ for block_id in range(num_blocks): -+ block_offset = block_id * dst_block_len -+ blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i)) -+ logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i) -+ descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") -+ self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs) -+ -+ return agent_names -+ -+ def get_done_tranfers(self) -> List[str]: -+ done_req_ids = [] -+ for req_id, handles in self._transfers.items(): -+ running_reqs = [] -+ for handle in handles: -+ xfer_state = self.nixl_wrapper.check_xfer_state(handle) -+ if xfer_state == "DONE": -+ # self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors? -+ continue -+ if xfer_state == "PROC": -+ running_reqs.append(handle) -+ else: -+ raise RuntimeError("Transfer failed with state %s", xfer_state) -+ if len(running_reqs) == 0: -+ done_req_ids.append(req_id) -+ else: -+ self._transfers[req_id] = running_reqs -+ return done_req_ids -diff --git a/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py -new file mode 100644 -index 000000000..418fc7154 ---- /dev/null -+++ b/vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py -@@ -0,0 +1,363 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+""" -+Simple KV Cache Connector for Distributed Machine Learning Inference -+ -+The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache -+producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or -+MooncakePipe. -+ -+But the logic can be extended to support other pipe and lookup buffer. -+""" -+import re -+from typing import TYPE_CHECKING, List, Optional, Tuple, Union -+ -+import torch -+ -+from vllm import _custom_ops as ops -+from vllm.config import VllmConfig, KVTransferConfig -+from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -+from vllm.distributed.utils import StatelessProcessGroup -+from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( -+ SimpleBuffer) -+from vllm.logger import init_logger -+from vllm.sequence import IntermediateTensors -+ -+if TYPE_CHECKING: -+ from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata -+ -+logger = init_logger(__name__) -+ -+ -+class DynamoConnector(KVConnectorBase): -+ -+ def __init__( -+ self, -+ rank: int, -+ local_rank: int, -+ config: VllmConfig, -+ world_group, -+ ): -+ -+ self.config = config.kv_transfer_config -+ self.tp_size = config.parallel_config.tensor_parallel_size -+ self.rank = rank -+ -+ if self.config.kv_connector != "DynamoNcclConnector": -+ raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class") -+ -+ from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( -+ PyNcclPipe) -+ from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import ( -+ DynamoNcclDataPlane) -+ -+ logger.info( -+ "Initializing DynamoNcclConnector under kv_transfer_config %s", -+ self.config) -+ -+ self.lookup_buffer_size = self.config.kv_buffer_size -+ -+ self.producer_data_pipe: PyNcclPipe -+ self.consumer_data_pipe: PyNcclPipe -+ self.producer_signal_pipe: PyNcclPipe -+ self.consumer_signal_pipe: PyNcclPipe -+ -+ self._broadcast_and_enhance_kv_config(rank, config, world_group) -+ -+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) -+ self.tp_size = config.parallel_config.tensor_parallel_size -+ -+ # 2 pipes for every rank in the world -+ if self.config.is_kv_producer: -+ port_offset_base = rank + 1 -+ else: -+ port_offset_base = rank // self.config.tensor_parallel_multiplier + 1 -+ -+ -+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier -+ self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config) -+ -+ self.data_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, -+ local_rank=local_rank, -+ config=self.config, -+ port_offset=port_offset_base, -+ ) -+ -+ self.data_plane = DynamoNcclDataPlane( -+ data_pipe=self.data_pipe, -+ port=self._get_data_plane_port(self.global_kv_rank), -+ ) -+ -+ def send_kv_caches_and_hidden_states( -+ self, -+ model_executable: torch.nn.Module, -+ model_input: "ModelInputForGPUWithSamplingMetadata", -+ kv_caches: List[torch.Tensor], -+ hidden_or_intermediate_states: Union[torch.Tensor, -+ IntermediateTensors], -+ ) -> None: -+ -+ input_tokens_tensor = model_input.input_tokens -+ seq_lens = model_input.attn_metadata.seq_lens -+ slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() -+ start_layer = model_executable.model.start_layer -+ end_layer = model_executable.model.end_layer -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) -+ -+ model_config = model_executable.model.config -+ is_deepseek = "deepseek" in model_config.architectures[0].lower() -+ if not is_deepseek: -+ num_heads = int(model_config.num_key_value_heads / self.tp_size) -+ hidden_size = model_config.hidden_size -+ num_attention_heads = model_config.num_attention_heads -+ head_size = int(hidden_size / num_attention_heads) -+ else: -+ num_heads = int(model_config.num_key_value_heads / self.tp_size) -+ hidden_size = model_config.hidden_size -+ num_attention_heads = model_config.num_attention_heads -+ head_size = int(4.5 * hidden_size / num_attention_heads) -+ -+ # query_lens contains new KV caches that are added to vLLM. -+ # so we will send them to decode instance -+ # FIXME(Kuntai): This assume that all requests are prefill. -+ for idx, slen in enumerate(seq_lens): -+ start_pos = sum(seq_lens[:idx]) -+ end_pos = start_pos + slen -+ current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id) -+ decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config) -+ -+ for target_rank in range(self.config.tensor_parallel_multiplier): -+ -+ keys, values = [], [] -+ -+ for layer_id in range(start_layer, end_layer): -+ kv_cache = kv_caches[layer_id - start_layer] -+ -+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] -+ -+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier -+ head_start = target_rank * num_heads_per_rank -+ head_end = head_start + num_heads_per_rank -+ -+ if not is_deepseek: -+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) -+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) -+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ else: -+ key_cache = kv_cache -+ keys.append(key_cache[current_slot_mapping].unsqueeze(0)) -+ values.append(torch.empty(0)) -+ -+ keys = torch.cat(keys, dim=0) -+ values = torch.cat(values, dim=0) -+ -+ decode_global_rank = decode_first_global_rank + target_rank -+ decode_port = self._get_data_plane_port(decode_global_rank) -+ partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos] -+ self._send(decode_hostname, decode_port, current_request_id, keys, values, -+ partial_hidden_or_intermediate_states) -+ -+ logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) -+ -+ def recv_kv_caches_and_hidden_states( -+ self, model_executable: torch.nn.Module, -+ model_input: "ModelInputForGPUWithSamplingMetadata", -+ kv_caches: List[torch.Tensor] -+ ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, -+ "ModelInputForGPUWithSamplingMetadata"]: -+ -+ # When bypass_model_exec is set to False, it means that at least for one -+ # request its corresponding KV cache or hidden state is missing. -+ # In this case we need to do prefilling to recompute missing KV cache -+ # and hidden states. -+ bypass_model_exec = True -+ -+ input_tokens_tensor = model_input.input_tokens -+ seq_lens = model_input.attn_metadata.seq_lens -+ slot_mapping = model_input.attn_metadata.slot_mapping.flatten() -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) -+ -+ hidden_or_intermediate_states_for_one_req = [] -+ -+ input_tokens_list = [] -+ start_pos_list = [] -+ -+ model_config = model_executable.model.config -+ is_deepseek = "deepseek" in model_config.architectures[0].lower() -+ -+ # enumerate different requests -+ # FIXME(Kuntai): This impl assumes that all requests are prefill. -+ for idx, slen in enumerate(seq_lens): -+ -+ start_pos = sum(seq_lens[:idx]) -+ end_pos = start_pos + slen -+ current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ num_tokens = slen -+ -+ # collecting data for rebuilding the input -+ input_tokens_list.append(current_tokens) -+ start_pos_list.append(start_pos) -+ -+ ret = self._recv(current_request_id) -+ keys: torch.Tensor = ret[0] -+ values: torch.Tensor = ret[1] -+ hidden: torch.Tensor = ret[2] -+ -+ # put received KV caches into paged memory -+ for i in range(model_executable.model.start_layer, -+ model_executable.model.end_layer): -+ -+ kv_cache = kv_caches[i - model_executable.model.start_layer] -+ layer = model_executable.model.layers[i] -+ -+ if not is_deepseek: -+ key_cache, value_cache = kv_cache[0], kv_cache[1] -+ ops.reshape_and_cache_flash( -+ keys[i - model_executable.model.start_layer].to( -+ key_cache.device), -+ values[i - model_executable.model.start_layer].to( -+ value_cache.device), -+ key_cache, -+ value_cache, -+ slot_mapping[start_pos:end_pos], -+ layer.self_attn.attn.kv_cache_dtype, -+ layer.self_attn.attn._k_scale, -+ layer.self_attn.attn._v_scale, -+ ) -+ else: -+ key_cache = kv_cache -+ copy_from =keys[i - model_executable.model.start_layer].to( -+ key_cache.device) -+ kv_cache[slot_mapping[start_pos:end_pos]] = copy_from -+ -+ hidden_or_intermediate_states_for_one_req.append(hidden) -+ -+ if not bypass_model_exec: -+ # Some of the KV cache is not retrieved -+ # Here we will fall back to normal model forwarding -+ # But optionally you can adjust model_input so that you only do -+ # prefilling on those tokens that are missing KV caches. -+ logger.debug( -+ "[rank%d]: Failed to receive all KVs and hidden " -+ "states, redo model forwarding.", torch.distributed.get_rank()) -+ hidden_or_intermediate_states = None -+ -+ else: -+ logger.debug( -+ "[rank%d]: Successfully received all KVs and hidden " -+ "states, skip model forwarding.", torch.distributed.get_rank()) -+ hidden_or_intermediate_states = torch.cat( -+ hidden_or_intermediate_states_for_one_req, dim=0) -+ -+ return hidden_or_intermediate_states, bypass_model_exec, model_input -+ -+ def close(self): -+ self.data_pipe.close() -+ # self.data_plane.close() -+ -+ @staticmethod -+ def parse_request_id(request_id: str) -> Tuple[str, int]: -+ # Regular expression to match the string hostname and integer decode_kv_rank -+ pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)" -+ -+ # Use re.search to find the pattern in the request_id -+ match = re.search(pattern, request_id) -+ if match: -+ # Extract the ranks -+ decode_hostname = match.group(1) -+ decode_rank = int(match.group(2)) -+ -+ return decode_hostname, decode_rank -+ raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank") -+ -+ def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor): -+ remote_address = f"{hostname}:{port}" -+ self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address) -+ self.data_plane.send_tensor(values, f"{request_id}_values", remote_address) -+ self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address) -+ -+ def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: -+ keys = self.data_plane.recv_tensor(f"{request_id}_keys") -+ values = self.data_plane.recv_tensor(f"{request_id}_values") -+ hidden = self.data_plane.recv_tensor(f"{request_id}_hidden") -+ return keys, values, hidden -+ -+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: -+ if kv_rank < config.kv_producers_parallel_size: -+ return kv_rank -+ -+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size -+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier -+ -+ -+ def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: -+ if kv_rank <= config.kv_producers_parallel_size: -+ return kv_rank * config.kv_producers_tensor_parallel_size + rank -+ -+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size -+ return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank -+ -+ -+ def _get_data_plane_port(self, global_kv_rank: int) -> int: -+ return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank -+ -+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): -+ if rank == 0: -+ config_group = StatelessProcessGroup.create( -+ host=self.config.kv_ip, -+ port=self.config.kv_port, -+ rank=self.config.kv_rank, -+ world_size=self.config.kv_parallel_size, -+ ) -+ parallel_configs = config_group.all_gather_obj({ -+ "kv_role": self.config.kv_role, -+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size, -+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, -+ }) -+ logger.debug("parallel_configs: %s", parallel_configs) -+ kv_config_enhanced = { -+ "kv_producers_tensor_parallel_size": None, -+ "kv_consumers_tensor_parallel_size": None, -+ "kv_producers_pipeline_parallel_size": None, -+ "kv_consumers_pipeline_parallel_size": None, -+ "kv_producers_parallel_size": 0, -+ } -+ for parallel_config in parallel_configs: -+ kv_role = parallel_config["kv_role"] -+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" -+ -+ if kv_role == "kv_producer": -+ kv_config_enhanced["kv_producers_parallel_size"] += 1 -+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: -+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] -+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] -+ else: -+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" -+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" -+ world_group.broadcast_object(kv_config_enhanced) -+ else: -+ kv_config_enhanced = world_group.broadcast_object() -+ logger.info("kv_config_enhanced: %s", kv_config_enhanced) -+ -+ self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] -+ self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] -+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] -+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] -+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] -diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py -index e37ce6dc7..f1ba144c7 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/factory.py -+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import importlib - from typing import TYPE_CHECKING, Callable, Dict, Type -@@ -27,13 +40,13 @@ class KVConnectorFactory: - - @classmethod - def create_connector(cls, rank: int, local_rank: int, -- config: "VllmConfig") -> KVConnectorBase: -+ config: "VllmConfig", world_group) -> KVConnectorBase: - connector_name = config.kv_transfer_config.kv_connector - if connector_name not in cls._registry: - raise ValueError(f"Unsupported connector type: {connector_name}") - - connector_cls = cls._registry[connector_name]() -- return connector_cls(rank, local_rank, config) -+ return connector_cls(rank, local_rank, config, world_group) - - - # Register various connectors here. -diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py -index 49b97d7b5..c77c570ea 100644 ---- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py -+++ b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - Simple KV Cache Connector for Distributed Machine Learning Inference - -@@ -8,14 +21,16 @@ MooncakePipe. - - But the logic can be extended to support other pipe and lookup buffer. - """ -+import re - from typing import TYPE_CHECKING, List, Optional, Tuple, Union - - import torch - - import vllm.envs as envs - from vllm import _custom_ops as ops --from vllm.config import VllmConfig -+from vllm.config import VllmConfig, KVTransferConfig - from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -+from vllm.distributed.utils import StatelessProcessGroup - from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) - from vllm.logger import init_logger -@@ -34,9 +49,11 @@ class SimpleConnector(KVConnectorBase): - rank: int, - local_rank: int, - config: VllmConfig, -+ world_group, - ): - - self.config = config.kv_transfer_config -+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) - self.tp_size = config.parallel_config.tensor_parallel_size - self.is_deepseek_mla = config.model_config.is_deepseek_mla - self.use_mla_opt = not envs.VLLM_MLA_DISABLE -@@ -74,20 +91,31 @@ class SimpleConnector(KVConnectorBase): - self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - -+ self._broadcast_and_enhance_kv_config(rank, config, world_group) -+ -+ self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config) -+ self.tp_size = config.parallel_config.tensor_parallel_size -+ - # 2 pipes for every rank in the world -- port_offset_base = 2 * rank -+ if self.config.is_kv_producer: -+ port_offset_base = 2 * rank + 1 -+ else: -+ port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1 - -+ self.local_kv_rank = rank % self.config.tensor_parallel_multiplier - # In disaggregated prefill, the prefill vLLM only uses send pipe - # and the decode vLLM only uses recv pipe - if self.config.is_kv_producer: - - if self.config.kv_connector == "PyNcclConnector": - self.producer_data_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.producer_signal_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, -@@ -111,11 +139,13 @@ class SimpleConnector(KVConnectorBase): - # its recv pipe to the send pipe of KV producder - if self.config.kv_connector == "PyNcclConnector": - self.consumer_data_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.consumer_signal_pipe = PyNcclPipe( -+ kv_group_rank=self.kv_group_rank, - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, -@@ -134,21 +164,25 @@ class SimpleConnector(KVConnectorBase): - self.config.kv_buffer_size, - ) - -- def select(self, input_tokens: Optional[torch.Tensor], -+ def select(self, source_rank: int, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - -+ logger.info("Selecting KV caches and hidden states for source rank %d", source_rank) -+ - assert self.consumer_buffer is not None, "Please initialize the "\ - "consumer buffer before calling select." -- return self.consumer_buffer.drop_select(input_tokens, roi) -+ return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi) - -- def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, -+ def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - -+ logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank) -+ - assert self.producer_buffer is not None, "Please initialize the "\ - "producer buffer before calling insert." - -- self.producer_buffer.insert(input_tokens, roi, key, value, hidden) -+ self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden) - - def send_kv_caches_and_hidden_states( - self, -@@ -165,6 +199,7 @@ class SimpleConnector(KVConnectorBase): - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) - - model_config = model_executable.model.config - num_heads = int(model_config.num_key_value_heads / self.tp_size) -@@ -207,11 +242,11 @@ class SimpleConnector(KVConnectorBase): - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ _, decode_kv_rank = self.parse_request_id(current_request_id) -+ starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config) - -- keys, values = [], [] -- -- for layer_id in range(start_layer, end_layer): -- kv_cache = kv_caches[layer_id - start_layer] -+ for target_rank in range(self.config.tensor_parallel_multiplier): - - if self.is_deepseek_mla and self.use_mla_opt: - key_cache = kv_cache.reshape(-1, num_heads, head_size) -@@ -220,18 +255,32 @@ class SimpleConnector(KVConnectorBase): - key_cache = kv_cache[0].reshape(-1, num_heads, head_size) - value_cache = kv_cache[1].reshape(-1, num_heads, head_size) - -- current_slot_mapping = slot_mapping_flat[start_pos:end_pos] -+ for layer_id in range(start_layer, end_layer): -+ kv_cache = kv_caches[layer_id - start_layer] -+ -+ current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - -- keys.append(key_cache[current_slot_mapping].unsqueeze(0)) -- values.append(value_cache[current_slot_mapping].unsqueeze(0)) -+ num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier -+ head_start = target_rank * num_heads_per_rank -+ head_end = head_start + num_heads_per_rank - -- keys = torch.cat(keys, dim=0) -- values = torch.cat(values, dim=0) -+ if not is_deepseek: -+ key_cache = kv_cache[0].reshape(-1, num_heads, head_size) -+ value_cache = kv_cache[1].reshape(-1, num_heads, head_size) -+ keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0)) -+ else: -+ key_cache = kv_cache -+ keys.append(key_cache[current_slot_mapping].unsqueeze(0)) -+ values.append(torch.empty(0)) - -- self.insert(current_tokens, -- torch.ones_like(current_tokens, -- dtype=bool), keys, values, -- hidden_or_intermediate_states[start_pos:end_pos]) -+ keys = torch.cat(keys, dim=0) -+ values = torch.cat(values, dim=0) -+ -+ self.insert(starting_kv_group_rank, target_rank, current_tokens, -+ torch.ones_like(current_tokens, -+ dtype=bool), keys, values, -+ hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - -@@ -254,6 +303,7 @@ class SimpleConnector(KVConnectorBase): - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() -+ request_ids = list(model_input.request_ids_to_seq_ids.keys()) - - hidden_or_intermediate_states_for_one_req = [] - -@@ -261,6 +311,9 @@ class SimpleConnector(KVConnectorBase): - num_computed_tokens_list = [] - start_pos_list = [] - -+ model_config = model_executable.model.config -+ is_deepseek = "deepseek" in model_config.architectures[0].lower() -+ - # enumerate different requests - # FIXME(Kuntai): This impl assumes that all requests are prefill. - for idx, slen in enumerate(seq_lens): -@@ -280,13 +333,15 @@ class SimpleConnector(KVConnectorBase): - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] -+ current_request_id = request_ids[idx] -+ prefill_rank, _ = self.parse_request_id(current_request_id) - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - -- ret = self.select(current_tokens, -+ ret = self.select(prefill_rank, current_tokens, - torch.ones_like(current_tokens, dtype=bool)) - if ret[0] is None: - # didn't find any match. -@@ -379,3 +434,77 @@ class SimpleConnector(KVConnectorBase): - # MooncakePipe reuses data_pipe for signal_pipe, so we only have to - # close the data_pipe. - pass -+ -+ @staticmethod -+ def parse_request_id(request_id): -+ # Regular expression to match the ranks -+ pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)" -+ -+ # Use re.search to find the pattern in the request_id -+ match = re.search(pattern, request_id) -+ -+ if match: -+ # Extract the ranks -+ prefill_rank = int(match.group(1)) -+ decode_rank = int(match.group(2)) -+ -+ return prefill_rank, decode_rank -+ else: -+ return None, None -+ -+ -+ -+ def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int: -+ if kv_rank < config.kv_producers_parallel_size: -+ return kv_rank -+ -+ kv_consumer_rank = kv_rank - config.kv_producers_parallel_size -+ return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier -+ -+ def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group): -+ if rank == 0: -+ if self.config.kv_connector == "PyNcclConnector": -+ config_group = StatelessProcessGroup.create( -+ host=self.config.kv_ip, -+ port=self.config.kv_port, -+ rank=self.config.kv_rank, -+ world_size=self.config.kv_parallel_size, -+ ) -+ parallel_configs = config_group.all_gather_obj({ -+ "kv_role": self.config.kv_role, -+ "tensor_parallel_size": config.parallel_config.tensor_parallel_size, -+ "pipeline_parallel_size": config.parallel_config.pipeline_parallel_size, -+ }) -+ logger.debug("parallel_configs: %s", parallel_configs) -+ kv_config_enhanced = { -+ "kv_producers_tensor_parallel_size": None, -+ "kv_consumers_tensor_parallel_size": None, -+ "kv_producers_pipeline_parallel_size": None, -+ "kv_consumers_pipeline_parallel_size": None, -+ "kv_producers_parallel_size": 0, -+ } -+ for parallel_config in parallel_configs: -+ kv_role = parallel_config["kv_role"] -+ assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances" -+ -+ if kv_role == "kv_producer": -+ kv_config_enhanced["kv_producers_parallel_size"] += 1 -+ if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None: -+ kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"] -+ kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"] -+ else: -+ assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size" -+ assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size" -+ world_group.broadcast_object(kv_config_enhanced) -+ -+ else: -+ raise NotImplementedError("MooncakeConnector is not supported in Dynamo patch") -+ else: -+ kv_config_enhanced = world_group.broadcast_object() -+ logger.info("kv_config_enhanced: %s", kv_config_enhanced) -+ -+ self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"] -+ self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"] -+ self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"] -+ self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"] -+ self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"] -diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py -index 10bbfe1dd..8268bf3eb 100644 ---- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py -+++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - Implements a distributed key-value (KV) cache transfer mechanism. - -@@ -11,7 +24,8 @@ - """ - import threading - from collections import deque --from typing import Deque, List, Optional, Union -+from concurrent.futures import ThreadPoolExecutor -+from typing import Deque, List, Optional, Union, Dict - - import torch - -@@ -45,7 +59,7 @@ class SimpleBuffer(KVLookupBufferBase): - self.buffer_cv = threading.Condition() - self.signal_pipe = signal_pipe - self.data_pipe = data_pipe -- self.request_handling_thread: Optional[threading.Thread] = None -+ self.request_handling_thread: Optional[ThreadPoolExecutor] = None - - self.normal_signal = torch.tensor([0], device="cpu") - self.end_signal = None -@@ -56,10 +70,16 @@ class SimpleBuffer(KVLookupBufferBase): - # tokens_roi_sender: tokens and roi of the producer (in the buffer) - # tokens_roi_recver: tokens and roi of the consumer (query) - -- tokens_sender = tokens_roi_sender[0] -- tokens_recver = tokens_roi_recver[0] -- roi_sender = tokens_roi_sender[1] -- roi_recver = tokens_roi_recver[1] -+ target_rank_sender = tokens_roi_sender[0] -+ target_rank_recver = tokens_roi_recver[0] -+ -+ if target_rank_sender.item() != target_rank_recver.item(): -+ return 0 -+ -+ tokens_sender = tokens_roi_sender[1] -+ tokens_recver = tokens_roi_recver[1] -+ roi_sender = tokens_roi_sender[2] -+ roi_recver = tokens_roi_recver[2] - - if tokens_recver is None: - # consumer sends an empty request -@@ -79,14 +99,14 @@ class SimpleBuffer(KVLookupBufferBase): - - return 0 - -- def _send_tensor_and_dec_size(self, -- tensor: Optional[torch.Tensor]) -> None: -+ def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor], -+ target_rank: int) -> None: - - assert tensor is not None, "Use self.data_pipe.send(None) instead" - self.buffer_size -= tensor.element_size() * tensor.numel() - if tensor.dtype == torch.bool: - tensor = tensor.float() -- self.data_pipe.send_tensor(tensor) -+ self.data_pipe.send_tensor(tensor, target_rank) - - def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]): - -@@ -99,7 +119,7 @@ class SimpleBuffer(KVLookupBufferBase): - - raise AssertionError(f"Unknown data type {type(data)}") - -- def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor, -+ def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor): - -@@ -132,23 +152,54 @@ class SimpleBuffer(KVLookupBufferBase): - def _is_end_signal(self, signal): - return signal is None - -- def drop_select_handler(self): -+ def drop_select_handler(self, rank: int): - - try: - -- while True: -- signal = self.signal_pipe.recv_tensor() -- if self._is_end_signal(signal): -- logger.info("Received end signal!") -- break -+ signal = self.signal_pipe.recv_tensor(rank) -+ if self._is_end_signal(signal): -+ logger.info("Received end signal!") -+ return -+ target_kv_rank = self.data_pipe.recv_tensor(rank) -+ # assert target_rank.item() == rank, "Target rank does not match"\ -+ # "the rank of the drop-select handler" -+ input_tokens = self.data_pipe.recv_tensor(rank) -+ roi = self.data_pipe.recv_tensor(rank) -+ assert roi is not None, "Please provide the roi when sending "\ -+ "drop-select request" -+ roi = (roi > 0.5) -+ tokens_roi_recver = [target_kv_rank, input_tokens, roi] -+ -+ matched_length = 0 -+ -+ # perform input tokens and roi matching -+ # FIXME: this matching is O(n), ideally it should be O(1) -+ # but this buffer size won't (and shouldn't) be too large so -+ # the fix is not urgent. -+ with self.buffer_lock: -+ -+ for _ in range(len(self.buffer)): - -- input_tokens = self.data_pipe.recv_tensor() -+ temp_length = self._matches(self.buffer[0], -+ tokens_roi_recver) -+ if temp_length > 0: -+ matched_length = temp_length -+ break -+ # rotate the element we just accessed to the end -+ self.buffer.rotate(-1) - -- roi = self.data_pipe.recv_tensor() -- assert roi is not None, "Please provide the roi when sending "\ -- "drop-select request" -- roi = (roi > 0.5) -- tokens_roi_recver = [input_tokens, roi] -+ if matched_length > 0: -+ # need to clone the tensor -+ # in case the tensor is freed before sending finishes -+ matched_item = self.buffer.popleft() -+ target_rank = matched_item[0].item() -+ for tensor in matched_item[1:]: -+ self._send_tensor_and_dec_size(tensor, rank) -+ -+ else: -+ # no match, just send None -+ for _ in range(5): -+ self.data_pipe.send_tensor(None, rank) - - def is_buffer_available( - tokens_roi_recver: List[torch.Tensor], ) -> bool: -@@ -182,11 +233,12 @@ class SimpleBuffer(KVLookupBufferBase): - - logger.debug("Closing drop_select_handler") - -+ - def drop_select( -- self, input_tokens: Optional[torch.Tensor], -+ self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]: - -- assert self.request_handling_thread is None, \ -+ assert not self.request_handling_thread, \ - "drop_select should be called by the KV cache consumer "\ - "(e.g. the decode vLLM instance)" - -@@ -195,19 +247,21 @@ class SimpleBuffer(KVLookupBufferBase): - if isinstance(roi, torch.Tensor): - roi = roi.clone().float() - -- self.signal_pipe.send_tensor(self.normal_signal) -- self.data_pipe.send_tensor(input_tokens) -- self.data_pipe.send_tensor(roi) -+ self.signal_pipe.send_tensor(self.normal_signal, rank) -+ -+ self.data_pipe.send_tensor(torch.tensor(kv_rank), rank) -+ self.data_pipe.send_tensor(input_tokens, rank) -+ self.data_pipe.send_tensor(roi, rank) - -- input_tokens = self.data_pipe.recv_tensor() -- roi = self.data_pipe.recv_tensor() -+ input_tokens = self.data_pipe.recv_tensor(rank) -+ roi = self.data_pipe.recv_tensor(rank) - if roi is not None: - # convert from float tensor to bool tensor - # as PyNccl does not support sending bool tensor - roi = (roi > 0.5) -- key = self.data_pipe.recv_tensor() -- value = self.data_pipe.recv_tensor() -- hidden = self.data_pipe.recv_tensor() -+ key = self.data_pipe.recv_tensor(rank) -+ value = self.data_pipe.recv_tensor(rank) -+ hidden = self.data_pipe.recv_tensor(rank) - - return [input_tokens, roi, key, value, hidden] - -@@ -220,15 +274,13 @@ class SimpleBuffer(KVLookupBufferBase): - # when calling the insert, the current process is a sender - # need to launch the request handler and start listening to request. - if self.request_handling_thread is None: -- self.request_handling_thread = threading.Thread( -- target=self.drop_select_handler) -- self.request_handling_thread.start() -+ self.request_handling_thread = ThreadPoolExecutor(max_workers=1) -+ self.request_handling_thread.submit(self.drop_select_handler) - - def close(self): - -- if hasattr(self, "request_handling_thread" -- ) and self.request_handling_thread is not None: -- self.request_handling_thread.join() -+ if hasattr(self, "request_handling_thread") and self.request_handling_thread: -+ self.request_handling_thread.shutdown() - - else: - # TODO: have a explicit close signal and have a explicit way to -diff --git a/vllm/distributed/kv_transfer/kv_pipe/base.py b/vllm/distributed/kv_transfer/kv_pipe/base.py -index 40589fb3e..a3991c39d 100644 ---- a/vllm/distributed/kv_transfer/kv_pipe/base.py -+++ b/vllm/distributed/kv_transfer/kv_pipe/base.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - This file defines an interface `KVPipeBase` - that provides an abstraction for sending and receiving tensors, or None, via -@@ -23,7 +36,7 @@ class KVPipeBase(ABC): - """ - - @abstractmethod -- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: -+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None: - """Send a tensor, or None, via the pipe. - - Need to support sending None -- important for error handling. -@@ -41,7 +54,7 @@ class KVPipeBase(ABC): - raise NotImplementedError - - @abstractmethod -- def recv_tensor(self) -> Optional[torch.Tensor]: -+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: - """Receive a tensor (can be None) from the pipeline. - - Returns: -diff --git a/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py -new file mode 100644 -index 000000000..ca5345359 ---- /dev/null -+++ b/vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py -@@ -0,0 +1,139 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+import logging -+import threading -+import typing -+import zmq -+import socket -+import time -+import torch -+ -+from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe -+ -+ -+logger = logging.getLogger(__name__) -+ -+ -+class DynamoNcclDataPlane: -+ def __init__( -+ self, -+ data_pipe: PyNcclPipe, -+ hostname: str = "", -+ port: int = 0, -+ ) -> None: -+ -+ self.data_pipe = data_pipe -+ if not hostname: -+ hostname = socket.gethostname() -+ if port == 0: -+ raise ValueError("Port cannot be 0") -+ self._hostname = hostname -+ self._port = port -+ self.store = {} -+ self.context = zmq.Context() -+ self.rep_socket = self.context.socket(zmq.REP) -+ logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}") -+ self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}") -+ self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True) -+ self._listener_thread.start() -+ self.req_sockets = {} -+ logger.info(f"Rank {self.rank} connected to the server") -+ -+ @property -+ def rank(self): -+ return self.data_pipe.kv_group_rank -+ -+ def send_tensor( -+ self, -+ tensor: torch.Tensor, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ): -+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}") -+ return self._send_tensor(tensor, tensor_id, remote_address) -+ -+ def recv_tensor( -+ self, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ) -> torch.Tensor: -+ ret = self._recv_tensor(tensor_id, remote_address) -+ return ret -+ -+ def _send_tensor( -+ self, -+ tensor: torch.Tensor, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ): -+ logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}") -+ if remote_address is None: -+ self.store[tensor_id] = tensor -+ else: -+ # tensor_shape = "_".join(str(dim) for dim in tensor.shape) -+ # tensor_dtype = str(tensor.dtype) -+ if remote_address not in self.req_sockets: -+ self.req_sockets[remote_address] = self.context.socket(zmq.REQ) -+ self.req_sockets[remote_address].connect(f"tcp://{remote_address}") -+ -+ req_socket = self.req_sockets[remote_address] -+ # req_socket.connect(f"tcp://{remote_address}") -+ req_socket.send_string(f"PUT {self.rank} {tensor_id}") -+ dst_rank = req_socket.recv_string() -+ logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}") -+ self.data_pipe.send_tensor(tensor, int(dst_rank)) -+ -+ def _recv_tensor( -+ self, -+ tensor_id: str, -+ remote_address: typing.Optional[str] = None, -+ ) -> torch.Tensor: -+ logger.debug(f"Rank {self.rank} receiving tensor") -+ if remote_address is not None: -+ raise NotImplementedError("Getting tensor from remote rank not implemented") -+ if tensor_id in self.store: -+ logger.debug(f"Popping tensor {tensor_id} from store") -+ future = self.store.pop(tensor_id) -+ tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait -+ logger.debug(f"Rank {self.rank} received tensor") -+ return tensor -+ -+ logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}") -+ time.sleep(0.001) -+ return self._recv_tensor(tensor_id, remote_address) -+ # raise NotImplementedError("Tensor not found in store") -+ -+ def _receive_tensor( -+ self, -+ tensor_id: str, -+ rank: int, -+ ): -+ future = self.data_pipe.recv_tensor(rank) -+ logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store") -+ self.store[tensor_id] = future -+ -+ def listen_for_requests(self): -+ while True: -+ cmd, rank, tensor_id = self.rep_socket.recv_string().split() -+ logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}") -+ self.rep_socket.send_string(f"{self.rank}") -+ if cmd == "GET": -+ raise NotImplementedError("Getting tensor from remote rank not implemented") -+ elif cmd == "PUT": -+ rank = int(rank) -+ # shape = [int(dim) for dim in shape.split("_")] -+ # dtype = getattr(torch, dtype) -+ self._receive_tensor(tensor_id, rank) -diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py -index e8bf607eb..fa5543fa9 100644 ---- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py -+++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """ - This module implements a PyNccl pipe for sending and receiving - Optional[torch.Tensor] between distributed ranks with advanced -@@ -45,14 +58,16 @@ class PyNcclPipe(KVPipeBase): - METADATA_DTYPE = torch.int64 - - def __init__(self, -+ kv_group_rank: int, - local_rank: int, - config: KVTransferConfig, - device: Optional[str] = None, - port_offset: int = 0): - self.config = config - self.local_rank = local_rank -- self.kv_rank = self.config.kv_rank -+ self.kv_group_rank = kv_group_rank - self.kv_parallel_size = self.config.kv_parallel_size -+ self.kv_world_size = self.config.kv_world_size - if device is None: - self.device = self._select_device(self.config.kv_buffer_device) - else: -@@ -71,9 +86,6 @@ class PyNcclPipe(KVPipeBase): - self.group.barrier() - impl = self._get_device_send_recv_impl(self.group) - self.device_send_func, self.device_recv_func = impl -- # set target rank -- self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size -- self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size - - # transportation-related variables - self.transport_thread: Optional[ThreadPoolExecutor] = None -@@ -147,16 +159,16 @@ class PyNcclPipe(KVPipeBase): - dtype=metadata["dtype"], - device=self.device) - -- def _send_metadata(self, metadata: Metadata): -+ def _send_metadata(self, metadata: Metadata, target_rank: int): - """ - Send the metadata dictionary to the target rank. - - Parameters: - - metadata: A dictionary with keys "dtype" and "shape". - """ -- self.group.send_obj(metadata, self.target_rank_for_send) -+ self.group.send_obj(metadata, target_rank) - -- def _recv_metadata(self) -> Metadata: -+ def _recv_metadata(self, src_rank: int) -> Metadata: - """ - Receive the metadata dictionary from the target rank. - -@@ -164,9 +176,9 @@ class PyNcclPipe(KVPipeBase): - - metadata: A dictionary with keys "dtype" and "shape" describing - the tensor. - """ -- return self.group.recv_obj(self.target_rank_for_recv) -+ return self.group.recv_obj(src_rank) - -- def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: -+ def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: - """ - The actual implementation of sending the tensor and its metadata to the - target rank. -@@ -176,12 +188,12 @@ class PyNcclPipe(KVPipeBase): - being sent. - """ - metadata = self._make_metadata(tensor) -- self._send_metadata(metadata) -+ self._send_metadata(metadata, target_rank) - if tensor is not None: - self.device_send_func(tensor.to(self.device), -- self.target_rank_for_send) -+ target_rank) - -- def _recv_impl(self) -> Optional[torch.Tensor]: -+ def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]: - """ - The actual implementation of receiving a tensor and its metadata from - the target rank. -@@ -189,21 +201,22 @@ class PyNcclPipe(KVPipeBase): - Returns: - - buffer: The received tensor, or None if no tensor is received. - """ -- metadata = self._recv_metadata() -+ metadata = self._recv_metadata(src_rank) - if metadata["dtype"] is None: - return None - buffer = self._prepare_recv_buffer(metadata) -- self.device_recv_func(buffer, self.target_rank_for_recv) -+ self.device_recv_func(buffer, src_rank) - - return buffer - - def send_tensor_wrapper(self, tensor: Optional[torch.Tensor], -- tensor_size: int) -> None: -+ tensor_size: int, -+ target_rank: int) -> None: - """ - Wrapper for _send_impl to handle exceptions and update buffer size. - """ - try: -- self._send_impl(tensor) -+ self._send_impl(tensor, target_rank) - - with self.buffer_size_lock: - self.buffer_size -= tensor_size -@@ -222,7 +235,7 @@ class PyNcclPipe(KVPipeBase): - logger.debug("KV cache transfer pipe is full. Waiting...") - time.sleep(0.05) - -- def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: -+ def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None: - """ - Sends a tensor and its metadata to the destination rank in a - non-blocking way. -@@ -230,6 +243,7 @@ class PyNcclPipe(KVPipeBase): - Parameters: - - tensor: The tensor to send, or None if no tensor is being sent. - """ -+ logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank) - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - -@@ -243,32 +257,39 @@ class PyNcclPipe(KVPipeBase): - with self.buffer_size_lock: - self.buffer_size += tensor_size - -- self.transport_thread.submit(self.send_tensor_wrapper, tensor, -- tensor_size) -+ future = self.transport_thread.submit(self.send_tensor_wrapper, tensor, -+ tensor_size, -+ target_rank) -+ return future - -- def recv_tensor(self) -> Optional[torch.Tensor]: -+ def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]: - """ - Receives a tensor and its metadata from the source rank. Blocking call. - - Returns: - - tensor: The received tensor, or None if no tensor is received. - """ -+ -+ logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank) -+ - if self.transport_thread is None: - self.transport_thread = ThreadPoolExecutor(max_workers=1) - -- future = self.transport_thread.submit(self._recv_impl) -+ future = self.transport_thread.submit(self._recv_impl, src_rank) - -- try: -- tensor = future.result() -- except Exception as e: -- logger.error("Encountering exception in KV receiving thread") -- logger.error("%s", e) -- logger.error("My device: %s", self.device) -- import traceback -- traceback.print_exc() -- raise e -+ return future -+ -+ # try: -+ # tensor = future.result() -+ # except Exception as e: -+ # logger.error("Encountering exception in KV receiving thread") -+ # logger.error("%s", e) -+ # logger.error("My device: %s", self.device) -+ # import traceback -+ # traceback.print_exc() -+ # raise e - -- return tensor -+ # return tensor - - def close(self): - """ -diff --git a/vllm/distributed/kv_transfer/kv_transfer_agent.py b/vllm/distributed/kv_transfer/kv_transfer_agent.py -index 1e80e0bd7..f06c7a5f6 100644 ---- a/vllm/distributed/kv_transfer/kv_transfer_agent.py -+++ b/vllm/distributed/kv_transfer/kv_transfer_agent.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - """A centralized entrypoint to perform distributed KV cache transfer. - - This implementation is a shim wrapper on two APIs exposed by `kv_connector`: -@@ -35,6 +48,7 @@ class KVTransferAgent: - rank: int, - local_rank: int, - config: "VllmConfig", -+ world_group, - ): - - self.config = config -@@ -47,7 +61,7 @@ class KVTransferAgent: - "TransferAgent should only be used when kv_connector is set." - - self.connector = KVConnectorFactory.create_connector( -- rank, local_rank, config) -+ rank, local_rank, config, world_group) - - def send_kv_caches_and_hidden_states( - self, -diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py -index e0eeeffb8..9010c6966 100644 ---- a/vllm/distributed/parallel_state.py -+++ b/vllm/distributed/parallel_state.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - # Copyright 2023 The vLLM team. - # Adapted from -@@ -979,7 +992,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: - _KV_TRANSFER = kv_transfer.KVTransferAgent( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, -- config=vllm_config) -+ config=vllm_config, -+ world_group=get_world_group()) - - - def ensure_model_parallel_initialized( -diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py -index 975afe5ad..2208abea0 100644 ---- a/vllm/engine/arg_utils.py -+++ b/vllm/engine/arg_utils.py -@@ -1159,7 +1159,7 @@ class EngineArgs: - # features and raise error for unsupported features. - # * If VLLM_USE_V1=0, we disable V1. - use_v1 = False -- try_v1 = envs.VLLM_USE_V1 or not envs.is_set("VLLM_USE_V1") -+ try_v1 = envs.VLLM_USE_V1 - if try_v1 and self._is_v1_supported_oracle(model_config): - use_v1 = True - -diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py -index 54f7b8fb6..9c1c2635f 100644 ---- a/vllm/engine/llm_engine.py -+++ b/vllm/engine/llm_engine.py -@@ -1,11 +1,28 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import copy - import time -+import pickle -+import uuid - from collections import Counter as collectionsCounter - from collections import deque -+from collections import defaultdict - from contextlib import contextmanager - from dataclasses import dataclass -+from concurrent.futures import ThreadPoolExecutor - from functools import partial - from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, - Iterable, List, Literal, Mapping, NamedTuple, Optional) -@@ -62,6 +79,9 @@ from vllm.utils import (Counter, Device, deprecate_kwargs, - resolve_obj_by_qualname, weak_bind) - from vllm.version import __version__ as VLLM_VERSION - from vllm.worker.model_runner_base import InputProcessingError -+from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest, MemoryOpType -+from vllm.distributed.device_communicators.nixl import NixlMetadata -+ - - logger = init_logger(__name__) - _LOCAL_LOGGING_INTERVAL_SEC = 5 -@@ -93,7 +113,7 @@ class OutputData(NamedTuple): - # outputs from multiple steps. - is_first_step_output: Optional[bool] - skip: List[int] -- -+ remote_prefill_requests: Optional[List[RemotePrefillRequest]] - - class SchedulerContext: - -@@ -107,11 +127,14 @@ class SchedulerContext: - - self.multi_step_stream_outputs: bool = multi_step_stream_outputs - -+ self.remote_prefill_requests: List[RemotePrefillRequest] = [] -+ - def append_output(self, outputs: List[SamplerOutput], - seq_group_metadata_list: List[SequenceGroupMetadata], - scheduler_outputs: SchedulerOutputs, is_async: bool, - is_last_step: bool, -- is_first_step_output: Optional[bool]): -+ is_first_step_output: Optional[bool], -+ remote_prefill_requests: Optional[List[RemotePrefillRequest]] = None): - self.output_queue.append( - OutputData(outputs=outputs, - seq_group_metadata_list=seq_group_metadata_list, -@@ -119,7 +142,9 @@ class SchedulerContext: - is_async=is_async, - is_last_step=is_last_step, - is_first_step_output=is_first_step_output, -- skip=[])) -+ skip=[], -+ remote_prefill_requests=remote_prefill_requests)) -+ - - - class LLMEngine: -@@ -362,7 +387,7 @@ class LLMEngine: - Scheduler = self.vllm_config.scheduler_config.scheduler_cls - self.scheduler = [ - Scheduler( -- self.scheduler_config, self.cache_config, self.lora_config, -+ self.model_config, self.scheduler_config, self.cache_config, self.lora_config, - self.parallel_config.pipeline_parallel_size, - self.async_callbacks[v_id] - if self.model_config.use_async_output_proc else None) -@@ -422,6 +447,39 @@ class LLMEngine: - # Flag to set when an input fails to process and the engine should run - # the next step without re-scheduling. - self._skip_scheduling_next_step = False -+ self.engine_id = str(uuid.uuid4()) -+ self._nixl_agents_names: Optional[List[str]] = None -+ if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": -+ self._nixl_agents_names = self._initialize_nixl() -+ -+ self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) -+ self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size) -+ self._finished_prefills = set() -+ self._finished_transfers = set() -+ -+ @property -+ def is_nixl_initialized(self) -> bool: -+ return getattr(self, "_nixl_agents_names", None) is not None -+ -+ def get_nixl_metadata(self) -> NixlMetadata: -+ if not self.is_nixl_initialized: -+ raise RuntimeError("Nixl is not initialized") -+ agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata") -+ kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr") -+ return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.cache_config.num_gpu_blocks) -+ -+ def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]: -+ if not self.is_nixl_initialized: -+ raise RuntimeError("Nixl is not initialized") -+ engine_id = nixl_metadata.engine_id -+ agents_metadata = nixl_metadata.agent_metadata -+ kv_caches_base_addr = nixl_metadata.kv_caches_base_addr -+ num_blocks = nixl_metadata.num_blocks -+ return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks)) -+ -+ def _initialize_nixl(self) -> List[bytes]: -+ agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,)) -+ return agents_names - - def _initialize_kv_caches(self) -> None: - """Initialize the KV cache in the worker(s). -@@ -535,6 +593,8 @@ class LLMEngine: - # Shutdown model executor when engine is garbage collected - # Use getattr since __init__ can fail before the field is set - if model_executor := getattr(self, "model_executor", None): -+ if self.is_nixl_initialized: -+ model_executor.collective_rpc("shutdown_nixl") - model_executor.shutdown() - - def get_tokenizer_group( -@@ -587,11 +647,14 @@ class LLMEngine: - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> Optional[SequenceGroup]: - """Add a processed request to the engine's request pool. - return the created sequence group. - """ - if isinstance(params, SamplingParams) and params.n > 1: -+ if remote_prefill_params is not None: -+ raise ValueError("Remote prefill params are not supported for multi-step sampling") - ParallelSampleSequenceGroup.add_request( - request_id, - self, -@@ -609,12 +672,14 @@ class LLMEngine: - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) -+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: -+ next(self.seq_counter) # empty sequence for staging - eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) - - encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) - - seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id, -- lora_request, prompt_adapter_request) -+ lora_request, prompt_adapter_request, remote_prefill_params) - - encoder_seq = (None if encoder_inputs is None else Sequence( - seq_id, encoder_inputs, block_size, eos_token_id, lora_request, -@@ -631,8 +696,12 @@ class LLMEngine: - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, -- priority=priority) -+ priority=priority, -+ remote_prefill_params=remote_prefill_params, -+ ) - elif isinstance(params, PoolingParams): -+ if remote_prefill_params is not None: -+ raise ValueError("Remote prefill params are not supported for pooling") - seq_group = self._create_sequence_group_with_pooling( - request_id, - seq, -@@ -703,6 +772,7 @@ class LLMEngine: - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: -@@ -794,6 +864,7 @@ class LLMEngine: - prompt_adapter_request=prompt_adapter_request, - trace_headers=trace_headers, - priority=priority, -+ remote_prefill_params=remote_prefill_params, - ) - - def _validate_token_prompt(self, prompt: PromptType, -@@ -828,6 +899,7 @@ class LLMEngine: - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - encoder_seq: Optional[Sequence] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> SequenceGroup: - """Creates a SequenceGroup with SamplingParams.""" - max_logprobs = self.get_model_config().max_logprobs -@@ -863,7 +935,9 @@ class LLMEngine: - prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq, - priority=priority, -- draft_size=draft_size) -+ draft_size=draft_size, -+ remote_prefill_params=remote_prefill_params, -+ ) - - return seq_group - -@@ -1030,11 +1104,11 @@ class LLMEngine: - # When we process only one request, no pop is required - # (since later we will process all of the rest) - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, -- is_last_step, is_first_step_output, skip) = ctx.output_queue[0] -+ is_last_step, is_first_step_output, skip, remote_prefill_requests) = ctx.output_queue[0] - else: - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step, is_first_step_output, -- skip) = ctx.output_queue.popleft() -+ skip, remote_prefill_requests) = ctx.output_queue.popleft() - - # Sanity check - assert len(seq_group_metadata_list) == len( -@@ -1360,6 +1434,12 @@ class LLMEngine: - - # Clear outputs for each new scheduler iteration - ctx.request_outputs.clear() -+ ctx.remote_prefill_requests.clear() -+ -+ remote_prefill_seq_group_metadata_list: List[SequenceGroupMetadata] = [] -+ running_seq_group_metadata_list: List[SequenceGroupMetadata] = [] -+ remote_prefill_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] -+ running_scheduled_seq_groups: List[ScheduledSequenceGroup] = [] - - # Skip the scheduler if there are any remaining steps in the seq groups. - # This ensures that the scheduler is only called again when the current -@@ -1372,7 +1452,43 @@ class LLMEngine: - # Schedule iteration - (seq_group_metadata_list, scheduler_outputs, - allow_async_output_proc -- ) = self.scheduler[virtual_engine].schedule() -+ ) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers) -+ -+ -+ # Separate remote prefill and running seq groups -+ for seq_group_metadata, scheduled_seq_group in zip(seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups): -+ if seq_group_metadata.do_remote_prefill: -+ remote_prefill_seq_group_metadata_list.append(seq_group_metadata) -+ remote_prefill_scheduled_seq_groups.append(scheduled_seq_group) -+ else: -+ running_seq_group_metadata_list.append(seq_group_metadata) -+ running_scheduled_seq_groups.append(scheduled_seq_group) -+ -+ seq_group_metadata_list = running_seq_group_metadata_list -+ scheduler_outputs.scheduled_seq_groups = running_scheduled_seq_groups -+ -+ # Send remote prefill requests before model execution -+ for seq_group_metadata, scheduled_seq_group in zip(remote_prefill_seq_group_metadata_list, remote_prefill_scheduled_seq_groups): -+ assert len(scheduled_seq_group.seq_group.seqs) == 1 -+ assert self._nixl_agents_names -+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id -+ block_table = seq_group_metadata.block_tables[seq_id] -+ if len(block_table) == len(seq_group_metadata.computed_block_nums): -+ logger.debug("No blocks to prefill") -+ self._finished_prefills.add(seq_group_metadata.request_id) -+ continue -+ -+ remote_prefill_request = RemotePrefillRequest( -+ request_id=seq_group_metadata.request_id, -+ # prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway -+ prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids, # TODO ptarasiewicz do not send the last token when NIXL fixes send notif (needed for writing 0 blocks) -+ sampling_params=scheduled_seq_group.seq_group.sampling_params, -+ block_ids=block_table, -+ engine_id=self.engine_id, -+ computed_block_ids=seq_group_metadata.computed_block_nums, -+ multimodal_data_source=scheduled_seq_group.seq_group.remote_prefill_params.multimodal_data_source -+ ) -+ scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request) - - ctx.seq_group_metadata_list = seq_group_metadata_list - ctx.scheduler_outputs = scheduler_outputs -@@ -1427,8 +1543,46 @@ class LLMEngine: - execute_model_req.async_callback = self.async_callbacks[ - virtual_engine] - -+ # After model execution, we need to transfer the memory from the prefill to the decode -+ memory_transfer_reqs = [] -+ for scheduled_seq_group, seq_group_metadata in zip(scheduler_outputs.scheduled_seq_groups, seq_group_metadata_list): -+ remote_prefill_params = scheduled_seq_group.seq_group.remote_prefill_params -+ if remote_prefill_params is not None and remote_prefill_params.is_remote_decode: -+ assert len(scheduled_seq_group.seq_group.seqs) == 1 -+ req_id = scheduled_seq_group.seq_group.request_id -+ seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id -+ block_table = seq_group_metadata.block_tables[seq_id] -+ staging_block_ids = seq_group_metadata.block_tables[seq_id + 1] -+ -+ num_computed_blocks = len(seq_group_metadata.computed_block_nums) -+ computed_decode_block_ids = remote_prefill_params.decode_block_ids[:num_computed_blocks] -+ -+ if computed_decode_block_ids: -+ kv_recv_req = MemoryTransferRequest( -+ request_id=req_id, -+ local_block_ids=block_table[:num_computed_blocks], -+ staging_block_ids=staging_block_ids[:num_computed_blocks], -+ remote_block_ids=computed_decode_block_ids, -+ remote_engine_id=remote_prefill_params.decode_engine_id, -+ notify_msg=req_id, -+ op_type=MemoryOpType.READ -+ ) -+ memory_transfer_reqs.append(kv_recv_req) -+ -+ kv_send_req = MemoryTransferRequest( -+ request_id=req_id, -+ local_block_ids=block_table[num_computed_blocks:], -+ staging_block_ids=staging_block_ids[num_computed_blocks:], -+ remote_block_ids=remote_prefill_params.decode_block_ids[num_computed_blocks:], -+ remote_engine_id=remote_prefill_params.decode_engine_id, -+ notify_msg=req_id, -+ op_type=MemoryOpType.WRITE -+ ) -+ memory_transfer_reqs.append(kv_send_req) -+ execute_model_req.memory_transfer_requests = memory_transfer_reqs -+ - try: -- outputs = self.model_executor.execute_model( -+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( - execute_model_req=execute_model_req) - self._skip_scheduling_next_step = False - except InputProcessingError as e: -@@ -1444,7 +1598,6 @@ class LLMEngine: - allow_async_output_proc=allow_async_output_proc) - # Raise so the caller is notified that this request failed - raise -- - # We need to do this here so that last step's sampled_token_ids can - # be passed to the next iteration for PP. - if self.scheduler_config.is_multi_step: -@@ -1455,7 +1608,26 @@ class LLMEngine: - if len(ctx.output_queue) > 0: - self._process_model_outputs(ctx=ctx) - # No outputs in this case -- outputs = [] -+ execute_model_req = ExecuteModelRequest( -+ seq_group_metadata_list=[], -+ blocks_to_swap_in=[], -+ blocks_to_swap_out=[], -+ blocks_to_copy=[]) -+ -+ outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model( -+ execute_model_req=execute_model_req) -+ -+ for req_id, notif_count in request_notif_counter.items(): -+ self._request_notif_counter[req_id] += notif_count -+ if self._request_notif_counter[req_id] > -1: -+ self._finished_prefills.add(req_id) -+ del self._request_notif_counter[req_id] -+ -+ for req_id, done_count in request_done_counter.items(): -+ self._request_done_counter[req_id] += done_count -+ if self._request_done_counter[req_id] > -1: -+ self._finished_transfers.add(req_id) -+ del self._request_done_counter[req_id] - - # Finish the current step for all the sequence groups. - if self.scheduler_config.is_multi_step: -@@ -1515,7 +1687,7 @@ class LLMEngine: - # queued control plane messages, such as add/remove lora adapters. - logger.debug("Stopping remote worker execution loop.") - self.model_executor.stop_remote_worker_execution_loop() -- -+ - return ctx.request_outputs - - def _abort_and_cache_schedule( -diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py -index cafd8150b..6a5e45b4e 100644 ---- a/vllm/engine/multiprocessing/__init__.py -+++ b/vllm/engine/multiprocessing/__init__.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import uuid - from dataclasses import dataclass, field -@@ -14,13 +27,17 @@ from vllm.outputs import RequestOutput - from vllm.prompt_adapter.request import PromptAdapterRequest - from vllm.sampling_params import SamplingParams - from vllm.utils import Device, deprecate_kwargs -- -+from vllm.remote_prefill import RemotePrefillParams -+from vllm.distributed.device_communicators.nixl import NixlMetadata - VLLM_RPC_SUCCESS_STR = "SUCCESS" - - IPC_INPUT_EXT = "_input_socket" - IPC_OUTPUT_EXT = "_output_socket" - IPC_HEALTH_EXT = "_health_socket" - IPC_DATA_EXT = "_data_socket" -+IPC_REMOTE_PREFILL_REQUEST_EXT = "_remote_prefill_request_socket" -+IPC_REMOTE_NIXL_METADATA_EXT = "_remote_nixl_metadata_socket" -+IPC_METRICS_EXT = "_metrics_socket" - - - class MQEngineDeadError(RuntimeError): -@@ -36,6 +53,7 @@ class RPCProcessRequest: - trace_headers: Optional[Mapping[str, str]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - priority: int = 0 -+ remote_prefill_params: Optional[RemotePrefillParams] = None - - @overload - def __init__( -@@ -78,6 +96,7 @@ class RPCProcessRequest: - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - *, - inputs: Optional[PromptType] = None, # DEPRECATED - ) -> None: -@@ -95,7 +114,7 @@ class RPCProcessRequest: - self.trace_headers = trace_headers - self.prompt_adapter_request = prompt_adapter_request - self.priority = priority -- -+ self.remote_prefill_params = remote_prefill_params - - @dataclass - class RPCError: -@@ -113,9 +132,21 @@ class RPCStartupRequest(Enum): - IS_SERVER_READY = 1 - - -+@dataclass -+class RPCHasUnfinishedRequestsRequest: -+ request_id: str = field(default_factory=lambda: str(uuid.uuid4())) -+ -+ - @dataclass - class RPCStartupResponse: - tracing_enabled: bool -+ nixl_metadata: Optional[bytes] = None -+ -+ -+@dataclass -+class RPCHasUnfinishedRequestsResponse: -+ has_unfinished_requests: bool -+ request_id: str - - - class RPCUProfileRequest(Enum): -@@ -165,10 +196,10 @@ class RPCAdapterLoadedResponse: - RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, - RPCUProfileRequest, RPCLoadAdapterRequest, - RPCResetPrefixCacheRequest, RPCSleepRequest, -- RPCWakeUpRequest, RPCIsSleepingRequest] -+ RPCWakeUpRequest, RPCIsSleepingRequest, RPCHasUnfinishedRequestsRequest] - - REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, -- RPCIsSleepingResponse, RPCError] -+ RPCIsSleepingResponse, RPCError, RPCHasUnfinishedRequestsResponse] - - - def ENGINE_DEAD_ERROR( -@@ -181,3 +212,13 @@ def ENGINE_DEAD_ERROR( - return MQEngineDeadError( - "Engine loop is not running. Inspect the stacktrace to " - f"find the original error: {repr(error)}.") -+ -+@dataclass -+class KvMetrics: -+ request_active_slots: int -+ request_total_slots: int -+ kv_active_blocks: int -+ kv_total_blocks: int -+ num_requests_waiting: int -+ gpu_cache_usage_perc: float -+ gpu_prefix_cache_hit_rate: float -diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py -index f058b1329..fd5610a3c 100644 ---- a/vllm/engine/multiprocessing/client.py -+++ b/vllm/engine/multiprocessing/client.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import asyncio - import copy -@@ -8,6 +21,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, - Optional, Union, cast, overload) - - import cloudpickle -+import msgspec - import psutil - import zmq - import zmq.asyncio -@@ -18,14 +32,17 @@ from zmq.asyncio import Socket - from vllm import PoolingParams - from vllm.config import DecodingConfig, ModelConfig, VllmConfig - from vllm.core.scheduler import SchedulerOutputs -+from vllm.engine.metrics import Stats - # yapf conflicts with isort for this block - # yapf: disable - from vllm.engine.async_llm_engine import ( - build_guided_decoding_logits_processor_async) - from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, -- IPC_OUTPUT_EXT, RPC_REQUEST_T, -- VLLM_RPC_SUCCESS_STR, RPCAbortRequest, -+ IPC_OUTPUT_EXT, IPC_REMOTE_PREFILL_REQUEST_EXT, -+ RPC_REQUEST_T, -+ VLLM_RPC_SUCCESS_STR, IPC_REMOTE_NIXL_METADATA_EXT, RPCAbortRequest, -+ IPC_METRICS_EXT, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, -@@ -33,8 +50,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - RPCProcessRequest, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, -- RPCStartupResponse, -- RPCUProfileRequest, RPCWakeUpRequest) -+ RPCStartupResponse, RPCHasUnfinishedRequestsRequest, -+ RPCHasUnfinishedRequestsResponse, -+ RPCUProfileRequest, KvMetrics, RPCWakeUpRequest) - from vllm.engine.protocol import EngineClient - # yapf: enable - from vllm.envs import VLLM_RPC_TIMEOUT -@@ -48,6 +66,17 @@ from vllm.prompt_adapter.request import PromptAdapterRequest - from vllm.sampling_params import SamplingParams - from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - from vllm.utils import Device, deprecate_kwargs -+from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback -+from vllm.distributed.device_communicators.nixl import NixlMetadata -+ -+# Import ForwardPassMetrics and related classes from dynamo -+try: -+ from dynamo.llm import ForwardPassMetrics, WorkerStats, KvStats -+except ImportError: -+ # Fallback if dynamo imports are not available -+ ForwardPassMetrics = None -+ WorkerStats = None -+ KvStats = None - - logger = init_logger(__name__) - -@@ -93,6 +122,7 @@ class MQLLMEngineClient(EngineClient): - self._errored_with: Optional[BaseException] = None - - # Get the configs. -+ self.vllm_config = engine_config - self.model_config = engine_config.model_config - self.decoding_config = engine_config.decoding_config - -@@ -117,6 +147,10 @@ class MQLLMEngineClient(EngineClient): - self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) - self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") - -+ # Metrics. -+ self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL) -+ self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}") -+ - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - -@@ -131,8 +165,27 @@ class MQLLMEngineClient(EngineClient): - # Loop to check health of the LLMEngine periodically. - # Started after the MQLLMEngine is ready. - self.health_loop: Optional[asyncio.Task] = None -+ -+ # Loop to check metrics of the LLMEngine periodically. -+ # Started after the MQLLMEngine is ready. -+ self.metrics_loop: Optional[asyncio.Task] = None -+ self.metrics_publisher = None -+ - self._engine_process = psutil.Process(engine_pid) - -+ self.nixl_metadata: Optional[NixlMetadata] = None -+ self.remote_prefill_request_socket: Socket = self.context.socket(zmq.constants.PULL) -+ self.remote_nixl_metadata_socket: Socket = self.context.socket(zmq.constants.PUSH) -+ self.remote_prefill_requests_callback: Dict[str, RemotePrefillRequestCallback] = {} -+ if self.using_nixl_connector: -+ self.remote_prefill_request_socket.connect(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") -+ self.remote_nixl_metadata_socket.connect(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") -+ -+ -+ @property -+ def using_nixl_connector(self) -> bool: -+ return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector" -+ - @staticmethod - def is_unsupported_config(vllm_config: VllmConfig): - # Pipeline parallel not yet supported -@@ -182,6 +235,76 @@ class MQLLMEngineClient(EngineClient): - except Exception as e: - self._set_errored(e) - -+ async def run_remote_prefill_request_handler_loop(self): -+ try: -+ while True: -+ if await self.remote_prefill_request_socket.poll(timeout=VLLM_RPC_TIMEOUT): -+ frames = await self.remote_prefill_request_socket.recv(copy=False) -+ remote_prefill_request = msgspec.msgpack.decode(frames.buffer, type=RemotePrefillRequest) -+ await self.remote_prefill_requests_callback[remote_prefill_request.request_id](remote_prefill_request) -+ except asyncio.CancelledError: -+ logger.debug("Shutting down MQLLMEngineClient remote prefill request handler loop.") -+ -+ async def run_metrics_loop(self, timeout: int): -+ """Background loop that continually checks to ensure the engine process -+ is still alive. -+ """ -+ try: -+ while True: -+ # Check if the engine process is running: -+ if not self._engine_process.is_running() or ( -+ self._engine_process.status() == psutil.STATUS_ZOMBIE): -+ # NB: is_running() returns True for zombies -+ self._set_errored( -+ RuntimeError( -+ f"Engine process (pid {self._engine_process.pid}) " -+ "died.")) -+ break -+ -+ if await self.metrics_socket.poll(timeout=timeout): -+ # Metrics received- check the message -+ message: Frame = await self.metrics_socket.recv(copy=False) -+ metrics = pickle.loads(message.buffer) -+ if self.metrics_publisher is not None and isinstance( -+ metrics, KvMetrics -+ ): -+ # Construct structured metrics objects -+ worker_stats = WorkerStats( -+ request_active_slots=metrics.request_active_slots, -+ request_total_slots=metrics.request_total_slots, -+ num_requests_waiting=metrics.num_requests_waiting, -+ data_parallel_rank=None -+ ) -+ -+ kv_stats = KvStats( -+ kv_active_blocks=metrics.kv_active_blocks, -+ kv_total_blocks=metrics.kv_total_blocks, -+ gpu_cache_usage_perc=metrics.gpu_cache_usage_perc, -+ gpu_prefix_cache_hit_rate=metrics.gpu_prefix_cache_hit_rate -+ ) -+ -+ forward_pass_metrics = ForwardPassMetrics( -+ worker_stats=worker_stats, -+ kv_stats=kv_stats, -+ spec_decode_stats=None -+ ) -+ -+ self.metrics_publisher.publish(forward_pass_metrics) -+ logger.debug("Metrics successful.") -+ -+ # TODO: Investigate sending whole stats object -+ -+ except asyncio.CancelledError: -+ logger.debug("Shutting down MQLLMEngineClient check metrics loop.") -+ -+ except psutil.NoSuchProcess: -+ self._set_errored( -+ RuntimeError( -+ f"Engine process (pid {self._engine_process.pid}) died.")) -+ -+ except Exception as e: -+ self._set_errored(e) -+ - async def run_output_handler_loop(self): - """Get RequestOutputs from Engine and stream to Request Queues""" - -@@ -250,7 +373,7 @@ class MQLLMEngineClient(EngineClient): - # Put each output into the appropriate queue. - elif isinstance( - request_outputs, -- (RPCAdapterLoadedResponse, RPCIsSleepingResponse)): -+ (RPCAdapterLoadedResponse, RPCIsSleepingResponse, RPCHasUnfinishedRequestsResponse)): - self._add_output(request_outputs) - else: - for request_output in request_outputs: -@@ -261,7 +384,7 @@ class MQLLMEngineClient(EngineClient): - - def _add_output(self, request_output: Union[RequestOutput, - RPCAdapterLoadedResponse, -- RPCIsSleepingResponse]): -+ RPCIsSleepingResponse, RPCHasUnfinishedRequestsResponse]): - queue = self.output_queues.get(request_output.request_id) - if queue is not None: - queue.put_nowait(request_output) -@@ -283,12 +406,25 @@ class MQLLMEngineClient(EngineClient): - # Wait until server is ready. - response = await self._wait_for_server_rpc(socket) - -+ if response.nixl_metadata is not None: -+ assert self.using_nixl_connector -+ self.nixl_metadata = msgspec.msgpack.decode(response.nixl_metadata, type=NixlMetadata) -+ - self.tracing_flag = response.tracing_enabled - - # Start health_loop. - if self.health_loop is None: - self.health_loop = asyncio.create_task( - self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) -+ -+ if self.using_nixl_connector: -+ self.remote_prefill_loop = asyncio.create_task( -+ self.run_remote_prefill_request_handler_loop()) -+ -+ # Start metrics_loop. -+ if self.metrics_loop is None: -+ self.metrics_loop = asyncio.create_task( -+ self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT)) - - def close(self): - """Destroy the ZeroMQ Context.""" -@@ -298,6 +434,8 @@ class MQLLMEngineClient(EngineClient): - # Cancel background tasks. - if self.health_loop is not None: - self.health_loop.cancel() -+ if self.metrics_loop is not None: -+ self.metrics_loop.cancel() - if self.output_loop is not None: - self.output_loop.cancel() - -@@ -420,6 +558,9 @@ class MQLLMEngineClient(EngineClient): - """ - if self._errored_with is not None: - raise self._errored_with -+ -+ async def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata): -+ await self.remote_nixl_metadata_socket.send(msgspec.msgpack.encode(nixl_metadata), copy=False) - - @property - def is_running(self) -> bool: -@@ -478,6 +619,7 @@ class MQLLMEngineClient(EngineClient): - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - *, - inputs: Optional[PromptType] = None # DEPRECATED - ) -> AsyncGenerator[RequestOutput, None]: -@@ -507,7 +649,8 @@ class MQLLMEngineClient(EngineClient): - - return self._process_request(prompt, sampling_params, request_id, - lora_request, trace_headers, -- prompt_adapter_request, priority) -+ prompt_adapter_request, priority, -+ remote_prefill_params) - - @overload - def encode( -@@ -591,6 +734,7 @@ class MQLLMEngineClient(EngineClient): - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ - PoolingRequestOutput, None]]: - """Send an RPCGenerateRequest to the RPCServer and stream responses.""" -@@ -636,6 +780,12 @@ class MQLLMEngineClient(EngineClient): - else: - lp_bytes = None - -+ if remote_prefill_params is not None: -+ self.remote_prefill_requests_callback[request_id] = remote_prefill_params.remote_prefill_request_callback -+ remote_prefill_params.remote_prefill_request_callback = None -+ else: -+ remote_prefill_request_callback = None -+ - request_bytes = pickle.dumps( - RPCProcessRequest( - prompt=prompt, -@@ -645,11 +795,11 @@ class MQLLMEngineClient(EngineClient): - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, -+ remote_prefill_params=remote_prefill_params, - )) - - # 3) Send the RPCGenerateRequest to the MQLLMEngine. -- parts = (request_bytes, -- lp_bytes) if lp_bytes else (request_bytes, ) -+ parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes,) - await self.input_socket.send_multipart(parts, copy=False) - - # 4) Stream the RequestOutputs from the output queue. Note -@@ -740,3 +890,22 @@ class MQLLMEngineClient(EngineClient): - # Raise on error, otherwise happily return None - if isinstance(request_output, BaseException): - raise request_output -+ -+ def set_metrics_publisher(self, metrics_publisher): -+ self.metrics_publisher = metrics_publisher -+ -+ async def has_unfinished_requests(self) -> bool: -+ logger.info("Checking if there are unfinished requests") -+ if "has_unfinished_requests" not in self.output_queues: -+ logger.info("Creating has unfinished requests queue") -+ -+ request = RPCHasUnfinishedRequestsRequest() -+ queue: asyncio.Queue[Union[BaseException, RPCHasUnfinishedRequestsResponse]] = asyncio.Queue() -+ self.output_queues[request.request_id] = queue -+ request_bytes = pickle.dumps(request) -+ await self.input_socket.send_multipart((request_bytes, ), copy=False) -+ response = await queue.get() -+ self.output_queues.pop(request.request_id) -+ if isinstance(response, BaseException): -+ raise response -+ return response.has_unfinished_requests -diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py -index 6ed5ae0a9..3a320c42c 100644 ---- a/vllm/engine/multiprocessing/engine.py -+++ b/vllm/engine/multiprocessing/engine.py -@@ -1,13 +1,27 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import pickle - import signal - from contextlib import contextmanager --from typing import Iterator, List, Optional, Union -+from typing import Iterator, List, Optional, Union, Dict - - import cloudpickle -+import time - import zmq -- -+import msgspec - from vllm import AsyncEngineArgs, SamplingParams - from vllm.config import VllmConfig - from vllm.engine.llm_engine import LLMEngine -@@ -15,8 +29,10 @@ from vllm.engine.llm_engine import LLMEngine - # yapf: disable - from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - IPC_HEALTH_EXT, IPC_INPUT_EXT, -- IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, -- VLLM_RPC_SUCCESS_STR, RPCAbortRequest, -+ REQUEST_OUTPUTS_T, -+ VLLM_RPC_SUCCESS_STR, IPC_REMOTE_PREFILL_REQUEST_EXT, -+ RPCAbortRequest, -+ IPC_OUTPUT_EXT, IPC_METRICS_EXT, - RPCAdapterLoadedResponse, RPCError, - RPCIsSleepingRequest, - RPCIsSleepingResponse, -@@ -25,13 +41,21 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, - RPCResetPrefixCacheRequest, - RPCSleepRequest, RPCStartupRequest, - RPCStartupResponse, -- RPCUProfileRequest, RPCWakeUpRequest) -+ RPCHasUnfinishedRequestsRequest, -+ RPCHasUnfinishedRequestsResponse, -+ RPCUProfileRequest, RPCWakeUpRequest, KvMetrics, -+ IPC_REMOTE_NIXL_METADATA_EXT) - # yapf: enable - from vllm.logger import init_logger - from vllm.outputs import RequestOutput - from vllm.transformers_utils.config import ( - maybe_register_config_serialize_by_value) - from vllm.usage.usage_lib import UsageContext -+from vllm.remote_prefill import RemotePrefillRequest -+from vllm.distributed.device_communicators.nixl import NixlMetadata -+ -+from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo -+from dataclasses import dataclass, field - from vllm.worker.model_runner_base import InputProcessingError - - logger = init_logger(__name__) -@@ -39,6 +63,77 @@ logger = init_logger(__name__) - POLLING_TIMEOUT_MS = 10000 - HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) - -+class KvStatLogger(StatLoggerBase): -+ def __init__( -+ self, -+ max_num_seqs: int, -+ num_total_gpu_blocks: int, -+ metrics_socket -+ ): -+ # Must query initialized scheduler for max infos -+ self.request_total_slots = max_num_seqs -+ self.kv_total_blocks = num_total_gpu_blocks -+ self.metrics_socket = metrics_socket -+ -+ # KV metrics -+ self._send_kv_metrics(0, 0, 0, 0.0, 0.0) -+ -+ def log(self, stats: Stats) -> None: -+ self._send_kv_metrics( -+ stats.num_running_sys, -+ int(stats.gpu_cache_usage_sys * self.kv_total_blocks), -+ stats.num_waiting_sys, -+ stats.gpu_cache_usage_sys, -+ stats.gpu_prefix_cache_hit_rate -+ ) -+ -+ def info(self, type: str, obj: SupportsMetricsInfo) -> None: -+ pass -+ -+ def _send_kv_metrics( -+ self, -+ active_slots, -+ active_kv_blocks, -+ num_requests_waiting, -+ gpu_cache_usage_perc, -+ gpu_prefix_cache_hit_rate, -+ ): -+ if not self.metrics_socket.closed: -+ metrics_bytes = pickle.dumps( -+ KvMetrics( -+ active_slots, -+ self.request_total_slots, -+ active_kv_blocks, -+ self.kv_total_blocks, -+ num_requests_waiting, -+ gpu_cache_usage_perc, -+ gpu_prefix_cache_hit_rate, -+ ) -+ ) -+ self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) -+ -+# TODO: Send entire stats object to the client -+# class StatLogger(StatLoggerBase): -+# def __init__( -+# self, -+# metrics_socket -+# ): -+# self.metrics_socket = metrics_socket -+ -+# def log(self, stats: Stats) -> None: -+# self._send_metrics(stats) -+ -+# def info(self, type: str, obj: SupportsMetricsInfo) -> None: -+# pass -+ -+# def _send_metrics(self, stats: Stats): -+# if not self.metrics_socket.closed: -+# metrics_bytes = pickle.dumps(stats) -+# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False) -+ -+ -+ -+ - - class MQLLMEngine: - """A multiprocessing wrapper for :class:`LLMEngine`. -@@ -101,12 +196,37 @@ class MQLLMEngine: - self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) - self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") - -+ # Send metrics back to client. -+ self.metrics_socket = self.ctx.socket(zmq.constants.PUSH) -+ self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}") -+ - # IPC path for the data socket. - self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" - - # Error state. - self._errored_with: Optional[BaseException] = None - -+ self.remote_prefill_request_socket = self.ctx.socket(zmq.constants.PUSH) -+ self.remote_nixl_metadata_socket = self.ctx.socket(zmq.constants.PULL) -+ if self.engine.is_nixl_initialized: -+ self.remote_prefill_request_socket.bind(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}") -+ self.remote_nixl_metadata_socket.bind(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}") -+ -+ -+ # Attach logger for continuous metrics publishing -+ self.kv_stat_logger = KvStatLogger( -+ self.engine.scheduler_config.max_num_seqs, -+ self.engine.cache_config.num_gpu_blocks, -+ self.metrics_socket -+ ) -+ self.engine.add_logger("kv_metrics", self.kv_stat_logger) -+ -+ # TODO investigate sending whole stats object -+ # self.general_stat_logger = StatLogger( -+ # self.metrics_socket -+ # ) -+ # self.engine.add_logger("general_metrics", self.general_stat_logger) -+ - @property - def dead_error(self) -> BaseException: - if self._errored_with is not None: -@@ -192,8 +312,17 @@ class MQLLMEngine: - # Handle the query from the Client. - if request == RPCStartupRequest.IS_SERVER_READY: - tracing_enabled = self.engine.is_tracing_enabled() -- response = RPCStartupResponse( -- tracing_enabled=tracing_enabled) -+ -+ # Send nixl metadata to the client -+ if self.engine.is_nixl_initialized: -+ nixl_metadata = self.engine.get_nixl_metadata() -+ encoded_nixl_metadata = msgspec.msgpack.encode(nixl_metadata) -+ response = RPCStartupResponse( -+ tracing_enabled=tracing_enabled, -+ nixl_metadata=encoded_nixl_metadata) -+ else: -+ response = RPCStartupResponse( -+ tracing_enabled=tracing_enabled) - - except Exception as e: - response = e -@@ -206,6 +335,7 @@ class MQLLMEngine: - - while True: - if not self.engine.has_unfinished_requests(): -+ logger.debug("No unfinished requests") - # Poll until there is work to do. - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - # When there's no work, check on engine health and send -@@ -249,6 +379,13 @@ class MQLLMEngine: - def handle_new_input(self): - """Handle new input from the socket""" - try: -+ if self.engine.is_nixl_initialized: -+ while self.remote_nixl_metadata_socket.poll(timeout=0) != 0: -+ frames = self.remote_nixl_metadata_socket.recv(copy=False) -+ nixl_metadata = msgspec.msgpack.decode(frames.buffer, type=NixlMetadata) -+ logger.debug("Adding remote nixl metadata for engine: %s", nixl_metadata.engine_id) -+ self.engine.add_remote_nixl_metadata(nixl_metadata) -+ - while self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - request = pickle.loads(frames[0].buffer) -@@ -277,6 +414,8 @@ class MQLLMEngine: - self.wake_up(request.tags) - elif isinstance(request, RPCIsSleepingRequest): - self._handle_is_sleeping_request(request) -+ elif isinstance(request, RPCHasUnfinishedRequestsRequest): -+ self._handle_has_unfinished_requests_request(request) - else: - raise ValueError("Unknown RPCRequest Type: " - f"{type(request)}") -@@ -297,6 +436,11 @@ class MQLLMEngine: - self._send_outputs(rpc_err) - - try: -+ if request.remote_prefill_params is not None and request.remote_prefill_params.is_remote_prefill: -+ def remote_prefill_request_callback(request: RemotePrefillRequest): -+ logger.debug("Sending remote prefill request: %s", request.request_id) -+ self.remote_prefill_request_socket.send(msgspec.msgpack.encode(request), copy=False) -+ request.remote_prefill_params.remote_prefill_request_callback = remote_prefill_request_callback - self.engine.add_request( - request_id=request_id, - prompt=request.prompt, -@@ -304,7 +448,9 @@ class MQLLMEngine: - lora_request=request.lora_request, - trace_headers=request.trace_headers, - prompt_adapter_request=request.prompt_adapter_request, -- priority=request.priority) -+ priority=request.priority, -+ remote_prefill_params=request.remote_prefill_params, -+ ) - - if self.log_requests: - logger.info("Added request %s.", request.request_id) -@@ -348,6 +494,10 @@ class MQLLMEngine: - self._send_outputs( - RPCIsSleepingResponse(request_id=request.request_id, - is_sleeping=is_sleeping)) -+ -+ def _handle_has_unfinished_requests_request(self, request: RPCHasUnfinishedRequestsRequest): -+ response = RPCHasUnfinishedRequestsResponse(request_id=request.request_id, has_unfinished_requests=self.engine.has_unfinished_requests()) -+ self._send_outputs(response) - - def _health_check(self): - # Send unhealthy if engine has already errored -diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py -index dd0b67df4..f436b0752 100644 ---- a/vllm/entrypoints/openai/serving_chat.py -+++ b/vllm/entrypoints/openai/serving_chat.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import asyncio - import json -@@ -41,6 +54,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer - from vllm.transformers_utils.tokenizers import (maybe_serialize_tool_calls, - truncate_tool_call_ids, - validate_request_params) -+from vllm.remote_prefill import RemotePrefillParams - - logger = init_logger(__name__) - -@@ -122,6 +136,7 @@ class OpenAIServingChat(OpenAIServing): - self, - request: ChatCompletionRequest, - raw_request: Optional[Request] = None, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, - ErrorResponse]: - """ -@@ -247,6 +262,7 @@ class OpenAIServingChat(OpenAIServing): - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=request.priority, -+ remote_prefill_params=remote_prefill_params, - ) - - generators.append(generator) -diff --git a/vllm/envs.py b/vllm/envs.py -index f80bf878f..f64c49fe8 100644 ---- a/vllm/envs.py -+++ b/vllm/envs.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import hashlib - import os -@@ -73,7 +86,7 @@ if TYPE_CHECKING: - VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False - VLLM_SKIP_P2P_CHECK: bool = False - VLLM_DISABLED_KERNELS: list[str] = [] -- VLLM_USE_V1: bool = True -+ VLLM_USE_V1: bool = False - VLLM_ROCM_USE_AITER: bool = False - VLLM_ROCM_USE_AITER_LINEAR: bool = True - VLLM_ROCM_USE_AITER_MOE: bool = True -@@ -107,6 +120,10 @@ if TYPE_CHECKING: - VLLM_TPU_BUCKET_PADDING_GAP: int = 0 - VLLM_USE_DEEP_GEMM: bool = False - VLLM_XGRAMMAR_CACHE_MB: int = 0 -+ VLLM_KV_CAPI_PATH: Optional[str] = None -+ VLLM_KV_NAMESPACE: Optional[str] = None -+ VLLM_KV_COMPONENT: Optional[str] = None -+ VLLM_WORKER_ID: Optional[int] = None - - - def get_default_cache_root(): -@@ -525,7 +542,7 @@ environment_variables: dict[str, Callable[[], Any]] = { - - # If set, use the V1 code path. - "VLLM_USE_V1": -- lambda: bool(int(os.getenv("VLLM_USE_V1", "1"))), -+ lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), - - # Disable aiter ops unless specifically enabled. - # Acts as a parent switch to enable the rest of the other operations. -@@ -704,6 +721,21 @@ environment_variables: dict[str, Callable[[], Any]] = { - # It can be changed with this variable if needed for some reason. - "VLLM_XGRAMMAR_CACHE_MB": - lambda: int(os.getenv("VLLM_XGRAMMAR_CACHE_MB", "512")), -+ -+ # Path to the C API Library -+ "VLLM_KV_CAPI_PATH": -+ lambda: os.environ.get("VLLM_KV_CAPI_PATH", None), -+ -+ # Identifiers to publish KV related information -+ "VLLM_KV_NAMESPACE": -+ lambda: os.environ.get("VLLM_KV_NAMESPACE", None), -+ "VLLM_KV_COMPONENT": -+ lambda: os.environ.get("VLLM_KV_COMPONENT", None), -+ -+ # Worker ID used for identifying workers in distributed settings -+ "VLLM_WORKER_ID": -+ lambda: int(os.getenv("VLLM_WORKER_ID", "0")) -+ if "VLLM_WORKER_ID" in os.environ else None, - } - - # end-env-vars-definition -diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py -index 23b450aed..23fe6d7b8 100644 ---- a/vllm/model_executor/models/deepseek_v2.py -+++ b/vllm/model_executor/models/deepseek_v2.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - # Adapted from - # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -diff --git a/vllm/outputs.py b/vllm/outputs.py -index 014e8d5d8..3ffc0f354 100644 ---- a/vllm/outputs.py -+++ b/vllm/outputs.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import time - from collections.abc import MutableSequence -@@ -6,16 +19,16 @@ from collections.abc import Sequence as GenericSequence - from dataclasses import dataclass - from typing import Generic, Optional, Union - -+import msgspec - import torch - from typing_extensions import TypeVar, deprecated - - from vllm.lora.request import LoRARequest - from vllm.multimodal.inputs import MultiModalPlaceholderDict --from vllm.sampling_params import RequestOutputKind -+from vllm.sampling_params import RequestOutputKind, SamplingParams - from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, - SequenceGroup, SequenceGroupBase, SequenceStatus) - -- - @dataclass - class CompletionOutput: - """The output data of one completion output of a request. -diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py -index 0ed221043..08dbc0e78 100644 ---- a/vllm/platforms/__init__.py -+++ b/vllm/platforms/__init__.py -@@ -20,7 +20,8 @@ def vllm_version_matches_substr(substr: str) -> bool: - """ - from importlib.metadata import PackageNotFoundError, version - try: -- vllm_version = version("vllm") -+ logger.warning("Using ai_dynamo_vllm") -+ vllm_version = version("ai_dynamo_vllm") - except PackageNotFoundError as e: - logger.warning( - "The vLLM package was not found, so its version could not be " -diff --git a/vllm/remote_prefill.py b/vllm/remote_prefill.py -new file mode 100644 -index 000000000..0a063f1ca ---- /dev/null -+++ b/vllm/remote_prefill.py -@@ -0,0 +1,84 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -+# SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ -+from dataclasses import dataclass -+from typing import Callable, Optional, List -+from enum import Enum -+ -+import msgspec -+ -+from vllm.sampling_params import SamplingParams -+ -+ -+class RemotePrefillRequest( -+ msgspec.Struct, -+ omit_defaults=True, # type: ignore[call-arg] -+ # required for @cached_property. -+ dict=True): -+ """The request data of one remote prefill output of a request. -+ -+ Args: -+ engine_id: The unique ID of the engine. -+ request_id: The unique ID of the request. -+ prompt_token_ids: The token IDs of the prompt. -+ sampling_params: The sampling parameters. -+ block_ids: The block IDs of the request. -+ computed_block_ids: The computed block IDs of the request. -+ """ -+ engine_id: str -+ request_id: str -+ prompt_token_ids: List[int] -+ sampling_params: SamplingParams -+ block_ids: List[int] -+ computed_block_ids: List[int] -+ multimodal_data_source: Optional[dict[str, str]] = None -+ -+ -+class MemoryOpType(str, Enum): -+ WRITE = "WRITE" -+ READ = "READ" -+ -+ -+class MemoryTransferRequest( -+ msgspec.Struct, -+ array_like=True, # type: ignore[call-arg] -+ omit_defaults=True): # type: ignore[call-arg] -+ """The request data of one memory transfer output of a request. -+ -+ Args: -+ request_id: The unique ID of the request. -+ """ -+ request_id: str -+ local_block_ids: List[int] -+ staging_block_ids: List[int] -+ remote_block_ids: List[int] -+ remote_engine_id: str -+ notify_msg: str -+ op_type: MemoryOpType -+ -+ -+RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None] -+ -+ -+@dataclass -+class RemotePrefillParams: -+ """Remote prefill parameters for text generation.""" -+ is_remote_prefill: bool = False -+ is_remote_decode: bool = False -+ decode_block_ids: Optional[List[int]] = None -+ decode_computed_block_ids: Optional[List[int]] = None -+ decode_engine_id: Optional[str] = None -+ remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None -+ multimodal_data_source: Optional[dict[str, str]] = None -diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py -index 68ed99664..5b0b7e6dc 100644 ---- a/vllm/sampling_params.py -+++ b/vllm/sampling_params.py -@@ -1,4 +1,18 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - """Sampling parameters for text generation.""" - import copy - from dataclasses import dataclass -@@ -103,7 +117,7 @@ class RequestOutputKind(Enum): - DELTA = 1 - # Do not return intermediate RequestOutput - FINAL_ONLY = 2 -- -+ - - class SamplingParams( - msgspec.Struct, -diff --git a/vllm/sequence.py b/vllm/sequence.py -index 61867b025..8a07cf39e 100644 ---- a/vllm/sequence.py -+++ b/vllm/sequence.py -@@ -1,4 +1,18 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - """Sequence and its related classes.""" - import copy - import enum -@@ -20,6 +34,7 @@ from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict - from vllm.pooling_params import PoolingParams - from vllm.prompt_adapter.request import PromptAdapterRequest - from vllm.sampling_params import RequestOutputKind, SamplingParams -+from vllm.remote_prefill import RemotePrefillParams, MemoryTransferRequest - - VLLM_TOKEN_ID_ARRAY_TYPE = "l" - -@@ -59,13 +74,14 @@ class SequenceStatus(enum.IntEnum): - """Status of a sequence.""" - WAITING = 0 - RUNNING = 1 -- SWAPPED = 2 -- # Note: anything after SWAPPED (2) will be considered -+ REMOTE_PREFILLING = 2 -+ SWAPPED = 3 -+ # Note: anything after SWAPPED (3) will be considered - # as a finished status. -- FINISHED_STOPPED = 3 -- FINISHED_LENGTH_CAPPED = 4 -- FINISHED_ABORTED = 5 -- FINISHED_IGNORED = 6 -+ FINISHED_STOPPED = 4 -+ FINISHED_LENGTH_CAPPED = 5 -+ FINISHED_ABORTED = 6 -+ FINISHED_IGNORED = 7 - - @staticmethod - def is_finished(status: "SequenceStatus") -> bool: -@@ -417,6 +433,7 @@ class Sequence: - eos_token_id: Optional[int] = None, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, -+ remote_prefill_params: Optional[RemotePrefillParams] = None, - ) -> None: - self.seq_id = seq_id - self.inputs = SingletonInputsAdapter(inputs) -@@ -424,7 +441,7 @@ class Sequence: - self.eos_token_id = eos_token_id - self.lora_request = lora_request - self.prompt_adapter_request = prompt_adapter_request -- -+ self.remote_prefill_params = remote_prefill_params - self.data = SequenceData.from_seqs(self.prompt_token_ids) - self.output_logprobs: SampleLogprobs = [] - self.output_text = "" -@@ -651,6 +668,7 @@ class SequenceGroup: - model; equal to max number of tokens a step can generate - for single-draft speculative decoding but larger than - that for multi-draft SD (currently not supported). -+ remote_prefill_params: Remote prefill parameters. - """ - - def __init__(self, -@@ -665,7 +683,8 @@ class SequenceGroup: - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -- draft_size: int = 1) -> None: -+ draft_size: int = 1, -+ remote_prefill_params: Optional[RemotePrefillParams] = None) -> None: - self.request_id = request_id - self.seqs = seqs - self.first_seq = seqs[0] -@@ -691,7 +710,7 @@ class SequenceGroup: - self.encoder_seq = encoder_seq - self.trace_headers = trace_headers - self.priority = priority -- -+ self.remote_prefill_params = remote_prefill_params - self.cached_request_output = None - - @property -@@ -940,6 +959,9 @@ class SequenceGroupMetadata( - query tokens for prefill, we don't need sampling. - token_chunk_size: The number of tokens to be processed (per sequence). - None if chunking is not required. -+ do_remote_prefill: True if remote prefill is required. -+ do_remote_decode: True if remote decode is required. -+ decode_memory_desc: The memory descriptor for the decoder blocks. - lora_request: LoRA request. - computed_block_nums: The block numbers that are already computed, - used in prefix caching. -@@ -979,6 +1001,9 @@ class SequenceGroupMetadata( - cross_block_table: Optional[list[int]] = None - prompt_adapter_request: Optional[PromptAdapterRequest] = None - token_chunk_size: Optional[int] = None -+ do_remote_prefill: bool = False -+ do_remote_decode: bool = False -+ decode_memory_desc: Optional[bytes] = None - - ### Stateful fields that are lazily defined. ### - # The number of speculative tokens adopted in this request. -@@ -1329,6 +1354,8 @@ class ExecuteModelRequest( - last_sampled_token_ids: Optional[torch.Tensor] = None - # Async callback - async_callback: Optional[Callable] = None -+ # The memory transfer requests. -+ memory_transfer_requests: Optional[list[MemoryTransferRequest]] = None - - @property - def is_first_multi_step(self) -> bool: -diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py -index 9524a69f6..c314fbe8f 100644 ---- a/vllm/worker/model_runner.py -+++ b/vllm/worker/model_runner.py -@@ -1,4 +1,17 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. - - import dataclasses - import gc -@@ -1875,6 +1888,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): - - if self.vllm_config.kv_transfer_config is None: - return False -+ -+ if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": -+ return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - -@@ -1900,6 +1916,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): - - if self.vllm_config.kv_transfer_config is None: - return False -+ -+ if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector": -+ return False - - prefill_meta = model_input.attn_metadata.prefill_metadata - -diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py -index d59f20f49..4c78301a9 100644 ---- a/vllm/worker/worker.py -+++ b/vllm/worker/worker.py -@@ -1,8 +1,22 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - """A GPU worker class.""" - import gc - import os --from typing import Dict, List, Optional, Set, Tuple, Type, Union -+from typing import Dict, List, Optional, Set, Tuple, Type, Union, TYPE_CHECKING, Any - - import torch - import torch.distributed -@@ -31,6 +45,9 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner - from vllm.worker.pooling_model_runner import PoolingModelRunner - from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase, - WorkerInput) -+from vllm.distributed.device_communicators.nixl import DynamoNixlConnector -+from vllm.remote_prefill import MemoryOpType -+ - - logger = init_logger(__name__) - -@@ -307,6 +324,46 @@ class Worker(LocalOrDistributedWorkerBase): - self._init_cache_engine() - self._warm_up_model() - -+ def initialize_nixl(self, engine_id: str) -> List[bytes]: -+ -+ # TODO ptarasiewicz nixl can also support DRAM -+ assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector" -+ -+ self.nixl_connector = DynamoNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank? -+ assert len(self.cache_engine) == 1, "Only one cache engine is supported for now" -+ self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache) -+ return self.nixl_connector.agent_name -+ -+ def get_nixl_agent_metadata(self) -> bytes: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ return self.nixl_connector.get_agent_metadata() -+ -+ def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]], num_blocks: int) -> str: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank? -+ return agent_name -+ -+ def get_nixl_kv_caches_base_addr(self) -> List[bytes]: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id] -+ -+ def _read_blocks(self, worker_input: WorkerInput) -> None: -+ for i, op_type in enumerate(worker_input.op_type): -+ if op_type == MemoryOpType.READ: -+ self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i]) -+ -+ def _write_blocks(self, worker_input: WorkerInput) -> None: -+ if not self.is_driver_worker: -+ torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize -+ -+ for i, op_type in enumerate(worker_input.op_type): -+ if op_type == MemoryOpType.WRITE: -+ self.nixl_connector.write_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i], worker_input.notify_msg[i]) -+ -+ def shutdown_nixl(self) -> None: -+ assert self.nixl_connector is not None, "Nixl connector is not initialized" -+ self.nixl_connector.shutdown() -+ - def _init_cache_engine(self): - assert self.cache_config.num_gpu_blocks is not None - self.cache_engine = [ -@@ -368,6 +425,8 @@ class Worker(LocalOrDistributedWorkerBase): - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, - dtype=torch.int64).view(-1, 2) -+ -+ mem_transfer_reqs = execute_model_req.memory_transfer_requests or [] - - return WorkerInput( - num_seq_groups=num_seq_groups, -@@ -376,6 +435,12 @@ class Worker(LocalOrDistributedWorkerBase): - blocks_to_copy=blocks_to_copy, - virtual_engine=virtual_engine, - num_steps=num_steps, -+ local_block_ids=[r.local_block_ids for r in mem_transfer_reqs], -+ staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs], -+ remote_block_ids=[r.remote_block_ids for r in mem_transfer_reqs], -+ remote_engine_id=[r.remote_engine_id for r in mem_transfer_reqs], -+ notify_msg=[r.notify_msg for r in mem_transfer_reqs], -+ op_type=[r.op_type for r in mem_transfer_reqs], - ) - - @torch.inference_mode() -diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py -index e5662e693..ffcf1193a 100644 ---- a/vllm/worker/worker_base.py -+++ b/vllm/worker/worker_base.py -@@ -1,4 +1,18 @@ -+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - # SPDX-License-Identifier: Apache-2.0 -+# -+# Licensed under the Apache License, Version 2.0 (the "License"); -+# you may not use this file except in compliance with the License. -+# You may obtain a copy of the License at -+# -+# http://www.apache.org/licenses/LICENSE-2.0 -+# -+# Unless required by applicable law or agreed to in writing, software -+# distributed under the License is distributed on an "AS IS" BASIS, -+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -+# See the License for the specific language governing permissions and -+# limitations under the License. -+ - - import dataclasses - import os -@@ -9,6 +23,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union - import cloudpickle - import torch - import torch.nn as nn -+from collections import defaultdict - - from vllm.config import (ObservabilityConfig, VllmConfig, - set_current_vllm_config) -@@ -24,6 +39,9 @@ from vllm.utils import (enable_trace_function_call_for_thread, - from vllm.worker.model_runner_base import (BroadcastableModelInput, - ModelRunnerBase, - ModelRunnerInputBase) -+from vllm.distributed.device_communicators.nixl import DynamoNixlConnector -+from vllm.remote_prefill import MemoryOpType -+ - - logger = init_logger(__name__) - -@@ -55,6 +73,9 @@ class WorkerBase: - from vllm.platforms import current_platform - self.current_platform = current_platform - -+ self.nixl_connector: Optional[DynamoNixlConnector] = None -+ -+ @abstractmethod - def init_device(self) -> None: - """Initialize device state, such as loading the model or other on-device - memory allocations. -@@ -221,6 +242,13 @@ class WorkerInput: - virtual_engine: int = 0 - num_steps: int = 1 - -+ local_block_ids: Optional[List[List[int]]] = None -+ staging_block_ids: Optional[List[List[int]]] = None -+ remote_block_ids: Optional[List[List[int]]] = None -+ remote_engine_id: Optional[List[str]] = None -+ notify_msg: Optional[List[str]] = None -+ op_type: Optional[List[MemoryOpType]] = None -+ - @classmethod - def from_broadcasted_tensor_dict( - cls: Type["WorkerInput"], -@@ -237,6 +265,12 @@ class WorkerInput: - blocks_to_copy=tensor_dict.pop("blocks_to_copy"), - virtual_engine=tensor_dict["virtual_engine"], - num_steps=tensor_dict.pop("num_steps"), -+ local_block_ids=tensor_dict.pop("local_block_ids"), -+ staging_block_ids=tensor_dict.pop("staging_block_ids"), -+ remote_block_ids=tensor_dict.pop("remote_block_ids"), -+ remote_engine_id=tensor_dict.pop("remote_engine_id"), -+ notify_msg=tensor_dict.pop("notify_msg"), -+ op_type=tensor_dict.pop("op_type"), - ) - - def as_broadcastable_tensor_dict( -@@ -251,6 +285,12 @@ class WorkerInput: - "blocks_to_copy": self.blocks_to_copy, - "virtual_engine": self.virtual_engine, - "num_steps": self.num_steps, -+ "local_block_ids": self.local_block_ids, -+ "staging_block_ids": self.staging_block_ids, -+ "remote_block_ids": self.remote_block_ids, -+ "remote_engine_id": self.remote_engine_id, -+ "notify_msg": self.notify_msg, -+ "op_type": self.op_type, - } - - return tensor_dict -@@ -321,13 +361,16 @@ class LocalOrDistributedWorkerBase(WorkerBase): - return None - - worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) -- model_input = ( -- self.model_runner.make_model_input_from_broadcasted_tensor_dict( -- broadcast_data)) -+ if worker_input.num_seq_groups > 0: -+ model_input = ( -+ self.model_runner.make_model_input_from_broadcasted_tensor_dict( -+ broadcast_data)) - -- kwargs = extract_previous_hidden_states(broadcast_data) -+ kwargs = extract_previous_hidden_states(broadcast_data) - -- return model_input, worker_input, kwargs -+ return model_input, worker_input, kwargs -+ else: -+ return None, worker_input, {} - - def _get_driver_input_and_broadcast( - self, execute_model_req: ExecuteModelRequest -@@ -403,49 +446,88 @@ class LocalOrDistributedWorkerBase(WorkerBase): - self.execute_worker(worker_input) - - # If there is no input, we don't need to execute the model. -- if worker_input.num_seq_groups == 0: -- return [] -+ if worker_input.num_seq_groups > 0: -+ -+ self._read_blocks(worker_input) -+ -+ intermediate_tensors = None -+ orig_model_execute_time = 0.0 -+ if not get_pp_group().is_first_rank: -+ intermediate_tensors = IntermediateTensors( -+ get_pp_group().recv_tensor_dict( -+ all_gather_group=get_tp_group())) -+ if (self.observability_config is not None -+ and self.observability_config.collect_model_execute_time): -+ orig_model_execute_time = intermediate_tensors.tensors.get( -+ "model_execute_time", torch.tensor(0)).item() -+ -+ output = self.model_runner.execute_model( -+ model_input=model_input, -+ kv_caches=self.kv_cache[worker_input.virtual_engine] -+ if self.kv_cache is not None else None, -+ intermediate_tensors=intermediate_tensors, -+ num_steps=num_steps, -+ **kwargs, -+ ) - -- intermediate_tensors = None -- orig_model_execute_time = 0.0 -- if not get_pp_group().is_first_rank: -- intermediate_tensors = IntermediateTensors( -- get_pp_group().recv_tensor_dict( -- all_gather_group=get_tp_group())) -+ model_execute_time = time.perf_counter() - start_time -+ if not get_pp_group().is_last_rank: -+ # output is IntermediateTensors -+ assert isinstance(output, IntermediateTensors) -+ if (self.observability_config is not None -+ and self.observability_config.collect_model_execute_time): -+ output.tensors["model_execute_time"] = torch.tensor( -+ model_execute_time + orig_model_execute_time) -+ get_pp_group().send_tensor_dict(output.tensors, -+ all_gather_group=get_tp_group()) -+ return [None] - if (self.observability_config is not None -- and self.observability_config.collect_model_execute_time): -- orig_model_execute_time = intermediate_tensors.tensors.get( -- "model_execute_time", torch.tensor(0)).item() -- -- output = self.model_runner.execute_model( -- model_input=model_input, -- kv_caches=self.kv_cache[worker_input.virtual_engine] -- if self.kv_cache is not None else None, -- intermediate_tensors=intermediate_tensors, -- num_steps=num_steps, -- **kwargs, -- ) -+ and self.observability_config.collect_model_execute_time -+ and output is not None): -+ for o in output: -+ o.model_execute_time = (orig_model_execute_time + -+ model_execute_time) - -- model_execute_time = time.perf_counter() - start_time -- if not get_pp_group().is_last_rank: -- # output is IntermediateTensors -- assert isinstance(output, IntermediateTensors) -- if (self.observability_config is not None -- and self.observability_config.collect_model_execute_time): -- output.tensors["model_execute_time"] = torch.tensor( -- model_execute_time + orig_model_execute_time) -- get_pp_group().send_tensor_dict(output.tensors, -- all_gather_group=get_tp_group()) -- return [None] -- if (self.observability_config is not None -- and self.observability_config.collect_model_execute_time -- and output is not None): -- for o in output: -- o.model_execute_time = (orig_model_execute_time + -- model_execute_time) -+ self._write_blocks(worker_input) - -+ else: -+ output = [] -+ -+ # collect kv transfer notifications from non driver workers -+ -+ if self.nixl_connector is not None: -+ new_notifs = self.nixl_connector.get_new_notifs() -+ rank = get_tp_group().rank -+ all_new_notifs = [new_notifs] -+ if rank > 0: -+ get_tp_group().send_object(new_notifs, dst=0) -+ else: -+ for i in range(1, get_tp_group().world_size): -+ all_new_notifs.append(get_tp_group().recv_object(src=i)) -+ -+ request_notif_counter = defaultdict(int) -+ for notifs in all_new_notifs: -+ for req_ids in notifs.values(): -+ for req_id in req_ids: -+ request_notif_counter[req_id.decode("utf-8")] += 1 -+ -+ if request_notif_counter: -+ logger.debug("Request notif counter: %s", request_notif_counter) -+ -+ request_done_counter = defaultdict(int) -+ for req_id in self.nixl_connector.get_done_tranfers(): -+ request_done_counter[req_id] += 1 -+ else: -+ request_notif_counter = {} -+ request_done_counter = {} - # output is List[SamplerOutput] -- return output -+ return output, request_notif_counter, request_done_counter -+ -+ def _read_blocks(self, worker_input: WorkerInput) -> None: -+ pass -+ -+ def _write_blocks(self, worker_input: WorkerInput) -> None: -+ pass - - def _execute_model_spmd( - self, diff --git a/container/run.sh b/container/run.sh index a576e968b6..376ad13db0 100755 --- a/container/run.sh +++ b/container/run.sh @@ -24,7 +24,7 @@ RUN_PREFIX= # dependencies are specified in the /container/deps folder and # installed within framework specific sections of the Dockerfile. -declare -A FRAMEWORKS=(["VLLM"]=1 ["TENSORRTLLM"]=2 ["SGLANG"]=3 ["VLLM_V1"]=4) +declare -A FRAMEWORKS=(["VLLM"]=1 ["TENSORRTLLM"]=2 ["SGLANG"]=3) DEFAULT_FRAMEWORK=VLLM SOURCE_DIR=$(dirname "$(readlink -f "$0")") diff --git a/deploy/metrics/README.md b/deploy/metrics/README.md index b37b28373e..979974feb2 100644 --- a/deploy/metrics/README.md +++ b/deploy/metrics/README.md @@ -25,7 +25,7 @@ graph TD The dcgm-exporter service in the Docker Compose network is configured to use port 9401 instead of the default port 9400. This adjustment is made to avoid port conflicts with other dcgm-exporter instances that may be running simultaneously. Such a configuration is typical in distributed systems like SLURM. -As of Q2 2025, Dynamo HTTP Frontend metrics are exposed when you build containers with `--framework VLLM_V1` or `--framework TENSORRTLLM`. +As of Q2 2025, Dynamo HTTP Frontend metrics are exposed when you build containers with `--framework VLLM` or `--framework TENSORRTLLM`. ## Getting Started diff --git a/examples/vllm/README.md b/examples/vllm/README.md index ef5fe2b4e6..d3a0224a09 100644 --- a/examples/vllm/README.md +++ b/examples/vllm/README.md @@ -36,11 +36,11 @@ docker compose -f deploy/metrics/docker-compose.yml up -d ### Build and Run docker ```bash -./container/build.sh --framework VLLM_V1 +./container/build.sh ``` ```bash -./container/run.sh -it --framework VLLM_V1 [--mount-workspace] +./container/run.sh -it [--mount-workspace] ``` This includes the specific commit [vllm-project/vllm#19790](https://github.com/vllm-project/vllm/pull/19790) which enables support for external control of the DP ranks. @@ -129,9 +129,9 @@ For Kubernetes deployment, YAML manifests are provided in the `deploy/` director - **Dynamo Cloud**: Follow the [Quickstart Guide](../../docs/guides/dynamo_deploy/quickstart.md) to deploy Dynamo Cloud first. -- **Container Images**: The deployment files currently require access to `nvcr.io/nvidian/nim-llm-dev/vllm_v1-runtime`. If you don't have access, build and push your own image: +- **Container Images**: The deployment files currently require access to `nvcr.io/nvidian/nim-llm-dev/vllm-runtime`. If you don't have access, build and push your own image: ```bash - ./container/build.sh --framework VLLM_V1 + ./container/build.sh --framework VLLM # Tag and push to your container registry # Update the image references in the YAML files ``` diff --git a/tests/serve/test_vllm.py b/tests/serve/test_vllm.py index d1b12ca6c8..c0c4e94a91 100644 --- a/tests/serve/test_vllm.py +++ b/tests/serve/test_vllm.py @@ -186,7 +186,7 @@ def wait_for_ready(self, payload, logger=logging.getLogger()): vllm_configs = { "aggregated": VLLMConfig( name="aggregated", - directory="/workspace/examples/llm", + directory="/workspace/examples/vllm", script_name="agg.sh", marks=[pytest.mark.gpu_1, pytest.mark.vllm], endpoints=["v1/chat/completions", "v1/completions"], @@ -199,7 +199,7 @@ def wait_for_ready(self, payload, logger=logging.getLogger()): ), "disaggregated": VLLMConfig( name="disaggregated", - directory="/workspace/examples/llm", + directory="/workspace/examples/vllm", script_name="disagg.sh", marks=[pytest.mark.gpu_2, pytest.mark.vllm], endpoints=["v1/chat/completions", "v1/completions"],