diff --git a/vllm/config.py b/vllm/config.py index 5b5ac40f6aa2..d841eeb7a474 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1245,22 +1245,70 @@ def is_matryoshka(self) -> bool: or getattr(self.hf_config, "is_matryoshka", False)) +BlockSize = Literal[8, 16, 32, 64, 128] +CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] +PrefixCachingHashAlgo = Literal["builtin", "sha256"] + + +@config +@dataclass class CacheConfig: - """Configuration for the KV cache. + """Configuration for the KV cache.""" - Args: - block_size: Size of a cache block in number of tokens. - gpu_memory_utilization: Fraction of GPU memory to use for the - vLLM execution. - swap_space: Size of the CPU swap space per GPU (in GiB). - cache_dtype: Data type for kv cache storage. - is_attention_free: Whether the model is attention-free. - num_gpu_blocks_override: Number of GPU blocks to use. This overrides the - profiled num_gpu_blocks if specified. Does nothing if None. - sliding_window: Sliding window size for the KV cache. - enable_prefix_caching: Whether to enable prefix caching. - cpu_offload_gb: Size of the CPU offload buffer in GiB. + block_size: Optional[BlockSize] = None + """Size of a contiguous cache block in number of tokens. This is ignored on + neuron devices and set to `--max-model-len`. On CUDA devices, only block + sizes up to 32 are supported. On HPU devices, block size defaults to 128. + """ + gpu_memory_utilization: float = 0.9 + """The fraction of GPU memory to be used for the model executor, which can + range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory + utilization. If unspecified, will use the default value of 0.9. This is a + per-instance limit, and only applies to the current vLLM instance. It does + not matter if you have another vLLM instance running on the same GPU. For + example, if you have two vLLM instances running on the same GPU, you can + set the GPU memory utilization to 0.5 for each instance.""" + swap_space: float = 4 + """Size of the CPU swap space per GPU (in GiB).""" + cache_dtype: CacheDType = "auto" + """Data type for kv cache storage. If "auto", will use model data type. + CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports + fp8 (=fp8_e4m3).""" + is_attention_free: bool = False + """Whether the model is attention-free. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + num_gpu_blocks_override: Optional[int] = None + """Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks` + if specified. Does nothing if `None`. Used for testing preemption.""" + sliding_window: Optional[int] = None + """Sliding window size for the KV cache. This is primarily set in + `ModelConfig` and that value should be manually duplicated here.""" + enable_prefix_caching: Optional[bool] = None + """Whether to enable prefix caching. Disabled by default for V0. Enabled by + default for V1.""" + prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin" + """Set the hash algorithm for prefix caching:\n + - "builtin" is Python's built-in hash.\n + - "sha256" is collision resistant but with certain overheads.""" + cpu_offload_gb: float = 0 + """The space in GiB to offload to CPU, per GPU. Default is 0, which means + no offloading. Intuitively, this argument can be seen as a virtual way to + increase the GPU memory size. For example, if you have one 24 GB GPU and + set this to 10, virtually you can think of it as a 34 GB GPU. Then you can + load a 13B model with BF16 weight, which requires at least 26GB GPU memory. + Note that this requires fast CPU-GPU interconnect, as part of the model is + loaded from CPU memory to GPU memory on the fly in each model forward pass. """ + calculate_kv_scales: bool = False + """This enables dynamic calculation of `k_scale` and `v_scale` when + kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model + checkpoint if available. Otherwise, the scales will default to 1.0.""" + + # Will be set after profiling. + num_gpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for GPU memory.""" + num_cpu_blocks: Optional[int] = field(default=None, init=False) + """The number of blocks to allocate for CPU memory.""" def compute_hash(self) -> str: """ @@ -1281,43 +1329,13 @@ def compute_hash(self) -> str: usedforsecurity=False).hexdigest() return hash_str - def __init__( - self, - block_size: int, - gpu_memory_utilization: float, - swap_space: float, - cache_dtype: str, - is_attention_free: bool = False, - num_gpu_blocks_override: Optional[int] = None, - sliding_window: Optional[int] = None, - enable_prefix_caching: bool = False, - prefix_caching_hash_algo: str = "builtin", - cpu_offload_gb: float = 0, - calculate_kv_scales: Optional[bool] = None, - ) -> None: - self.block_size = block_size - self.gpu_memory_utilization = gpu_memory_utilization - self.swap_space_bytes = swap_space * GiB_bytes - self.num_gpu_blocks_override = num_gpu_blocks_override - self.cache_dtype = cache_dtype - self.is_attention_free = is_attention_free - self.sliding_window = sliding_window - self.enable_prefix_caching = enable_prefix_caching - self.prefix_caching_hash_algo = prefix_caching_hash_algo - self.cpu_offload_gb = cpu_offload_gb - self.calculate_kv_scales = calculate_kv_scales + def __post_init__(self) -> None: + self.swap_space_bytes = self.swap_space * GiB_bytes + self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() - # Will be set after profiling. - self.num_gpu_blocks: Optional[int] = None - self.num_cpu_blocks: Optional[int] = None - - # Set calculate_kv_scales to False if the value is unset. - if self.calculate_kv_scales is None: - self.calculate_kv_scales = False - def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info @@ -1336,7 +1354,7 @@ def _verify_args(self) -> None: def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass - elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + elif self.cache_dtype in get_args(CacheDType): logger.info( "Using fp8 data type to store kv cache. It reduces the GPU " "memory footprint and boosts the performance. " @@ -1354,12 +1372,12 @@ def _verify_prefix_caching(self) -> None: "Prefix caching is not supported with sliding window. " "Run with --disable-sliding-window to use prefix caching.") - if self.enable_prefix_caching and self.prefix_caching_hash_algo not in ( - "builtin", "sha256"): + if (self.enable_prefix_caching and self.prefix_caching_hash_algo + not in get_args(PrefixCachingHashAlgo)): raise ValueError( "Unknown prefix caching hash algorithm: " - f"{self.prefix_caching_hash_algo}. Must be either " - "'builtin' or 'sha256'.") + f"{self.prefix_caching_hash_algo}. Must be one of " + f"{get_args(PrefixCachingHashAlgo)}.") def verify_with_parallel_config( self, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1f719392bd9f..4c3477ddf5b3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -16,16 +16,16 @@ import vllm.envs as envs from vllm import version -from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat, - DecodingConfig, Device, DeviceConfig, - DistributedExecutorBackend, HfOverrides, +from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, + Config, ConfigFormat, DecodingConfig, Device, + DeviceConfig, DistributedExecutorBackend, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig, ObservabilityConfig, ParallelConfig, PoolerConfig, - PoolType, PromptAdapterConfig, SchedulerConfig, - SchedulerPolicy, SpeculativeConfig, TaskOption, - TokenizerPoolConfig, VllmConfig, get_attr_docs, - get_field) + PoolType, PrefixCachingHashAlgo, PromptAdapterConfig, + SchedulerConfig, SchedulerPolicy, SpeculativeConfig, + TaskOption, TokenizerPoolConfig, VllmConfig, + get_attr_docs, get_field) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -138,7 +138,7 @@ class EngineArgs: load_format: str = LoadConfig.load_format config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' - kv_cache_dtype: str = 'auto' + kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = None max_model_len: Optional[int] = None # Note: Specifying a custom executor backend by passing a class @@ -154,15 +154,16 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel max_parallel_loading_workers: Optional[ int] = ParallelConfig.max_parallel_loading_workers - block_size: Optional[int] = None - enable_prefix_caching: Optional[bool] = None - prefix_caching_hash_algo: str = "builtin" + block_size: Optional[BlockSize] = CacheConfig.block_size + enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching + prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + CacheConfig.prefix_caching_hash_algo disable_sliding_window: bool = False disable_cascade_attn: bool = False use_v2_block_manager: bool = True - swap_space: float = 4 # GiB - cpu_offload_gb: float = 0 # GiB - gpu_memory_utilization: float = 0.90 + swap_space: float = CacheConfig.swap_space + cpu_offload_gb: float = CacheConfig.cpu_offload_gb + gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization max_num_batched_tokens: Optional[ int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills @@ -211,7 +212,8 @@ class EngineArgs: num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[int] = None + num_gpu_blocks_override: Optional[ + int] = CacheConfig.num_gpu_blocks_override num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots model_loader_extra_config: dict = \ get_field(LoadConfig, "model_loader_extra_config") @@ -250,7 +252,7 @@ class EngineArgs: enable_sleep_mode: bool = False model_impl: str = "auto" - calculate_kv_scales: Optional[bool] = None + calculate_kv_scales: bool = CacheConfig.calculate_kv_scales additional_config: Optional[Dict[str, Any]] = None enable_reasoning: Optional[bool] = None @@ -306,12 +308,19 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} for field in fields(cls): - name = field.name + # Get the default value of the field default = field.default - # This will only be True if default is MISSING if field.default_factory is not MISSING: default = field.default_factory() - kwargs[name] = {"default": default, "help": cls_docs[name]} + + # Get the help text for the field + name = field.name + help = cls_docs[name] + # Escape % for argparse + help = help.replace("%", "%%") + + # Initialise the kwargs dictionary for the field + kwargs[name] = {"default": default, "help": help} # Make note of if the field is optional and get the actual # type of the field if it is @@ -319,6 +328,8 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: field_type = get_args( field.type)[0] if optional else field.type + # Set type, action and choices for the field depending on the + # type of the field if can_be_type(field_type, bool): # Creates --no- and -- flags kwargs[name]["action"] = argparse.BooleanOptionalAction @@ -463,14 +474,6 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: '* "bfloat16" for a balance between precision and range.\n' '* "float" is shorthand for FP32 precision.\n' '* "float32" for FP32 precision.') - parser.add_argument( - '--kv-cache-dtype', - type=str, - choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], - default=EngineArgs.kv_cache_dtype, - help='Data type for kv cache storage. If "auto", will use model ' - 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' - 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') parser.add_argument('--max-model-len', type=human_readable_int, default=EngineArgs.max_model_len, @@ -544,33 +547,30 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: parallel_group.add_argument( '--disable-custom-all-reduce', **parallel_kwargs["disable_custom_all_reduce"]) - # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32, 64, 128], - help='Token block size for contiguous chunks of ' - 'tokens. This is ignored on neuron devices and ' - 'set to ``--max-model-len``. On CUDA devices, ' - 'only block sizes up to 32 are supported. ' - 'On HPU devices, block size defaults to 128.') - parser.add_argument( - "--enable-prefix-caching", - action=argparse.BooleanOptionalAction, - default=EngineArgs.enable_prefix_caching, - help="Enables automatic prefix caching. " - "Use ``--no-enable-prefix-caching`` to disable explicitly.", - ) - parser.add_argument( - "--prefix-caching-hash-algo", - type=str, - choices=["builtin", "sha256"], - default=EngineArgs.prefix_caching_hash_algo, - help="Set the hash algorithm for prefix caching. " - "Options are 'builtin' (Python's built-in hash) or 'sha256' " - "(collision resistant but with certain overheads).", + # KV cache arguments + cache_kwargs = get_kwargs(CacheConfig) + cache_group = parser.add_argument_group( + title="CacheConfig", + description=CacheConfig.__doc__, ) + cache_group.add_argument('--block-size', **cache_kwargs["block_size"]) + cache_group.add_argument('--gpu-memory-utilization', + **cache_kwargs["gpu_memory_utilization"]) + cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"]) + cache_group.add_argument('--kv-cache-dtype', + **cache_kwargs["cache_dtype"]) + cache_group.add_argument('--num-gpu-blocks-override', + **cache_kwargs["num_gpu_blocks_override"]) + cache_group.add_argument("--enable-prefix-caching", + **cache_kwargs["enable_prefix_caching"]) + cache_group.add_argument("--prefix-caching-hash-algo", + **cache_kwargs["prefix_caching_hash_algo"]) + cache_group.add_argument('--cpu-offload-gb', + **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument('--calculate-kv-scales', + **cache_kwargs["calculate_kv_scales"]) + parser.add_argument('--disable-sliding-window', action='store_true', help='Disables sliding window, ' @@ -588,43 +588,6 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: type=int, default=EngineArgs.seed, help='Random seed for operations.') - parser.add_argument('--swap-space', - type=float, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU.') - parser.add_argument( - '--cpu-offload-gb', - type=float, - default=0, - help='The space in GiB to offload to CPU, per GPU. ' - 'Default is 0, which means no offloading. Intuitively, ' - 'this argument can be seen as a virtual way to increase ' - 'the GPU memory size. For example, if you have one 24 GB ' - 'GPU and set this to 10, virtually you can think of it as ' - 'a 34 GB GPU. Then you can load a 13B model with BF16 weight, ' - 'which requires at least 26GB GPU memory. Note that this ' - 'requires fast CPU-GPU interconnect, as part of the model is ' - 'loaded from CPU memory to GPU memory on the fly in each ' - 'model forward pass.') - parser.add_argument( - '--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='The fraction of GPU memory to be used for the model ' - 'executor, which can range from 0 to 1. For example, a value of ' - '0.5 would imply 50%% GPU memory utilization. If unspecified, ' - 'will use the default value of 0.9. This is a per-instance ' - 'limit, and only applies to the current vLLM instance.' - 'It does not matter if you have another vLLM instance running ' - 'on the same GPU. For example, if you have two vLLM instances ' - 'running on the same GPU, you can set the GPU memory utilization ' - 'to 0.5 for each instance.') - parser.add_argument( - '--num-gpu-blocks-override', - type=int, - default=None, - help='If specified, ignore GPU profiling result and use this number' - ' of GPU blocks. Used for testing preemption.') parser.add_argument( '--max-logprobs', type=int, @@ -994,15 +957,6 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: help="Enable sleep mode for the engine. " "(only cuda platform is supported)") - parser.add_argument( - '--calculate-kv-scales', - action='store_true', - help='This enables dynamic calculation of ' - 'k_scale and v_scale when kv-cache-dtype is fp8. ' - 'If calculate-kv-scales is false, the scales will ' - 'be loaded from the model checkpoint if available. ' - 'Otherwise, the scales will default to 1.0.') - parser.add_argument( "--additional-config", type=json.loads, @@ -1625,9 +1579,7 @@ def _set_default_args_v0(self, model_config: ModelConfig) -> None: self.enable_prefix_caching = False # VLLM_V0 only supports builtin hash algo for prefix caching. - if self.prefix_caching_hash_algo is None: - self.prefix_caching_hash_algo = "builtin" - elif self.prefix_caching_hash_algo == "sha256": + if self.prefix_caching_hash_algo == "sha256": raise ValueError( "sha256 is not supported for prefix caching in V0 engine. " "Please use 'builtin'.") @@ -1646,10 +1598,6 @@ def _set_default_args_v1(self, usage_context: UsageContext) -> None: if self.enable_prefix_caching is None: self.enable_prefix_caching = True - # if using prefix caching, we must set a hash algo - if self.enable_prefix_caching and self.prefix_caching_hash_algo is None: - self.prefix_caching_hash_algo = "builtin" - # V1 should use the new scheduler by default. # Swap it only if this arg is set to the original V0 default if self.scheduler_cls == EngineArgs.scheduler_cls: diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index c1f426e5b880..e37a3a578cf2 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -50,7 +50,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cache_config: # neuron needs block_size = max_model_len vllm_config.cache_config.block_size = \ - vllm_config.model_config.max_model_len + vllm_config.model_config.max_model_len # type: ignore @classmethod def is_pin_memory_available(cls) -> bool: