Skip to content

Commit 69b8f27

Browse files
[mxfp8 moe training] mxfp8 all_to_all_vdev_2d kernel
1 parent d2fae7a commit 69b8f27

File tree

6 files changed

+823
-1
lines changed

6 files changed

+823
-1
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import torch
2+
import torch.distributed as dist
3+
import torch.distributed._symmetric_memory as symm_mem
4+
from torch.distributed._functional_collectives import (
5+
all_to_all_single_autograd,
6+
)
7+
from torch.nn import functional as F
8+
from torch.testing._internal.common_distributed import (
9+
MultiProcessTestCase,
10+
)
11+
from torch.testing._internal.common_utils import (
12+
instantiate_parametrized_tests,
13+
run_tests,
14+
)
15+
16+
from torchao.float8.float8_utils import (
17+
compute_error,
18+
)
19+
from torchao.prototype.moe_training.kernels.mxfp8.comms import (
20+
mxfp8_on_device_all_to_all_v,
21+
)
22+
23+
24+
@instantiate_parametrized_tests
25+
class TritonAllReduceTest(MultiProcessTestCase):
26+
def setUp(self) -> None:
27+
super().setUp()
28+
self._spawn_processes()
29+
30+
@property
31+
def world_size(self) -> int:
32+
return 2
33+
34+
@property
35+
def device(self) -> torch.device:
36+
return torch.device(f"cuda:{self.rank}")
37+
38+
def _init_process(self):
39+
torch.cuda.set_device(self.device)
40+
store = dist.FileStore(self.file_name, self.world_size)
41+
dist.init_process_group(
42+
backend="nccl",
43+
world_size=self.world_size,
44+
rank=self.rank,
45+
store=store,
46+
)
47+
torch.manual_seed(42 + self.rank)
48+
49+
def _init_device(self):
50+
symm_mem.set_backend("NVSHMEM")
51+
52+
def test_a2a_fwd_bwd(self):
53+
self._init_process()
54+
try:
55+
torch.manual_seed(42 + self.rank)
56+
self._init_device()
57+
58+
group_name = dist.group.WORLD.group_name
59+
symm_mem.enable_symm_mem_for_group(group_name)
60+
61+
tokens_per_ep_rank = 8192
62+
dim = 2048
63+
input_tensor = torch.randn(
64+
tokens_per_ep_rank,
65+
dim,
66+
device=self.device,
67+
dtype=torch.float32,
68+
requires_grad=True,
69+
)
70+
ref_input_tensor = input_tensor.detach().clone().requires_grad_(True)
71+
72+
# Generate random input splits that sum to tokens_per_ep_rank
73+
num_splits = self.world_size
74+
input_splits = generate_split_sizes(
75+
num_splits, tokens_per_ep_rank, self.device
76+
)
77+
78+
# Max output tokens per rank is worst case where one rank receives all tokens
79+
max_output_tokens_per_rank = tokens_per_ep_rank * self.world_size
80+
81+
# Test forward
82+
output, output_splits = mxfp8_on_device_all_to_all_v(
83+
input_tensor,
84+
input_splits,
85+
max_output_tokens_per_rank,
86+
group_name,
87+
)
88+
89+
# Reference torch.all_to_all_single to compare against
90+
output_splits_ref = torch.empty_like(output_splits)
91+
92+
# Compute output splits from input splits
93+
dist.all_to_all_single(output_splits_ref, input_splits)
94+
95+
# Pre-allocate output buffer for reference a2a
96+
total_tokens_on_rank_after_a2a = output_splits_ref.sum()
97+
ref_output = torch.empty(
98+
total_tokens_on_rank_after_a2a,
99+
dim,
100+
device=self.device,
101+
dtype=torch.float32,
102+
)
103+
104+
# Do the actual all_to_all_single
105+
ref_output = all_to_all_single_autograd(
106+
ref_input_tensor,
107+
output_splits_ref.tolist(),
108+
input_splits.tolist(),
109+
dist.group.WORLD,
110+
)
111+
112+
# Compare output
113+
assert torch.equal(output_splits, output_splits_ref), (
114+
"output_splits mismatch"
115+
)
116+
out_no_padding = output[:total_tokens_on_rank_after_a2a]
117+
sqnr = compute_error(ref_output, out_no_padding)
118+
min_sqnr = 30.0
119+
assert sqnr > min_sqnr, f"sqnr={sqnr} is less than min_sqnr={min_sqnr}"
120+
121+
# Test backwards
122+
labels = torch.ones_like(out_no_padding)
123+
loss = F.mse_loss(out_no_padding, labels)
124+
ref_loss = F.mse_loss(ref_output, labels)
125+
loss.backward()
126+
ref_loss.backward()
127+
128+
# Compare grads
129+
grad_sqnr = compute_error(ref_input_tensor.grad, input_tensor.grad)
130+
min_grad_sqnr = 28.0
131+
assert grad_sqnr > min_grad_sqnr, (
132+
f"grad_sqnr={grad_sqnr} is less than min_grad_sqnr={min_grad_sqnr}"
133+
)
134+
135+
finally:
136+
dist.destroy_process_group()
137+
138+
139+
def generate_split_sizes(K: int, N: int, device: str = "cpu") -> torch.Tensor:
140+
"""
141+
Generates a tensor of K random non-negative integers that sum to N.
142+
"""
143+
if K <= 0:
144+
raise ValueError("K must be a positive integer.")
145+
if N < 0:
146+
raise ValueError("N must be a non-negative integer.")
147+
148+
if K == 1:
149+
return torch.tensor([N], dtype=torch.long, device=device)
150+
151+
# Generate K-1 random "dividers" in the range [0, N].
152+
dividers = torch.randint(0, N + 1, (K - 1,), device=device)
153+
154+
# Add 0 and N to the set of dividers to form the boundaries.
155+
boundaries = torch.cat(
156+
[torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)]
157+
)
158+
159+
# Sort the boundaries to ensure they are in order
160+
sorted_boundaries = torch.sort(boundaries).values
161+
162+
# The K integers are the differences between consecutive boundaries (will sum to N)
163+
result = sorted_boundaries[1:] - sorted_boundaries[:-1]
164+
165+
return result.to(dtype=torch.int64)
166+
167+
168+
if __name__ == "__main__":
169+
run_tests()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from torchao.prototype.moe_training.kernels.mxfp8.quant import (
2+
compute_blocked_scale_offsets_for_K_groups, # noqa: F401
3+
compute_blocked_scale_offsets_for_M_groups, # noqa: F401
4+
mxfp8_quantize_cuda_3d, # noqa: F401
5+
torch_to_blocked_2d_K_groups, # noqa: F401
6+
torch_to_blocked_2d_M_groups, # noqa: F401
7+
torch_to_blocked_per_group_3d, # noqa: F401
8+
triton_mx_block_rearrange_2d_K_groups, # noqa: F401
9+
triton_mx_block_rearrange_2d_M_groups, # noqa: F401
10+
triton_mx_block_rearrange_per_group_3d, # noqa: F401
11+
)

0 commit comments

Comments
 (0)