Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,3 +1041,30 @@ def load_weights(self, weights: List[Dict]):

def post_load_weights(self):
self.quant_method.post_load_weights(self)

def forward_fake(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
*,
do_finalize: bool = True,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
**kwargs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
moe_output = super().forward_fake(
x,
router_logits,
do_finalize=do_finalize,
output_dtype=torch.bfloat16,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
**kwargs)
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
shape = moe_output.shape
top_k = self.routing_method.experts_per_token
new_shape = [shape[0], top_k, shape[1]]
return moe_output.new_empty(new_shape)
else:
return moe_output
69 changes: 34 additions & 35 deletions tests/unittest/_torch/modules/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
assert r is None


@pytest.mark.skip(reason="https://nvbugs/5467531")
@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="needs 4 GPUs to run this test")
@pytest.mark.parametrize("alltoall_method_type", [
Expand All @@ -304,7 +303,7 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):

world_size = 4
dtype = torch.bfloat16
HIDDEN_SIZE = 2560
HIDDEN_SIZE = 4096
INTERMEDIATE_SIZE = 1536
NUM_EXPERTS = 72
TOP_K = 6
Expand All @@ -320,8 +319,8 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
x_list = []
m = MAX_NUM_TOKENS
while m >= 1:
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype, device="cuda")
x_list.append(x.cuda(i))
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype)
x_list.append(x)
m //= 2

x_abs_max = torch.cat([x.flatten() for x in x_list]).abs().max().float()
Expand Down Expand Up @@ -366,49 +365,37 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))

w1_input_scale = x_sf_global.cuda(i)
w2_input_scale = x_sf_global.cuda(i)
w3_input_scale = x_sf_global.cuda(i)
weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4.cpu()
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4.cpu()
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4.cpu()
weights[f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled
weights[f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled
weights[f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled

weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4.cuda(i)
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4.cuda(i)
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4.cuda(i)
weights[
f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.cuda(i)
weights[
f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.cuda(i)
weights[
f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.cuda(i)

weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale.cuda(
i)
weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale.cuda(
i)
weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale.cuda(
i)
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global.cuda(
i)
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global.cuda(
i)
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global.cuda(
i)
weights[f"{expert_id}.w1.input_scale"] = 1.0 / x_sf_global
weights[f"{expert_id}.w2.input_scale"] = 1.0 / x_sf_global
weights[f"{expert_id}.w3.input_scale"] = 1.0 / x_sf_global
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global.cpu()
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global.cpu()
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global.cpu()

x_list_world.append(x_list)
weights_world.append(weights)
torch.cuda.synchronize()

def per_rank_test_fused_moe_alltoall(job_id):
def per_rank_test_fused_moe_alltoall(job_id, weights, x_list):
routing_method = DefaultMoeRoutingMethod(top_k=TOP_K)
mapping = Mapping(world_size=world_size,
rank=mpi_rank(),
rank=job_id,
tp_size=world_size,
moe_ep_size=world_size,
moe_tp_size=1,
enable_attention_dp=True)
torch.cuda.set_device(mapping.rank)
torch.manual_seed(mapping.rank)

x_list = x_list_world[mapping.rank]
weights = weights_world[mapping.rank]
weights = {k: v.cuda() for k, v in weights.items()}
x_list = [x.cuda() for x in x_list]

quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4)
with mock.patch.object(WideEPMoE,
Expand Down Expand Up @@ -459,6 +446,16 @@ def per_rank_test_fused_moe_alltoall(job_id):
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
# Verify the fake impl is correct.
output_fake = alltoall_model.forward_fake(
x,
router_logits,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=False)
assert output_fake.shape == output.shape
assert output_fake.dtype == output.dtype
if len(output.shape) == 3:
output = torch.sum(output, dim=1, keepdim=False)
ref_output = ref_model.forward(
x,
router_logits,
Expand All @@ -470,8 +467,10 @@ def per_rank_test_fused_moe_alltoall(job_id):
m //= 2

with MPIPoolExecutor(max_workers=world_size) as executor:
results = executor.map(per_rank_test_fused_moe_alltoall,
range(world_size))
results = executor.map(
per_rank_test_fused_moe_alltoall,
*zip(*[(i, weights_world[i], x_list_world[i])
for i in range(world_size)]))
for r in results:
assert r is None

Expand Down