Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/pre_commit/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/distributed",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging_utils",
Expand All @@ -43,7 +44,6 @@
"tests",
"vllm/attention",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/executor",
"vllm/inputs",
Expand Down
36 changes: 24 additions & 12 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from .base_device_communicator import All2AllManagerBase, Cache

if has_flashinfer_all2all():
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
from flashinfer.comm import Mapping # type: ignore[import-not-found]
from flashinfer.comm.mnnvl import MnnvlConfig # type: ignore[import-not-found]
from flashinfer.comm.trtllm_alltoall import (
MnnvlMoe, # type: ignore[import-not-found]
)

logger = init_logger(__name__)

Expand Down Expand Up @@ -65,6 +67,7 @@ def dispatch(
) -> tuple[torch.Tensor, torch.Tensor]:
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

hidden_states = self.naive_multicast(
Expand All @@ -81,6 +84,7 @@ def combine(
ep_rank = self.rank if is_sequence_parallel else self.dp_rank

dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
cu_tokens_across_sp_cpu = dp_metadata.cu_tokens_across_sp(sp_size)

Expand Down Expand Up @@ -113,7 +117,10 @@ def dispatch(
"""
Gather hidden_states and router_logits from all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
Expand All @@ -130,7 +137,10 @@ def combine(
"""
Reduce-scatter hidden_states across all dp ranks.
"""
sizes = get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None

dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
hidden_states = dist_group.reduce_scatterv(hidden_states, dim=0, sizes=sizes)
Expand All @@ -155,7 +165,7 @@ def __init__(self, cpu_group):
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly
from pplx_kernels.nvshmem import (
from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
nvshmem_alloc_empty_unique_id,
nvshmem_get_unique_id,
nvshmem_init,
Expand All @@ -182,7 +192,7 @@ def __init__(self, cpu_group):
self.handle_cache = Cache()

def get_handle(self, kwargs):
import pplx_kernels as pplx
import pplx_kernels as pplx # type: ignore[import-not-found]

return self.handle_cache.get_or_create(
kwargs,
Expand All @@ -208,7 +218,9 @@ def destroy(self):
handle.destroy()

if self.internode:
from pplx_kernels.nvshmem import nvshmem_finalize
from pplx_kernels.nvshmem import (
nvshmem_finalize, # type: ignore[import-not-found]
)

logger.debug("PPLX NVSHMEM finalize")
nvshmem_finalize()
Expand Down Expand Up @@ -288,7 +300,7 @@ def get_handle(self, kwargs):
"args are computed in the Manager itself."
)

import deep_ep
import deep_ep # type: ignore[import-not-found]

buffer_kwargs = self._make_all2all_kwargs()
logger.debug("DeepEP all2all args %s", buffer_kwargs)
Expand All @@ -298,7 +310,7 @@ def get_handle(self, kwargs):
return handle

def set_num_sms(self, num_sms: int):
import deep_ep
import deep_ep # type: ignore[import-not-found]

# Right now the buffers are sized for only what the kernels were
# created with. So we can only reduce the number of SMS used
Expand Down Expand Up @@ -332,7 +344,7 @@ def _make_all2all_kwargs(
num_global_experts: Number of experts in the model.
num_local_experts: Number of experts in an EP rank.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]

# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
Expand All @@ -358,7 +370,7 @@ def get_handle(self, kwargs):
The kwargs for DeepEPLLAll2AllManager is dictated by
_make_all2all_kwargs.
"""
import deep_ep
import deep_ep # type: ignore[import-not-found]

buffer_kwargs = self._make_all2all_kwargs(**kwargs)
logger.debug("DeepEP all2all args %s", buffer_kwargs)
Expand Down
27 changes: 18 additions & 9 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,18 @@ def __init__(
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
device_capability = current_platform.get_device_capability().as_version_str()
device_capability = current_platform.get_device_capability()
if (
current_platform.is_cuda()
and symm_mem_enabled
and device_capability in CUSTOM_ALL_REDUCE_MAX_SIZES
and device_capability is not None
):
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[device_capability][world_size], max_size
)
cap_str = device_capability.as_version_str()
if cap_str in CUSTOM_ALL_REDUCE_MAX_SIZES:
max_size = min(
CUSTOM_ALL_REDUCE_MAX_SIZES[cap_str][world_size],
max_size,
)
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
Expand Down Expand Up @@ -214,17 +217,23 @@ def register_graph_buffers(self):
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
all_data: list[list[object]] = [
[None, None] for _ in range(dist.get_world_size(group=self.group))
]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(
all_data[i], src=rank, group=self.group, device="cpu"
)
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
handles = [d[0] for d in all_data]
offsets = [d[1] for d in all_data]
ops.register_graph_buffers(
self._ptr,
handles, # type: ignore[arg-type]
offsets, # type: ignore[arg-type]
)

def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
Expand Down
11 changes: 8 additions & 3 deletions vllm/distributed/device_communicators/symm_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ def __init__(
self.device = device
self.group = group
self.world_size = dist.get_world_size(self.group)
self.device_capability = (
current_platform.get_device_capability().as_version_str()
)
capability = current_platform.get_device_capability()
if capability is None:
logger.warning(
"SymmMemCommunicator: device capability is unknown, "
"communicator is not available."
)
return
self.device_capability = capability.as_version_str()
if self.device_capability not in SYMM_MEM_ALL_REDUCE_MAX_SIZES:
logger.warning(
"SymmMemCommunicator: Device capability %s not supported, "
Expand Down
12 changes: 10 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import importlib
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, cast

import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import (
Expand Down Expand Up @@ -47,6 +47,7 @@ def create_connector(
)

kv_transfer_config = config.kv_transfer_config
assert kv_transfer_config is not None
connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info(
"Creating v1 connector with name: %s and engine_id: %s",
Expand All @@ -69,14 +70,21 @@ def get_connector_class(
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
connector_cls_any = getattr(connector_module, connector_name, None)
if connector_cls_any is None:
raise AttributeError(
f"Class {connector_name} not found in {connector_module_path}"
)
connector_cls = cast(type[KVConnectorBaseType], connector_cls_any)
return connector_cls


Expand Down
14 changes: 7 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,21 @@ def update_finished_set(
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
output = model_runner_output.kv_connector_output
if not output:
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
update_finished_set(
output.finished_sending, self._send_remaining_count, finished_sending
kv_output.finished_sending, self._send_remaining_count, finished_sending
)
update_finished_set(
output.finished_recving, self._recv_remaining_count, finished_recving
kv_output.finished_recving, self._recv_remaining_count, finished_recving
)

# Aggregate kv_connector_stats from all workers.
if aggregated_kv_connector_stats is None:
# Use the first worker's kv_connector_stats as accumulator.
aggregated_kv_connector_stats = output.kv_connector_stats
elif kv_connector_stats := output.kv_connector_stats:
aggregated_kv_connector_stats = kv_output.kv_connector_stats
elif kv_connector_stats := kv_output.kv_connector_stats:
if aggregated_kv_connector_stats is None:
aggregated_kv_connector_stats = kv_connector_stats
else:
Expand All @@ -176,7 +176,7 @@ def update_finished_set(
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)

invalid_block_ids |= output.invalid_block_ids
invalid_block_ids |= kv_output.invalid_block_ids

# select output of the worker specified by output_rank
output = outputs[output_rank]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
Expand Down Expand Up @@ -296,6 +297,7 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> Optional[str]
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
Expand Down
21 changes: 14 additions & 7 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
+ vllm_config.parallel_config.data_parallel_rank
* vllm_config.parallel_config.tensor_parallel_size
)
assert vllm_config.kv_transfer_config is not None
self.use_host_buffer = vllm_config.kv_transfer_config.kv_buffer_device == "cpu"
logger.info("Initializing NIXL Scheduler %s", engine_id)

Expand Down Expand Up @@ -334,7 +335,8 @@ def get_num_new_matched_tokens(

if params is not None and params.get("do_remote_prefill"):
# Remote prefill: get all prompt blocks from remote.
count = len(request.prompt_token_ids) - num_computed_tokens
token_ids = request.prompt_token_ids or []
count = len(token_ids) - num_computed_tokens
if count > 0:
return count, True

Expand Down Expand Up @@ -515,6 +517,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size

assert vllm_config.kv_transfer_config is not None
self.nixl_backends = vllm_config.kv_transfer_config.get_from_extra_config(
"backends", ["UCX"]
)
Expand Down Expand Up @@ -555,7 +558,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):

# KV Caches and nixl tracking data.
self.device_type = current_platform.device_type
self.kv_buffer_device: str = vllm_config.kv_transfer_config.kv_buffer_device
assert vllm_config.kv_transfer_config is not None
self.kv_buffer_device: str = (
vllm_config.kv_transfer_config.kv_buffer_device or "cuda"
)
if self.device_type not in _NIXL_SUPPORTED_DEVICE:
raise RuntimeError(f"{self.device_type} is not supported.")
elif self.kv_buffer_device not in _NIXL_SUPPORTED_DEVICE[self.device_type]:
Expand All @@ -571,17 +577,18 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.use_host_buffer = self.kv_buffer_device == "cpu"
# support for oot platform which can't register nixl memory
# type based on kv_buffer_device
self.nixl_memory_type = current_platform.get_nixl_memory_type()
if self.nixl_memory_type is None:
nixl_memory_type = current_platform.get_nixl_memory_type()
if nixl_memory_type is None:
if self.kv_buffer_device == "cuda":
self.nixl_memory_type = "VRAM"
nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
self.nixl_memory_type = "DRAM"
if self.nixl_memory_type is None:
nixl_memory_type = "DRAM"
if nixl_memory_type is None:
raise RuntimeError(
f"{self.device_type} with {self.kv_buffer_device} kv_buffer "
"is not supported."
)
self.nixl_memory_type = nixl_memory_type

# Note: host xfer buffer ops when use_host_buffer is True
self.copy_blocks: Optional[CopyBlocksOp] = None
Expand Down
Loading