Skip to content

[PD][Nixl] Add support for hybrid SSM-FA models#36687

Merged
NickLucche merged 22 commits intovllm-project:mainfrom
NickLucche:nixl-ssm-rebase
Mar 16, 2026
Merged

[PD][Nixl] Add support for hybrid SSM-FA models#36687
NickLucche merged 22 commits intovllm-project:mainfrom
NickLucche:nixl-ssm-rebase

Conversation

@NickLucche
Copy link
Copy Markdown
Collaborator

@NickLucche NickLucche commented Mar 10, 2026

For a comprehensive description of the changes proposed here, check out the corresponding RFC #36780.

This PR adds support for hybrid SSM-based models such as nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 with NixlConnector, enabling KVCache transfer of both FA and Mamba states in disaggregated setups.
Currently it only supports Homogeneous TP sizes on both P and D.

Note that we're only transferring actual mamba states and skipping the padding that may be present, as that might have non-trivial size.

UPDATE:
re this change"

- curr_tensor_size_bytes = cache.numel() * cache.element_size()
+ curr_tensor_size_bytes = num_blocks * physical_page_size

in this PR I am trying to further move away from relying on tensor views while trying to unify usage in code of kv_cache_config as single source of truth.
This is also necessary for Mamba-like models in which tensors (cache above) gives the unpadded tensor size, which doesn't reflect the num_blocks * physical_page_size, as one would need to take into account padding manually.

Important notes

  • TP > 1 currently require --no-async-scheduling to run correctly. @ZhanqiuHu and I identified a synchronization issue where states may be transferred in a corrupted form, leading to high variance in evaluations. Will address separately as that is likely unrelated to SSMs.
  • @ZhanqiuHu has identified an issue with current PD workflow in which we're recomputing the first token on D, leading to burning-in that extra step into the SSM state in-place.

Test with

Enable HMA experimental support with --no-disable-hybrid-kv-cache-manager:

# usual P/D command
vllm serve nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8
--trust-remote-code \
--block-size 64 \
--no-enable-prefix-caching \
--no-disable-hybrid-kv-cache-manager \
 --mamba-ssm-cache-dtype float16 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'

# usual toy_proxy_server.py command

or

HYBRID_SSM=1 bash v1/kv_connector/nixl_integration/config_sweep_accuracy_test.sh

or check out unit tests added with this PR.

Results from running consecutive full lm-eval runs with no prefix caching:

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5444|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8355|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5345|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8340|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5398|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8355|±  |0.0102|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5428|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8332|±  |0.0103|

local-completions ({'model': 'nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8', 'base_url': 'http://127.0.0.1:55483/v1/completions', 'num_concurrent': 10, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5557|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.8506|±  |0.0098|

TODO

  • Address kernel<>logical block size mismatch
  • Benchmark

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support for hybrid SSM-FA models, a significant and complex feature. While the changes span across test configurations, core KV connector logic, and scheduler behavior, a critical security vulnerability was identified in the scheduler's handling of invalid KV cache blocks for HMA-enabled requests. Specifically, the validation logic fails to check all relevant KV cache groups, which could lead to the use of uninitialized memory and potential PII leakage. Additionally, several debug print statements and potential logic errors in the Nixl connector were observed and should be addressed to ensure production readiness.


register_remote_blocks(blocks_data, mamba=False)
if self._is_mamba:
assert self.num_descs == len(blocks_data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This assertion is critical for ensuring that the number of descriptors (self.num_descs) matches the actual number of blocks being registered (len(blocks_data)). An inconsistency here could lead to memory corruption or incorrect KV cache transfers. It's important to verify that self.num_descs is always accurately calculated to reflect all registered blocks, including those for Mamba layers and any logical duplications for K/V splits.

Comment on lines +2118 to +2128
all_req_block_ids = (
(block_id for group in req_block_ids for block_id in group)
if is_hma
else req_block_ids[0]
)
req_num_computed_blocks = (
req_num_computed_tokens + self.block_size - 1
) // self.block_size
for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids):
for idx, block_id in enumerate(all_req_block_ids):
if idx >= req_num_computed_blocks:
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The scheduler's logic for identifying requests affected by invalid KV cache blocks is incomplete when Hybrid Memory Allocator (HMA) is used. The code flattens all KV cache groups into a single list but only iterates through the first req_num_computed_blocks elements. In HMA mode, this typically corresponds to the blocks in the Full Attention group, causing the validation to skip blocks in other groups (e.g., Sliding Window). If a block in these skipped groups failed to load from a remote source, the scheduler will fail to detect it, potentially leading the model runner to use uninitialized or stale GPU memory, which could result in PII leakage or incorrect model outputs.

Modify the validation loop to ensure all blocks in all KV cache groups that are relevant to the computed tokens are checked against the invalid_block_ids set. Since HMA does not support partial recovery, any invalid block in any group should trigger a full eviction and recomputation for the request.

Suggested change
all_req_block_ids = (
(block_id for group in req_block_ids for block_id in group)
if is_hma
else req_block_ids[0]
)
req_num_computed_blocks = (
req_num_computed_tokens + self.block_size - 1
) // self.block_size
for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids):
for idx, block_id in enumerate(all_req_block_ids):
if idx >= req_num_computed_blocks:
break
if is_hma:
all_req_block_ids = [
block_id for group in req_block_ids for block_id in group
]
req_num_computed_blocks = len(all_req_block_ids)
else:
all_req_block_ids = req_block_ids[0]
req_num_computed_blocks = (
req_num_computed_tokens + self.block_size - 1
) // self.block_size
for idx, block_id in enumerate(all_req_block_ids):
if idx >= req_num_computed_blocks:
break

Comment on lines +1489 to +1492
print(f"{self.vllm_config.cache_config.mamba_page_size_padded=}\n\n")
# block size: 400, the one from the FA spec
print(f"block size: {self.block_size}\n\n")
print("NUM_BLOCKS: ", self.num_blocks, "\n\n", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These print statements appear to be for debugging purposes. They should be removed before merging to avoid unnecessary output in production environments.

if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes

print(f"{layer_name=}, {[v.shape for v in cache_list]}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be for debugging purposes. It should be removed before merging to avoid unnecessary output in production environments.

local_block_len = self.get_backend_aware_kv_block_len(
layer_idx=i, first_split=True, mamba_view=mamba
)
print(f"Add agent {i=}, {local_block_len=}\n", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This print statement appears to be for debugging purposes. It should be removed before merging to avoid unnecessary output in production environments.

layer_spec.page_size_bytes
if isinstance(layer_spec, MambaSpec)
else layer_spec.page_size_bytes
// self._physical_blocks_per_logical_kv_block
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The block_len_per_layer list is populated conditionally for non-Mamba specs and then truncated based on seen_base_addresses. This approach can be fragile. If the order or count of seen_base_addresses does not perfectly align with the non-Mamba layers for which block_len_per_layer was intended, it could lead to incorrect block length assignments. Consider ensuring a more robust mapping or initialization of block_len_per_layer that directly corresponds to the registered regions.

@ZhanqiuHu
Copy link
Copy Markdown
Contributor

ZhanqiuHu commented Mar 10, 2026

Tested the hybrid SSM P/D disaggregation on 2× H100 with nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 (--enforce-eager --block-size 128 --no-disable-hybrid-kv-cache-manager, NixlConnector kv_role=kv_both). Ran lm_eval gsm8k 5-shot (1319 questions) across four configs:

Config strict-match flexible-extract
Direct GPU 0 0.8605 ± 0.0095 0.5603 ± 0.0137
Direct GPU 1 0.8469 ± 0.0099 0.5663 ± 0.0137
P/D GPU 0 → GPU 1 0.8431 ± 0.0100 0.5618 ± 0.0137
P/D GPU 1 → GPU 0 0.8355 ± 0.0102 0.5512 ± 0.0137
  • All within the CI tolerance of 0.84 (RTOL=0.03).
  • Prometheus metrics reports the external prefix caching hit correctly.
  • Local prefic caching disabled automatically, so local prefic caching hit reporting 0 as expected.
  • Differences exist but are within the range of cross-GPU non-determinism.

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 11, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

1 similar comment
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 11, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

"DP_EP=1 GPU_MEMORY_UTILIZATION=0.8 PREFILLER_TP_SIZE=2 DECODER_TP_SIZE=2 MODEL_NAMES=deepseek-ai/deepseek-vl2-tiny" # MLA+P-TP2, D-DPEP=2 (TP=1)
)
hybrid_ssm_configs=(
"ENABLE_HMA_FLAG=1 GPU_MEMORY_UTILIZATION=0.8 MODEL_NAMES=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8 VLLM_SERVE_EXTRA_ARGS=--max-model-len,8192,--trust-remote-code"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope this fits on CI

Comment on lines +999 to +1001
if self._is_mamba:
assert self._is_hma_required
mamba_spec = next(
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could probably wrap this bit to reduce bloat

Comment on lines +1565 to +1582
page_size = (
layer_spec.page_size_bytes
if isinstance(layer_spec, MambaSpec)
else layer_spec.page_size_bytes
// self._physical_blocks_per_logical_kv_block
)
num_blocks = (
self._logical_num_blocks
if isinstance(layer_spec, MambaSpec)
else self.num_blocks
)
# `page_size` accounts for physical blocks, st KVCache is always
# [`num_blocks` * `page_size`]
if not isinstance(layer_spec, MambaSpec):
self.block_len_per_layer.append(page_size)
curr_tensor_size_bytes = num_blocks * page_size
if tensor_size_bytes is None:
tensor_size_bytes = curr_tensor_size_bytes
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic has been moved from within the inner loop to here, extending work to solely rely on KVCacheConfig rather than tensor view.

blocks_data: list[tuple[int, int, int]] = []
local_base_addresses = self.kv_caches_base_addr[self.engine_id][self.tp_rank]

def register_blocks(blocks_data: list[tuple[int, int, int]], mamba: bool):
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrapping the whole block in a functoin to re-use for the mamba descriptors, appended at the end

# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
kv_topo = self.kv_topo
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy was complaining

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 11, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Mar 11, 2026

I see the scheduler changes related to failure recovery.
Is that intended? I thought you were somehow disabling it altogether for the nixl connector

@NickLucche
Copy link
Copy Markdown
Collaborator Author

@orozery yep sorry took a few minutes to realize I had cruft, removed in ffb48cf

Copy link
Copy Markdown
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial comments (haven't finished reading it all yet).

"deepseek-ai/deepseek-vl2-tiny": 0.19,
"deepseek-ai/DeepSeek-V2-Lite-Chat": 0.65,
"google/gemma-3-4b-it": 0.74,
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-FP8": 0.84,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite a big model for CI

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can switch to granite

Comment on lines +1578 to +1579
if not isinstance(layer_spec, MambaSpec):
self.block_len_per_layer.append(page_size)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I miss something but where does self.block_len_per_layer get populated for the Mamba layers?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, there's a comment here where I define the var

        # Enable different block lengths for different layers *only* when MLA is used.
        # This is not used for SSM layers, which use the counterpart `mamba_ssm_size`.

let me know if that should be expanded.
Basically UniformTypeKVCacheSpecs can allow for different page sizes. Currently that is only used for dsv32 Indexer afaik @heheda12345

@mergify
Copy link
Copy Markdown

mergify bot commented Mar 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @NickLucche.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 11, 2026
# we just mock num_blocks to 1 for the dimension check below.
self._is_kv_layout_blocks_first = (
# Hybrid SSM models assume a single blocks_first layout
self._is_kv_layout_blocks_first = self.is_mamba or (
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if _is_kv_layout_blocks_first could be a property of the attention backend rather than needing to compute it here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point!
We can address it in a scoped PR

Comment on lines +534 to +535
# Regular case: backends like FA register K/V in separate regions
return cache if self.split_k_and_v else [cache]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again this is probably just my lack for familiarity with this part of the code, but how does returning the tensor vs. the tensor wrapped in a list relate to registering the K/V separately?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because in the original register_kv_cache code we iterate over the returned value as in
for cache in cache_list

start refactoring
address kernel block size miscmatch by handling 2 num_blocks

Signed-off-by: NickLucche <nlucches@redhat.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 13, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: NickLucche <nlucches@redhat.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 13, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: NickLucche <nlucches@redhat.com>
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 13, 2026

Hi @NickLucche, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

@NickLucche
Copy link
Copy Markdown
Collaborator Author

/gemini review

Signed-off-by: NickLucche <nlucches@redhat.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for hybrid SSM-FA models to the NixlConnector, which is a significant feature enhancement. The changes are extensive, touching upon KV cache registration, descriptor management, and metadata handling to accommodate the specific requirements of Mamba-based models alongside traditional attention mechanisms. The addition of comprehensive unit tests is commendable. I've identified a critical issue related to the calculation of page sizes for Mamba layers in the presence of a kernel block size mismatch, which could lead to incorrect behavior. The logic is unnecessarily complex and error-prone. My review includes a suggestion to refactor this for correctness and improved clarity.

@NickLucche NickLucche added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 14, 2026
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche merged commit f5c081d into vllm-project:main Mar 16, 2026
52 of 53 checks passed
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
andylolu2 pushed a commit to andylolu2/vllm that referenced this pull request Mar 18, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: wendyliu235 <wenjun.liu@intel.com>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
Signed-off-by: EricccYang <yangyang4991@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

kv-connector ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants