Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 130 additions & 7 deletions vllm/distributed/device_communicators/xpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,146 @@ def __init__(
):
super().__init__(cpu_group, device, device_group, unique_name)
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on XPU. "
"Falling back to `naive` all2all manager for XPU.",
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager

self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")

elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager

self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AgRs manager on XPU device.")

else: # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on XPU. "
"Falling back to AgRs manager for XPU, "
"which is the Default backend",
self.all2all_backend, # type: ignore[has-type]
)
from .all2all import AgRsAll2AllManager

self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AgRs manager on XPU device.")

def all_reduce(self, input_) -> torch.Tensor:
dist.all_reduce(input_, group=self.device_group)
return input_

def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()

assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size,) + input_tensor.shape[1:]

output = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)

dist.reduce_scatter_tensor(output, input_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to dist.reduce_scatter_tensor is missing the group argument. This will cause it to use the default process group, which is likely incorrect for tensor parallelism. You should pass group=self.device_group to ensure communication occurs within the intended process group.

Suggested change
dist.reduce_scatter_tensor(output, input_tensor)
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)


# Reshape before returning
return output.movedim(0, dim).contiguous()

def reduce_scatterv(
self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
):
world_size = self.world_size

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()

# Note: This will produce an incorrect answer if we don't make
# the input_tensor contiguous. Possible bug in reduce_scatter_tensor?
input_tensor = input_.movedim(0, dim).contiguous()

if sizes is not None:
assert len(sizes) == world_size
assert input_tensor.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_tensor.shape[0] % world_size == 0
chunk_size = input_tensor.shape[0] // world_size
output_shape = (chunk_size,) + input_tensor.shape[1:]

output = torch.empty(
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits)
else:
dist.reduce_scatter_tensor(output, input_tensor)
Comment on lines +103 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dist.reduce_scatter and dist.reduce_scatter_tensor calls are missing the group argument. This will cause them to use the default process group instead of the specific device_group for this communicator, which can lead to incorrect behavior.

Suggested change
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits)
else:
dist.reduce_scatter_tensor(output, input_tensor)
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits, group=self.device_group)
else:
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)

# Reshape before returning
return output.movedim(0, dim).contiguous()

def all_gatherv(
self,
input_: torch.Tensor | list[torch.Tensor],
dim: int = 0,
sizes: list[int] | None = None,
):
if dim != 0:
raise NotImplementedError("only dim 0 all-gatherv is supported")
world_size = self.world_size

# 'sizes' is not needed if all inputs in the same group have the same
# shape
if sizes is not None and all(s == sizes[0] for s in sizes):
sizes = None

def _all_gather_single(input_: torch.Tensor, sizes: list[int] | None = None):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[dim] == sizes[self.rank_in_group], (
f"{input_.shape[dim]} != {sizes[self.rank_in_group]}"
)
output_size = (sum(sizes),) + input_size[1:]
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)

if sizes is not None:
all_gather_list = []
for size in sizes:
all_gather_list.append(
torch.empty(
(size,) + input_.shape[1:],
dtype=input_.dtype,
device=input_.device,
)
)
dist.all_gather(all_gather_list, input_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dist.all_gather call is missing the group argument. This will cause it to use the default process group, which is likely incorrect. You should pass group=self.device_group to ensure communication occurs within the intended process group.

Suggested change
dist.all_gather(all_gather_list, input_)
dist.all_gather(all_gather_list, input_, group=self.device_group)

output_tensor = torch.cat(all_gather_list, dim=0)
else:
dist.all_gather([output_tensor], input_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of dist.all_gather([output_tensor], input_) is incorrect. dist.all_gather expects a list of tensors (one for each rank) to gather data into, but it's being passed a list with a single large tensor. This will fail for world_size > 1. You should use dist.all_gather_into_tensor instead. Additionally, the group argument is missing.

Suggested change
dist.all_gather([output_tensor], input_)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)

return output_tensor

if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)

output_list = []
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
return output_list
Comment on lines +161 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When input_ is a list of tensors, the communication calls inside _all_gather_single are performed sequentially for each tensor. This is inefficient. For better performance, these collective operations should be batched. For the common case where sizes is None, you can achieve this by using dist.all_gather_into_tensor with async_op=True for each tensor and then waiting for all the returned handles to complete. The case where sizes is not None is harder to batch with the standard torch.distributed API as dist.all_gather does not support asynchronous operations.


def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> torch.Tensor | None:
Expand Down