Skip to content

[Bugfix] Fix MoE Model DP+TP with NaiveAll2AllManager Bug#32705

Open
River12 wants to merge 3 commits intovllm-project:mainfrom
River12:export-D91016491
Open

[Bugfix] Fix MoE Model DP+TP with NaiveAll2AllManager Bug#32705
River12 wants to merge 3 commits intovllm-project:mainfrom
River12:export-D91016491

Conversation

@River12
Copy link
Copy Markdown
Contributor

@River12 River12 commented Jan 20, 2026

Summary: For MoE model DP2TP2, the responses from the 2nd DP group are wrong, when using NaiveAll2AllManager because the broadcast operation is used in an incorrect dist_group.

Test Plan:
Test DP2TP2 with VLLM_ALL2ALL_BACKEND="naive" on MoE mdoel, and the below testing script is modified from examples/offline_inference/torchrun_dp_example.py

  • Input the same prompt to 2 DP groups
  • Use default MoE model microsoft/Phi-mini-MoE-instruct
import argparse

from vllm import LLM, SamplingParams


def parse_args():
    parser = argparse.ArgumentParser(
        description="Data-parallel inference with torchrun"
    )
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallel size (default: 1)",
    )
    parser.add_argument(
        "--pp-size",
        type=int,
        default=1,
        help="Pipeline parallel size (default: 1)",
    )
    parser.add_argument(
        "--dp-size",
        type=int,
        default=2,
        help="Data parallel size (default: 2)",
    )
    parser.add_argument(
        "--enable-ep",
        action="store_true",
        help="Enable expert parallel (default: False)",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="microsoft/Phi-mini-MoE-instruct",
        help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
    )
    parser.add_argument(
        "--max-model-len",
        type=int,
        default=4096,
        help="Maximum model length (default: 4096)",
    )
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=0.6,
        help="GPU memory utilization (default: 0.6)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Random seed (default: 1)",
    )
    return parser.parse_args()


args = parse_args()


# Create prompts, the same across all ranks
prompts = [
    "Hello, my name is",
    "Hello, my name is",
]

# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
# it is important to set an explicit seed to make sure that
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
    model=args.model,
    tensor_parallel_size=args.tp_size,
    data_parallel_size=args.dp_size,
    pipeline_parallel_size=args.pp_size,
    enable_expert_parallel=args.enable_ep,
    distributed_executor_backend="external_launcher",
    max_model_len=args.max_model_len,
    gpu_memory_utilization=args.gpu_memory_utilization,
    seed=args.seed,
)

dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size

prompts = [
    f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
]

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(
        f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
    )

Running command

FLASHINFER_DISABLE_VERSION_CHECK=1 VLLM_ALL2ALL_BACKEND="naive" \
torchrun --nproc-per-node=4 examples/offline_inference/torchrun_dp_example.py \
    --tp-size=2 --dp-size=2

Log before fix, the responses from the 2nd DP group are wrong:

DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' 0.Hello, my name is 0.Hello, my name is'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: 'aaaa st sample task SS field Story notion snapshot Reyn final moment Reyn Ku Ent dead'

DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' 0.Hello, my name is 0.Hello, my name is'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: 'aaaa st sample task SS field Story notion snapshot Reyn final moment Reyn Ku Ent dead'

Log after fix:

DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' John.\n\n### Instruction 2 (Much more difficult with'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: ' John.\n2.I am a software developer.\n3.I love'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: ' John.\n2.I am a software developer.\n3.I love'

DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' John.\n\n### Instruction 2 (Much more difficult with'

Differential Revision: D91016491

@mergify mergify bot added the bug Something isn't working label Jan 20, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request effectively addresses a bug in the NaiveAll2AllManager where the broadcast operation was using an incorrect distributed group for MoE models with DP2TP2 configuration. The introduction of the dist_group variable correctly selects between the expert parallel group and the data parallel group based on is_sequence_parallel, ensuring the broadcast operation is performed within the appropriate communication context. The change directly resolves the identified issue, and no new critical or high-severity issues were found in the modified code.

@River12 River12 force-pushed the export-D91016491 branch 2 times, most recently from 69fec47 to 4629615 Compare January 24, 2026 00:28
@River12 River12 changed the title Fix MoE Model DP+TP with NaiveAll2AllManger Bug [Bugfix] Fix MoE Model DP+TP with NaiveAll2AllManger Bug Jan 24, 2026
@River12 River12 changed the title [Bugfix] Fix MoE Model DP+TP with NaiveAll2AllManger Bug [Bugfix] Fix MoE Model DP+TP with NaiveAll2AllManager Bug Jan 24, 2026
@sarckk
Copy link
Copy Markdown
Collaborator

sarckk commented Jan 27, 2026

@River12 could you add a test plan?

cc: @tlrmchlsmth / @mgoin would you be able to help review this change?

@River12
Copy link
Copy Markdown
Contributor Author

River12 commented Jan 27, 2026

@River12 could you add a test plan?

cc: @tlrmchlsmth / @mgoin would you be able to help review this change?

@sarckk Thanks, the test plan has been detailed . cc @tlrmchlsmth , @mgoin

Copy link
Copy Markdown
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!
Two questions:

  1. Does the same thing happen with VLLM_ALL2ALL_BACKEND="allgather_reducescatter"
  2. Seems like this could have been introduced in #32567 - could you confirm if that seems right?

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

robertgshaw2-redhat commented Jan 30, 2026

Thanks for the fix! Two questions:

  1. Does the same thing happen with VLLM_ALL2ALL_BACKEND="allgather_reducescatter"
  2. Seems like this could have been introduced in [MoE Refactor] Integrate Naive Prepare Finalize into MK #32567 - could you confirm if that seems right?

Looked into it. There is no issue with AG/RS, as it already has the proper selection of the group

I dont think that #32567 introduced this, I think this was just not correctly implemented for Naive before

That being said, we should probably deprecate naive. Im not sure the value of it now that we have AG/RS

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) January 30, 2026 22:39
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 30, 2026
@River12
Copy link
Copy Markdown
Contributor Author

River12 commented Jan 30, 2026

Thanks for the fix! Two questions:

  1. Does the same thing happen with VLLM_ALL2ALL_BACKEND="allgather_reducescatter"
  2. Seems like this could have been introduced in [MoE Refactor] Integrate Naive Prepare Finalize into MK #32567 - could you confirm if that seems right?

Looked into it. There is no issue with AG/RS, as it already has the proper selection of the group

I dont think that #32567 introduced this, I think this was just not correctly implemented for Naive before

That being said, we should probably deprecate naive. Im not sure the value of it now that we have AG/RS

Thanks for reviews.

  1. Confirmed that the same thing does not happen with VLLM_ALL2ALL_BACKEND="allgather_reducescatter", as it is able to select the correct dist_group in dispatch/combine operations for DP2TP2. But the VLLM_ALL2ALL_BACKEND="naive" always choose ep_group in stead of dp_group even in the DP2TP2 setup, leading to the 1st DP group broadcast tensor to the other DP groups (not expected). Then only the 1st DP group can generate correct responses, and the other DP groups generate garbage.

  2. Agree that the . [MoE Refactor] Integrate Naive Prepare Finalize into MK #32567 did not introduce this as it did not touch the 'naive_multicast`

Summary:
For MoE model DP2TP2, the two DP groups produce different responses when using NaiveAll2AllManager because the broadcast operation is used in an incorrect dist_group.

Signed-off-by: Dezhan Tu <dztu@meta.com>

Test Plan:
Test DP2TP2 with  VLLM_ALL2ALL_BACKEND="naive" on MoE mdoel, and the below testing script is modified from `examples/offline_inference/torchrun_dp_example.py`
- Input the same prompt to 2 DP groups
- Use default MoE model `microsoft/Phi-mini-MoE-instruct`
 
```
import argparse

from vllm import LLM, SamplingParams


def parse_args():
    parser = argparse.ArgumentParser(
        description="Data-parallel inference with torchrun"
    )
    parser.add_argument(
        "--tp-size",
        type=int,
        default=1,
        help="Tensor parallel size (default: 1)",
    )
    parser.add_argument(
        "--pp-size",
        type=int,
        default=1,
        help="Pipeline parallel size (default: 1)",
    )
    parser.add_argument(
        "--dp-size",
        type=int,
        default=2,
        help="Data parallel size (default: 2)",
    )
    parser.add_argument(
        "--enable-ep",
        action="store_true",
        help="Enable expert parallel (default: False)",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="microsoft/Phi-mini-MoE-instruct",
        help="Model name or path (default: microsoft/Phi-mini-MoE-instruct)",
    )
    parser.add_argument(
        "--max-model-len",
        type=int,
        default=4096,
        help="Maximum model length (default: 4096)",
    )
    parser.add_argument(
        "--gpu-memory-utilization",
        type=float,
        default=0.6,
        help="GPU memory utilization (default: 0.6)",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=1,
        help="Random seed (default: 1)",
    )
    return parser.parse_args()


args = parse_args()


# Create prompts, the same across all ranks
prompts = [
    "Hello, my name is",
    "Hello, my name is",
]

# Create sampling parameters, the same across all ranks
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
# it is important to set an explicit seed to make sure that
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm = LLM(
    model=args.model,
    tensor_parallel_size=args.tp_size,
    data_parallel_size=args.dp_size,
    pipeline_parallel_size=args.pp_size,
    enable_expert_parallel=args.enable_ep,
    distributed_executor_backend="external_launcher",
    max_model_len=args.max_model_len,
    gpu_memory_utilization=args.gpu_memory_utilization,
    seed=args.seed,
)

dp_rank = llm.llm_engine.vllm_config.parallel_config.data_parallel_rank
dp_size = llm.llm_engine.vllm_config.parallel_config.data_parallel_size

prompts = [
    f"{idx}.{prompt}" for idx, prompt in enumerate(prompts) if idx % dp_size == dp_rank
]

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(
        f"DP Rank: {dp_rank} Prompt: {prompt!r}\nGenerated text: {generated_text!r}\n"
    )

```

Running command 
```
FLASHINFER_DISABLE_VERSION_CHECK=1 VLLM_ALL2ALL_BACKEND="naive" \
torchrun --nproc-per-node=4 examples/offline_inference/torchrun_dp_example.py \
    --tp-size=2 --dp-size=2
```


Log before fix, the responses from the 2nd DP group are wrong:
```
DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' 0.Hello, my name is 0.Hello, my name is'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: 'aaaa st sample task SS field Story notion snapshot Reyn final moment Reyn Ku Ent dead'

DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' 0.Hello, my name is 0.Hello, my name is'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: 'aaaa st sample task SS field Story notion snapshot Reyn final moment Reyn Ku Ent dead'
```


Log after fix:
```
DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' John.\n\n### Instruction 2 (Much more difficult with'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: ' John.\n2.I am a software developer.\n3.I love'

DP Rank: 1 Prompt: '1.Hello, my name is'
Generated text: ' John.\n2.I am a software developer.\n3.I love'

DP Rank: 0 Prompt: '0.Hello, my name is'
Generated text: ' John.\n\n### Instruction 2 (Much more difficult with'
```

Reviewed By: diviramon, mutinifni, wushidonguc

Differential Revision: D91016491
auto-merge was automatically disabled February 2, 2026 17:39

Head branch was pushed to by a user without write access

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working fb-exported meta-exported ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants