diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index de29028da618..7fdfdb37a0c0 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -26,6 +26,7 @@ FILES = [ "vllm/*.py", "vllm/assets", + "vllm/distributed", "vllm/entrypoints", "vllm/inputs", "vllm/logging_utils", @@ -42,7 +43,6 @@ "tests", "vllm/attention", "vllm/compilation", - "vllm/distributed", "vllm/engine", "vllm/executor", "vllm/inputs", diff --git a/vllm/config/kv_transfer.py b/vllm/config/kv_transfer.py index ba17d9a65f28..d7a9d5808319 100644 --- a/vllm/config/kv_transfer.py +++ b/vllm/config/kv_transfer.py @@ -27,7 +27,7 @@ class KVTransferConfig: engine_id: str | None = None """The engine id for KV transfers.""" - kv_buffer_device: str | None = "cuda" + kv_buffer_device: str = "cuda" """The device used by kv connector to buffer the KV cache. Choices are 'cuda' and 'cpu'.""" diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 48673202c6cc..fae48cbe3374 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -15,9 +15,11 @@ from .base_device_communicator import All2AllManagerBase, Cache if has_flashinfer_all2all(): - from flashinfer.comm import Mapping - from flashinfer.comm.mnnvl import MnnvlConfig - from flashinfer.comm.trtllm_alltoall import MnnvlMoe + from flashinfer.comm import Mapping # type: ignore[import-not-found] + from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found] + from flashinfer.comm.trtllm_alltoall import ( + MnnvlMoe, # type: ignore[import-not-found] + ) logger = init_logger(__name__) @@ -65,6 +67,7 @@ def dispatch( ) -> tuple[torch.Tensor, torch.Tensor]: sp_size = self.tp_group.world_size if is_sequence_parallel else 1 dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) hidden_states = self.naive_multicast( @@ -81,6 +84,7 @@ def combine( ep_rank = self.rank if is_sequence_parallel else self.dp_rank dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None sp_size = self.tp_group.world_size if is_sequence_parallel else 1 cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size) @@ -113,7 +117,10 @@ def dispatch( """ Gather hidden_states and router_logits from all dp ranks. """ - sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + assert sizes is not None dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] @@ -130,7 +137,10 @@ def combine( """ Reduce-scatter hidden_states across all dp ranks. """ - sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() + dp_metadata = get_forward_context().dp_metadata + assert dp_metadata is not None + sizes = dp_metadata.get_chunk_sizes_across_dp_rank() + assert sizes is not None dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes) @@ -155,7 +165,7 @@ def __init__(self, cpu_group): if self.internode: # inter-node communication needs nvshmem, # intra-node communication uses p2p mapping directly - from pplx_kernels.nvshmem import ( + from pplx_kernels.nvshmem import ( # type: ignore[import-not-found] nvshmem_alloc_empty_unique_id, nvshmem_get_unique_id, nvshmem_init, @@ -182,7 +192,7 @@ def __init__(self, cpu_group): self.handle_cache = Cache() def get_handle(self, kwargs): - import pplx_kernels as pplx + import pplx_kernels as pplx # type: ignore[import-not-found] return self.handle_cache.get_or_create( kwargs, @@ -208,7 +218,9 @@ def destroy(self): handle.destroy() if self.internode: - from pplx_kernels.nvshmem import nvshmem_finalize + from pplx_kernels.nvshmem import ( + nvshmem_finalize, # type: ignore[import-not-found] + ) logger.debug("PPLX NVSHMEM finalize") nvshmem_finalize() @@ -288,7 +300,7 @@ def get_handle(self, kwargs): "args are computed in the Manager itself." ) - import deep_ep + import deep_ep # type: ignore[import-not-found] buffer_kwargs = self._make_all2all_kwargs() logger.debug("DeepEP all2all args %s", buffer_kwargs) @@ -298,7 +310,7 @@ def get_handle(self, kwargs): return handle def set_num_sms(self, num_sms: int): - import deep_ep + import deep_ep # type: ignore[import-not-found] # Right now the buffers are sized for only what the kernels were # created with. So we can only reduce the number of SMS used @@ -332,7 +344,7 @@ def _make_all2all_kwargs( num_global_experts: Number of experts in the model. num_local_experts: Number of experts in an EP rank. """ - import deep_ep + import deep_ep # type: ignore[import-not-found] # Defaults for internode and intranode are taken from DeepEP tests. num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 @@ -358,7 +370,7 @@ def get_handle(self, kwargs): The kwargs for DeepEPLLAll2AllManager is dictated by _make_all2all_kwargs. """ - import deep_ep + import deep_ep # type: ignore[import-not-found] buffer_kwargs = self._make_all2all_kwargs(**kwargs) logger.debug("DeepEP all2all args %s", buffer_kwargs) diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index 171e93ba53ee..4bc737494cb5 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import contextmanager +from typing import cast import torch import torch.distributed as dist @@ -118,15 +119,18 @@ def __init__( # now `device` is a `torch.device` object assert isinstance(device, torch.device) self.device = device - device_capability = current_platform.get_device_capability().as_version_str() + device_capability = current_platform.get_device_capability() if ( current_platform.is_cuda() and symm_mem_enabled - and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES + and device_capability is not None ): - max_size = min( - CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size - ) + device_capability_str = device_capability.as_version_str() + if device_capability_str in CUSTOM_ALL_REDUCE_MAX_SIZES: + max_size = min( + CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability_str][world_size], + max_size, + ) cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) @@ -213,6 +217,7 @@ def register_graph_buffers(self): # We cannot directly use `dist.all_gather_object` here # because it is incompatible with `gloo` backend under inference mode. # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data: list[list[list[int] | None]] all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] all_data[self.rank] = [handle, offset] ranks = sorted(dist.get_process_group_ranks(group=self.group)) @@ -221,8 +226,8 @@ def register_graph_buffers(self): all_data[i], src=rank, group=self.group, device="cpu" ) # Unpack list of tuples to tuple of lists. - handles = [d[0] for d in all_data] # type: ignore - offsets = [d[1] for d in all_data] # type: ignore + handles = cast(list[list[int]], [d[0] for d in all_data]) + offsets = cast(list[list[int]], [d[1] for d in all_data]) ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): diff --git a/vllm/distributed/device_communicators/symm_mem.py b/vllm/distributed/device_communicators/symm_mem.py index aeea9b777b25..96f8e7b35535 100644 --- a/vllm/distributed/device_communicators/symm_mem.py +++ b/vllm/distributed/device_communicators/symm_mem.py @@ -52,9 +52,14 @@ def __init__( self.device = device self.group = group self.world_size = dist.get_world_size(self.group) - self.device_capability = ( - current_platform.get_device_capability().as_version_str() - ) + capability = current_platform.get_device_capability() + if capability is None: + logger.warning( + "SymmMemCommunicator: device capability is unknown, " + "communicator is not available." + ) + return + self.device_capability = capability.as_version_str() if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES: logger.warning( "SymmMemCommunicator: Device capability %s not supported, " diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index aaf43842cf7c..ff806962028c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -3,7 +3,7 @@ import importlib from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.base import ( @@ -48,6 +48,8 @@ def create_connector( ) kv_transfer_config = config.kv_transfer_config + if kv_transfer_config is None: + raise ValueError("kv_transfer_config must be set to create a connector") connector_cls = cls.get_connector_class(kv_transfer_config) logger.info( "Creating v1 connector with name: %s and engine_id: %s", @@ -70,6 +72,8 @@ def get_connector_class( ) -> type[KVConnectorBaseType]: """Get the connector class by name.""" connector_name = kv_transfer_config.kv_connector + if connector_name is None: + raise ValueError("Connector name is not set in KVTransferConfig") if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() else: @@ -77,7 +81,13 @@ def get_connector_class( if connector_module_path is None: raise ValueError(f"Unsupported connector type: {connector_name}") connector_module = importlib.import_module(connector_module_path) - connector_cls = getattr(connector_module, connector_name) + try: + connector_cls = getattr(connector_module, connector_name) + except AttributeError as e: + raise AttributeError( + f"Class {connector_name} not found in {connector_module_path}" + ) from e + connector_cls = cast(type[KVConnectorBaseType], connector_cls) return connector_cls diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index b7e9daaa5b59..0fe678b9c615 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -151,21 +151,21 @@ def update_finished_set( aggregated_kv_connector_stats = None invalid_block_ids = set[int]() for model_runner_output in outputs: - output = model_runner_output.kv_connector_output - if not output: + kv_output = model_runner_output.kv_connector_output + if not kv_output: continue update_finished_set( - output.finished_sending, self._send_remaining_count, finished_sending + kv_output.finished_sending, self._send_remaining_count, finished_sending ) update_finished_set( - output.finished_recving, self._recv_remaining_count, finished_recving + kv_output.finished_recving, self._recv_remaining_count, finished_recving ) # Aggregate kv_connector_stats from all workers. if aggregated_kv_connector_stats is None: # Use the first worker's kv_connector_stats as accumulator. - aggregated_kv_connector_stats = output.kv_connector_stats - elif kv_connector_stats := output.kv_connector_stats: + aggregated_kv_connector_stats = kv_output.kv_connector_stats + elif kv_connector_stats := kv_output.kv_connector_stats: if aggregated_kv_connector_stats is None: aggregated_kv_connector_stats = kv_connector_stats else: @@ -176,7 +176,7 @@ def update_finished_set( aggregated_kv_connector_stats.aggregate(kv_connector_stats) ) - invalid_block_ids |= output.invalid_block_ids + invalid_block_ids |= kv_output.invalid_block_ids # select output of the worker specified by output_rank output = outputs[output_rank] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index c51e26ce2f44..ab5d2ecdc71b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -95,6 +95,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): ) self._connector_metadata: KVConnectorMetadata | None = None self._vllm_config = vllm_config + if vllm_config.kv_transfer_config is not None: + self._kv_transfer_config = vllm_config.kv_transfer_config + else: + raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1") self._role = role @property diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 25625762f447..845ce320837d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -86,13 +86,11 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] - ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( - "connectors" - ) + ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors") assert ktcs is not None for ktc in ktcs: temp_config = copy.copy(vllm_config) - engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id) + engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id) temp_config.kv_transfer_config = KVTransferConfig( **ktc, engine_id=engine_id ) @@ -296,6 +294,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: str: the required KV cache layout. e.g. HND, or NHD. None if the connector does not require a specific layout. """ + assert vllm_config.kv_transfer_config is not None ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( "connectors" ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a8730bf78987..a7054daa8d34 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -291,6 +291,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): + vllm_config.parallel_config.data_parallel_rank * vllm_config.parallel_config.tensor_parallel_size ) + assert vllm_config.kv_transfer_config is not None self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu" logger.info("Initializing NIXL Scheduler %s", engine_id) @@ -334,7 +335,8 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. - count = len(request.prompt_token_ids) - num_computed_tokens + token_ids = request.prompt_token_ids or [] + count = len(token_ids) - num_computed_tokens if count > 0: return count, True @@ -515,6 +517,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size + if vllm_config.kv_transfer_config is None: + raise ValueError("kv_transfer_config must be set for NixlConnector") + self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config( "backends", ["UCX"] ) @@ -571,17 +576,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): self.use_host_buffer = self.kv_buffer_device == "cpu" # support for oot platform which can't register nixl memory # type based on kv_buffer_device - self.nixl_memory_type = current_platform.get_nixl_memory_type() - if self.nixl_memory_type is None: + nixl_memory_type = current_platform.get_nixl_memory_type() + if nixl_memory_type is None: if self.kv_buffer_device == "cuda": - self.nixl_memory_type = "VRAM" + nixl_memory_type = "VRAM" elif self.kv_buffer_device == "cpu": - self.nixl_memory_type = "DRAM" - if self.nixl_memory_type is None: + nixl_memory_type = "DRAM" + if nixl_memory_type is None: raise RuntimeError( f"{self.device_type} with {self.kv_buffer_device} kv_buffer " "is not supported." ) + self.nixl_memory_type = nixl_memory_type # Note: host xfer buffer ops when use_host_buffer is True self.copy_blocks: CopyBlocksOp | None = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index c9fa9efeeb6f..e47cde2614fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -75,9 +75,8 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Any] = {} - self.config = vllm_config.kv_transfer_config - self.is_producer = self.config.is_kv_producer - self.chunked_prefill: dict[str, Any] = {} + self.is_producer = self._kv_transfer_config.is_kv_producer + self.chunked_prefill: dict[str, tuple[list[int], list[int] | None]] = {} self._rank = get_world_group().rank if role == KVConnectorRole.WORKER else 0 self._local_rank = ( @@ -87,7 +86,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.p2p_nccl_engine = ( P2pNcclEngine( local_rank=self._local_rank, - config=self.config, + config=self._kv_transfer_config, hostname="", port_offset=self._rank, ) @@ -346,7 +345,8 @@ def get_num_new_matched_tokens( if self.is_producer: return 0, False - num_external_tokens = len(request.prompt_token_ids) - 1 - num_computed_tokens + prompt_token_ids = request.prompt_token_ids or [] + num_external_tokens = len(prompt_token_ids) - 1 - num_computed_tokens if num_external_tokens < 0: num_external_tokens = 0 @@ -387,7 +387,7 @@ def build_connector_meta( ] num_tokens = num_scheduled_tokens + new_req.num_computed_tokens # the request's prompt is chunked prefill - if num_tokens < len(new_req.prompt_token_ids): + if num_tokens < len(new_req.prompt_token_ids or []): # 'CachedRequestData' has no attribute 'prompt_token_ids' self.chunked_prefill[new_req.req_id] = ( new_req.block_ids[0], @@ -397,7 +397,7 @@ def build_connector_meta( # the request's prompt is not chunked prefill meta.add_request( request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, + token_ids=new_req.prompt_token_ids or [], block_ids=new_req.block_ids[0], block_size=self._block_size, ) @@ -405,7 +405,7 @@ def build_connector_meta( if new_req.req_id in self._requests_need_load: meta.add_request( request_id=new_req.req_id, - token_ids=new_req.prompt_token_ids, + token_ids=new_req.prompt_token_ids or [], block_ids=new_req.block_ids[0], block_size=self._block_size, ) @@ -421,10 +421,12 @@ def build_connector_meta( num_scheduled_tokens = (scheduler_output.num_scheduled_tokens)[req_id] num_tokens = num_scheduled_tokens + num_computed_tokens assert req_id in self.chunked_prefill + assert new_block_ids is not None block_ids = new_block_ids[0] if not resumed_from_preemption: block_ids = self.chunked_prefill[req_id][0] + block_ids prompt_token_ids = self.chunked_prefill[req_id][1] + assert prompt_token_ids is not None # the request's prompt is chunked prefill again if num_tokens < len(prompt_token_ids): self.chunked_prefill[req_id] = (block_ids, prompt_token_ids) @@ -450,6 +452,7 @@ def build_connector_meta( # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. + assert new_block_ids is not None block_ids = new_block_ids[0] meta.add_request( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index a4beebecbe22..d0cd4b07c51d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -90,11 +90,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): super().__init__(vllm_config=vllm_config, role=role) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Request] = {} - transfer_config = vllm_config.kv_transfer_config - self._storage_path = transfer_config.get_from_extra_config( + self._storage_path = self._kv_transfer_config.get_from_extra_config( "shared_storage_path", "/tmp" ) - logger.info(vllm_config.kv_transfer_config) + logger.info(self._kv_transfer_config) logger.info("Shared storage path is %s", self._storage_path) def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: @@ -277,9 +276,8 @@ def get_num_new_matched_tokens( # Now, first num_tokens_to_check tokens are hit, we need to prepare # the metadata for the worker connector to correctly load the KV - num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size - ) + token_ids = request.prompt_token_ids or [] + num_tokens_to_check = align_to_block_size(len(token_ids) - 1, self._block_size) return num_tokens_to_check - num_computed_tokens, False @@ -311,13 +309,15 @@ def build_connector_meta( total_need_load = 0 for new_req in scheduler_output.scheduled_new_reqs: + token_ids = new_req.prompt_token_ids or [] + mm_hashes = [f.identifier for f in new_req.mm_features] if new_req.req_id in self._requests_need_load: meta.add_request( - token_ids=new_req.prompt_token_ids, + token_ids=token_ids, block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=False, - mm_hashes=[f.identifier for f in new_req.mm_features], + mm_hashes=mm_hashes, ) total_need_load += 1 else: @@ -325,13 +325,13 @@ def build_connector_meta( # but a single request can have both store and load. # NOTE(rob): for this debug implementation, we only cache # the original prompt tokens. - if not self._found_match_for_request(new_req): + if not self._found_match_for_prompt(token_ids, mm_hashes): meta.add_request( - token_ids=new_req.prompt_token_ids, + token_ids=token_ids, block_ids=new_req.block_ids[0], block_size=self._block_size, is_store=True, - mm_hashes=[f.identifier for f in new_req.mm_features], + mm_hashes=mm_hashes, ) cached_reqs = scheduler_output.scheduled_cached_reqs @@ -355,6 +355,7 @@ def build_connector_meta( # NOTE(rob): For resumed req, new_block_ids is all # of the block_ids for the request. + assert new_block_ids is not None block_ids = new_block_ids[0] meta.add_request( @@ -379,12 +380,22 @@ def _found_match_for_request( request: "Request", ) -> bool: """Check if the cache is hit for the request.""" + return self._found_match_for_prompt( + list(request.prompt_token_ids or []), + [f.identifier for f in request.mm_features], + ) + + def _found_match_for_prompt( + self, + prompt_token_ids: list[int], + mm_hashes: list[str], + ) -> bool: num_tokens_to_check = align_to_block_size( - len(request.prompt_token_ids) - 1, self._block_size + len(prompt_token_ids) - 1, self._block_size ) foldername = self._generate_foldername_debug( - torch.tensor(request.prompt_token_ids)[:num_tokens_to_check], - [f.identifier for f in request.mm_features], + torch.tensor(prompt_token_ids)[:num_tokens_to_check], + mm_hashes, create_folder=False, ) return os.path.exists(foldername) diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 8203c57e2dc6..d28ce20b609d 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -236,6 +236,7 @@ def __init__( self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank + assert self.kv_rank is not None if device is None: self.device = self._select_device(self.config.kv_buffer_device) else: diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 4682eeee2768..526c5cd1d527 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -53,6 +53,7 @@ def __init__( self.config = config self.local_rank = local_rank self.kv_rank = self.config.kv_rank + assert self.kv_rank is not None self.kv_parallel_size = self.config.kv_parallel_size if device is None: self.device = self._select_device(self.config.kv_buffer_device)