diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index aaac2deb12ac..3d43a51e9cf2 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1021,7 +1021,7 @@ def test_kv_connector_unable_to_allocate(): """ # Setup Scheduler With Mock External Cache Hit. - BLOCK_SIZE = 4 + BLOCK_SIZE = 8 NUM_BLOCKS = 10 scheduler = create_scheduler( enable_prefix_caching=True, @@ -1103,7 +1103,7 @@ def test_kv_connector_handles_preemption(): """ # Setup Scheduler With Mock External Cache Hit. - BLOCK_SIZE = 2 + BLOCK_SIZE = 8 # NOTE: there is 1 null block, so this is 6 blocks. NUM_BLOCKS = 7 scheduler = create_scheduler( @@ -1124,8 +1124,8 @@ def test_kv_connector_handles_preemption(): # Both can be scheduled at first, but the second request # will be preempted and re-scheduled. NUM_REQUESTS = 2 - NUM_TOKENS = BLOCK_SIZE * 2 + 1 - MAX_TOKENS = BLOCK_SIZE * 2 + NUM_TOKENS = 3 * BLOCK_SIZE - 1 + MAX_TOKENS = 4 requests = create_requests( num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS, diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index a3aa54634725..3abddfbc997a 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -26,6 +26,7 @@ FILES = [ "vllm/*.py", "vllm/assets", + "vllm/engine", "vllm/distributed", "vllm/entrypoints", "vllm/executor", @@ -36,6 +37,7 @@ "vllm/transformers_utils", "vllm/triton_utils", "vllm/usage", + "vllm/utils", ] # After fixing errors resulting from changing follow_imports @@ -44,7 +46,6 @@ "tests", "vllm/attention", "vllm/compilation", - "vllm/engine", "vllm/inputs", "vllm/lora", "vllm/model_executor", diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 04b1e7bf2ac1..41537b56707e 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +from collections.abc import Callable from dataclasses import field from typing import TYPE_CHECKING, Any, Literal -from pydantic import Field, SkipValidation, field_validator +from pydantic import Field, field_validator from pydantic.dataclasses import dataclass from vllm.config.utils import config @@ -30,7 +31,7 @@ class CacheConfig: """Configuration for the KV cache.""" - block_size: SkipValidation[BlockSize] = None # type: ignore + block_size: BlockSize = None """Size of a contiguous cache block in number of tokens. On CUDA devices, only block sizes up to 32 are supported. @@ -150,6 +151,13 @@ def metrics_info(self): # metrics info return {key: str(value) for key, value in self.__dict__.items()} + @field_validator("block_size", mode="wrap") + @classmethod + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + if value is None: + return value + return handler(value) + @field_validator("cache_dtype", mode="after") @classmethod def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: diff --git a/vllm/config/model.py b/vllm/config/model.py index 6e5757ba037d..4c7bf99c26b4 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast, get_args import torch -from pydantic import ConfigDict, SkipValidation, field_validator, model_validator +from pydantic import ConfigDict, field_validator, model_validator from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -120,7 +120,7 @@ class ModelConfig: Note that the model may support other tasks using the same model runner. """ - tokenizer: SkipValidation[str] = None # type: ignore + tokenizer: str = None """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode = "auto" @@ -171,7 +171,7 @@ class ModelConfig: """The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - max_model_len: SkipValidation[int] = None # type: ignore + max_model_len: int = None """Model context length (prompt and output). If unspecified, will be automatically derived from the model config. @@ -182,7 +182,7 @@ class ModelConfig: - 25.6k -> 25,600""" spec_target_max_model_len: int | None = None """Specify the maximum length for spec decoding draft models.""" - quantization: SkipValidation[QuantizationMethods | None] = None + quantization: str | QuantizationMethods | None = None """Method used to quantize the weights. If `None`, we first check the `quantization_config` attribute in the model config file. If that is `None`, we assume the model weights are not quantized and use `dtype` to @@ -302,6 +302,13 @@ class ModelConfig: skip_mm_profiling: InitVar[bool | None] = None video_pruning_rate: InitVar[float | None] = None + @field_validator("tokenizer", "max_model_len", mode="wrap") + @classmethod + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + if value is None: + return value + return handler(value) + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 944a1e8666f4..155b4c528c8c 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -32,6 +32,14 @@ ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DataParallelBackend = Literal["ray", "mp"] +All2allBackendType = Literal[ + "naive", + "pplx", + "deepep_high_throughput", + "deepep_low_latency", + "allgather_reducescatter", + "flashinfer_all2allv", +] @config @@ -113,17 +121,7 @@ class ParallelConfig: with 4 experts and 2 ranks, rank 0 will have experts [0, 2] and rank 1 will have experts [1, 3]. This strategy can help improve load balancing for grouped expert models with no redundant experts.""" - all2all_backend: ( - Literal[ - "naive", - "pplx", - "deepep_high_throughput", - "deepep_low_latency", - "allgather_reducescatter", - "flashinfer_all2allv", - ] - | None - ) = None + all2all_backend: All2allBackendType | None = None """All2All backend for MoE expert parallel communication. If not set, uses the value from VLLM_ALL2ALL_BACKEND environment variable. Available options: - "naive": Naive all2all implementation using broadcasts diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index d5eb07730923..f4015f43de8b 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -2,10 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import hashlib +from collections.abc import Callable from dataclasses import InitVar, field from typing import Any, Literal -from pydantic import SkipValidation, model_validator +from pydantic import field_validator, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -31,19 +32,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: SkipValidation[int] = None # type: ignore + max_num_batched_tokens: int = None """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: SkipValidation[int] = None # type: ignore + max_num_seqs: int = None """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: SkipValidation[int] = None # type: ignore + max_model_len: int = None """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -79,7 +80,7 @@ class SchedulerConfig: 3. more than one value (e.g. 1 2 128) is provided, then the capture list will follow the provided list.""" - enable_chunked_prefill: SkipValidation[bool] = None # type: ignore + enable_chunked_prefill: bool = None """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -169,6 +170,19 @@ def compute_hash(self) -> str: hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str + @field_validator( + "max_num_batched_tokens", + "max_num_seqs", + "max_model_len", + "enable_chunked_prefill", + mode="wrap", + ) + @classmethod + def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: + if value is None: + return value + return handler(value) + def __post_init__(self, is_encoder_decoder: bool) -> None: if self.max_model_len is None: self.max_model_len = 8192 diff --git a/vllm/config/utils.py b/vllm/config/utils.py index 5e7e7580c5a9..3452315b3e2e 100644 --- a/vllm/config/utils.py +++ b/vllm/config/utils.py @@ -39,7 +39,7 @@ def config(cls: ConfigT) -> ConfigT: return cls -def get_field(cls: ConfigType, name: str) -> Field: +def get_field(cls: ConfigType, name: str) -> Any: """Get the default factory field of a dataclass by name. Used for getting default factory fields in `EngineArgs`.""" if not is_dataclass(cls): diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7ee522ea9f0c..cd31fade025b 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -57,7 +57,7 @@ class VllmConfig: # TODO: use default_factory once default constructing ModelConfig doesn't # try to download a model - model_config: ModelConfig = Field(default=None) + model_config: ModelConfig = None """Model configuration.""" cache_config: CacheConfig = Field(default_factory=CacheConfig) """Cache configuration.""" @@ -77,7 +77,9 @@ class VllmConfig: default_factory=StructuredOutputsConfig ) """Structured outputs configuration.""" - observability_config: ObservabilityConfig | None = None + observability_config: ObservabilityConfig = Field( + default_factory=ObservabilityConfig + ) """Observability configuration.""" quant_config: QuantizationConfig | None = None """Quantization configuration.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 801c30dc9478..2575665e1ad2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -55,6 +55,7 @@ ) from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo from vllm.config.device import Device +from vllm.config.lora import LoRAExtraVocabSize, MaxLoRARanks from vllm.config.model import ( ConvertOption, HfOverrides, @@ -66,8 +67,14 @@ ) from vllm.config.multimodal import MMCacheType, MMEncoderTPMode from vllm.config.observability import DetailedTraceModules -from vllm.config.parallel import DistributedExecutorBackend, ExpertPlacementStrategy +from vllm.config.parallel import ( + All2allBackendType, + DataParallelBackend, + DistributedExecutorBackend, + ExpertPlacementStrategy, +) from vllm.config.scheduler import SchedulerPolicy +from vllm.config.structured_outputs import StructuredOutputsBackend from vllm.config.utils import get_field from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform @@ -214,11 +221,11 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, dict[str, Any]]: default = field.default # Handle pydantic.Field defaults if isinstance(default, FieldInfo): - default = ( - default.default - if default.default_factory is None - else default.default_factory() - ) + if default.default_factory is None: + default = default.default + else: + default_factory = cast(Callable[[], Any], default.default_factory) + default = default_factory() elif field.default_factory is not MISSING: default = field.default_factory() @@ -369,9 +376,9 @@ class EngineArgs: data_parallel_address: str | None = None data_parallel_rpc_port: int | None = None data_parallel_hybrid_lb: bool = False - data_parallel_backend: str = ParallelConfig.data_parallel_backend + data_parallel_backend: DataParallelBackend = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel - all2all_backend: str | None = ParallelConfig.all2all_backend + all2all_backend: All2allBackendType | None = ParallelConfig.all2all_backend enable_dbo: bool = ParallelConfig.enable_dbo dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold @@ -419,7 +426,7 @@ class EngineArgs: hf_token: bool | str | None = ModelConfig.hf_token hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") tokenizer_revision: str | None = ModelConfig.tokenizer_revision - quantization: QuantizationMethods | None = ModelConfig.quantization + quantization: str | QuantizationMethods | None = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce limit_mm_per_prompt: dict[str, int | dict[str, int]] = get_field( @@ -441,16 +448,16 @@ class EngineArgs: mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: str | None = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling - video_pruning_rate: float = MultiModalConfig.video_pruning_rate + video_pruning_rate: float | None = MultiModalConfig.video_pruning_rate # LoRA fields enable_lora: bool = False max_loras: int = LoRAConfig.max_loras - max_lora_rank: int = LoRAConfig.max_lora_rank + max_lora_rank: MaxLoRARanks = LoRAConfig.max_lora_rank default_mm_loras: dict[str, str] | None = LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: int | None = LoRAConfig.max_cpu_loras lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype - lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size + lora_extra_vocab_size: LoRAExtraVocabSize = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override @@ -470,7 +477,7 @@ class EngineArgs: ) reasoning_parser: str = StructuredOutputsConfig.reasoning_parser # Deprecated guided decoding fields - guided_decoding_backend: str | None = None + guided_decoding_backend: StructuredOutputsBackend | None = None guided_decoding_disable_fallback: bool | None = None guided_decoding_disable_any_whitespace: bool | None = None guided_decoding_disable_additional_properties: bool | None = None @@ -506,7 +513,7 @@ class EngineArgs: ModelConfig, "override_generation_config" ) model_impl: str = ModelConfig.model_impl - override_attention_dtype: str = ModelConfig.override_attention_dtype + override_attention_dtype: str | None = ModelConfig.override_attention_dtype calculate_kv_scales: bool = CacheConfig.calculate_kv_scales mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype @@ -515,7 +522,7 @@ class EngineArgs: additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load - pt_load_map_location: str = LoadConfig.pt_load_map_location + pt_load_map_location: str | dict[str, str] = LoadConfig.pt_load_map_location # DEPRECATED enable_multimodal_encoder_data_parallel: bool = False @@ -1113,7 +1120,7 @@ def create_model_config(self) -> ModelConfig: runner=self.runner, convert=self.convert, task=self.task, - tokenizer=self.tokenizer, + tokenizer=self.tokenizer, # type: ignore[arg-type] tokenizer_mode=self.tokenizer_mode, trust_remote_code=self.trust_remote_code, allowed_local_media_path=self.allowed_local_media_path, @@ -1127,7 +1134,7 @@ def create_model_config(self) -> ModelConfig: hf_token=self.hf_token, hf_overrides=self.hf_overrides, tokenizer_revision=self.tokenizer_revision, - max_model_len=self.max_model_len, + max_model_len=self.max_model_len, # type: ignore[arg-type] quantization=self.quantization, enforce_eager=self.enforce_eager, max_logprobs=self.max_logprobs, @@ -1314,7 +1321,7 @@ def create_engine_config( ) cache_config = CacheConfig( - block_size=self.block_size, + block_size=self.block_size, # type: ignore[arg-type] gpu_memory_utilization=self.gpu_memory_utilization, kv_cache_memory_bytes=self.kv_cache_memory_bytes, swap_space=self.swap_space, @@ -1501,9 +1508,9 @@ def create_engine_config( scheduler_config = SchedulerConfig( runner_type=model_config.runner_type, - max_num_batched_tokens=self.max_num_batched_tokens, - max_num_seqs=self.max_num_seqs, - max_model_len=model_config.max_model_len, + max_num_batched_tokens=self.max_num_batched_tokens, # type: ignore[arg-type] + max_num_seqs=self.max_num_seqs, # type: ignore[arg-type] + max_model_len=model_config.max_model_len, # type: ignore[arg-type] cuda_graph_sizes=self.cuda_graph_sizes, num_lookahead_slots=num_lookahead_slots, enable_chunked_prefill=self.enable_chunked_prefill, @@ -1555,17 +1562,15 @@ def create_engine_config( # Forward the deprecated CLI args to the StructuredOutputsConfig so_config = self.structured_outputs_config if self.guided_decoding_backend is not None: - so_config.guided_decoding_backend = self.guided_decoding_backend + so_config.backend = self.guided_decoding_backend if self.guided_decoding_disable_fallback is not None: - so_config.guided_decoding_disable_fallback = ( - self.guided_decoding_disable_fallback - ) + so_config.disable_fallback = self.guided_decoding_disable_fallback if self.guided_decoding_disable_any_whitespace is not None: - so_config.guided_decoding_disable_any_whitespace = ( + so_config.disable_any_whitespace = ( self.guided_decoding_disable_any_whitespace ) if self.guided_decoding_disable_additional_properties is not None: - so_config.guided_decoding_disable_additional_properties = ( + so_config.disable_additional_properties = ( self.guided_decoding_disable_additional_properties ) @@ -1606,6 +1611,13 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ) return False + # No Mamba or Encoder-Decoder so far. + if not getattr(model_config, "is_v1_compatible", True): + _raise_or_fallback( + feature_name=str(model_config.architectures), recommend_to_remove=False + ) + return False + # No Concurrent Partial Prefills so far. if ( self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills @@ -1690,7 +1702,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: return True def _set_default_args( - self, usage_context: UsageContext, model_config: ModelConfig + self, usage_context: UsageContext | None, model_config: ModelConfig ) -> None: """Set Default Arguments for V1 Engine.""" @@ -1718,6 +1730,7 @@ def _set_default_args( else: self.enable_prefix_caching = True else: + assert model_config.pooler_config is not None pooling_type = model_config.pooler_config.pooling_type is_causal = getattr(model_config.hf_config, "is_causal", True) incremental_prefill_supported = ( diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 64f1961dd849..a01b11e13c7d 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -513,8 +513,8 @@ def log(self, stats: Stats) -> None: def _reset(self, stats, prompt_throughput, generation_throughput) -> None: # Reset tracked stats for next interval. - self.num_prompt_tokens = [] - self.num_generation_tokens = [] + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] self.last_local_log = stats.now self.last_prompt_throughput = prompt_throughput self.last_generation_throughput = generation_throughput @@ -660,9 +660,9 @@ def log(self, stats: Stats): # Log locally every local_interval seconds. if local_interval_elapsed(stats.now, self.last_local_log, self.local_interval): # Reset tracked stats for next interval. - self.num_prompt_tokens = [] - self.num_generation_tokens = [] - self.last_local_log = stats.now + self.num_prompt_tokens: list[int] = [] + self.num_generation_tokens: list[int] = [] + self.last_local_log: float = stats.now def info(self, type: str, obj: SupportsMetricsInfo) -> None: # Info type metrics are syntactic sugar for a gauge permanently set to 1 diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 748355309521..2035c39cd9dd 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -422,6 +422,7 @@ def _parse_audio_data( if self._is_embeddings(data): return AudioEmbeddingItems(data) + data_items: list[AudioItem] if ( is_list_of(data, float) or isinstance(data, (np.ndarray, torch.Tensor)) @@ -485,6 +486,7 @@ def _parse_video_data( if self._is_embeddings(data): return VideoEmbeddingItems(data) + data_items: list[VideoItem] if ( is_list_of(data, PILImage.Image) or isinstance(data, (np.ndarray, torch.Tensor)) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 99a9225cb6a4..5835218e4d24 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -870,7 +870,7 @@ def find_nccl_include_paths() -> list[str] | None: import importlib.util spec = importlib.util.find_spec("nvidia.nccl") - if spec and getattr(spec, "submodule_search_locations", None): + if spec is not None and spec.submodule_search_locations is not None: for loc in spec.submodule_search_locations: inc_dir = os.path.join(loc, "include") if os.path.exists(os.path.join(inc_dir, "nccl.h")): diff --git a/vllm/utils/asyncio.py b/vllm/utils/asyncio.py index b6c24e1ceeee..0f9bed81ce82 100644 --- a/vllm/utils/asyncio.py +++ b/vllm/utils/asyncio.py @@ -274,7 +274,7 @@ async def merge_async_iterators( loop = asyncio.get_running_loop() - awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)} + awaits = {loop.create_task(anext(it)): (i, it) for i, it in enumerate(iterators)} # type: ignore[var-annotated, arg-type] try: while awaits: done, _ = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED) @@ -283,7 +283,7 @@ async def merge_async_iterators( try: item = await d i, it = pair - awaits[loop.create_task(anext(it))] = pair + awaits[loop.create_task(anext(it))] = pair # type: ignore[arg-type] yield i, item except StopAsyncIteration: pass diff --git a/vllm/utils/jsontree.py b/vllm/utils/jsontree.py index cde9aa6ff901..ddd1e1a7ec6a 100644 --- a/vllm/utils/jsontree.py +++ b/vllm/utils/jsontree.py @@ -80,7 +80,7 @@ def json_map_leaves( ) -> JSONTree[_U]: ... -def json_map_leaves( +def json_map_leaves( # type: ignore[misc] func: Callable[[_T], _U], value: "BatchedTensorInputs" | _JSONTree[_T], ) -> "BatchedTensorInputs" | _JSONTree[_U]: @@ -91,7 +91,7 @@ def json_map_leaves( for k, v in value.items() } elif isinstance(value, list): - return [json_map_leaves(func, v) for v in value] + return [json_map_leaves(func, v) for v in value] # type: ignore[return-value] elif isinstance(value, tuple): return tuple(json_map_leaves(func, v) for v in value) else: @@ -142,7 +142,7 @@ def json_reduce_leaves( def json_reduce_leaves( func: Callable[..., _T | _U], value: _JSONTree[_T], - initial: _U = cast(_U, ...), # noqa: B008 + initial: _U = cast(_U, ...), # type: ignore # noqa /, ) -> _T | _U: """ @@ -151,7 +151,7 @@ def json_reduce_leaves( sequence to a single value. """ if initial is ...: - return reduce(func, json_iter_leaves(value)) # type: ignore[arg-type] + return reduce(func, json_iter_leaves(value)) # type: ignore return reduce( func, # type: ignore[arg-type]