-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[KVConnector] MultiConnector SupportsHMA #39571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,9 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import copy | ||
| from collections.abc import Iterable | ||
| from collections.abc import Callable, Iterable | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Any | ||
| from typing import TYPE_CHECKING, Any, cast | ||
|
|
||
| import torch | ||
|
|
||
|
|
@@ -18,6 +18,8 @@ | |
| KVConnectorMetadata, | ||
| KVConnectorRole, | ||
| KVConnectorWorkerMetadata, | ||
| SupportsHMA, | ||
| supports_hma, | ||
| ) | ||
| from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( | ||
| KVConnectorPromMetrics, | ||
|
|
@@ -123,7 +125,7 @@ def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): | |
| self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx) | ||
|
|
||
|
|
||
| class MultiConnector(KVConnectorBase_V1): | ||
| class MultiConnector(KVConnectorBase_V1, SupportsHMA): | ||
| """ | ||
| A wrapper for using multiple KVConnectors at the same time. | ||
|
|
||
|
|
@@ -166,6 +168,12 @@ def __init__( | |
| self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) | ||
| self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) | ||
|
|
||
| self._all_support_hma = all(supports_hma(c) for c in self._connectors) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we want to add here:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @orozery I have a model-dependent check here which is more confined.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should assert on
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added the assert |
||
| assert ( | ||
| vllm_config.scheduler_config.disable_hybrid_kv_cache_manager | ||
| or self._all_support_hma | ||
| ), "HMA should not be enabled unless all sub-connectors support it" | ||
|
|
||
| # A mapping from request id to the index of the connector chosen to | ||
| # load the request from (if any). | ||
| self._requests_to_connector: dict[str, int] = {} | ||
|
|
@@ -436,15 +444,17 @@ def set_xfer_handshake_metadata( | |
| for c in self._connectors: | ||
| c.set_xfer_handshake_metadata(metadata) | ||
|
|
||
| def request_finished( | ||
| def _aggregate_request_finished( | ||
| self, | ||
| request: "Request", | ||
| blocks: list[int], | ||
| per_connector_fn: Callable[ | ||
| [KVConnectorBase_V1], tuple[bool, dict[str, Any] | None] | ||
| ], | ||
| ) -> tuple[bool, dict[str, Any] | None]: | ||
| async_saves = 0 | ||
| kv_txfer_params = None | ||
| for c in self._connectors: | ||
| async_save, txfer_params = c.request_finished(request, blocks) | ||
| async_save, txfer_params = per_connector_fn(c) | ||
| if async_save: | ||
| async_saves += 1 | ||
| if txfer_params is not None: | ||
|
|
@@ -458,11 +468,39 @@ def request_finished( | |
| if async_saves > 1: | ||
| self._extra_async_saves[request.request_id] = async_saves - 1 | ||
|
|
||
| # Clean up other state for this request. | ||
| self._requests_to_connector.pop(request.request_id, None) | ||
|
|
||
| return async_saves > 0, kv_txfer_params | ||
|
|
||
| def request_finished( | ||
| self, | ||
| request: "Request", | ||
| blocks: list[int], | ||
| ) -> tuple[bool, dict[str, Any] | None]: | ||
| return self._aggregate_request_finished( | ||
| request, | ||
| lambda c: c.request_finished(request, blocks), | ||
| ) | ||
|
|
||
| def request_finished_all_groups( | ||
| self, | ||
| request: "Request", | ||
| block_ids: tuple[list[int], ...], | ||
| ) -> tuple[bool, dict[str, Any] | None]: | ||
| if not self._all_support_hma: | ||
| assert len(block_ids) == 1, ( | ||
| "HMA with multiple kv_cache_groups requires all " | ||
| "sub-connectors to support HMA" | ||
| ) | ||
|
NickLucche marked this conversation as resolved.
|
||
| return self.request_finished(request, block_ids[0]) | ||
|
|
||
| return self._aggregate_request_finished( | ||
| request, | ||
| lambda c: cast(SupportsHMA, c).request_finished_all_groups( | ||
| request, block_ids | ||
| ), | ||
| ) | ||
|
|
||
| def take_events(self) -> Iterable["KVCacheEvent"]: | ||
| for c in self._connectors: | ||
| yield from c.take_events() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is problematic.
HMA will be enabled though you may have sub-connectors which don't support it.
Their
update_state_after_allocandbuild_connector_metamay fail.I think we need to replace
SupportsHMAwith a runtime check likeprefers_cross_layer_blocks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@orozery I am doing a runtime check in this PR with
to fall back for the request_finished method.
I think the easiest solution might be to just ensure hma is disabled for such cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think hma will be enabled if and only if multiconnector SupportsHMA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually let me rephrase that.
If the user explicitly sets
--no-disable-hybrid-kv-cache-managerand still packs a sub-connector with no HMA support, due to the runtime check above, any hma-requiring model (Mamba/SW etc) will crash heredo you have another scenario in mind? @orozery
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default, hma is enabled, unless using a connector which does not have SupportsHMA, right?
If so, this PR will enable HMA by default for the multi connector, even if sub connectors do not support HMA.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, if connector is configured, HMA will be disabled by default unless users explicitly set
--no-disable-hybrid-kv-cache-manageras @NickLucche mentioned.With that being said, I agree that having a connector API for
supports_hma()(and changekv_connector/v1/base.py::supports_hma()to call connector's API) is a more robust solution.One quick alternative for now is to change
kv_connector/v1/base.py::supports_hma()with a special path for multi connector, which iteratively checks each connector, and making the rest of API changes as another follow-up PR.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this means that if you use a connector you must set either:
--disable-hybrid-kv-cache-manageror--no-disable-hybrid-kv-cache-managerRight?
This is weird IMO.
I think that using a connector which supports HMA should work (with HMA enabled) without requiring any more CLI, just as HMA works without explicit enablement if not using a connector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope it's just HMA is opt-in for Connector still.
This is the logic for HMA on/off
vllm/vllm/config/vllm.py
Lines 1274 to 1277 in 6f786f2
but I don't think that needs to be changed with this PR, imo the scope is narrower here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@orozery @ivanium
I also think a
kv_connector/v1/base.py::supports_hma()interface change is in order, but that should be a follow-up PR (which I can still put up in the meantime), where we can also look at changing the auto-on/off behavior.But I think right now, given the hma feature is still experimental/opt-in, the changes proposed here are safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am okay with this PR and addressing the remaining UX issues in a follow-up PR