-
-
Notifications
You must be signed in to change notification settings - Fork 15.7k
Add FULL CUDA-Graph support for KV Connector path #27026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,9 +7,7 @@ | |
| import copy | ||
| from collections.abc import Generator | ||
| from contextlib import AbstractContextManager, contextmanager, nullcontext | ||
| from typing import ( | ||
| TYPE_CHECKING, # noqa: UP035 | ||
| ) | ||
| from typing import Final | ||
|
|
||
| from vllm.config import VllmConfig | ||
| from vllm.distributed.kv_transfer import ( | ||
|
|
@@ -21,17 +19,29 @@ | |
| from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats | ||
| from vllm.forward_context import get_forward_context, set_forward_context | ||
| from vllm.logger import init_logger | ||
| from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput | ||
| from vllm.v1.outputs import ( | ||
| EMPTY_MODEL_RUNNER_OUTPUT, | ||
| KVConnectorOutput, | ||
| ModelRunnerOutput, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.v1.core.sched.output import SchedulerOutput | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| _EMPTY_SCHEDULER_OUTPUT: Final[SchedulerOutput] = SchedulerOutput( | ||
| scheduled_new_reqs=[], | ||
| scheduled_cached_reqs=CachedRequestData.make_empty(), | ||
| num_scheduled_tokens={}, | ||
| total_num_scheduled_tokens=0, | ||
| scheduled_spec_decode_tokens={}, | ||
| scheduled_encoder_inputs={}, | ||
| num_common_prefix_blocks=[], | ||
| finished_req_ids=set(), | ||
| free_encoder_mm_hashes=[], | ||
| structured_output_request_ids=[], | ||
| grammar_bitmask=None, | ||
| ) | ||
|
|
||
|
|
||
| # Defined as a kv connector functionality mixin for ModelRunner (GPU, TPU) | ||
| class KVConnectorModelRunnerMixin: | ||
|
|
@@ -73,7 +83,7 @@ def get_finished_kv_transfers( | |
|
|
||
| @staticmethod | ||
| def kv_connector_no_forward( | ||
| scheduler_output: "SchedulerOutput", vllm_config: VllmConfig | ||
| scheduler_output: SchedulerOutput, vllm_config: VllmConfig | ||
| ) -> ModelRunnerOutput: | ||
| # KV send/recv even if no work to do. | ||
| with ( | ||
|
|
@@ -93,14 +103,29 @@ def kv_connector_no_forward( | |
|
|
||
| @staticmethod | ||
| def maybe_get_kv_connector_output( | ||
| scheduler_output: "SchedulerOutput", | ||
| scheduler_output: SchedulerOutput, | ||
| ) -> AbstractContextManager[KVConnectorOutput | None]: | ||
| return ( | ||
| KVConnectorModelRunnerMixin._get_kv_connector_output(scheduler_output) | ||
| if has_kv_transfer_group() | ||
| else nullcontext() | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def maybe_get_kv_connector_dummy_run_output() -> AbstractContextManager[ | ||
| KVConnectorOutput | None | ||
| ]: | ||
| global _EMPTY_SCHEDULER_OUTPUT | ||
| if has_kv_transfer_group(): | ||
| kv_connector = get_kv_transfer_group() | ||
| meta = kv_connector.build_connector_meta(_EMPTY_SCHEDULER_OUTPUT) | ||
| _EMPTY_SCHEDULER_OUTPUT.kv_connector_metadata = meta | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: behavior does not (rightfully) match the other EMPTY_* objects which are copied to guarantee they remain EMPTY_*, might be misleading to some level
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True — unlike other EMPTY_* objects, this one is only used internally for this specific path, and the leading |
||
| return KVConnectorModelRunnerMixin._get_kv_connector_output( | ||
| _EMPTY_SCHEDULER_OUTPUT | ||
| ) | ||
|
|
||
| return nullcontext() | ||
|
|
||
| # This context manager must be used within an active forward context. | ||
| # It encapsulates the entire KV connector lifecycle within execute_model | ||
| @staticmethod | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a fan of having to do this, also we probably should not maintain this structure here as it belongs with the scheduler "namespace".
Furthermore, I am not sure this pre-init has any benefit given it should only be called once during the dummy run.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that
SchedulerOutputis already an input toGPUModelRunner, I think it’s acceptable for the runner to construct a dummy instance for its own internal use.Also, the dummy run may be invoked multiple times (e.g., for model warm-up or for capturing different CUDA graphs), so having a pre-initialized SchedulerOutput keeps this path simple and consistent.