Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
4aeabf2
initial MoERunner refactor
bnellnm Jan 13, 2026
a4d3acb
fix lint
bnellnm Feb 12, 2026
5b7f133
rebase
bnellnm Feb 24, 2026
fad7f33
rebase + remove dead code
bnellnm Mar 5, 2026
ec88db3
fix gate overlap
bnellnm Mar 19, 2026
76aff0a
wip
bnellnm Feb 4, 2026
4fab915
fix
bnellnm Feb 9, 2026
d8a7f91
WIP DOUBLE CHECK THIS
bnellnm Feb 11, 2026
3dec78f
wip more refactoring
bnellnm Feb 19, 2026
e94b863
wip
bnellnm Feb 19, 2026
6cc5074
SharedExperts wip
bnellnm Feb 23, 2026
e8865e6
cleanups
bnellnm Feb 23, 2026
f83e0f5
fix circular import
bnellnm Feb 23, 2026
88e80b9
fixes
bnellnm Feb 24, 2026
781d4ea
renames
bnellnm Feb 24, 2026
3695016
add comment
bnellnm Feb 24, 2026
053f66f
more renames
bnellnm Feb 24, 2026
708dd2b
cleanup
bnellnm Feb 25, 2026
5748f7c
remove memoizing router, not needed yet
bnellnm Feb 26, 2026
9123f15
fix UBD bug
bnellnm Feb 27, 2026
04b430f
cleanup merge
bnellnm Mar 5, 2026
526db38
fix merge
bnellnm Mar 5, 2026
67bdab2
fix merge
bnellnm Mar 5, 2026
e9afbe6
fix typos
bnellnm Mar 5, 2026
453ab3d
fix merge
bnellnm Mar 18, 2026
48acc59
fix format
bnellnm Mar 18, 2026
c067844
fix gate overlap
bnellnm Mar 19, 2026
9f0e8d7
merge with main
bnellnm Mar 19, 2026
bc82978
renames, revert lora changes
bnellnm Mar 19, 2026
3dc9d4f
review comments + cleanup
bnellnm Mar 20, 2026
12bda3d
remove _must_reduce_shared_expert_outputs
bnellnm Mar 20, 2026
8aaddea
undo some changes + add Rob's changes
bnellnm Mar 23, 2026
bbaaca7
Merge remote-tracking branch 'origin/main' into moe-runner-2
bnellnm Mar 23, 2026
392f311
hacky fix for unquantized method
bnellnm Mar 23, 2026
7d5adbe
fix lint
bnellnm Mar 23, 2026
f345165
fix lint
bnellnm Feb 12, 2026
bdefdf5
fix merge
bnellnm Mar 25, 2026
377acc8
fix merge
bnellnm Mar 25, 2026
4fe1531
attempt to fix zero experts
bnellnm Feb 26, 2026
b6ba920
simplify ZeroExpertFusedMoE and add ZeroExpertRouter
bnellnm Feb 27, 2026
3dd1f27
add value test
bnellnm Feb 27, 2026
20832b5
move ZeroExpertRouter construction into router factory
bnellnm Feb 27, 2026
d5676bd
move zero expert handling into MoERunnerBase
bnellnm Feb 27, 2026
165cfe6
slightly improved test
bnellnm Feb 27, 2026
a7ac4c4
simplifications
bnellnm Feb 27, 2026
dabcd5f
better test
bnellnm Feb 27, 2026
3b950b0
remove ZeroExpertFusedMoE
bnellnm Feb 27, 2026
8f88991
Add comment
bnellnm Mar 2, 2026
3a9852b
fix lint
bnellnm Mar 18, 2026
b94eaca
move shared expert all gather to SharedExperts
bnellnm Mar 3, 2026
03de451
remove must_reduce_shared_expert_outputs external method
bnellnm Mar 3, 2026
d43fa50
wip moving experts epilog into MoERunnerBase
bnellnm Mar 3, 2026
b4fd107
cleanups
bnellnm Mar 3, 2026
4f05802
apply_scale_to_output flag
bnellnm Mar 4, 2026
72c0b5e
fix fp16 scaling factor stuff
bnellnm Mar 4, 2026
3239666
cleanups
bnellnm Mar 4, 2026
3a6cfd7
add claude generated comments
bnellnm Mar 4, 2026
14a395a
move stuff out ot custom op
bnellnm Mar 4, 2026
8beaca1
fix transformers/moe.py
bnellnm Mar 4, 2026
33a0109
tweak op registration
bnellnm Mar 18, 2026
3c4ce4c
fix rebase
bnellnm Mar 18, 2026
47a2260
fix merge
bnellnm Mar 18, 2026
d8cfa5a
fix merge
bnellnm Mar 25, 2026
9b45bdf
fix
bnellnm Mar 25, 2026
149947d
update layer test for MoE refactoring
bnellnm Mar 17, 2026
ce8465c
Add test files
bnellnm Mar 18, 2026
d813203
fix
bnellnm Mar 18, 2026
2d4eac0
remove cruft
bnellnm Mar 18, 2026
e1b52ed
add better skip msg
bnellnm Mar 18, 2026
8856700
fix merge
bnellnm Mar 25, 2026
73cddf9
fix test imports
bnellnm Mar 25, 2026
dff6766
part 1
bnellnm Mar 19, 2026
9dc5176
_get_quant_method
bnellnm Mar 19, 2026
77f473b
wip
bnellnm Mar 23, 2026
46101e7
hack fix for weight loading
bnellnm Mar 23, 2026
6b9a89a
make ctors more independent
bnellnm Mar 24, 2026
6d0ed32
make MoERunner into a torch.nn.Module
bnellnm Mar 24, 2026
b122141
separate out eplb + fix shared_experts
bnellnm Mar 24, 2026
fd361fd
one eplb fix
bnellnm Mar 24, 2026
018195f
fixes
bnellnm Mar 25, 2026
6664f6c
hack fixes for chunked runner
bnellnm Mar 25, 2026
0002540
fix lora
bnellnm Mar 25, 2026
d9e9495
fix eplb bug
bnellnm Mar 25, 2026
4270ba7
turn SharedExperts into a torch.nn.Module
bnellnm Mar 26, 2026
b807df1
move eplb state to router
bnellnm Mar 26, 2026
19c9fa3
simplify enable_eplb flag
bnellnm Mar 26, 2026
eaaceda
add ExpertMapManager
bnellnm Mar 26, 2026
a8c5280
comment
bnellnm Mar 26, 2026
7164218
reshuffle code
bnellnm Mar 26, 2026
2c07624
make RoutedExperts class
bnellnm Mar 26, 2026
8dba22b
move RoutedExperts to separate file
bnellnm Mar 26, 2026
b326564
CHECKPOINT WIP moving weight loading to RoutedExperts
bnellnm Mar 27, 2026
54f88b2
fix? weight loader
bnellnm Mar 28, 2026
1f385af
fix test
bnellnm Mar 28, 2026
f0ceac4
another fix
bnellnm Mar 28, 2026
dddfd90
incomplete lora fixes. runner/router cleanups
bnellnm Mar 28, 2026
e7e1ab1
revert hacks
bnellnm Mar 30, 2026
629f1ed
more cleanups
bnellnm Apr 2, 2026
0bf4b9d
fixes
bnellnm Apr 6, 2026
3fcc816
remove FusedMoE/SharedFusedMoE classes
bnellnm Apr 10, 2026
57c7223
lora/transformers tweaks
bnellnm Apr 10, 2026
3fbde20
lora tweaks. still not working
bnellnm Apr 10, 2026
028db8f
collapse runner
bnellnm Apr 22, 2026
09abcbd
fix annotation
bnellnm Apr 22, 2026
92f770a
mega merge
bnellnm Apr 23, 2026
f7b705d
revert some stuff
bnellnm Apr 23, 2026
2630038
tweaks
bnellnm Apr 23, 2026
86e8568
merge
bnellnm Apr 23, 2026
ca58511
trunc before
bnellnm Apr 25, 2026
df14236
fix
bnellnm Apr 25, 2026
9fbd5b7
don't store SharedExperts in MK
bnellnm Apr 23, 2026
e218fc7
fix cruft
bnellnm Apr 23, 2026
f37de99
fix broken imports
bnellnm Apr 23, 2026
71d27a0
fixes
bnellnm Apr 23, 2026
04858c2
fix up new quant method
bnellnm Apr 25, 2026
5446029
fix lint
bnellnm Apr 27, 2026
89395fa
fix lint
bnellnm Apr 27, 2026
ced1799
expert map manager
bnellnm Apr 23, 2026
ba52a86
wip
bnellnm Apr 23, 2026
64cf4ac
update
bnellnm Apr 27, 2026
42c7fc4
merge
bnellnm Apr 27, 2026
9fe1392
eplb manager
bnellnm Apr 27, 2026
819a2dd
fixes
bnellnm Apr 27, 2026
9331477
eplb manager
bnellnm Apr 27, 2026
2332fd7
fix num_local_expert update
bnellnm Apr 27, 2026
bb6b39c
Merge remote-tracking branch 'nm-vllm/pass-shared-experts' into layer…
bnellnm Apr 28, 2026
6d3903f
Merge remote-tracking branch 'nm-vllm/expert-map-manager' into layer-…
bnellnm Apr 28, 2026
ec9e8be
Merge remote-tracking branch 'nm-vllm/eplb-manager' into layer-refactor
bnellnm Apr 28, 2026
3498820
fix
bnellnm Apr 28, 2026
c89c5ee
Merge remote-tracking branch 'nm-vllm/eplb-manager' into layer-refactor
bnellnm Apr 28, 2026
692912e
fix
bnellnm Apr 28, 2026
6270c70
Merge remote-tracking branch 'nm-vllm/eplb-manager' into layer-refactor
bnellnm Apr 28, 2026
7dba7b9
fix merge
bnellnm Apr 28, 2026
afb772c
Merge branch 'main' into expert-map-manager
bnellnm Apr 28, 2026
f208e97
Merge remote-tracking branch 'origin/main' into layer-refactor-prep
bnellnm Apr 28, 2026
9c55e67
Merge remote-tracking branch 'origin/main' into layer-runner-refactor
bnellnm Apr 28, 2026
aa210d3
fix
bnellnm Apr 28, 2026
914f53c
Merge remote-tracking branch 'nm-vllm/expert-map-manager' into expert…
bnellnm Apr 28, 2026
d5af506
Merge remote-tracking branch 'nm-vllm/expert-map-manager' into layer-…
bnellnm Apr 28, 2026
91fd26f
Merge branch 'layer-runner-refactor' into layer-refactor
bnellnm Apr 28, 2026
796a384
remove debug print
bnellnm Apr 29, 2026
c74f285
try to fix doc
bnellnm Apr 29, 2026
4791779
some fixes
bnellnm Apr 30, 2026
0bfa201
Merge remote-tracking branch 'origin/main' into pass-shared-experts
bnellnm Apr 30, 2026
1efaddb
Merge branch 'main' into expert-map-manager
bnellnm Apr 30, 2026
c8a42c2
fixes
bnellnm May 1, 2026
522a8fc
weight loading fixes
bnellnm May 1, 2026
1b38417
more weight loading fixes
bnellnm May 1, 2026
3adc884
fix
bnellnm May 4, 2026
13d121e
loader fixes
bnellnm May 4, 2026
4256550
Merge remote-tracking branch 'origin/main' into pass-shared-experts
bnellnm May 4, 2026
fa69759
fix
bnellnm May 4, 2026
45e60f0
Merge remote-tracking branch 'nm-vllm/pass-shared-experts' into layer…
bnellnm May 4, 2026
456668e
fix
bnellnm May 4, 2026
69c62e8
fix merge
bnellnm May 4, 2026
8c9f209
Merge remote-tracking branch 'nm-vllm/pass-shared-experts' into layer…
bnellnm May 4, 2026
fa64b76
fix aria
bnellnm May 4, 2026
5f996a4
fix lint
bnellnm May 4, 2026
09a18e2
remove debugging code
bnellnm May 5, 2026
427aa7b
Merge branch 'main' into pass-shared-experts
bnellnm May 5, 2026
d5e2481
Merge remote-tracking branch 'origin/main' into pass-shared-experts
bnellnm May 5, 2026
e3ab482
Merge remote-tracking branch 'nm-vllm/pass-shared-experts' into pass-…
bnellnm May 5, 2026
eeeff4a
fixes
bnellnm May 5, 2026
7650918
Merge remote-tracking branch 'nm-vllm/pass-shared-experts' into layer…
bnellnm May 5, 2026
960a842
fix merge
bnellnm May 5, 2026
90c74a8
move mapping fn back to FusedMoE
bnellnm May 5, 2026
2bc4adc
Merge remote-tracking branch 'origin/main' into eplb-manager
bnellnm May 5, 2026
eab0423
cleanups
bnellnm May 5, 2026
26ffc77
fix dbrx
bnellnm May 5, 2026
b33b31e
review comments
bnellnm May 5, 2026
2a686f3
fix
bnellnm May 5, 2026
fd5192e
Merge remote-tracking branch 'nm-vllm/pass-shared-experts' into layer…
bnellnm May 5, 2026
893dfff
Merge remote-tracking branch 'origin/main' into pass-shared-experts
bnellnm May 6, 2026
0780907
review comments + redo stuff
bnellnm May 6, 2026
c1a332c
review comments
bnellnm May 6, 2026
4a3d996
cleanup routing table initialization and updating
bnellnm May 6, 2026
f96b5cb
Merge remote-tracking branch 'nm-vllm/expert-map-manager' into expert…
bnellnm May 6, 2026
c038c7e
Merge remote-tracking branch 'origin/main' into expert-map-manager
bnellnm May 6, 2026
778c141
fix local_num_experts
bnellnm May 6, 2026
fefd17a
Merge remote-tracking branch 'nm-vllm/pass-shared-experts' into layer…
bnellnm May 6, 2026
9be870e
Merge remote-tracking branch 'nm-vllm/eplb-manager' into layer-refactor
bnellnm May 6, 2026
e69a213
fix device stuff
bnellnm May 6, 2026
3c21f32
tweak
bnellnm May 6, 2026
2dd2ea9
try to fix update_expert_map
bnellnm May 6, 2026
3fafc35
Merge remote-tracking branch 'nm-vllm/expert-map-manager' into layer-…
bnellnm May 6, 2026
979dd65
remove unused arg
bnellnm May 6, 2026
56d8738
update comment
bnellnm May 6, 2026
16b43a5
cleanup
bnellnm May 6, 2026
4cd8a2c
Merge remote-tracking branch 'nm-vllm/expert-map-manager' into layer-…
bnellnm May 6, 2026
34652de
fix test
bnellnm May 7, 2026
f9806c1
fix lint
bnellnm May 7, 2026
dd14e4e
fix lora
bnellnm May 7, 2026
27e0f6b
fix lint
bnellnm May 7, 2026
121517c
fix doc
bnellnm May 7, 2026
446e805
move state around
bnellnm May 7, 2026
bd1b8cc
update test
bnellnm May 7, 2026
b1573e4
fix lint
bnellnm May 7, 2026
d2bc7df
fix lint
bnellnm May 7, 2026
408f217
Merge remote-tracking branch 'origin/main' into fix-capture
bnellnm May 11, 2026
64ccbf8
use layer_name instead of layer_id for map
bnellnm May 11, 2026
3de6da7
update doc
bnellnm May 11, 2026
3c52f2c
remove debug print
bnellnm May 11, 2026
dcfd80c
Merge branch 'main' into fix-capture
bnellnm May 12, 2026
95e09a0
Merge remote-tracking branch 'nm-vllm/fix-capture' into layer-refactor
bnellnm May 12, 2026
e305608
Merge remote-tracking branch 'origin/main' into layer-refactor
bnellnm May 12, 2026
0430804
Merge remote-tracking branch 'origin/main' into layer-refactor
bnellnm May 12, 2026
474bfc4
fix merge
bnellnm May 12, 2026
53884ac
cleaner fix for unpadded output
bnellnm May 12, 2026
37b076b
fix
bnellnm May 12, 2026
57fec3a
Merge remote-tracking branch 'origin/main' into layer-refactor
bnellnm May 19, 2026
3e492d6
Merge remote-tracking branch 'nm-vllm/layer-refactor' into layer-refa…
bnellnm May 19, 2026
535d224
lora + misc merge fixes
bnellnm May 20, 2026
3882ec0
Merge remote-tracking branch 'origin/main' into layer-refactor
bnellnm May 20, 2026
bdae82e
merge/lint fix
bnellnm May 20, 2026
1959f6d
update FusedMoE comment
bnellnm May 20, 2026
9028c9a
add missing type annotation
bnellnm May 21, 2026
68a8a72
Merge remote-tracking branch 'origin/main' into layer-refactor
bnellnm May 28, 2026
0c7b074
cleanups
bnellnm May 28, 2026
dd8107f
fix
bnellnm May 29, 2026
1f26c1b
claude fix for test_gps_oss_tp2
bnellnm May 29, 2026
d3832f9
padding fix for cudagraph + test_gpt_oss_tp2
bnellnm May 29, 2026
ec79e2e
Merge remote-tracking branch 'origin/main' into layer-refactor
bnellnm May 29, 2026
fdcab37
fix renamed func
bnellnm May 29, 2026
8342bae
Merge remote-tracking branch 'origin/main' into layer-refactor
bnellnm May 29, 2026
66800ed
cleanups + fixes
bnellnm May 29, 2026
df1a743
make extra parameters explicit + test cleanups
bnellnm May 29, 2026
6c4e0a6
cleaner truncation handling
bnellnm May 30, 2026
9d98e3f
add comment
bnellnm May 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def run():
num_experts=num_experts,
experts_per_token=topk,
hidden_dim=hidden_size,
intermediate_size_per_partition=shard_intermediate_size,
intermediate_size=shard_intermediate_size,
num_local_experts=num_experts,
num_logical_experts=num_experts,
activation=MoEActivation.SILU,
Expand Down
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Modular kernels are supported by the following `FusedMoEMethodBase` classes.
- [`CompressedTensorsW4A4Nvfp4MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w4a4_nvfp4.CompressedTensorsW4A4Nvfp4MoEMethod]
- [`CompressedTensorsW8A8Fp8MoEMethod`][vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.compressed_tensors_moe_w8a8_fp8.CompressedTensorsW8A8Fp8MoEMethod]
- [`GptOssMxfp4MoEMethod`][vllm.model_executor.layers.quantization.mxfp4.GptOssMxfp4MoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod]
- [`UnquantizedFusedMoEMethod`][vllm.model_executor.layers.fused_moe.UnquantizedFusedMoEMethod]

## Fused Experts Kernels

Expand Down
38 changes: 25 additions & 13 deletions tests/distributed/test_eplb_fused_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
import torch

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_eplb_group,
get_tp_group,
)
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
Expand Down Expand Up @@ -75,9 +77,9 @@ def make_fused_moe_layer(
intermediate_size=test_config.intermediate_size,
prefix=f"dummy_layer_{layer_idx}",
activation="silu",
is_act_and_mul=True,
params_dtype=test_config.weight_dtype,
)
re = fml.routed_experts

device = torch.device(f"cuda:{rank}")

Expand All @@ -90,12 +92,12 @@ def make_fused_moe_layer(
tensor_device=device,
)

assert isinstance(fml.w13_weight.data, torch.Tensor)
assert isinstance(fml.w2_weight.data, torch.Tensor)
fml.w13_weight.data = fml.w13_weight.data.to(device=device)
fml.w2_weight.data = fml.w2_weight.data.to(device=device)
w13_weight = fml.w13_weight.data
w2_weight = fml.w2_weight.data
assert isinstance(re.w13_weight.data, torch.Tensor)
assert isinstance(re.w2_weight.data, torch.Tensor)
re.w13_weight.data = re.w13_weight.data.to(device=device)
re.w2_weight.data = re.w2_weight.data.to(device=device)
w13_weight = re.w13_weight.data
w2_weight = re.w2_weight.data
assert w13_weight.size(0) == test_config.num_local_experts
for i in range(test_config.num_local_experts):
g_i = rank * test_config.num_local_experts + i
Expand Down Expand Up @@ -170,10 +172,10 @@ def block_quant_scales_shape(
assert not w2_weight_scale_inv.is_contiguous()

# Add scales to the parameter list
fml.w13_weight_scale_inv = torch.nn.Parameter(
re.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale_inv, requires_grad=False
)
fml.w2_weight_scale_inv = torch.nn.Parameter(
re.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale_inv, requires_grad=False
)

Expand All @@ -185,9 +187,12 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
# to expert parallel)
set_env_vars_and_device(env)

vllm_config = VllmConfig()
vllm_config.parallel_config.tensor_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True
parallel_config = ParallelConfig(
tensor_parallel_size=world_size,
enable_expert_parallel=True,
enable_eplb=True,
)
vllm_config = VllmConfig(parallel_config=parallel_config)

with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
Expand All @@ -213,12 +218,19 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)

communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=vllm_config.parallel_config.eplb_config.communicator,
expert_weights=rank_expert_weights[0],
)

rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
ep_group,
is_profile=False,
communicator=communicator,
)

num_local_experts = test_config.num_local_experts
Expand Down
48 changes: 30 additions & 18 deletions tests/distributed/test_eplb_fused_moe_layer_dep_nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import torch

from tests.kernels.moe.utils import make_test_quant_config
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed.eplb.eplb_communicator import create_eplb_communicator
from vllm.distributed.eplb.eplb_state import EplbLayerState
from vllm.distributed.eplb.rebalance_execute import rearrange_expert_weights_inplace
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized,
get_dp_group,
get_eplb_group,
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
Expand Down Expand Up @@ -59,7 +61,6 @@ def make_fused_moe_layer(
intermediate_size=test_config.intermediate_size,
prefix=f"dummy_layer_{layer_idx}",
activation="silu",
is_act_and_mul=True,
params_dtype=torch.bfloat16,
quant_config=quant_config,
)
Expand All @@ -85,21 +86,22 @@ def make_fused_moe_layer(
per_act_token_quant=False,
)

fml.w13_weight.data = w1_q
fml.w2_weight.data = w2_q
re = fml.routed_experts
re.w13_weight.data = w1_q
re.w2_weight.data = w2_q

fml.w2_input_scale.data = torch.randn_like(fml.w2_input_scale.data) / 5
fml.w13_input_scale.data = torch.randn_like(fml.w13_input_scale.data) / 5
fml.w2_weight_scale_2.data = torch.randn_like(fml.w2_weight_scale_2.data) / 5
fml.w13_weight_scale_2.data = torch.randn_like(fml.w13_weight_scale_2.data) / 5
fml.w2_weight_scale.data = (
torch.randn(fml.w2_weight_scale.data.shape, device=device) / 5
).to(fml.w2_weight_scale.data.dtype)
fml.w13_weight_scale.data = (
torch.randn(fml.w13_weight_scale.data.shape, device=device) / 5
).to(fml.w13_weight_scale.data.dtype)
re.w2_input_scale.data = torch.randn_like(re.w2_input_scale.data) / 5
re.w13_input_scale.data = torch.randn_like(re.w13_input_scale.data) / 5
re.w2_weight_scale_2.data = torch.randn_like(re.w2_weight_scale_2.data) / 5
re.w13_weight_scale_2.data = torch.randn_like(re.w13_weight_scale_2.data) / 5
re.w2_weight_scale.data = (
torch.randn(re.w2_weight_scale.data.shape, device=device) / 5
).to(re.w2_weight_scale.data.dtype)
re.w13_weight_scale.data = (
torch.randn(re.w13_weight_scale.data.shape, device=device) / 5
).to(re.w13_weight_scale.data.dtype)

nvfp4_fused_moe.process_weights_after_loading(fml)
nvfp4_fused_moe.process_weights_after_loading(fml.routed_experts)

fml.maybe_init_modular_kernel()

Expand All @@ -109,9 +111,12 @@ def make_fused_moe_layer(
def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
set_env_vars_and_device(env)

vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True
parallel_config = ParallelConfig(
data_parallel_size=world_size,
enable_expert_parallel=True,
enable_eplb=True,
)
vllm_config = VllmConfig(parallel_config=parallel_config)

with set_current_vllm_config(vllm_config):
ensure_model_parallel_initialized(
Expand Down Expand Up @@ -171,12 +176,19 @@ def _test_eplb_fml(env, world_size: int, test_config: TestConfig):
for lidx in range(test_config.num_layers):
shuffled_indices[lidx] = torch.randperm(test_config.num_experts)

communicator = create_eplb_communicator(
group_coordinator=get_eplb_group(),
backend=vllm_config.parallel_config.eplb_config.communicator,
expert_weights=rank_expert_weights[0],
)

rearrange_expert_weights_inplace(
indices,
shuffled_indices,
rank_expert_weights,
ep_group,
is_profile=False,
communicator=communicator,
)

num_global_experts = test_config.num_experts
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/modular_kernel_tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def make_modular_kernel(
num_experts=config.E,
experts_per_token=config.topk,
hidden_dim=config.K,
intermediate_size_per_partition=config.N,
intermediate_size=config.N,
num_local_experts=config.num_local_experts,
num_logical_experts=config.E,
moe_parallel_config=moe_parallel_config,
Expand Down
2 changes: 2 additions & 0 deletions tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _worker_parallel_launch(
rank = node_rank * world_local_size + local_rank
device = torch.device("cuda", local_rank)
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
Expand Down Expand Up @@ -116,6 +117,7 @@ def _worker_parallel_launch(
traceback.print_exc()
raise
finally:
torch.accelerator.synchronize()
if vllm_config is not None:
cleanup_dist_env_and_memory()
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/moe/test_cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def slice_experts():
moe_config = make_dummy_moe_config(
num_experts=w2.shape[0],
hidden_dim=w2.shape[1],
intermediate_size_per_partition=w2.shape[2],
intermediate_size=w2.shape[2],
in_dtype=a.dtype,
)
kernel = mk.FusedMoEKernel(
Expand Down Expand Up @@ -269,7 +269,7 @@ def run_8_bit(
moe_config = make_dummy_moe_config(
num_experts=moe_tensors.w2_q.shape[0], # type: ignore[union-attr]
hidden_dim=moe_tensors.w2_q.shape[1], # type: ignore[union-attr]
intermediate_size_per_partition=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
intermediate_size=moe_tensors.w2_q.shape[2], # type: ignore[union-attr]
in_dtype=moe_tensors.a.dtype,
)
kernel = mk.FusedMoEKernel(
Expand Down
6 changes: 2 additions & 4 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,11 @@ def make_moe_tensors_8bit(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
intermediate_size=n,
num_local_experts=e,
num_logical_experts=e,
moe_parallel_config=layer.moe_parallel_config,
in_dtype=hidden_states.dtype,
is_act_and_mul=is_gated,
routing_method=layer.routing_method_type,
activation=activation,
device=w13_quantized.device,
Expand Down Expand Up @@ -339,14 +338,13 @@ def get_fused_moe_quant_config(n: torch.nn.Module) -> FusedMoEQuantConfig:
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
intermediate_size=n,
num_local_experts=e,
num_logical_experts=e,
activation=activation,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
is_act_and_mul=activation.is_gated,
routing_method=RoutingMethodType.TopK,
max_num_tokens=next_power_of_2(m),
)
Expand Down
2 changes: 1 addition & 1 deletion tests/kernels/moe/test_flashinfer_b12x_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_flashinfer_b12x_moe(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
intermediate_size=n,
in_dtype=dtype,
)

Expand Down
3 changes: 1 addition & 2 deletions tests/kernels/moe/test_flashinfer_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,13 @@ def test_flashinfer_fp4_moe_no_graph(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
intermediate_size=n,
num_local_experts=e,
num_logical_experts=e,
activation=activation,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype,
is_act_and_mul=is_gated_act,
routing_method=RoutingMethodType.TopK,
max_num_tokens=next_power_of_2(m),
)
Expand Down
3 changes: 1 addition & 2 deletions tests/kernels/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,14 +1617,13 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
num_experts=e,
experts_per_token=topk,
hidden_dim=k,
intermediate_size_per_partition=n,
intermediate_size=n,
num_local_experts=e,
num_logical_experts=e,
activation=MoEActivation.SILU,
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=dtype,
is_act_and_mul=True,
routing_method=RoutingMethodType.Renormalize,
max_num_tokens=next_power_of_2(m),
)
Expand Down
Loading
Loading