Default to 'align' mamba cache mode for Mamba-based models when speculative decoding is enabled#40454
Conversation
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
There was a problem hiding this comment.
Code Review
This pull request updates the configuration logic in vllm/model_executor/models/config.py to default the Mamba cache mode to 'align' when both prefix caching and speculative decoding are enabled. A critical issue was identified where defaulting to 'align' mode without also enabling chunked prefill results in a server crash due to internal assertions. The reviewer recommends automatically enabling chunked prefill whenever 'align' mode is selected as the default to prevent this regression.
| if ( | ||
| model_config.supports_mamba_prefix_caching | ||
| and vllm_config.speculative_config is not None | ||
| ): | ||
| cache_config.mamba_cache_mode = "align" | ||
| logger.warning( | ||
| "Mamba cache mode is set to 'align' for %s by default " | ||
| "when prefix caching and speculative decoding are enabled", | ||
| model_config.architecture, | ||
| ) | ||
| else: | ||
| cache_config.mamba_cache_mode = ( | ||
| "all" if model_config.supports_mamba_prefix_caching else "align" | ||
| ) | ||
| logger.warning( | ||
| "Mamba cache mode is set to '%s' for %s by default " | ||
| "when prefix caching is enabled", | ||
| cache_config.mamba_cache_mode, | ||
| model_config.architecture, | ||
| ) |
There was a problem hiding this comment.
Defaulting to align mode for Mamba cache will cause a server crash if chunked prefill is not enabled, due to the strict assertion at line 359. Since this PR increases the cases where align is used as a default (specifically when speculative decoding is enabled), we should ensure that enable_chunked_prefill is automatically enabled to avoid this regression in usability.
Note that this requirement applies whenever mamba_cache_mode is set to align. It would be ideal to handle this enablement consistently for all paths that lead to align mode.
if (
model_config.supports_mamba_prefix_caching
and vllm_config.speculative_config is not None
):
cache_config.mamba_cache_mode = "align"
vllm_config.scheduler_config.enable_chunked_prefill = True
logger.warning(
"Mamba cache mode is set to 'align' for %s by default "
"when prefix caching and speculative decoding are enabled. "
"Chunked prefill has been enabled as it is required for 'align' mode.",
model_config.architecture,
)
else:
cache_config.mamba_cache_mode = (
"all" if model_config.supports_mamba_prefix_caching else "align"
)
if cache_config.mamba_cache_mode == "align":
vllm_config.scheduler_config.enable_chunked_prefill = True
logger.warning(
"Mamba cache mode is set to '%s' for %s by default "
"when prefix caching is enabled",
cache_config.mamba_cache_mode,
model_config.architecture,
)…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: Yifan <yzong@redhat.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: Avinash Singh <avinashsingh.rcoem@gmail.com>
…en speculative decoding is enabled (vllm-project#40454)" This reverts commit f819265. Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: Adrian <info@zzit.ch>
…en speculative decoding is enabled (vllm-project#40454)" This reverts commit f819265. Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com> Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
…en speculative decoding is enabled (vllm-project#40454)" This reverts commit f819265. Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com>
…en speculative decoding is enabled (vllm-project#40454)" This reverts commit f819265. Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com>
…en speculative decoding is enabled (vllm-project#40454)" This reverts commit f819265. Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com>
…lative decoding is enabled (vllm-project#40454) Signed-off-by: Roi Koren <roik@nvidia.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Purpose
The 'all' mamba cache mode seems to be buggy at the moment, when combined with speculative decoding, at least when it comes to Nemotron models. For example - #39809. This PR defaults the mode to 'align', which might be less efficient in prefix caching, but works consistently, at least until we fix 'all' mode in combination with SpecDec.
Test Plan
All current tests pass.
Test Result
All current tests pass.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.