Skip to content

Commit 2b69b77

Browse files
[mxfp8 moe training] mxfp8 all_to_all_vdev_2d kernel
1 parent d2fae7a commit 2b69b77

File tree

4 files changed

+716
-0
lines changed

4 files changed

+716
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch.distributed as dist
3+
import torch.distributed._symmetric_memory as symm_mem
4+
from torch.testing._internal.common_distributed import (
5+
MultiProcessTestCase,
6+
skip_if_lt_x_gpu,
7+
)
8+
from torch.testing._internal.common_utils import (
9+
instantiate_parametrized_tests,
10+
run_tests,
11+
)
12+
13+
14+
@instantiate_parametrized_tests
15+
class TritonAllReduceTest(MultiProcessTestCase):
16+
def setUp(self) -> None:
17+
super().setUp()
18+
self._spawn_processes()
19+
20+
@property
21+
def world_size(self) -> int:
22+
return 2
23+
24+
@property
25+
def device(self) -> torch.device:
26+
return torch.device(f"cuda:{self.rank}")
27+
28+
def _init_process(self):
29+
torch.cuda.set_device(self.device)
30+
store = dist.FileStore(self.file_name, self.world_size)
31+
dist.init_process_group(
32+
backend="nccl",
33+
world_size=self.world_size,
34+
rank=self.rank,
35+
store=store,
36+
)
37+
torch.manual_seed(42 + self.rank)
38+
39+
def _init_device(self):
40+
symm_mem.set_backend("NVSHMEM")
41+
42+
@skip_if_lt_x_gpu(4)
43+
def test_a2a(self):
44+
self._init_process()
45+
try:
46+
torch.manual_seed(42 + self.rank)
47+
self._init_device()
48+
49+
group_name = dist.group.WORLD.group_name
50+
symm_mem.enable_symm_mem_for_group(group_name)
51+
52+
experts_per_rank = 2
53+
nsplits = experts_per_rank * self.world_size
54+
55+
# Number of elements for an expert is random between [0, k)
56+
k = 10
57+
inp_splits = torch.randint(
58+
k, (nsplits,), dtype=torch.int64, device=self.device
59+
)
60+
61+
# TODO
62+
63+
finally:
64+
dist.destroy_process_group()
65+
66+
67+
if __name__ == "__main__":
68+
run_tests()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2+
compute_blocked_scale_offsets_for_K_groups, # noqa: F401
3+
compute_blocked_scale_offsets_for_M_groups, # noqa: F401
4+
mxfp8_quantize_cuda_3d, # noqa: F401
5+
torch_to_blocked_2d_K_groups, # noqa: F401
6+
torch_to_blocked_2d_M_groups, # noqa: F401
7+
torch_to_blocked_per_group_3d, # noqa: F401
8+
triton_mx_block_rearrange_2d_K_groups, # noqa: F401
9+
triton_mx_block_rearrange_2d_M_groups, # noqa: F401
10+
triton_mx_block_rearrange_per_group_3d, # noqa: F401
11+
)

0 commit comments

Comments
 (0)