From 5028299978b182a068f86bece80e16c31865dd39 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Thu, 24 Apr 2025 21:25:53 +0200 Subject: [PATCH] Move missed `SchedulerConfig` args into scheduler config group in `EngineArgs` Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config.py | 3 ++- vllm/engine/arg_utils.py | 23 +++++++---------------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 41a30efea039..8551e0dd734a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1778,6 +1778,7 @@ def _verify_args(self) -> None: "worker_extension_cls must be a string (qualified class name).") +PreemptionMode = Literal["swap", "recompute"] SchedulerPolicy = Literal["fcfs", "priority"] @@ -1854,7 +1855,7 @@ class SchedulerConfig: NOTE: This is not currently configurable. It will be overridden by max_num_batched_tokens in case max multimodal embedding size is larger.""" - preemption_mode: Optional[str] = None + preemption_mode: Optional[PreemptionMode] = None """Whether to perform preemption by swapping or recomputation. If not specified, we determine the mode as follows: We use recomputation by default since it incurs lower overhead than diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9cb2aa797be5..68bea93b3354 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -753,12 +753,6 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: ) device_group.add_argument("--device", **device_kwargs["device"]) - parser.add_argument('--num-scheduler-steps', - type=int, - default=1, - help=('Maximum number of forward steps per ' - 'scheduler call.')) - # Speculative arguments speculative_group = parser.add_argument_group( title="SpeculativeConfig", @@ -779,13 +773,6 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: help="The pattern(s) to ignore when loading the model." "Default to `original/**/*` to avoid repeated loading of llama's " "checkpoints.") - parser.add_argument( - '--preemption-mode', - type=str, - default=None, - help='If \'recompute\', the engine performs preemption by ' - 'recomputing; If \'swap\', the engine performs preemption by ' - 'block swapping.') parser.add_argument( "--served-model-name", @@ -865,14 +852,18 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: **scheduler_kwargs["num_lookahead_slots"]) scheduler_group.add_argument('--scheduler-delay-factor', **scheduler_kwargs["delay_factor"]) - scheduler_group.add_argument( - '--enable-chunked-prefill', - **scheduler_kwargs["enable_chunked_prefill"]) + scheduler_group.add_argument('--preemption-mode', + **scheduler_kwargs["preemption_mode"]) + scheduler_group.add_argument('--num-scheduler-steps', + **scheduler_kwargs["num_scheduler_steps"]) scheduler_group.add_argument( '--multi-step-stream-outputs', **scheduler_kwargs["multi_step_stream_outputs"]) scheduler_group.add_argument('--scheduling-policy', **scheduler_kwargs["policy"]) + scheduler_group.add_argument( + '--enable-chunked-prefill', + **scheduler_kwargs["enable_chunked_prefill"]) scheduler_group.add_argument( "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"])