Skip to content

[XPU]Support AgRsAll2AllManager on XPU device#32654

Merged
jikunshang merged 1 commit intovllm-project:mainfrom
ys950902:yis_AgRs
Jan 20, 2026
Merged

[XPU]Support AgRsAll2AllManager on XPU device#32654
jikunshang merged 1 commit intovllm-project:mainfrom
ys950902:yis_AgRs

Conversation

@ys950902
Copy link
Contributor

@ys950902 ys950902 commented Jan 20, 2026

Purpose

This PR is to enable AgRsAll2AllManager on XPU device, using allgather+reduce_scatter to replace naive implementation.

Test Plan

VLLM_ALL2ALL_BACKEND=allgather_reducescatter VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_USE_V1=1 VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 VLLM_WORKER_MULTIPROC_METHOD=spawn python3 -m vllm.entrypoints.openai.api_server --model Qwen/Qwen3-Coder-30B-A3B-Instruct --enforce-eager --port 8000 --host 0.0.0.0 --trust-remote-code --gpu-memory-util=0.9 --no-enable-prefix-caching --max-num-batched-tokens=8192 --disable-log-requests --max-model-len=8192 --block-size 64 --enable-expert-parallel --quantization fp8 --data-parallel-size 2 --dtype=float16 -tp=2

lm_eval --model local-chat-completions --tasks gsm8k --num_fewshot 1 --batch_size 1 --model_args "model=Qwen/Qwen3-Coder-30B-A3B-Instruct,base_url=http://0.0.0.0:8000/v1/chat/completions,max_gen_toks=4096,num_concurrent=64" --apply_chat_template --output_path ./lm_eval_output --log_samples

Test Result

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 1 exact_match 0.9363 ± 0.0067
strict-match 1 exact_match 0.9371 ± 0.0067

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR adds support for AgRsAll2AllManager on XPU devices by implementing reduce_scatter, reduce_scatterv, and all_gatherv in XpuCommunicator. The changes introduce a few critical issues. There's a logic bug in the __init__ method that can lead to a crash when using an unsupported all2all backend. Additionally, several of the new distributed communication calls are missing the required group parameter, which would cause them to use the wrong process group. There is also an incorrect usage of dist.all_gather and a performance issue due to lack of communication batching.

Comment on lines 27 to 43
if self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AgRs manager on XPU device.")

elif self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on XPU. "
"Falling back to `naive` all2all manager for XPU.",
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
if self.all2all_backend == "naive":

elif self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager

self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for selecting the all2all_manager is flawed. The if/elif/elif structure is mutually exclusive. If an unsupported all2all_backend is provided, the code will enter the elif self.all2all_backend != "naive": block, set self.all2all_backend = "naive", and then exit the conditional chain. The final elif self.all2all_backend == "naive": block will not be executed. This leaves self.all2all_manager uninitialized, which will cause a runtime error later.

Suggested change
if self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AgRs manager on XPU device.")
elif self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on XPU. "
"Falling back to `naive` all2all manager for XPU.",
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
if self.all2all_backend == "naive":
elif self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
if self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
logger.info("Using AgRs manager on XPU device.")
else:
if self.all2all_backend != "naive": # type: ignore[has-type]
logger.warning(
"`%s` all2all manager is not supported on XPU. "
"Falling back to `naive` all2all manager for XPU.",
self.all2all_backend, # type: ignore[has-type]
)
self.all2all_backend = "naive"
from .all2all import NaiveAll2AllManager
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)

output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)

dist.reduce_scatter_tensor(output, input_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to dist.reduce_scatter_tensor is missing the group argument. This will cause it to use the default process group, which is likely incorrect for tensor parallelism. You should pass group=self.device_group to ensure communication occurs within the intended process group.

Suggested change
dist.reduce_scatter_tensor(output, input_tensor)
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)

Comment on lines +99 to +108
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits)
else:
dist.reduce_scatter_tensor(output, input_tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dist.reduce_scatter and dist.reduce_scatter_tensor calls are missing the group argument. This will cause them to use the default process group instead of the specific device_group for this communicator, which can lead to incorrect behavior.

Suggested change
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits)
else:
dist.reduce_scatter_tensor(output, input_tensor)
if sizes is not None and sizes.count(sizes[0]) != len(sizes):
# if inputs shape in different ranks is not the same using reduce_scatter
input_splits = list(input_tensor.split(sizes, dim=0))
dist.reduce_scatter(output, input_splits, group=self.device_group)
else:
dist.reduce_scatter_tensor(output, input_tensor, group=self.device_group)

all_gather_list = []
for size in sizes:
all_gather_list.append(torch.empty((size,) + input_.shape[1:], dtype=input_.dtype, device=input_.device))
dist.all_gather(all_gather_list, input_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The dist.all_gather call is missing the group argument. This will cause it to use the default process group, which is likely incorrect. You should pass group=self.device_group to ensure communication occurs within the intended process group.

Suggested change
dist.all_gather(all_gather_list, input_)
dist.all_gather(all_gather_list, input_, group=self.device_group)

dist.all_gather(all_gather_list, input_)
output_tensor = torch.cat(all_gather_list, dim=0)
else:
dist.all_gather([output_tensor], input_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The use of dist.all_gather([output_tensor], input_) is incorrect. dist.all_gather expects a list of tensors (one for each rank) to gather data into, but it's being passed a list with a single large tensor. This will fail for world_size > 1. You should use dist.all_gather_into_tensor instead. Additionally, the group argument is missing.

Suggested change
dist.all_gather([output_tensor], input_)
dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)

Comment on lines +151 to +164
output_list = []
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
return output_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When input_ is a list of tensors, the communication calls inside _all_gather_single are performed sequentially for each tensor. This is inefficient. For better performance, these collective operations should be batched. For the common case where sizes is None, you can achieve this by using dist.all_gather_into_tensor with async_op=True for each tensor and then waiting for all the returned handles to complete. The case where sizes is not None is harder to batch with the standard torch.distributed API as dist.all_gather does not support asynchronous operations.

@mergify
Copy link

mergify bot commented Jan 20, 2026

Hi @ys950902, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@jikunshang
Copy link
Collaborator

the default all2all manager will be naive or ag_rs?

@mergify
Copy link

mergify bot commented Jan 20, 2026

Hi @ys950902, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link

mergify bot commented Jan 20, 2026

Hi @ys950902, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@ys950902
Copy link
Contributor Author

ys950902 commented Jan 20, 2026

the default all2all manager will be naive or ag_rs?

default will change to allgather_reducescatter, which is the default backend that VLLM suggested.
can see here https://docs.vllm.com.cn/en/latest/serving/expert_parallel_deployment/#backend-selection-guide.

@ys950902 ys950902 closed this Jan 20, 2026
@ys950902 ys950902 reopened this Jan 20, 2026
Signed-off-by: yisheng <yi.sheng@intel.com>
Copy link
Collaborator

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. thanks for fixing!

@jikunshang jikunshang enabled auto-merge (squash) January 20, 2026 12:44
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 20, 2026
@jikunshang jikunshang merged commit 13f6630 into vllm-project:main Jan 20, 2026
54 checks passed
gopalsarda pushed a commit to gopalsarda/vllm that referenced this pull request Jan 20, 2026
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
aykoppol pushed a commit to aykoppol/vllm that referenced this pull request Jan 21, 2026
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 21, 2026
daje0601 pushed a commit to daje0601/vllm that referenced this pull request Jan 22, 2026
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: daje0601 <englishmt4118@gmail.com>
monajafi-amd pushed a commit to monajafi-amd/vllm that referenced this pull request Jan 23, 2026
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: mohammad najafi <mohammad.najafi@amd.com>
@ys950902 ys950902 deleted the yis_AgRs branch January 26, 2026 06:57
yma11 pushed a commit to yma11/vllm that referenced this pull request Jan 27, 2026
Josephasafg pushed a commit to Josephasafg/vllm that referenced this pull request Jan 27, 2026
Signed-off-by: yisheng <yi.sheng@intel.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants