diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 083b23aef92c..7a69629f707c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -595,6 +595,24 @@ class CompilationConfig: local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" + fast_moe_cold_start = True + """Optimization for fast MOE cold start. + + This is a bit of a hack that assumes that: + 1. the only decoder forward pass being run is the current model + 2. the decoder forward pass runs all of the MOEs in the order in which they + are initialized + + When the above two conditions hold, this option greatly decreases cold start + time for MOE models. + + If the above two conditions don't hold, then this option will lead to silent + incorrectness. The only condition in which this doesn't hold is speculative + decoding, where there is a draft model that may have MOEs in them. + + NB: We're working on a longer-term solution that doesn't need these assumptions. + """ + # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are enabled""" diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a5c833b5e433..e308c05bc669 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -286,9 +286,22 @@ def create_forward_context( additional_kwargs: dict[str, Any] | None = None, skip_compiled: bool = False, ): + if vllm_config.compilation_config.fast_moe_cold_start: + if vllm_config.speculative_config is None: + all_moe_layers = vllm_config.compilation_config.static_all_moe_layers + else: + logger.warning_once( + "vllm_config.compilation_config.fast_moe_cold_start is not " + "compatible with speculative decoding so we are ignoring " + "fast_moe_cold_start." + ) + all_moe_layers = None + else: + all_moe_layers = None + return ForwardContext( no_compile_layers=vllm_config.compilation_config.static_forward_context, - all_moe_layers=vllm_config.compilation_config.static_all_moe_layers, + all_moe_layers=all_moe_layers, virtual_engine=virtual_engine, attn_metadata=attn_metadata, slot_mapping=slot_mapping or {},