Skip to content

[P/D] Prefill compute optimizations with bi-directional KV cache transfers between P and D nodes#32553

Open
snadampal wants to merge 2 commits intovllm-project:mainfrom
snadampal:prefill_optimizations
Open

[P/D] Prefill compute optimizations with bi-directional KV cache transfers between P and D nodes#32553
snadampal wants to merge 2 commits intovllm-project:mainfrom
snadampal:prefill_optimizations

Conversation

@snadampal
Copy link
Copy Markdown

@snadampal snadampal commented Jan 18, 2026

Prefill worker can pull KV cache from remote engines to eliminate redundant prefill computation

Benefits:
(1) Multi-turn conversations: Prefill worker loads larger KV cache from decode worker
(2) Cache eviction recovery: Prefill worker retrieves evicted blocks still available on decode worker

Purpose

To avoid redundant compute cycles on Prefill nodes and hence improve prefill throughput.

Test Plan

Test Result

Unit test report:

configfile: pyproject.toml
plugins: anyio-4.12.1
collected 56 items                                                                                                                

tests/v1/kv_connector/unit/test_nixl_connector.py::test_basic_interface PASSED                                              [  1%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_prompt_less_than_block_size PASSED                                  [  3%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_kv_transfer_handshake PASSED                                        [  5%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_multi_xfer_one_engine PASSED                     [  7%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_async_load_kv[1-1] PASSED                        [  8%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_async_load_kv[2-1] PASSED                        [ 10%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_async_load_kv[4-2] PASSED                        [ 12%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_async_load_kv[4-4] PASSED                        [ 14%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_prefill_tp_size_greater_than_decode_tp_size[1] PASSED [ 16%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_prefill_tp_size_greater_than_decode_tp_size[2] PASSED [ 17%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_prefill_tp_size_greater_than_decode_tp_size_mla PASSED [ 19%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_concurrent_load_kv PASSED                        [ 21%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_handshake_fails_on_kv_cache_layout_mismatch PASSED [ 23%]
tests/v1/kv_connector/unit/test_nixl_connector.py::TestNixlHandshake::test_handshake_succeed_on_kv_cache_layout_mismatch_with_experimental PASSED [ 25%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_kv_connector_stats PASSED                                           [ 26%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_kv_connector_stats_aggregation PASSED                               [ 28%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_multi_kv_connector_stats_aggregation PASSED                         [ 30%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_scheduler_kv_connector_stats_aggregation PASSED                     [ 32%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_abort_timeout_on_prefiller[ray] PASSED                              [ 33%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_abort_timeout_on_prefiller[None] PASSED                             [ 35%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_register_kv_caches[FLASH_ATTN-False] PASSED                         [ 37%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_register_kv_caches[FLASH_ATTN-True] PASSED                          [ 39%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_register_kv_caches[ROCM_ATTN-False] SKIPPED (Attention backend ...) [ 41%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_register_kv_caches[ROCM_ATTN-True] SKIPPED (Attention backend R...) [ 42%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_register_kv_caches[TRITON_ATTN-False] PASSED                        [ 44%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_register_kv_caches[TRITON_ATTN-True] PASSED                         [ 46%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_kv_buffer_to_nixl_memory_types[oot-VRAM] PASSED                     [ 48%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_shutdown_cleans_up_resources PASSED                                 [ 50%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_aborted_request_removed_from_worker_in_batch PASSED                 [ 51%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[False-transfer_setup_failed-wrapper_config0-False] PASSED [ 53%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[False-handshake_failed-wrapper_config1-False] PASSED [ 55%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[False-notification_failed-wrapper_config2-False] PASSED [ 57%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[False-transfer_failed-wrapper_config3-True] PASSED [ 58%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[False-transfer_exception-wrapper_config4-True] PASSED [ 60%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[True-transfer_setup_failed-wrapper_config0-False] PASSED [ 62%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[True-handshake_failed-wrapper_config1-False] PASSED [ 64%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[True-notification_failed-wrapper_config2-False] PASSED [ 66%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[True-transfer_failed-wrapper_config3-True] PASSED [ 67%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_failure_logging[True-transfer_exception-wrapper_config4-True] PASSED             [ 69%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_handshake_failure_returns_finished PASSED                                                              [ 71%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_transfer_setup_failure_returns_finished PASSED                                                         [ 73%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[vllm_version-config_overrides0-version_override0-True-True] PASSED       [ 75%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[nixl_connector_version-config_overrides1-version_override1-True-True] PASSED [ 76%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[model_name-config_overrides2-version_override2-True-True] PASSED         [ 78%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[dtype-config_overrides3-version_override3-True-True] PASSED              [ 80%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[cache_dtype-config_overrides4-version_override4-True-True] PASSED        [ 82%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[num_kv_heads-config_overrides5-version_override5-True-True] PASSED       [ 83%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[num_hidden_layers-config_overrides6-version_override6-True-True] PASSED  [ 85%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[hidden_size-config_overrides7-version_override7-True-True] PASSED        [ 87%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[block_size-config_overrides8-version_override8-False-True] PASSED        [ 89%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[matching_config-config_overrides9-version_override9-False-True] PASSED   [ 91%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_compatibility_hash_validation[escape_hatch-config_overrides10-version_override10-False-False] PASSED   [ 92%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_handshake_decode_errors[handshake_decode_error] PASSED                                                 [ 94%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_handshake_decode_errors[handshake_validation_error] PASSED                                             [ 96%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_handshake_decode_errors[metadata_decode_error] PASSED                                                  [ 98%]
tests/v1/kv_connector/unit/test_nixl_connector.py::test_handshake_decode_errors[metadata_validation_error] PASSED                                              [100%]

HMA unit tests:


tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_sw_sizes[True-expected_sw_sizes0] PASSED                                                           [  6%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_sw_sizes[False-expected_sw_sizes1] PASSED                                                          [ 13%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_logical_to_kernel_block_ids_with_hma PASSED                                                        [ 20%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_fewer_blocks_with_hma[google/gemma-3-1b-it-512] PASSED                                             [ 26%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_nixl_metadata_hma_block_ids_structure PASSED                                                       [ 33%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_get_block_descs_ids_hybrid_ssm PASSED                                                              [ 40%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_get_block_descs_ids_kernel_block_mismatch PASSED                                                   [ 46%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_nixl_metadata_hybrid_ssm_block_ids PASSED                                                          [ 53%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_mamba_n1_d_side[mamba] PASSED                                                                      [ 60%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_mamba_n1_d_side[fa_only] PASSED                                                                    [ 66%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_mamba_n1_d_side[swa_only] PASSED                                                                   [ 73%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_mamba_n1_p_side_truncation PASSED                                                                  [ 80%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_has_mamba_init[fa_swa_mamba] PASSED                                                                [ 86%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_has_mamba_init[fa_swa_only] PASSED                                                                 [ 93%]
tests/v1/kv_connector/unit/test_nixl_connector_hma.py::test_has_mamba_init[fa_only] PASSED                                                                     [100%]

Performance benchmarking report:
Tested the following multi-turn conversations benchmark on AWS Trainium (the platform where the feature was initially developed) and AWS p5en.
https://github.com/vllm-project/vllm/tree/main/benchmarks/multi_turn#benchmark-kv-cache-offloading-with-multi-turn-conversations

Here is the data from p5en and the similar gains are observed on AWS Trainium as well.

  1. With kv pull from D node when there is no prefix caching on prefill node (--no-enable-prefix-caching), there is a > 2x improvement for TTFT

(i) without the PR

Parameters:
  model                          = /models/llama-3.3-70b-instruct
  num_clients                    = 2
  num_conversations              = 24
  active_conversations           = 24
  seed                           = 0

Conversations Generation Parameters:
  text_files                     = pg1184.txt
  input_num_turns                = UniformDistribution[12, 18]
  input_common_prefix_num_tokens = Constant[500]
  input_prefix_num_tokens        = LognormalDistribution[4000, 0.85, 10000]
  input_num_tokens               = UniformDistribution[120, 160]
  output_num_tokens              = UniformDistribution[800, 1200]

Statistics summary:
  runtime_sec      = 1483.861
  requests_per_sec = 0.087

                    count     mean      std       min       25%       50%       75%       90%       99%       max
ttft_ms             129.0  1049.77   441.35    383.66    723.82    958.59   1266.52   1639.54   2489.80   2603.19
tpot_ms             129.0    21.88     0.20     21.54     21.72     21.86     22.02     22.17     22.25     22.33
latency_ms          129.0 22886.20  2736.89  18037.89  20712.63  22689.81  25056.06  26621.98  27781.96  28012.94
input_num_turns     129.0     5.42     3.14      1.00      3.00      5.00      9.00      9.00     11.00     11.00
input_num_tokens    129.0  6959.82  2857.70   2007.00   4798.00   6731.00   8680.00  10916.40  14096.08  15488.00
output_num_tokens   129.0   999.01   116.98    803.00    909.00    994.00   1090.00   1161.80   1193.44   1200.00
output_num_chunks   129.0   997.96   116.95    802.00    908.00    993.00   1089.00   1160.80   1191.00   1199.00


(ii) with the PR

Parameters:
  model                          = /modelsl/llama-3.3-70b-instruct
  num_clients                    = 2
  num_conversations              = 24
  active_conversations           = 24
  seed                           = 0

Conversations Generation Parameters:
  text_files                     = pg1184.txt
  input_num_turns                = UniformDistribution[12, 18]
  input_common_prefix_num_tokens = Constant[500]
  input_prefix_num_tokens        = LognormalDistribution[4000, 0.85, 10000]
  input_num_tokens               = UniformDistribution[120, 160]
  output_num_tokens              = UniformDistribution[800, 1200]

Statistics summary:
  runtime_sec      = 1445.109
  requests_per_sec = 0.089

                    count     mean      std       min       25%       50%       75%       90%       99%       max
ttft_ms             129.0   498.15   300.91    181.89    284.78    460.48    564.72    739.08   1746.15   2043.56
tpot_ms             129.0    21.89     0.21     21.55     21.73     21.88     22.03     22.18     22.33     22.34
latency_ms          129.0 22348.51  2622.72  17671.81  20231.96  22232.89  24704.40  25982.67  26935.93  27037.89
input_num_turns     129.0     5.42     3.14      1.00      3.00      5.00      9.00      9.00     11.00     11.00
input_num_tokens    129.0  6959.91  2857.69   2007.00   4798.00   6731.00   8680.00  10916.40  14096.08  15488.00
output_num_tokens   129.0   999.03   116.97    803.00    909.00    994.00   1090.00   1161.80   1193.44   1200.00
output_num_chunks   129.0   998.02   116.96    802.00    908.00    993.00   1089.00   1160.80   1192.44   1199.00
  1. With KV pull from D node where prefill node has prefix cache enabled, TTFT mean is improved by 18%. and
    TTFT median improved by 28.6% at p50: 219.61ms (KV pull) vs 307.86ms (recompute)
    TTFT p75 improved by 28.6%: 244.65ms vs 342.86ms
    TTFT min improved by 41%: 143.76ms vs 243.50ms

(i) without the PR

Parameters:
  model                          = /models/llama-3.3-70b-instruct
  num_clients                    = 2
  num_conversations              = 24
  active_conversations           = 24
  seed                           = 0

Conversations Generation Parameters:
  text_files                     = pg1184.txt
  input_num_turns                = UniformDistribution[12, 18]
  input_common_prefix_num_tokens = Constant[500]
  input_prefix_num_tokens        = LognormalDistribution[4000, 0.85, 10000]
  input_num_tokens               = UniformDistribution[120, 160]
  output_num_tokens              = UniformDistribution[800, 1200]

Statistics summary:
  runtime_sec      = 1429.684
  requests_per_sec = 0.090

                    count     mean      std       min       25%       50%       75%       90%       99%       max
ttft_ms             129.0   396.73   333.86    243.50    282.47    307.86    342.86    535.37   2186.20   2709.38
tpot_ms             129.0    21.81     0.20     21.48     21.64     21.80     21.96     22.09     22.24     22.26
latency_ms          129.0 22166.47  2621.66  17667.78  20104.41  21986.13  24585.03  25809.75  26698.04  27417.87
input_num_turns     129.0     5.42     3.14      1.00      3.00      5.00      9.00      9.00     11.00     11.00
input_num_tokens    129.0  6959.91  2857.69   2007.00   4798.00   6731.00   8680.00  10916.40  14096.08  15488.00
output_num_tokens   129.0   999.03   116.97    803.00    909.00    994.00   1090.00   1161.80   1193.44   1200.00
output_num_chunks   129.0   998.01   116.96    802.00    908.00    993.00   1089.00   1160.80   1192.44   1199.00


(ii) with the PR

Parameters:
  model                          = /models/llama-3.3-70b-instruct
  num_clients                    = 2
  num_conversations              = 24
  active_conversations           = 24
  seed                           = 0

Conversations Generation Parameters:
  text_files                     = pg1184.txt
  input_num_turns                = UniformDistribution[12, 18]
  input_common_prefix_num_tokens = Constant[500]
  input_prefix_num_tokens        = LognormalDistribution[4000, 0.85, 10000]
  input_num_tokens               = UniformDistribution[120, 160]
  output_num_tokens              = UniformDistribution[800, 1200]

Statistics summary:
  runtime_sec      = 1406.542
  requests_per_sec = 0.090

                    count     mean      std       min       25%       50%       75%       90%       99%       max
ttft_ms             126.0   334.39   338.86    143.76    203.95    219.61    244.65    517.37   1975.91   2355.37
tpot_ms             126.0    21.81     0.21     21.45     21.66     21.80     21.97     22.12     22.24     22.37
latency_ms          126.0 22167.56  2577.02  17762.80  20067.89  22007.78  24561.08  25689.76  26501.85  27484.72
input_num_turns     126.0     5.29     3.05      1.00      3.00      5.00      7.00      9.00     11.00     11.00
input_num_tokens    126.0  6923.63  2881.37   2007.00   4764.25   6719.00   8670.75  10935.00  14124.25  15488.00
output_num_tokens   126.0  1001.81   115.64    803.00    909.25    995.50   1090.75   1163.00   1193.50   1200.00
output_num_chunks   126.0  1000.80   115.64    802.00    908.25    994.00   1089.75   1162.00   1192.50   1199.00



Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@snadampal snadampal changed the title [P/D] Enable KV cache queries and eliminate redundant prefill computation [P/D] Add KV cache queries and eliminate redundant prefill computation Jan 18, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant enhancements for KV cache management, enabling query-only requests and remote KV cache pulling for prefill workers. These features will improve performance for multi-turn conversations and cache eviction recovery. The implementation is well-structured, with logical changes across the scheduler and the NIXL KV connector. I've identified one potential robustness issue that could lead to a crash under specific configurations, and I've provided a suggestion to address it.

Comment on lines +1300 to +1301
kv_transfer_params["remote_block_ids"] = \
cached_computed_blocks.get_block_ids()[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code cached_computed_blocks.get_block_ids()[0] assumes that get_block_ids() will always return a non-empty tuple, allowing it to be indexed at [0]. However, if a model is configured with a KV connector but has no KV cache groups (e.g., an encoder-only model), get_block_ids() could return an empty tuple (), leading to an IndexError and a crash. While this might be a misconfiguration, the code should be more robust to prevent this.

                kv_transfer_params["remote_block_ids"] = (cached_computed_blocks.get_block_ids() or [[]])[0]

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

I think this feature is a good idea, however Im a bit worried about having an implicit implementation that overloads the meaning of the fields. I wonder if we can come up with a design that enables this explicitly. Ping me on slack if you want to discuss further.

@snadampal
Copy link
Copy Markdown
Author

@robertgshaw2-redhat , thanks for the prompt review! Sure, will discuss with you on slack.

@markmc
Copy link
Copy Markdown
Member

markmc commented Jan 19, 2026

It would be very useful to have a much more detailed description of the problem being solved, background context including which project will use this new API, how this approach solves the problem, and alternatives considered. Perhaps the context is obvious to others, but not to me at least.

For inspiration on what I mean by more details, take a look at #24256 and #24520 which aims to:

  1. Ensure preempted requests don't have prefill recomputed in decode workers
  2. Add a "decode-first" flow, whereby the decode instance first gets to decide whether it needs remote prefill done

@snadampal
Copy link
Copy Markdown
Author

Hi @robertgshaw2-redhat , @markmc , I have created the following RFC with more details on use cases, challenges, and proposed solution.
#32733

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 13, 2026

Hi @snadampal, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@snadampal snadampal changed the title [P/D] Add KV cache queries and eliminate redundant prefill computation [P/D] Prefill compute optimizations with bi-directional KV cache transfers between P and D nodes Feb 17, 2026
@snadampal snadampal force-pushed the prefill_optimizations branch from d988d3b to e7b030c Compare March 12, 2026 04:01
@mergify mergify bot added the frontend label Mar 12, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 12, 2026

Hi @snadampal, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@snadampal snadampal force-pushed the prefill_optimizations branch from e7b030c to d170329 Compare March 12, 2026 04:56
@snadampal
Copy link
Copy Markdown
Author

During benchmarking, I identified that the additional round-trip to the Decode node to query KV block metadata introduces latency overhead. To address this, I added support for returning kv_transfer_params as part of the streaming response from the Decode node, eliminating the need for a separate query. The KV params query feature is still added, as it will be useful for initial node pairing with vllm engine specific details.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 12, 2026

Hi @snadampal, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@snadampal snadampal force-pushed the prefill_optimizations branch from d170329 to 0754908 Compare March 12, 2026 17:10
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @snadampal.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 19, 2026
@snadampal snadampal force-pushed the prefill_optimizations branch from 0754908 to ee248b9 Compare March 23, 2026 05:35
@mergify mergify bot removed the needs-rebase label Mar 23, 2026
@snadampal
Copy link
Copy Markdown
Author

I've rebased the PR to the latest commit which brings in HMA support. So far tested full attention layer models and next testing the models with hybrid KV (full attention + sliding window attention).

Prefill worker can pull KV cache from remote engines
to eliminate redundant prefill computation

Benefits:
(1) Multi-turn conversations: Prefill worker loads larger KV cache from decode worker
(2) Cache eviction recovery: Prefill worker retrieves evicted blocks still available on decode worker

Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
@snadampal snadampal force-pushed the prefill_optimizations branch from ee248b9 to a8905ec Compare March 23, 2026 17:48
@snadampal
Copy link
Copy Markdown
Author

Hi @NickLucche I have added support for hybrid kv (full+ sliding window attention) for bi-directional pull, fixed the pre-commit warnings, and also ran unit tests (added reports to the PR description test section above)

Please let me know if there are any other pre-requisites to get this PR reviewed and merged. thank you!

Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@snadampal One alternative to modifying public API interface

class CompletionStreamResponse(OpenAIBaseModel):
    id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
    object: str = "text_completion"
    created: int = Field(default_factory=lambda: int(time.time()))
    model: str
    choices: list[CompletionResponseStreamChoice]
    usage: UsageInfo | None = Field(default=None)
+    # vLLM-specific: KV transfer params returned with the final chunk
+   kv_transfer_params: dict[str, Any] | None = None

could be to use our inner tokens-in-tokens-out interface (/inference/v1/generate with kv_transfer support #38094).

Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pinging a couple of folks on the llm-d and dynamo teams: @vMaroon @alec-flowers FYI this PR has some interface changes to the kv transfer interface to allow for bi-directional transfer, so would be great if you could take a look (and forward to relevant teammates). Thank you!

…eature

Signed-off-by: Sunita Nadampalli <nadampal@amazon.com>
@snadampal
Copy link
Copy Markdown
Author

I have added a new unit test for this feature:
pytest -x -v -s tests/v1/kv_connector/unit/test_bidirectional_kv_transfer.py

Next, I will update the inner tokens-in-tokens-out interface (/inference/v1/generate) along with the public API.

@snadampal
Copy link
Copy Markdown
Author

Thanks for the front-end API suggestion @NickLucche. I see the token generator interface (/inference/v1/generate) already supports kv transfer params for non-streaming mode, and when it supports streaming mode, we can add kv transfer params as part of response chunk metadata.

@vMaroon , @alec-flowers, appreciate your review/feedback on this PR. Please find the design details along with router here: #32733

dhruv-2604 added a commit to dhruv-2604/vllm that referenced this pull request Apr 13, 2026
Exploratory work for a from-scratch bidirectional KV transfer
implementation, before discovering PR vllm-project#32553 already
ships the feature. Kept on this branch for record only. The active
benchmarking work lives on the blackwell-benchmark branch.

- DESIGN.md: component diagrams and interface design doc
- nixl_connector.py: adds remote_num_tokens support in
  get_num_new_matched_tokens and a report_decode_kv path in
  request_finished

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
dhruv-2604 added a commit to dhruv-2604/vllm that referenced this pull request Apr 14, 2026
Multi-turn proxy, TTFT sweep, NVLink-vs-recompute crossover,
SLURM jobs for PACE ICE, analysis script, and PR comment template.
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, thanks @snadampal !

Comment on lines +1060 to +1067
delay_free = any(len(group) > 0 for group in block_ids)
if delay_free:
# Track in _reqs_in_batch so the worker adds it to
# _reqs_to_process, which is required for _reqs_to_send
# to be accepted by the worker.
self._reqs_in_batch.add(request.request_id)
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think holding blocks on D complicates workflow for when P dies or if a lot of requests end up delaying the execution of the read back from P to D.

Can we make this feature optional for now using self.kv_transfer_config.get_from_extra_config?

Comment on lines +1088 to +1089
do_remote_prefill=False,
do_remote_decode=True,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure these do_remote* matter to the router at this point, given we're adding remote-num-tokens here to signal a "special" D to P request..

Comment on lines +838 to +843
logger.info(
"Skipping remote pull for %s: %d remote tokens < threshold %d",
request.request_id,
count,
self.remote_pull_threshold,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a debug log?


# Remote pull threshold: minimum number of remote tokens
# before P pulls KV from D instead of recomputing locally.
self.remote_pull_threshold: int = int(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: what other names could we use for this variable to make it more informative?

Comment on lines +882 to +897
elif all(
p in params
for p in ("remote_engine_id", "remote_host", "remote_port")
):
unhashed_local_block_ids_p: BlockIds = (
blocks.get_unhashed_block_ids_all_groups()
)
local_block_ids = self.get_sw_clipped_blocks(
unhashed_local_block_ids_p
)

# Get unhashed blocks to pull into from remote.
self._reqs_need_recv[request.request_id] = (
request,
local_block_ids,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this branch can be unified with what's already happening for D, so that ideally we have the same code but two conditions: if D or if P with multiturn pull mode

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants