Skip to content

Commit d7c1e52

Browse files
[mxfp8 moe training] mxfp8 all_to_all_vdev_2d kernel
1 parent d2fae7a commit d7c1e52

File tree

4 files changed

+754
-0
lines changed

4 files changed

+754
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import torch
2+
import torch.distributed as dist
3+
import torch.distributed._symmetric_memory as symm_mem
4+
from torch.testing._internal.common_distributed import (
5+
MultiProcessTestCase,
6+
skip_if_lt_x_gpu,
7+
)
8+
from torch.testing._internal.common_utils import (
9+
instantiate_parametrized_tests,
10+
run_tests,
11+
)
12+
from torchao.prototype.moe_training.kernels.mxfp8.comms import mxfp8_on_device_all_to_all_v
13+
14+
15+
@instantiate_parametrized_tests
16+
class TritonAllReduceTest(MultiProcessTestCase):
17+
def setUp(self) -> None:
18+
super().setUp()
19+
self._spawn_processes()
20+
21+
@property
22+
def world_size(self) -> int:
23+
return 2
24+
25+
@property
26+
def device(self) -> torch.device:
27+
return torch.device(f"cuda:{self.rank}")
28+
29+
def _init_process(self):
30+
torch.cuda.set_device(self.device)
31+
store = dist.FileStore(self.file_name, self.world_size)
32+
dist.init_process_group(
33+
backend="nccl",
34+
world_size=self.world_size,
35+
rank=self.rank,
36+
store=store,
37+
)
38+
torch.manual_seed(42 + self.rank)
39+
40+
def _init_device(self):
41+
symm_mem.set_backend("NVSHMEM")
42+
43+
@skip_if_lt_x_gpu(4)
44+
def test_a2a_fwd_bwd(self):
45+
self._init_process()
46+
try:
47+
torch.manual_seed(42 + self.rank)
48+
self._init_device()
49+
50+
group_name = dist.group.WORLD.group_name
51+
symm_mem.enable_symm_mem_for_group(group_name)
52+
53+
experts_per_rank = 2
54+
num_splits = experts_per_rank * self.world_size
55+
56+
# Number of elements for an expert is random between [0, k)
57+
tokens_per_ep_rank = 1024
58+
dim = 2048
59+
input_tensor = torch.randn(tokens_per_ep_rank, dim, device=self.device, dtype=torch.bfloat16)
60+
input_splits = torch.randint(
61+
tokens_per_ep_rank, (num_splits,), dtype=torch.int64, device=self.device
62+
)
63+
64+
max_output_len_per_rank = tokens_per_ep_rank # Alias for clarity
65+
66+
# Test forward
67+
output, output_splits = mxfp8_on_device_all_to_all_v(
68+
input_tensor,
69+
input_splits,
70+
max_output_len_per_rank,
71+
group_name,
72+
)
73+
74+
finally:
75+
dist.destroy_process_group()
76+
77+
78+
if __name__ == "__main__":
79+
run_tests()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2+
compute_blocked_scale_offsets_for_K_groups, # noqa: F401
3+
compute_blocked_scale_offsets_for_M_groups, # noqa: F401
4+
mxfp8_quantize_cuda_3d, # noqa: F401
5+
torch_to_blocked_2d_K_groups, # noqa: F401
6+
torch_to_blocked_2d_M_groups, # noqa: F401
7+
torch_to_blocked_per_group_3d, # noqa: F401
8+
triton_mx_block_rearrange_2d_K_groups, # noqa: F401
9+
triton_mx_block_rearrange_2d_M_groups, # noqa: F401
10+
triton_mx_block_rearrange_per_group_3d, # noqa: F401
11+
)

0 commit comments

Comments
 (0)