Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MX single node performance tracker #1768

Open
vkuzo opened this issue Feb 24, 2025 · 0 comments
Open

MX single node performance tracker #1768

vkuzo opened this issue Feb 24, 2025 · 0 comments
Assignees

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Feb 24, 2025

This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.

invididual components

system overview (for training)

# There are three gemms in a forward + backward of a Linear layer:
#
# 1.       input @ weight_t    = output     (forward pass)
# 2. grad_output @ weight      = grad_input (backward pass)
# 3.     input_t @ grad_output = grad_weight (backward pass)
# 
# in Python pseudocode, we want the following (for mxfp8):

# forward pass

# inputs are in high precision
x_hp, w_hp = ...

# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)

# backward pass

# inputs are in high precision
x_hp, w_hp, go_hp = ...

# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)

# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)

We want:

  1. the mx gemm to be fast
  2. the cast from high precision to mx (to_mx in pseudocode above) to be fast
  3. the cast from high precision to mx to be fused to preceding/subsequent ops where possible

gemm kernel

Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)

kernel wrapper current TFLOPs peak TFLOPs notes
mxfp8 cuBLAS torch._scaled_mm TBD 4.25 petaFLOPs in progress, pytorch/pytorch#147548
mxfp8 CUTLASS torchao.ops.mx_fp8_bf16 TBD 4.25 petaFLOPs landed, #1637
mxfp4 CUTLASS torchao.ops.mx_fp4_bf16 TBD 9.0 petaFLOPs landed, #1661
nvfp4 cuBLAS torch._scaled_mm TBD 9.0 petaFLOPs planned

Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.

scaling/casting kernels

Our current plan is to use torch.compile, same as we are doing with float8.

e2e training performance

From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).

  • we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
  • [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: [CUDA][Blackwell] Blackwell Tracking Issue pytorch#145949

e2e inference performance

  • need an inference roofline
  • need to decide where to benchmark
@vkuzo vkuzo self-assigned this Feb 24, 2025
@vkuzo vkuzo changed the title MX training single GPU performance MX training performance tracker Feb 24, 2025
@vkuzo vkuzo changed the title MX training performance tracker MX training single node performance tracker Feb 24, 2025
@vkuzo vkuzo changed the title MX training single node performance tracker MX single node performance tracker Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant