From a376cf40d92109e89345ca4e08875dc4a7ab9bda Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 1 Dec 2025 18:55:13 +0000 Subject: [PATCH 1/7] enable hma+kv_conn Signed-off-by: NickLucche --- tests/v1/kv_connector/unit/test_config.py | 11 ++++++++ vllm/config/vllm.py | 32 ++++++++++++++--------- vllm/envs.py | 5 ++++ 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_config.py b/tests/v1/kv_connector/unit/test_config.py index 6cf86f3d5c4a..35a5c1720095 100644 --- a/tests/v1/kv_connector/unit/test_config.py +++ b/tests/v1/kv_connector/unit/test_config.py @@ -63,3 +63,14 @@ def test_kv_connector( assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes # Existing config should be replaced assert "existing_key" not in kv_connector_extra_config + + +def test_kv_connector_hma_enabled(monkeypatch): + # Default is False + vllm_config = VllmConfig(kv_transfer_config=KVTransferConfig()) + assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is True + + # Optionally enable HMA for KV connector as experimental feature + monkeypatch.setenv("VLLM_USE_HMA_FOR_KV_CONNECTOR", "1") + vllm_config = VllmConfig(kv_transfer_config=KVTransferConfig()) + assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a74413536407..c28be4239b71 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -897,19 +897,25 @@ def has_blocked_weights(): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_transfer_config is not None: - # NOTE(Kuntai): turn HMA off for connector for now. - # TODO(Kuntai): have a more elegent solution to check and - # turn off HMA for connector that does not support HMA. - logger.warning( - "Turning off hybrid kv cache manager because " - "`--kv-transfer-config` is set. This will reduce the " - "performance of vLLM on LLMs with sliding window attention " - "or Mamba attention. If you are a developer of kv connector" - ", please consider supporting hybrid kv cache manager for " - "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py." - ) - self.scheduler_config.disable_hybrid_kv_cache_manager = True + if envs.VLLM_USE_HMA_FOR_KV_CONNECTOR: + logger.info( + "Hybrid kv cache manager is enabled for KV connector." + "This is an experimental feature." + ) + else: + # NOTE(Kuntai): turn HMA off for connector for now. + # TODO(Kuntai): have a more elegent solution to check and + # turn off HMA for connector that does not support HMA. + logger.warning( + "Turning off hybrid kv cache manager because " + "`--kv-transfer-config` is set. This will reduce the " + "performance of vLLM on LLMs with sliding window attention " + "or Mamba attention. If you are a developer of kv connector" + ", please consider supporting hybrid kv cache manager for " + "your connector by making sure your connector is a subclass" + " of `SupportsHMA` defined in kv_connector/v1/base.py." + ) + self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True diff --git a/vllm/envs.py b/vllm/envs.py index bda9e6e42335..3b040bf27102 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -241,6 +241,7 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_USE_HMA_FOR_KV_CONNECTOR: bool = False def get_default_cache_root(): @@ -1579,6 +1580,10 @@ def get_vllm_port() -> int | None: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Flag to enable HMA for KV connector (experimental). + "VLLM_USE_HMA_FOR_KV_CONNECTOR": lambda: bool( + int(os.getenv("VLLM_USE_HMA_FOR_KV_CONNECTOR", "0")) + ), } # --8<-- [end:env-vars-definition] From 95d4971b9cdaa94e596604a1948085b68c34f007 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 4 Dec 2025 12:36:28 +0000 Subject: [PATCH 2/7] Revert "enable hma+kv_conn" This reverts commit 250ea870740e99ced7b496d61d4c1056d6b98156. Signed-off-by: NickLucche --- tests/v1/kv_connector/unit/test_config.py | 11 -------- vllm/config/vllm.py | 32 +++++++++-------------- vllm/envs.py | 5 ---- 3 files changed, 13 insertions(+), 35 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_config.py b/tests/v1/kv_connector/unit/test_config.py index 35a5c1720095..6cf86f3d5c4a 100644 --- a/tests/v1/kv_connector/unit/test_config.py +++ b/tests/v1/kv_connector/unit/test_config.py @@ -63,14 +63,3 @@ def test_kv_connector( assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes # Existing config should be replaced assert "existing_key" not in kv_connector_extra_config - - -def test_kv_connector_hma_enabled(monkeypatch): - # Default is False - vllm_config = VllmConfig(kv_transfer_config=KVTransferConfig()) - assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is True - - # Optionally enable HMA for KV connector as experimental feature - monkeypatch.setenv("VLLM_USE_HMA_FOR_KV_CONNECTOR", "1") - vllm_config = VllmConfig(kv_transfer_config=KVTransferConfig()) - assert vllm_config.scheduler_config.disable_hybrid_kv_cache_manager is False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index c28be4239b71..a74413536407 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -897,25 +897,19 @@ def has_blocked_weights(): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_transfer_config is not None: - if envs.VLLM_USE_HMA_FOR_KV_CONNECTOR: - logger.info( - "Hybrid kv cache manager is enabled for KV connector." - "This is an experimental feature." - ) - else: - # NOTE(Kuntai): turn HMA off for connector for now. - # TODO(Kuntai): have a more elegent solution to check and - # turn off HMA for connector that does not support HMA. - logger.warning( - "Turning off hybrid kv cache manager because " - "`--kv-transfer-config` is set. This will reduce the " - "performance of vLLM on LLMs with sliding window attention " - "or Mamba attention. If you are a developer of kv connector" - ", please consider supporting hybrid kv cache manager for " - "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py." - ) - self.scheduler_config.disable_hybrid_kv_cache_manager = True + # NOTE(Kuntai): turn HMA off for connector for now. + # TODO(Kuntai): have a more elegent solution to check and + # turn off HMA for connector that does not support HMA. + logger.warning( + "Turning off hybrid kv cache manager because " + "`--kv-transfer-config` is set. This will reduce the " + "performance of vLLM on LLMs with sliding window attention " + "or Mamba attention. If you are a developer of kv connector" + ", please consider supporting hybrid kv cache manager for " + "your connector by making sure your connector is a subclass" + " of `SupportsHMA` defined in kv_connector/v1/base.py." + ) + self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True diff --git a/vllm/envs.py b/vllm/envs.py index 3b040bf27102..bda9e6e42335 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -241,7 +241,6 @@ VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False - VLLM_USE_HMA_FOR_KV_CONNECTOR: bool = False def get_default_cache_root(): @@ -1580,10 +1579,6 @@ def get_vllm_port() -> int | None: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), - # Flag to enable HMA for KV connector (experimental). - "VLLM_USE_HMA_FOR_KV_CONNECTOR": lambda: bool( - int(os.getenv("VLLM_USE_HMA_FOR_KV_CONNECTOR", "0")) - ), } # --8<-- [end:env-vars-definition] From 2142e462a553ad4d84bc8127b2494420b2e60f1f Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 4 Dec 2025 13:18:47 +0000 Subject: [PATCH 3/7] optional bool flag Signed-off-by: NickLucche --- vllm/config/scheduler.py | 4 +++- vllm/config/vllm.py | 51 +++++++++++++++++++++++++++++----------- vllm/engine/arg_utils.py | 2 +- 3 files changed, 41 insertions(+), 16 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 8da3ae538d67..8abbe8ba0103 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -122,10 +122,12 @@ class SchedulerConfig: the default scheduler. Can be a class directly or the path to a class of form "mod.custom_class".""" - disable_hybrid_kv_cache_manager: bool = False + disable_hybrid_kv_cache_manager: bool | None = None """If set to True, KV cache manager will allocate the same size of KV cache for all attention layers even if there are multiple type of attention layers like full attention and sliding window attention. + If set to None, the default value will be determined based on the environment + and starting configuration. """ async_scheduling: bool = False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index a74413536407..992b19886ff0 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -889,27 +889,17 @@ def has_blocked_weights(): if not self.instance_id: self.instance_id = random_uuid()[:5] + # Runtime-dependent disable of hybrid kv cache manager logic. if not self.scheduler_config.disable_hybrid_kv_cache_manager: + # Forcely disable HMA even if explicitly enabled by user (None/False). + prev_disable_hma = self.scheduler_config.disable_hybrid_kv_cache_manager + # logger should only print warning message for hybrid models. As we # can't know whether the model is hybrid or not now, so we don't log # warning message here and will log it later. if not current_platform.support_hybrid_kv_cache(): # Hybrid KV cache manager is not supported on non-GPU platforms. self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_transfer_config is not None: - # NOTE(Kuntai): turn HMA off for connector for now. - # TODO(Kuntai): have a more elegent solution to check and - # turn off HMA for connector that does not support HMA. - logger.warning( - "Turning off hybrid kv cache manager because " - "`--kv-transfer-config` is set. This will reduce the " - "performance of vLLM on LLMs with sliding window attention " - "or Mamba attention. If you are a developer of kv connector" - ", please consider supporting hybrid kv cache manager for " - "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py." - ) - self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True @@ -935,6 +925,39 @@ def has_blocked_weights(): # local attention. self.scheduler_config.disable_hybrid_kv_cache_manager = True + if ( + prev_disable_hma is False + and self.scheduler_config.disable_hybrid_kv_cache_manager is True + ): + logger.info( + "Hybrid KV cache manager explicitly enabled but not supported in " + "this configuration; falling back to standard manager. Consider " + "omitting this setting to let vLLM decide automatically." + ) + + if ( + self.scheduler_config.disable_hybrid_kv_cache_manager is None + and self.kv_transfer_config is not None + ): + # Disable HMA logic but only if the user didn't express a preference. + # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. + # TODO(Kuntai): have a more elegent solution to check and + # turn off HMA for connector that does not support HMA. + logger.warning( + "Turning off hybrid kv cache manager because " + "`--kv-transfer-config` is set. This will reduce the " + "performance of vLLM on LLMs with sliding window attention " + "or Mamba attention. If you are a developer of kv connector" + ", please consider supporting hybrid kv cache manager for " + "your connector by making sure your connector is a subclass" + " of `SupportsHMA` defined in kv_connector/v1/base.py." + ) + self.scheduler_config.disable_hybrid_kv_cache_manager = True + + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to enable HMA if not explicitly disabled by user or logic above. + self.scheduler_config.disable_hybrid_kv_cache_manager = False + if self.compilation_config.debug_dump_path: self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ceac5407af6e..7ea2b60e7689 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -494,7 +494,7 @@ class EngineArgs: enable_chunked_prefill: bool | None = None disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input - disable_hybrid_kv_cache_manager: bool = ( + disable_hybrid_kv_cache_manager: bool | None = ( SchedulerConfig.disable_hybrid_kv_cache_manager ) From 664db48abfb4bd919f350d1c5973cbc7a422d7a5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 5 Dec 2025 16:36:45 +0000 Subject: [PATCH 4/7] comment Signed-off-by: NickLucche --- vllm/config/vllm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 992b19886ff0..3240b276668f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -950,7 +950,8 @@ def has_blocked_weights(): "or Mamba attention. If you are a developer of kv connector" ", please consider supporting hybrid kv cache manager for " "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py." + " of `SupportsHMA` defined in kv_connector/v1/base.py and" + " use --no-disable-hybrid-kv-cache-manager to start vLLM." ) self.scheduler_config.disable_hybrid_kv_cache_manager = True From e85444851cf496e7a6724ba80faba89e4a9e9c04 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 9 Dec 2025 17:31:21 +0000 Subject: [PATCH 5/7] rasing Signed-off-by: NickLucche --- vllm/config/vllm.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 3240b276668f..35c80dd9e982 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -929,10 +929,11 @@ def has_blocked_weights(): prev_disable_hma is False and self.scheduler_config.disable_hybrid_kv_cache_manager is True ): - logger.info( - "Hybrid KV cache manager explicitly enabled but not supported in " - "this configuration; falling back to standard manager. Consider " - "omitting this setting to let vLLM decide automatically." + raise ValueError( + "Hybrid KV cache manager was explicitly enabled but is not " + "supported in this configuration. Consider omitting the " + "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide" + " automatically." ) if ( From 4d595659eaa42b45acce0cc13391333cdc92b31c Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 9 Dec 2025 17:42:31 +0000 Subject: [PATCH 6/7] comment Signed-off-by: NickLucche --- vllm/config/vllm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 35c80dd9e982..0cf372223459 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -891,7 +891,6 @@ def has_blocked_weights(): # Runtime-dependent disable of hybrid kv cache manager logic. if not self.scheduler_config.disable_hybrid_kv_cache_manager: - # Forcely disable HMA even if explicitly enabled by user (None/False). prev_disable_hma = self.scheduler_config.disable_hybrid_kv_cache_manager # logger should only print warning message for hybrid models. As we From c965044d79fce4e117a9c3320f4314f437a74de2 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 11 Dec 2025 16:25:58 +0000 Subject: [PATCH 7/7] chen review Signed-off-by: NickLucche --- vllm/config/vllm.py | 118 ++++++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0cf372223459..a956044fabab 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -889,71 +889,71 @@ def has_blocked_weights(): if not self.instance_id: self.instance_id = random_uuid()[:5] - # Runtime-dependent disable of hybrid kv cache manager logic. - if not self.scheduler_config.disable_hybrid_kv_cache_manager: - prev_disable_hma = self.scheduler_config.disable_hybrid_kv_cache_manager - - # logger should only print warning message for hybrid models. As we - # can't know whether the model is hybrid or not now, so we don't log - # warning message here and will log it later. - if not current_platform.support_hybrid_kv_cache(): - # Hybrid KV cache manager is not supported on non-GPU platforms. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.kv_events_config is not None: - # Hybrid KV cache manager is not compatible with KV events. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - if ( - self.model_config is not None - and self.model_config.attention_chunk_size is not None - ): - if ( - self.speculative_config is not None - and self.speculative_config.use_eagle() - ): - # Hybrid KV cache manager is not yet supported with chunked - # local attention + eagle. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: - logger.warning( - "There is a latency regression when using chunked local" - " attention with the hybrid KV cache manager. Disabling" - " it, by default. To enable it, set the environment " - "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." - ) - # Hybrid KV cache manager is not yet supported with chunked - # local attention. - self.scheduler_config.disable_hybrid_kv_cache_manager = True - + # Hybrid KV cache manager (HMA) runtime rules: + # - Explicit enable (--no-disable-kv-cache-manager): error if runtime + # disables it + # - No preference: auto-disable for unsupported features (e.g. kv connector) + # - Explicit disable (--disable-kv-cache-manager): always respect it + need_disable_hybrid_kv_cache_manager = False + # logger should only print warning message for hybrid models. As we + # can't know whether the model is hybrid or not now, so we don't log + # warning message here and will log it later. + if not current_platform.support_hybrid_kv_cache(): + # Hybrid KV cache manager is not supported on non-GPU platforms. + need_disable_hybrid_kv_cache_manager = True + if self.kv_events_config is not None: + # Hybrid KV cache manager is not compatible with KV events. + need_disable_hybrid_kv_cache_manager = True + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): if ( - prev_disable_hma is False - and self.scheduler_config.disable_hybrid_kv_cache_manager is True + self.speculative_config is not None + and self.speculative_config.use_eagle() ): - raise ValueError( - "Hybrid KV cache manager was explicitly enabled but is not " - "supported in this configuration. Consider omitting the " - "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide" - " automatically." + # Hybrid KV cache manager is not yet supported with chunked + # local attention + eagle. + need_disable_hybrid_kv_cache_manager = True + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + logger.warning( + "There is a latency regression when using chunked local" + " attention with the hybrid KV cache manager. Disabling" + " it, by default. To enable it, set the environment " + "VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE=1." ) + # Hybrid KV cache manager is not yet supported with chunked + # local attention. + need_disable_hybrid_kv_cache_manager = True - if ( - self.scheduler_config.disable_hybrid_kv_cache_manager is None - and self.kv_transfer_config is not None + if self.scheduler_config.disable_hybrid_kv_cache_manager is None: + # Default to disable HMA, but only if the user didn't express a preference. + if self.kv_transfer_config is not None: + # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. + need_disable_hybrid_kv_cache_manager = True + logger.warning( + "Turning off hybrid kv cache manager because " + "`--kv-transfer-config` is set. This will reduce the " + "performance of vLLM on LLMs with sliding window attention " + "or Mamba attention. If you are a developer of kv connector" + ", please consider supporting hybrid kv cache manager for " + "your connector by making sure your connector is a subclass" + " of `SupportsHMA` defined in kv_connector/v1/base.py and" + " use --no-disable-hybrid-kv-cache-manager to start vLLM." + ) + self.scheduler_config.disable_hybrid_kv_cache_manager = ( + need_disable_hybrid_kv_cache_manager + ) + elif ( + self.scheduler_config.disable_hybrid_kv_cache_manager is False + and need_disable_hybrid_kv_cache_manager ): - # Disable HMA logic but only if the user didn't express a preference. - # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. - # TODO(Kuntai): have a more elegent solution to check and - # turn off HMA for connector that does not support HMA. - logger.warning( - "Turning off hybrid kv cache manager because " - "`--kv-transfer-config` is set. This will reduce the " - "performance of vLLM on LLMs with sliding window attention " - "or Mamba attention. If you are a developer of kv connector" - ", please consider supporting hybrid kv cache manager for " - "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py and" - " use --no-disable-hybrid-kv-cache-manager to start vLLM." + raise ValueError( + "Hybrid KV cache manager was explicitly enabled but is not " + "supported in this configuration. Consider omitting the " + "--no-disable-hybrid-kv-cache-manager flag to let vLLM decide" + " automatically." ) - self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.scheduler_config.disable_hybrid_kv_cache_manager is None: # Default to enable HMA if not explicitly disabled by user or logic above.