diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py new file mode 100644 index 0000000000..9b4e03cc49 --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +###################################################################### +# +# To run these benchmarks, use the following command: +# +# torchrun --nproc-per-node=8 --local-ranks-filter=0 benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py +# +####################################################################### +import os +import time +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from torch import distributed as dist +from torch.distributed._functional_collectives import ( + all_to_all_single_autograd, +) +from tqdm import tqdm + +from torchao.prototype.moe_training.kernels.mxfp8.comms import ( + mxfp8_on_device_all_to_all_v, +) + +device = torch.device("cuda") + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_us: float + mxfp8_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # (batch_size, seq_len, dim) + input_shapes = [ + (8, 8192, 5120), + ] + configs = [] + for shape in input_shapes: + configs.append( + ExperimentConfig( + input_shape=shape, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + batch_size, seq_len, dim = config.input_shape + x = torch.randn( + (batch_size * seq_len, dim), + dtype=torch.bfloat16, + device=device, + ) + ref_x = x.detach().clone() + + # Max output tokens per rank is worst case where one rank receives all tokens + input_tokens_per_rank = batch_size * seq_len + max_output_tokens_per_rank = input_tokens_per_rank * dist.get_world_size() + + def using_bf16( + input_tensor: torch.Tensor, input_splits: torch.Tensor + ) -> torch.Tensor: + # Calculate output splits from input splits + output_splits = torch.empty_like(input_splits) + dist.all_to_all_single(output_splits, input_splits) + + # Perform all-to-all + out = all_to_all_single_autograd( + input_tensor, + output_splits.tolist(), + input_splits.tolist(), + dist.group.WORLD, + ) + out = torch.ops._c10d_functional.wait_tensor(out) + return out + + def using_mxfp8( + input_tensor: torch.Tensor, input_splits: torch.Tensor + ) -> torch.Tensor: + output, output_splits = mxfp8_on_device_all_to_all_v( + input_tensor, + input_splits, + max_output_tokens_per_rank, + dist.group.WORLD.group_name, + ) + output = torch.ops._c10d_functional.wait_tensor(output) + output_splits = torch.ops._c10d_functional.wait_tensor(output_splits) + return output + + def warmup(func_no_args): + for _ in range(2): + func_no_args() + + num_splits = dist.get_world_size() + input_splits = generate_split_sizes( + num_splits, input_tokens_per_rank, device=device + ) + + print( + "Benchmarking using bf16", + "batch_size", + batch_size, + "seq_len", + seq_len, + "dim", + dim, + "input_tokens_per_rank", + input_tokens_per_rank, + "max_output_tokens_per_rank", + max_output_tokens_per_rank, + ) + warmup(lambda: using_bf16(ref_x, input_splits)) + start_ns = time.perf_counter() + using_bf16(ref_x, input_splits) + end_ns = time.perf_counter() + bf16_us = (end_ns - start_ns) * 1e6 + + print( + "Benchmarking using_mxfp8", + "batch_size", + batch_size, + "seq_len", + seq_len, + "dim", + dim, + "input_tokens_per_rank", + input_tokens_per_rank, + "max_output_tokens_per_rank", + max_output_tokens_per_rank, + ) + warmup(lambda: using_mxfp8(x, input_splits)) + start_ns = time.perf_counter() + using_mxfp8(x, input_splits) + end_ns = time.perf_counter() + mxfp8_us = (end_ns - start_ns) * 1e6 + + return ExperimentResult( + bf16_us=bf16_us, + mxfp8_us=mxfp8_us, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape", + "num_splits", + "bf16_us", + "mxfp8_us", + ] + rows = [] + num_splits = dist.get_world_size() + for experiment in experiments: + rows.append( + [ + str(experiment.config.input_shape), + num_splits, + experiment.result.bf16_us, + experiment.result.mxfp8_us, + ] + ) + print(tabulate(rows, headers=headers)) + + +def generate_split_sizes(K: int, N: int, device: str = "cuda") -> torch.Tensor: + """ + Generates a tensor of K random non-negative integers that sum to N. + Used for testing mxfp8_all_to_all_v implementation. + """ + if K <= 0: + raise ValueError("K must be a positive integer.") + if N < 0: + raise ValueError("N must be a non-negative integer.") + + if K == 1: + return torch.tensor([N], dtype=torch.long, device=device) + + # Generate K-1 random "dividers" in the range [0, N]. + dividers = torch.randint(0, N + 1, (K - 1,), device=device) + + # Add 0 and N to the set of dividers to form the boundaries. + boundaries = torch.cat( + [torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)] + ) + + # Sort the boundaries to ensure they are in order + sorted_boundaries = torch.sort(boundaries).values + + # The K integers are the differences between consecutive boundaries (will sum to N) + result = sorted_boundaries[1:] - sorted_boundaries[:-1] + + return result.to(dtype=torch.int64) + + +def main(): + torch.random.manual_seed(123) + + # Set up process group + setup_distributed() + + # Generate experiment configs + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + # Clean up process group + dist.destroy_process_group() + + +def setup_distributed(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/utils.py b/benchmarks/utils.py index c59142d571..2467534a32 100644 --- a/benchmarks/utils.py +++ b/benchmarks/utils.py @@ -72,5 +72,27 @@ def profile_fwd_bwd( print(f"Saved: {profile_name}.json") +def profile_fn(fn, *args, profile_name="profile", **kwargs): + wait, warmup, active = 1, 1, 1 + total_steps = wait + warmup + active + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=wait, warmup=warmup, active=active, repeat=0 + ), + record_shapes=True, + ) as prof: + for _ in range(total_steps): + _ = fn(*args, **kwargs) + prof.step() + + # Save profiler results + prof.export_chrome_trace(f"{profile_name}.json") + print(f"Saved: {profile_name}.json") + + def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 diff --git a/test/prototype/moe_training/mxfp8/__init__.py b/test/prototype/moe_training/mxfp8/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py b/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py new file mode 100644 index 0000000000..d9c448c343 --- /dev/null +++ b/test/prototype/moe_training/mxfp8/test_mxfp8_a2a.py @@ -0,0 +1,147 @@ +import pytest +import torch + +if not torch.cuda.is_available() or torch.cuda.get_device_capability() != (10, 0): + pytest.skip("Test requires CUDA build on SM100", allow_module_level=True) + +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +from torch.distributed._functional_collectives import ( + all_to_all_single_autograd, +) +from torch.nn import functional as F +from torch.testing._internal.common_distributed import ( + MultiProcessTestCase, +) +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + run_tests, +) + +from torchao.float8.float8_utils import ( + compute_error, +) +from torchao.prototype.moe_training.kernels.mxfp8.comms import ( + mxfp8_on_device_all_to_all_v, +) + +from ..testing_utils import generate_split_sizes + + +@instantiate_parametrized_tests +class TritonAllReduceTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + return 4 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + torch.manual_seed(42 + self.rank) + + def _init_device(self): + symm_mem.set_backend("NVSHMEM") + + def test_a2a_fwd_bwd(self): + self._init_process() + try: + torch.manual_seed(42 + self.rank) + self._init_device() + + group_name = dist.group.WORLD.group_name + symm_mem.enable_symm_mem_for_group(group_name) + + tokens_per_ep_rank = 8192 + dim = 2048 + input_tensor = torch.randn( + tokens_per_ep_rank, + dim, + device=self.device, + dtype=torch.float32, + requires_grad=True, + ) + ref_input_tensor = input_tensor.detach().clone().requires_grad_(True) + + # Generate random input splits that sum to tokens_per_ep_rank + num_splits = self.world_size + input_splits = generate_split_sizes( + num_splits, tokens_per_ep_rank, self.device + ) + + # Max output tokens per rank is worst case where one rank receives all tokens + max_output_tokens_per_rank = tokens_per_ep_rank * self.world_size + + # Test forward + output, output_splits = mxfp8_on_device_all_to_all_v( + input_tensor, + input_splits, + max_output_tokens_per_rank, + group_name, + ) + + # Reference torch.all_to_all_single to compare against + output_splits_ref = torch.empty_like(output_splits) + + # Compute output splits from input splits + dist.all_to_all_single(output_splits_ref, input_splits) + + # Pre-allocate output buffer for reference a2a + total_tokens_on_rank_after_a2a = output_splits_ref.sum() + ref_output = torch.empty( + total_tokens_on_rank_after_a2a, + dim, + device=self.device, + dtype=torch.float32, + ) + + # Do the actual all_to_all_single + ref_output = all_to_all_single_autograd( + ref_input_tensor, + output_splits_ref.tolist(), + input_splits.tolist(), + dist.group.WORLD, + ) + + # Compare output + assert torch.equal(output_splits, output_splits_ref), ( + "output_splits mismatch" + ) + out_no_padding = output[:total_tokens_on_rank_after_a2a] + sqnr = compute_error(ref_output, out_no_padding) + min_sqnr = 30.0 + assert sqnr > min_sqnr, f"sqnr={sqnr} is less than min_sqnr={min_sqnr}" + + # Test backwards + labels = torch.ones_like(out_no_padding) + loss = F.mse_loss(out_no_padding, labels) + ref_loss = F.mse_loss(ref_output, labels) + loss.backward() + ref_loss.backward() + + # Compare grads + grad_sqnr = compute_error(ref_input_tensor.grad, input_tensor.grad) + min_grad_sqnr = 28.0 + assert grad_sqnr > min_grad_sqnr, ( + f"grad_sqnr={grad_sqnr} is less than min_grad_sqnr={min_grad_sqnr}" + ) + + finally: + dist.destroy_process_group() + + +if __name__ == "__main__": + run_tests() diff --git a/test/prototype/moe_training/testing_utils.py b/test/prototype/moe_training/testing_utils.py index cf13b81ae3..1d062b5b8b 100644 --- a/test/prototype/moe_training/testing_utils.py +++ b/test/prototype/moe_training/testing_utils.py @@ -1,3 +1,4 @@ +import torch from torch import nn from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor @@ -31,3 +32,33 @@ def _recursive_validate( _recursive_validate(child_module, child_fqn) _recursive_validate(root_module, "") + + +def generate_split_sizes(K: int, N: int, device: str = "cpu") -> torch.Tensor: + """ + Generates a tensor of K random non-negative integers that sum to N. + Used for testing mxfp8_all_to_all_v implementation. + """ + if K <= 0: + raise ValueError("K must be a positive integer.") + if N < 0: + raise ValueError("N must be a non-negative integer.") + + if K == 1: + return torch.tensor([N], dtype=torch.long, device=device) + + # Generate K-1 random "dividers" in the range [0, N]. + dividers = torch.randint(0, N + 1, (K - 1,), device=device) + + # Add 0 and N to the set of dividers to form the boundaries. + boundaries = torch.cat( + [torch.tensor([0], device=device), dividers, torch.tensor([N], device=device)] + ) + + # Sort the boundaries to ensure they are in order + sorted_boundaries = torch.sort(boundaries).values + + # The K integers are the differences between consecutive boundaries (will sum to N) + result = sorted_boundaries[1:] - sorted_boundaries[:-1] + + return result.to(dtype=torch.int64) diff --git a/torchao/prototype/moe_training/kernels/mxfp8/__init__.py b/torchao/prototype/moe_training/kernels/mxfp8/__init__.py new file mode 100644 index 0000000000..42b485e374 --- /dev/null +++ b/torchao/prototype/moe_training/kernels/mxfp8/__init__.py @@ -0,0 +1,11 @@ +from torchao.prototype.moe_training.kernels.mxfp8.quant import ( + compute_blocked_scale_offsets_for_K_groups, # noqa: F401 + compute_blocked_scale_offsets_for_M_groups, # noqa: F401 + mxfp8_quantize_cuda_3d, # noqa: F401 + torch_to_blocked_2d_K_groups, # noqa: F401 + torch_to_blocked_2d_M_groups, # noqa: F401 + torch_to_blocked_per_group_3d, # noqa: F401 + triton_mx_block_rearrange_2d_K_groups, # noqa: F401 + triton_mx_block_rearrange_2d_M_groups, # noqa: F401 + triton_mx_block_rearrange_per_group_3d, # noqa: F401 +) diff --git a/torchao/prototype/moe_training/kernels/mxfp8/comms.py b/torchao/prototype/moe_training/kernels/mxfp8/comms.py new file mode 100644 index 0000000000..9593b02513 --- /dev/null +++ b/torchao/prototype/moe_training/kernels/mxfp8/comms.py @@ -0,0 +1,427 @@ +import torch +import torch.distributed as dist +import torch.distributed._symmetric_memory as symm_mem +import triton +import triton.language as tl + +from torchao.prototype.moe_training.kernels.triton_utils import ( + blockwise_barrier, + sync_threads, +) +from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx + + +# This performs dynamic mxfp8 quantization of the input tensor, +# followed by an on-device all-to-all-v operation as determined by the input_splits, implented via Triton + PyTorch symmetric memory. +# This kernel is an extension of the original bf16 version here: +# https://github.com/pytorch/torchtitan/blob/476a965f93432f4f1681bc1bac064d689a2d0cec/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_on_device_all_to_all_v.py#L1 +class MXFP8OnDeviceAllToAllV(torch.autograd.Function): + # A symmetric memory buffer for exchanging input rows/tokens during forward + input_sym_mem_buf = None + + # A symmetric memory for exchanging scales during both forward and backward + scales_sym_mem_buf = None + + # A symmetric memory for exchanging split sizes during both forward and backward + input_splits_sym_mem_buf = None + + # A symmetric memory buffer holding the grad_output during backward + grad_out_sym_mem_buf = None + + # Maximum output length (need to be set before use of MXFP8OnDeviceAllToAllV) + max_output_rows_per_rank = None + + @staticmethod + def forward( + ctx, + input: torch.Tensor, + input_splits: torch.Tensor, + max_output_rows_per_rank: int, + group: dist.ProcessGroup = dist.group.WORLD, + ): + """ + Args: + input: input float8_e4m3fn tensor with data for all ranks concatenated. + input_scales: float8_e8m0fnu scales for the input tensor. + input_splits: input splits of shape (group.world_size,) + max_output_rows_per_rank: maximum output rows/tokens per rank. + group: process group to scope the collective. + """ + assert input.dtype in (torch.float32, torch.bfloat16) + + # Enable symm mem for the group if not already enabled + if not symm_mem.is_symm_mem_enabled_for_group(group): + symm_mem.enable_symm_mem_for_group(group) + + MXFP8OnDeviceAllToAllV.max_output_rows_per_rank = max_output_rows_per_rank + + # Quantize input + block_size = 32 + to_mx_c = torch.compile(to_mx) + input_scales, input_data = to_mx_c( + input, + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Triton doesn't support float8_e8m0fnu yet, view as uint8 + input_scales = input_scales.view(torch.uint8) + + # Initialize sym mem buffer for float8 e4m3 input data (one time only) + if MXFP8OnDeviceAllToAllV.input_sym_mem_buf is None: + MXFP8OnDeviceAllToAllV.input_sym_mem_buf = symm_mem.empty( + MXFP8OnDeviceAllToAllV.max_output_rows_per_rank, + *input_data.shape[1:], + dtype=input_data.dtype, + device=input_data.device, + ) + + # Initialize input splits buffer (one time only) + if MXFP8OnDeviceAllToAllV.input_splits_sym_mem_buf is None: + MXFP8OnDeviceAllToAllV.input_splits_sym_mem_buf = symm_mem.empty( + *input_splits.shape, + dtype=input_splits.dtype, + device=input_splits.device, + ) + + # Initialize symm mem buffer for float8 e8m0 scales (one time only) + if MXFP8OnDeviceAllToAllV.scales_sym_mem_buf is None: + MXFP8OnDeviceAllToAllV.scales_sym_mem_buf = symm_mem.empty( + MXFP8OnDeviceAllToAllV.max_output_rows_per_rank, + *input_scales.shape[1:], + dtype=input_scales.dtype, + device=input_scales.device, + ) + + # Copy quantized input data to symm mem buffer + MXFP8OnDeviceAllToAllV.input_sym_mem_buf.narrow( + 0, 0, input_data.shape[0] + ).copy_(input_data) + + # Copy input splits to symm mem buffer + MXFP8OnDeviceAllToAllV.input_splits_sym_mem_buf.copy_(input_splits) + + # Copy input scales to symm mem buffer + MXFP8OnDeviceAllToAllV.scales_sym_mem_buf.narrow( + 0, 0, input_scales.shape[0] + ).copy_(input_scales) + + # Allocate buffers for output data, scales, and splits. + output = input_data.new_empty( + MXFP8OnDeviceAllToAllV.max_output_rows_per_rank, *input_data.shape[1:] + ) + output_scales = input_scales.new_empty( + MXFP8OnDeviceAllToAllV.max_output_rows_per_rank, *input_scales.shape[1:] + ) + output_splits = torch.empty_like(input_splits) + + # Shuffle input to output + _mxfp8_on_device_all_to_all_v( + MXFP8OnDeviceAllToAllV.input_sym_mem_buf, + MXFP8OnDeviceAllToAllV.scales_sym_mem_buf, + MXFP8OnDeviceAllToAllV.input_splits_sym_mem_buf, + output, + output_scales, + output_splits, + group=group, + ) + + # Dequantize output + lowp_dtype = output.dtype + highp_dtype = input.dtype + hp_output = to_dtype( + output, + output_scales.view(torch.float8_e8m0fnu), + lowp_dtype, + block_size, + highp_dtype, + ) + + # Saving for backward: output splits in forward is the input splits in backward + ctx.save_for_backward(output_splits) + ctx.group = group + ctx.input_shape = input_data.shape + ctx.input_scales_shape = input_scales.shape + ctx.highp_dtype = highp_dtype + + return hp_output, output_splits + + @staticmethod + def backward(ctx, grad_output, grad_splits): + """ + Backward is implemented as a shuffle of the output's gradients to the input. + Args: + `grad_output`: output's gradients passed from the downstream. + `grad_splits`: unused. + """ + # In backward, mxfp8_all_to_all_v input is `grad_output`, and output is `grad_input`. + (grad_output_splits,) = ctx.saved_tensors + + # Initialize grad_output sym mem buffer (one time only) + if MXFP8OnDeviceAllToAllV.grad_out_sym_mem_buf is None: + assert MXFP8OnDeviceAllToAllV.max_output_rows_per_rank is not None, ( + "`max_output_rows_per_rank` not set" + ) + MXFP8OnDeviceAllToAllV.grad_out_sym_mem_buf = symm_mem.empty( + MXFP8OnDeviceAllToAllV.max_output_rows_per_rank, + *grad_output.shape[1:], + dtype=torch.float8_e4m3fn, + device=grad_output.device, + ) + + # Quantize grad_output + block_size = 32 + grad_out_scales, grad_out_data = to_mx( + grad_output, + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Triton doesn't support float8_e8m0fnu yet, view as uint8 + grad_out_scales = grad_out_scales.view(torch.uint8) + + # Copy in float8 grad out data to a symm mem buffer + MXFP8OnDeviceAllToAllV.grad_out_sym_mem_buf.narrow( + 0, 0, grad_out_data.shape[0] + ).copy_(grad_out_data) + + # Copy in grad out e8m0 scales to symm mem buffer + MXFP8OnDeviceAllToAllV.scales_sym_mem_buf.narrow( + 0, 0, grad_out_scales.shape[0] + ).copy_(grad_out_scales) + + # Copy in splits to symm mem buffer + MXFP8OnDeviceAllToAllV.input_splits_sym_mem_buf.copy_(grad_output_splits) + + # Allocate outputs. + grad_input = grad_out_data.new_empty(*ctx.input_shape) + grad_input_scales = torch.empty( + *ctx.input_scales_shape, + dtype=grad_out_scales.dtype, + device=grad_out_scales.device, + ) + grad_input_splits = torch.empty_like(grad_output_splits) + + # Shuffle gradients back to the input + _mxfp8_on_device_all_to_all_v( + MXFP8OnDeviceAllToAllV.grad_out_sym_mem_buf, # input + MXFP8OnDeviceAllToAllV.scales_sym_mem_buf, # input scales + MXFP8OnDeviceAllToAllV.input_splits_sym_mem_buf, # input splits + grad_input, # output + grad_input_scales, # output scales + grad_input_splits, # output splits + group=ctx.group, + ) + + # Dequantize grad_input + lowp_dtype = grad_out_data.dtype + grad_input_highp = to_dtype( + grad_input, + grad_input_scales.view(torch.float8_e8m0fnu), + lowp_dtype, + block_size, + ctx.highp_dtype, + ) + return grad_input_highp, None, None, None + + +# Alias +mxfp8_on_device_all_to_all_v = MXFP8OnDeviceAllToAllV.apply + + +# Triton launcher function +def _mxfp8_on_device_all_to_all_v( + input: torch.Tensor, + input_scales: torch.Tensor, + input_splits: torch.Tensor, + output: torch.Tensor, + output_scales: torch.Tensor, + output_splits: torch.Tensor, + group: dist.ProcessGroup = dist.group.WORLD, + BLOCKS_PER_REMOTE_RANK: int = 8, + BLOCK_SIZE: int = 16384, +): + assert input.dim() == 2, f"{input.shape}" + assert output.dim() == 2, f"{output.shape}" + assert output.shape[1] == input.shape[1] + + # Prepare symmetric memory managed buffers for input, input_splits, and input_scales. + # - `input` shape (tokens, dim) -> to a sym mem managed buffer of shape (num_ranks, tokens, dim) + # - `input_splits` shape (num_ranks,) -> to a sym mem managed buffer of shape (num_ranks, num_ranks)` + # - `input_scales` shape (tokens, dim//block_size) -> to a sym mem managed buffer of shape (num_ranks, tokens, dim//block_size) + input_hdl = symm_mem.rendezvous(input, group=group) + input_splits_hdl = symm_mem.rendezvous(input_splits, group=group) + input_scales_hdl = symm_mem.rendezvous(input_scales, group=group) + + input_ptrs = input_hdl.buffer_ptrs_dev + input_splits_ptrs = input_splits_hdl.buffer_ptrs_dev + input_scales_ptrs = input_scales_hdl.buffer_ptrs_dev + signal_pad_ptrs = input_hdl.signal_pad_ptrs_dev + dim = output.shape[1] + dim_scaling_groups = input_scales.shape[-1] + num_blocks = input_hdl.world_size * BLOCKS_PER_REMOTE_RANK + + _mxfp8_all_to_all_v_kernel[(num_blocks, 1, 1)]( + input_ptrs, + input_scales_ptrs, + input_splits_ptrs, + output, + output_scales, + output_splits, + signal_pad_ptrs, + dim=dim, + dim_scaling_groups=dim_scaling_groups, + rank=input_hdl.rank, + world_size=input_hdl.world_size, + BLOCKS_PER_REMOTE_RANK=BLOCKS_PER_REMOTE_RANK, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1, + ) + + return output + + +@triton.jit +def _mxfp8_all_to_all_v_kernel( + input_ptrs, + input_scales_ptrs, + input_splits_ptr, + output_ptr, + output_scales_ptr, + output_splits_ptr, + signal_pad_ptrs, + dim: tl.constexpr, + dim_scaling_groups: tl.constexpr, + rank: tl.constexpr, + world_size: tl.constexpr, + BLOCKS_PER_REMOTE_RANK: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + sync_threads() + + remote_rank = tl.program_id(0) // BLOCKS_PER_REMOTE_RANK + block_offset = tl.program_id(0) % BLOCKS_PER_REMOTE_RANK + + # 1. Get input row to read from the given remote rank (to get data coming to this local rank), + # and how many rows we're reading. + # 2. Get the output row offset to write that data to. + input_row_offset, output_row_offset, num_rows_to_read = _exchange_row_offsets( + input_splits_ptr, + rank, + remote_rank, + world_size, + ) + + # One thread block per rank will update output_splits + if block_offset == 0: + tl.store(output_splits_ptr + remote_rank, num_rows_to_read) + + # Update input and output pointers to point to the specific row we're reading/writing. + # 1. `input` is symmetric memory managed buffer of shape [num_ranks, tokens, dim]. + # We increment the ptr by `+remote_rank` along the 0th dim to get to the remote rank ptr, + # then increment that ptr by `input_row_offset * dim (stride)` to get the + # start offset for this rank's data on that remote rank. + # 2. `output` is a regular local tensor, we can stride into it as usual. + input_ptr = ( + tl.load(input_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to( + tl.pointer_type(tl.float8e4nv) + ) + + input_row_offset * dim + ) + output_ptr = output_ptr + output_row_offset * dim + + # Update input_scales and output_scales pointers to point to the specific row we're reading/writing. + # 1. `input_scales` is symmetric memory managed buffer of shape [num_ranks, tokens, dim//block_size]. + # We increment the ptr by `+remote_rank` along the 0th dim to get to the remote rank ptr, + # then increment by `input_row_offset * dim_scaling_groups (stride)` to get to the start of the + # scales for this rank on that remote rank. + # 2. `output_scales` is a regular local tensor, we can stride into it as usual. + input_scale_ptr = ( + tl.load(input_scales_ptrs.to(tl.pointer_type(tl.uint64)) + remote_rank).to( + tl.pointer_type( + tl.uint8 + ) # Triton doesn't support float8_e8m0fnu yet, use uint8 instead + ) + + input_row_offset * dim_scaling_groups + ) + output_scale_ptr = output_scales_ptr + output_row_offset * dim_scaling_groups + + # Copy target region of remote rank input data to our local output buffer. + total_input_elems_to_read = num_rows_to_read * dim + num_input_blocks = tl.cdiv(total_input_elems_to_read, BLOCK_SIZE) + for block_idx in tl.range(num_input_blocks): + offs = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < total_input_elems_to_read + data = tl.load(input_ptr + offs, mask=mask, other=0.0) + tl.store(output_ptr + offs, data, mask=mask) + + # Copy input_scales (scales on remote rank) to output_scales local buffer. + total_input_scales_to_read = num_rows_to_read * dim_scaling_groups + num_input_scale_blocks = tl.cdiv(total_input_scales_to_read, BLOCK_SIZE) + for block_idx in tl.range(num_input_scale_blocks): + offs = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < total_input_scales_to_read + data = tl.load(input_scale_ptr + offs, mask=mask, other=0.0) + tl.store(output_scale_ptr + offs, data, mask=mask) + + sync_threads() + blockwise_barrier(signal_pad_ptrs, None, rank, world_size, sem="relaxed") + return + + +@triton.jit +def _exchange_row_offsets( + split_sizes_ptrs, + local_rank: tl.constexpr, + remote_rank: tl.constexpr, + world_size: tl.constexpr, +): + """ + Returns: + - `input_offset_for_remote_rank`: + - `output_offset_for_remote_rank`: + - `num_rows`: + """ + # split_sizes_ptr points to 2d tensor of stacked input split size vectors (one per rank). Example: + # rank 0 = [30, 10, 10, 20] + # rank 1 = [20, 20, 10, 20] + split_sizes_ptrs = split_sizes_ptrs.to(tl.pointer_type(tl.uint64)) + + # Get pointer to remote rank's input_split_sizes tensor. + remote_rank_input_splits_ptr = tl.load(split_sizes_ptrs + remote_rank).to( + tl.pointer_type(tl.int64) + ) + + # num_rows_to_read is the specific number of tokens to read from remote_rank. + num_rows_to_read = tl.load(remote_rank_input_splits_ptr + local_rank) + + # Calculate starting offset in symm mem buf to read data from remote_rank for this local_rank. + # + # Do this by computing prefix sum of remote split offsets prev ranks. + # Ex. remote_rank split sizes = [10, 20, 30] + # For local rank 1, masked load = [10, 0, 0] + # Starting offset = sum([10, 0, 0]) = 10 + offsets = tl.arange(0, world_size) + remote_split_sizes_prefix = tl.load( + remote_rank_input_splits_ptr + offsets, mask=offsets < local_rank, other=0 + ) + input_offset_for_remote_rank = tl.sum(remote_split_sizes_prefix) + + # Calculate offset in local output buffer to start writing data to, for data coming from the remote_rank to this local_rank. + # + # We add `offsets` arange to get a set of pointers to the start of each row (rank) in the split_sizes matrix. + # Then, we add the local rank to each pointer, incrementing it colwise to reach the value for this local rank. + # Each ptrs now all point to how many tokens/rows that device has for local rank. + # + # torch equivalent: split_sizes_matrix[:, rank] + ptr_to_each_rank_split_sizes = tl.load(split_sizes_ptrs + offsets).to( + tl.pointer_type(tl.int64) + ) + output_split_sizes_ptrs = ptr_to_each_rank_split_sizes + local_rank + output_split_sizes = tl.load( + output_split_sizes_ptrs, mask=offsets < remote_rank, other=0 + ) + output_offset_for_remote_rank = tl.sum(output_split_sizes) + + return input_offset_for_remote_rank, output_offset_for_remote_rank, num_rows_to_read diff --git a/torchao/prototype/moe_training/kernels/mxfp8.py b/torchao/prototype/moe_training/kernels/mxfp8/quant.py similarity index 100% rename from torchao/prototype/moe_training/kernels/mxfp8.py rename to torchao/prototype/moe_training/kernels/mxfp8/quant.py diff --git a/torchao/prototype/moe_training/kernels/triton_utils.py b/torchao/prototype/moe_training/kernels/triton_utils.py new file mode 100644 index 0000000000..ec5908d90a --- /dev/null +++ b/torchao/prototype/moe_training/kernels/triton_utils.py @@ -0,0 +1,211 @@ +""" +Triton utility functions sourced from torchtitan: +https://github.com/pytorch/torchtitan/blob/3e1b843ecd91a80dbf56e16f2eae7b637209ebcb/torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py +""" + +import triton +import triton.language as tl + + +@triton.jit +def sync_threads(): + tl.inline_asm_elementwise( + "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 + ) + + +@triton.jit +def send_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.relaxed.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + send_signal: + atom.global.release.sys.cas.b32 %tmp32_0, [$1], 0, 1; + setp.eq.u32 %p0, %tmp32_0, 0; + @!%p0 bra send_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def wait_signal(addrs, sem: tl.constexpr): + if sem == "relaxed": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.relaxed.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + elif sem == "acq_rel": + tl.inline_asm_elementwise( + """ + { + .reg .u32 %tmp32_<1>; + .reg .pred %p<1>; + + wait_signal: + atom.global.sys.acquire.cas.b32 %tmp32_0, [$1], 1, 0; + setp.eq.u32 %p0, %tmp32_0, 1; + @!%p0 bra wait_signal; + } + """, + "=r, l", + [addrs], + dtype=tl.int32, + is_pure=False, + pack=1, + ) + else: + raise RuntimeError(f"Unrecognized sem: {sem}") + + +@triton.jit +def blockwise_barrier( + signal_pad_ptrs, + block_id, + rank: tl.constexpr, + world_size: tl.constexpr, + sem: tl.constexpr, +): + """ + Synchronizes blocks with matching block_id across participating devices. + + Note: the function itself is not a system level barrier/fence. It is a + building block for expressing different synchronization patterns. + + Pattern 0: Ensures that all writes to symm_mem buffers from previous + kernels across all devices are visible to the current kernel: + + blockwise_barrier(..., sem="relaxed") + sync_threads() + + Pattern 1: Ensures that all writes to symm_mem buffers from the current + block are visible to all remote blocks with matching blockIdx: + + sync_threads() + blockwise_barrier(..., sem="acq_rel") + sync_threads() + + Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe + for writing by subsequent kernels across all devices. + + sync_threads() + blockwise_barrier(..., sem="relaxed") + + CUDA graph friendliness: + + This barrier operates through atomic operations on a zero-filled signal + pad, which resets to a zero-filled state after each successful + synchronization. This design eliminates the need for incrementing a + flag from host. + """ + if block_id is None: + block_id = get_flat_bid() + flat_tid = get_flat_tid() + + remote_ranks = tl.arange(0, world_size) + signal_pad_ptrs = signal_pad_ptrs.to(tl.pointer_type(tl.uint64)) + remote_signal_pad_addrs = tl.load(signal_pad_ptrs + remote_ranks).to( + tl.pointer_type(tl.uint32) + ) + send_addrs = remote_signal_pad_addrs + block_id * world_size + rank + + local_signal_pad_addr = tl.load(signal_pad_ptrs + rank).to( + tl.pointer_type(tl.uint32) + ) + wait_addrs = local_signal_pad_addr + block_id * world_size + remote_ranks + + if flat_tid < world_size: + send_signal(send_addrs, sem) + wait_signal(wait_addrs, sem) + + +@triton.jit +def get_flat_bid(): + return ( + tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) + + tl.program_id(1) * tl.num_programs(0) + + tl.program_id(0) + ) + + +@triton.jit +def get_tid(): + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %tid.x; + mov.u32 $1, %tid.y; + mov.u32 $2, %tid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_ntid(): + return tl.inline_asm_elementwise( + """ + mov.u32 $0, %ntid.x; + mov.u32 $1, %ntid.y; + mov.u32 $2, %ntid.z; + """, + "=r,=r,=r", + [], + dtype=(tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + + +@triton.jit +def get_flat_tid(): + tid_x, tid_y, tid_z = get_tid() + ntid_x, ntid_y, _ = get_ntid() + return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index b717462b4d..8532337477 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -346,7 +346,7 @@ def to_dtype( elem_dtype, block_size, target_dtype, - pack_fp6, + pack_fp6: bool = False, ): orig_shape = data_lp.shape is_transposed = not data_lp.is_contiguous()