Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
983a5a2
Interleave shared experts with send/recv
bnellnm Aug 20, 2025
ba50625
use receiver callback
bnellnm Aug 20, 2025
b221237
fixes + add prepare_no_receive to deepep ll
bnellnm Aug 20, 2025
f3494bf
fixes
bnellnm Aug 21, 2025
261820d
update doc
bnellnm Aug 21, 2025
5d17858
wip cudagraphs
bnellnm Aug 21, 2025
89406ba
comments + tweaks
bnellnm Aug 21, 2025
83cc427
fix comments + tweak layer
bnellnm Aug 21, 2025
782b6d7
undo bogus changes
bnellnm Aug 21, 2025
f5e2ea9
simplify reduce step
bnellnm Aug 21, 2025
f294214
implement prepare_no_receive for deepep ht
bnellnm Aug 22, 2025
21e1f71
add SharedFusedExpert to deepseek_v2
bnellnm Aug 22, 2025
2a5b5a1
fix lint
bnellnm Aug 23, 2025
0f36c9f
lint
bnellnm Aug 23, 2025
f866d66
more lint
bnellnm Aug 23, 2025
fabc9d6
add unit test for shared/fused experts overlap
bnellnm Aug 27, 2025
a9548a1
use more realistic MLP for test
bnellnm Aug 27, 2025
31f2d05
fix cudagraphs
bnellnm Aug 27, 2025
79df8e7
undo debugging code in activation.py
bnellnm Aug 27, 2025
544b62a
review comments + cleanup deepseek changes
bnellnm Aug 28, 2025
d617c83
review comments
bnellnm Aug 28, 2025
46b2c3b
remove debugging env var
bnellnm Aug 29, 2025
cc11426
add all_reduce to fallback code
bnellnm Aug 29, 2025
8173209
fix data_parallel.py compilation config default
bnellnm Aug 29, 2025
9dc9205
add check for SharedFusedMoE layer
bnellnm Sep 2, 2025
2d5561f
ignore shared expert weights
bnellnm Sep 2, 2025
068ad88
ignore shared_expert parameters
bnellnm Sep 2, 2025
b94296f
make doc clearer
bnellnm Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/design/fused_moe_modular_kernel.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ The `FusedMoEModularKernel` acts as a bridge between the `FusedMoEPermuteExperts

### FusedMoEPrepareAndFinalize

The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare` and `finalize` functions.
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)
The `FusedMoEPrepareAndFinalize` abstract class exposes `prepare`, `prepare_no_receive` and `finalize` functions.
The `prepare` function is responsible for input activation Quantization and All2All Dispatch. If implemented, The `prepare_no_receive` is like `prepare` except it does not wait to receive results from other workers. Instead it returns a "receiver" callback that must be invoked to wait for the final results of worker. It is not required that this method is supported by all `FusedMoEPrepareAndFinalize` classes, but if it is available, it can be used to interleave work with the initial all to all communication, e.g. interleaving shared experts with fused experts. The `finalize` function is responsible for invoking the All2All Combine. Additionally the `finalize` function may or may not do the TopK weight application and reduction (Please refer to the TopKWeightAndReduce section)

![](../assets/design/fused_moe_modular_kernel/prepare_and_finalize_blocks.png "FusedMoEPrepareAndFinalize Blocks")

Expand Down Expand Up @@ -146,6 +146,10 @@ This section describes the significance of the various functions exposed by the

`FusedMoEPrepareAndFinalize::prepare()`: The prepare method implements the Quantization and All2All Dispatch. Typically the Dispatch function from the relevant All2All Manager is invoked.

`FusedMoEPrepareAndFinalize::has_prepare_no_receive()`: Indicates whether or not this subclass implements `prepare_no_receive`. Defaults to False.

`FusedMoEPrepareAndFinalize::prepare_no_receive()`: The prepare_no_receive method implements the Quantization and All2All Dispatch. It does not wait for the result of the dispatch operation but instead returns a thunk that can be invoked to wait for the final results. Typically the Dispatch function from the relevant All2All Manager is invoked.

`FusedMoEPrepareAndFinalize::finalize()`: Maybe perform TopK Weight Application and Reduction and All2All Combine. Typically the Combine function from the relevant All2AllManager is invoked.

`FusedMoEPrepareAndFinalize::activation_format()`: Return `FusedMoEActivationFormat.BatchedExperts` if the output of the prepare method (i.e. the All2All dispatch) is Batched. Return `FusedMoEActivationFormat.Standard` otherwise.
Expand Down
8 changes: 8 additions & 0 deletions examples/offline_inference/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def parse_args():
default=0.8,
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
)
parser.add_argument(
"--compilation-config",
type=int,
help=("Compilation optimization (O) level 0-3."),
)
parser.add_argument(
"--quantization",
type=str,
Expand All @@ -106,6 +111,7 @@ def main(
trust_remote_code,
max_num_seqs,
max_model_len,
compilation_config,
gpu_memory_utilization,
quantization,
):
Expand Down Expand Up @@ -162,6 +168,7 @@ def start(rank):
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
quantization=quantization,
compilation_config=compilation_config,
)
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
Expand Down Expand Up @@ -218,6 +225,7 @@ def start(rank):
args.trust_remote_code,
args.max_num_seqs,
args.max_model_len,
args.compilation_config,
args.gpu_memory_utilization,
args.quantization,
),
Expand Down
86 changes: 74 additions & 12 deletions tests/kernels/moe/test_pplx_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import copy
import itertools
import textwrap
import traceback
from typing import Callable, Optional
from typing import Callable, Optional, Union

import pytest
import torch
Expand All @@ -21,7 +22,10 @@
except ImportError:
has_pplx = False

from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
from tests.kernels.moe.modular_kernel_tools.parallel_utils import (
_set_vllm_config)
from tests.kernels.moe.utils import (make_shared_experts, make_test_weights,
naive_batched_moe)
from tests.kernels.quant_utils import dequant
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
Expand Down Expand Up @@ -511,7 +515,8 @@ def pplx_moe(
block_shape: Optional[list[int]] = None,
use_compile: bool = False,
use_cudagraphs: bool = True,
) -> torch.Tensor:
shared_experts: Optional[torch.nn.Module] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:

num_tokens, hidden_dim = a.shape
num_experts = w1.shape[0]
Expand Down Expand Up @@ -546,6 +551,7 @@ def pplx_moe(
fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
shared_experts,
)

# Note: workers with the same dp_rank must use the exact same inputs.
Expand Down Expand Up @@ -586,7 +592,11 @@ def pplx_moe(
global_num_experts=num_experts)

if use_cudagraphs:
out.fill_(0)
if isinstance(out, tuple):
out[0].fill_(0)
out[1].fill_(0)
else:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
Expand Down Expand Up @@ -626,6 +636,7 @@ def _pplx_moe(
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_internode: bool = False,
shared_experts: Optional[torch.nn.Module] = None,
):
try:
if use_internode:
Expand Down Expand Up @@ -666,6 +677,11 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)

if shared_experts is not None:
shared_output = shared_experts(a)
else:
shared_output = None

torch_output = torch_experts(
a,
w1,
Expand Down Expand Up @@ -696,7 +712,7 @@ def _pplx_moe(
block_shape=block_shape,
)

pplx_output = pplx_moe(
pplx_outputs = pplx_moe(
group_name,
rank,
world_size,
Expand All @@ -713,8 +729,24 @@ def _pplx_moe(
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
shared_experts=shared_experts,
)

if shared_experts is None:
pplx_shared_output = None
pplx_output = pplx_outputs
assert isinstance(pplx_output, torch.Tensor)
else:
pplx_shared_output, pplx_output = pplx_outputs

if shared_output is not None:
assert pplx_shared_output is not None
chunked_shared_output = chunk_by_rank(
shared_output, pgi.rank,
pgi.world_size).to(pplx_shared_output.device)
else:
chunked_shared_output = None

chunked_batch_output = chunk_by_rank(
batched_output, pgi.rank, pgi.world_size).to(pplx_output.device)

Expand All @@ -727,6 +759,15 @@ def _pplx_moe(
chunked_batch_output,
atol=3e-2,
rtol=3e-2)

if shared_experts is not None:
assert chunked_shared_output is not None
assert pplx_shared_output is not None
torch.testing.assert_close(pplx_shared_output,
chunked_shared_output,
atol=3e-2,
rtol=3e-2)

finally:
if use_internode:
nvshmem_finalize()
Expand Down Expand Up @@ -788,7 +829,8 @@ def test_pplx_moe_slow(


def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
make_weights: bool, test_fn: Callable):
use_shared_experts: bool, make_weights: bool,
test_fn: Callable):

def format_result(msg, ex=None):
if ex is not None:
Expand All @@ -803,6 +845,14 @@ def format_result(msg, ex=None):
else:
print(f"PASSED {msg}")

if use_shared_experts:
# Note: this config is only needed for the non-naive shared experts.
new_vllm_config = copy.deepcopy(vllm_config)
new_vllm_config.parallel_config.data_parallel_size = pgi.world_size
new_vllm_config.parallel_config.enable_expert_parallel = True
_set_vllm_config(new_vllm_config, pgi.world_size, pgi.rank,
pgi.local_rank)

current_platform.seed_everything(7)
combos = itertools.product(PPLX_COMBOS, NUM_EXPERTS, TOP_KS, DTYPES,
[False, True], [None, [128, 128]])
Expand All @@ -819,9 +869,11 @@ def format_result(msg, ex=None):
use_fp8_w8a8 = False
quant_dtype = None

test_desc = (f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}")
test_desc = (
f"test_pplx_moe[mnk={mnk}, e={e}, topk={topk}, "
f"dtype={dtype}, per_act_token={per_act_token_quant}, "
f"block_shape={block_shape}, use_internode={use_internode}, "
f"use_shared_experts={use_shared_experts}")

if not use_fp8_w8a8 and (per_act_token_quant
or block_shape is not None):
Expand Down Expand Up @@ -852,6 +904,14 @@ def format_result(msg, ex=None):
args["w1_s"] = w1_s
args["w2_s"] = w2_s

if use_shared_experts:
args["shared_experts"] = make_shared_experts(
n,
k,
in_dtype=a.dtype,
quant_dtype=quant_dtype,
)

try:
test_fn(
pgi=pgi,
Expand Down Expand Up @@ -891,18 +951,20 @@ def test_pplx_prepare_finalize(
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size * dp_size, _pplx_test_loop, dp_size,
use_internode, False, _pplx_prepare_finalize)
use_internode, False, False, _pplx_prepare_finalize)


@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.parametrize("use_shared_experts", [False, True])
@requires_pplx
@multi_gpu_test(num_gpus=2)
def test_pplx_moe(
world_dp_size: tuple[int, int],
use_internode: bool,
use_shared_experts: bool,
):
current_platform.seed_everything(7)
world_size, dp_size = world_dp_size
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode, True,
_pplx_moe)
parallel_launch(world_size, _pplx_test_loop, dp_size, use_internode,
use_shared_experts, True, _pplx_moe)
Loading