[XPU]Support AgRsAll2AllManager on XPU device#32654
[XPU]Support AgRsAll2AllManager on XPU device#32654jikunshang merged 1 commit intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This PR adds support for AgRsAll2AllManager on XPU devices by implementing reduce_scatter, reduce_scatterv, and all_gatherv in XpuCommunicator. The changes introduce a few critical issues. There's a logic bug in the __init__ method that can lead to a crash when using an unsupported all2all backend. Additionally, several of the new distributed communication calls are missing the required group parameter, which would cause them to use the wrong process group. There is also an incorrect usage of dist.all_gather and a performance issue due to lack of communication batching.
| if self.all2all_backend == "allgather_reducescatter": | ||
| from .all2all import AgRsAll2AllManager | ||
| self.all2all_manager = AgRsAll2AllManager(self.cpu_group) | ||
| logger.info("Using AgRs manager on XPU device.") | ||
|
|
||
| elif self.all2all_backend != "naive": # type: ignore[has-type] | ||
| logger.warning( | ||
| "`%s` all2all manager is not supported on XPU. " | ||
| "Falling back to `naive` all2all manager for XPU.", | ||
| self.all2all_backend, # type: ignore[has-type] | ||
| ) | ||
| self.all2all_backend = "naive" | ||
| if self.all2all_backend == "naive": | ||
|
|
||
| elif self.all2all_backend == "naive": | ||
| from .all2all import NaiveAll2AllManager | ||
|
|
||
| self.all2all_manager = NaiveAll2AllManager(self.cpu_group) |
There was a problem hiding this comment.
The logic for selecting the all2all_manager is flawed. The if/elif/elif structure is mutually exclusive. If an unsupported all2all_backend is provided, the code will enter the elif self.all2all_backend != "naive": block, set self.all2all_backend = "naive", and then exit the conditional chain. The final elif self.all2all_backend == "naive": block will not be executed. This leaves self.all2all_manager uninitialized, which will cause a runtime error later.
| if self.all2all_backend == "allgather_reducescatter": | |
| from .all2all import AgRsAll2AllManager | |
| self.all2all_manager = AgRsAll2AllManager(self.cpu_group) | |
| logger.info("Using AgRs manager on XPU device.") | |
| elif self.all2all_backend != "naive": # type: ignore[has-type] | |
| logger.warning( | |
| "`%s` all2all manager is not supported on XPU. " | |
| "Falling back to `naive` all2all manager for XPU.", | |
| self.all2all_backend, # type: ignore[has-type] | |
| ) | |
| self.all2all_backend = "naive" | |
| if self.all2all_backend == "naive": | |
| elif self.all2all_backend == "naive": | |
| from .all2all import NaiveAll2AllManager | |
| self.all2all_manager = NaiveAll2AllManager(self.cpu_group) | |
| if self.all2all_backend == "allgather_reducescatter": | |
| from .all2all import AgRsAll2AllManager | |
| self.all2all_manager = AgRsAll2AllManager(self.cpu_group) | |
| logger.info("Using AgRs manager on XPU device.") | |
| else: | |
| if self.all2all_backend != "naive": # type: ignore[has-type] | |
| logger.warning( | |
| "`%s` all2all manager is not supported on XPU. " | |
| "Falling back to `naive` all2all manager for XPU.", | |
| self.all2all_backend, # type: ignore[has-type] | |
| ) | |
| self.all2all_backend = "naive" | |
| from .all2all import NaiveAll2AllManager | |
| self.all2all_manager = NaiveAll2AllManager(self.cpu_group) |
| output_shape, dtype=input_tensor.dtype, device=input_tensor.device | ||
| ) | ||
|
|
||
| dist.reduce_scatter_tensor(output, input_tensor) |
There was a problem hiding this comment.
The call to dist.reduce_scatter_tensor is missing the group argument. This will cause it to use the default process group, which is likely incorrect for tensor parallelism. You should pass group=self.device_group to ensure communication occurs within the intended process group.
| dist.reduce_scatter_tensor(output, input_tensor) | |
| dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group) |
| if sizes is not None and sizes.count(sizes[0]) != len(sizes): | ||
| # if inputs shape in different ranks is not the same using reduce_scatter | ||
| input_splits = list(input_tensor.split(sizes, dim=0)) | ||
| dist.reduce_scatter(output, input_splits) | ||
| else: | ||
| dist.reduce_scatter_tensor(output, input_tensor) |
There was a problem hiding this comment.
The dist.reduce_scatter and dist.reduce_scatter_tensor calls are missing the group argument. This will cause them to use the default process group instead of the specific device_group for this communicator, which can lead to incorrect behavior.
| if sizes is not None and sizes.count(sizes[0]) != len(sizes): | |
| # if inputs shape in different ranks is not the same using reduce_scatter | |
| input_splits = list(input_tensor.split(sizes, dim=0)) | |
| dist.reduce_scatter(output, input_splits) | |
| else: | |
| dist.reduce_scatter_tensor(output, input_tensor) | |
| if sizes is not None and sizes.count(sizes[0]) != len(sizes): | |
| # if inputs shape in different ranks is not the same using reduce_scatter | |
| input_splits = list(input_tensor.split(sizes, dim=0)) | |
| dist.reduce_scatter(output, input_splits, group=self.device_group) | |
| else: | |
| dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group) |
| all_gather_list = [] | ||
| for size in sizes: | ||
| all_gather_list.append(torch.empty((size,) + input_.shape[1:], dtype=input_.dtype, device=input_.device)) | ||
| dist.all_gather(all_gather_list, input_) |
There was a problem hiding this comment.
The dist.all_gather call is missing the group argument. This will cause it to use the default process group, which is likely incorrect. You should pass group=self.device_group to ensure communication occurs within the intended process group.
| dist.all_gather(all_gather_list, input_) | |
| dist.all_gather(all_gather_list, input_, group=self.device_group) |
| dist.all_gather(all_gather_list, input_) | ||
| output_tensor = torch.cat(all_gather_list, dim=0) | ||
| else: | ||
| dist.all_gather([output_tensor], input_) |
There was a problem hiding this comment.
The use of dist.all_gather([output_tensor], input_) is incorrect. dist.all_gather expects a list of tensors (one for each rank) to gather data into, but it's being passed a list with a single large tensor. This will fail for world_size > 1. You should use dist.all_gather_into_tensor instead. Additionally, the group argument is missing.
| dist.all_gather([output_tensor], input_) | |
| dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) |
| output_list = [] | ||
| for inp in input_: | ||
| output_list.append(_all_gather_single(inp, sizes=sizes)) | ||
| return output_list |
There was a problem hiding this comment.
When input_ is a list of tensors, the communication calls inside _all_gather_single are performed sequentially for each tensor. This is inefficient. For better performance, these collective operations should be batched. For the common case where sizes is None, you can achieve this by using dist.all_gather_into_tensor with async_op=True for each tensor and then waiting for all the returned handles to complete. The case where sizes is not None is harder to batch with the standard torch.distributed API as dist.all_gather does not support asynchronous operations.
|
Hi @ys950902, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
the default all2all manager will be naive or ag_rs? |
|
Hi @ys950902, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
|
Hi @ys950902, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
default will change to allgather_reducescatter, which is the default backend that VLLM suggested. |
Signed-off-by: yisheng <yi.sheng@intel.com>
jikunshang
left a comment
There was a problem hiding this comment.
LGTM. thanks for fixing!
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: daje0601 <englishmt4118@gmail.com>
Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
…llm-project#105) Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: yisheng <yi.sheng@intel.com> Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: yisheng <yi.sheng@intel.com>
Purpose
This PR is to enable AgRsAll2AllManager on XPU device, using allgather+reduce_scatter to replace naive implementation.
Test Plan
VLLM_ALL2ALL_BACKEND=allgather_reducescatter VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_USE_V1=1 VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 VLLM_WORKER_MULTIPROC_METHOD=spawn python3 -m vllm.entrypoints.openai.api_server --model Qwen/Qwen3-Coder-30B-A3B-Instruct --enforce-eager --port 8000 --host 0.0.0.0 --trust-remote-code --gpu-memory-util=0.9 --no-enable-prefix-caching --max-num-batched-tokens=8192 --disable-log-requests --max-model-len=8192 --block-size 64 --enable-expert-parallel --quantization fp8 --data-parallel-size 2 --dtype=float16 -tp=2
lm_eval --model local-chat-completions --tasks gsm8k --num_fewshot 1 --batch_size 1 --model_args "model=Qwen/Qwen3-Coder-30B-A3B-Instruct,base_url=http://0.0.0.0:8000/v1/chat/completions,max_gen_toks=4096,num_concurrent=64" --apply_chat_template --output_path ./lm_eval_output --log_samples
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.