You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.
invididual components
system overview (for training)
# There are three gemms in a forward + backward of a Linear layer:## 1. input @ weight_t = output (forward pass)# 2. grad_output @ weight = grad_input (backward pass)# 3. input_t @ grad_output = grad_weight (backward pass)# # in Python pseudocode, we want the following (for mxfp8):# forward pass# inputs are in high precisionx_hp, w_hp= ...
# input @ weight_t = outputx_mx_dim0, x_scale_dim0=to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0=to_mx(w_hp, dim=0)
y=mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)
# backward pass# inputs are in high precisionx_hp, w_hp, go_hp= ...
# grad_output @ weight = grad_inputgo_mx_dim0, go_scale_dim0=to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1=to_mx(w_hp.t().contiguous(), dim=0)
gi=mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)
# input_t @ grad_output = grad_weightgo_mx_dim1, go_scale_dim1=to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1=to_mx(x_hp.t().contiguous(), dim=0)
gw=mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)
We want:
the mx gemm to be fast
the cast from high precision to mx (to_mx in pseudocode above) to be fast
the cast from high precision to mx to be fused to preceding/subsequent ops where possible
gemm kernel
Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)
Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.
scaling/casting kernels
Our current plan is to use torch.compile, same as we are doing with float8.
once we have a single fused kernel, we should make sure it's bandwidth bound. As of 2025-02-24, the casting to MX code is numerically correct but researchy and has not been optimized for performance. TODO issue.
given an MXLinear (fwd + bwd), we should expect at most six scale+cast kernels: two for each of input, weight, grad_output. The kernels for input and grad_output should be fused with preceding/subsequent ops as appropriate. TODO issue.
e2e training performance
From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).
we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
[blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: [CUDA][Blackwell] Blackwell Tracking Issue pytorch#145949
e2e inference performance
need an inference roofline
need to decide where to benchmark
The text was updated successfully, but these errors were encountered:
This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.
invididual components
system overview (for training)
We want:
to_mx
in pseudocode above) to be fastgemm kernel
Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)
torch._scaled_mm
torchao.ops.mx_fp8_bf16
torchao.ops.mx_fp4_bf16
torch._scaled_mm
Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.
scaling/casting kernels
Our current plan is to use
torch.compile
, same as we are doing with float8.float8_e8m0fnu
dtype was added to PyTorch in add thetorch.float8_e8m0fnu
dtype to PyTorch pytorch#147466, we need to updatetorchao
to use this dtype for scales, and then ensure that PT2 works e2e. TODO issueinput
,weight
,grad_output
. The kernels forinput
andgrad_output
should be fused with preceding/subsequent ops as appropriate. TODO issue.e2e training performance
From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).
e2e inference performance
The text was updated successfully, but these errors were encountered: