Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,24 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""

fast_moe_cold_start = True
"""Optimization for fast MOE cold start.

This is a bit of a hack that assumes that:
1. the only decoder forward pass being run is the current model
2. the decoder forward pass runs all of the MOEs in the order in which they
are initialized

When the above two conditions hold, this option greatly decreases cold start
time for MOE models.

If the above two conditions don't hold, then this option will lead to silent
incorrectness. The only condition in which this doesn't hold is speculative
decoding, where there is a draft model that may have MOEs in them.

NB: We're working on a longer-term solution that doesn't need these assumptions.
"""

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
"""custom ops that are enabled"""
Expand Down
15 changes: 14 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,22 @@ def create_forward_context(
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):
if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None:
all_moe_layers = vllm_config.compilation_config.static_all_moe_layers
else:
logger.warning_once(
"vllm_config.compilation_config.fast_moe_cold_start is not "
"compatible with speculative decoding so we are ignoring "
"fast_moe_cold_start."
)
all_moe_layers = None
else:
all_moe_layers = None

return ForwardContext(
no_compile_layers=vllm_config.compilation_config.static_forward_context,
all_moe_layers=vllm_config.compilation_config.static_all_moe_layers,
all_moe_layers=all_moe_layers,
virtual_engine=virtual_engine,
attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {},
Expand Down