Transpose one of the MLP matrices + add Triton kernel for symmetric matmul (new WR)#109
Conversation
|
This looks awesome. I expect to retime it on the same hardware I used to time the previous record, and then announce the new record, roughly at the end of week. |
|
Just linking relevant literature regarding symmetric matmul: https://arxiv.org/pdf/2505.09814 Not sure if it provides a speedup at this scale though. |
"Note that the accelerations not only holds asymptotically for large matrices I think we should definitely try this, can't find any implementations of out online though and the benchmarking for the paper was done on CPU |
you said EOW, just checking |
For people looking: https://github.com/microsoft/dion |
|
did you test NS triton kernel's end-to-end speedup? putting 24 FFN matrices into 3 batches instead of 4 batches is effective cuz it significantly reduces the communication volume. |
|
I tried ablations with only batching / only triton kernel. The MLP matrix batching is responsible for most of the speedup here (maybe around 75% of total) but the triton kernel does help, even at this very small scale with aggressive overlapping. By "AllGather time is dominant" are you saying that Muon is communication bound here? That would surprise me, since NVlink is very fast, and NS has fairly high compute/comms ratio with 15 matmuls to 1 all-gather. |
|
My memory might be wrong. (I can't exactly remember the AG vs. NS time in the trace) Here is my reasoning: Cuz the matrices are small, I observed (IIRC) like ~200GB/s AG bandwidth (nominal unidirectional bandwidth is 450GB/s) on H100s. If NS can achieve 600TFLOPS, we need 3000 FLOPs/byte arithmetic intensity (relative to network) to be compute bound. An important factor is that, each GPU only do NS for 1/8 of matrices but it need to gather 7/8 of matrices. For a Plz correct me if I'm wrong. Ofc, 600TFLOPS might be too optimistic... |
|
You may be right that we are communication bottlenecked. Here's another estimate. For matrix shape (m, n) with
Before the batching change, we have 2 batches of (4, 768, 768) for attention + 4 batches of (3072, 768) for MLP. After, the MLP part is reduced to 3 batches of (3072, 768).
This is 2x the FLOPS from batching. However, I observed that most speedup comes from batching alone. Either the Triton kernel is extremely inefficient, or we are communication bound. If we are communication bound, Triton only helps with the first batch of NS, which doesn't have an all-gather to overlap with. (but maybe it overlaps with previous reduce-scatter? not entirely sure here)
This is much more proportional to the observed ~75% of speedup coming from one fewer batch. I think the first batch is attention since it's defined first in the model (but haven't verified this). Then the comms time saving would be equal to roughly the time it takes to perform
This seems like a reasonable estimate of H100 TFLOPS for Newton Schulz on small matrices. I included a plot above with benchmark results on A100. Triton NS gets about 150 TFLOPS and Torch NS is a bit over 100 TFLOPS for a batch of (4, 768, 768) matrices. The H100 has about 3x theoretical FLOPS and 1.7x memory bandwidth as A100. I think your estimate of 600 TFLOPS does seem rather high. But 329 is quite realistic and in line with extrapolating from A100 benchmark results. |
|
Your analysis looks great! |
|
@KellerJordan any news? i'm wondering why all those WRs aren't being merged |
|
Because |
|
@byronxu99 You may find it interesting that Microsoft appears to have fully used your kernels in their repo dion: https://github.com/microsoft/dion/blob/main/dion/newton_schulz_triton.py |
|
@varunneal lol you can check the author list of dion, Bryon Xu is one of the authors of dion! |
|
@YouJiacheng 🤦♂️ My mistake |
New world record (169042 ms, 2.817 min)
Transposing a MLP matrix
This is an extremely simple change that results in processing one fewer batch of MLP parameters per optimizer step.
Our model has 12 transformer blocks. Each block has two MLP matrices of shape (768, 4*768) and (4*768, 768), which gives 12 matrices of each shape. The optimizer groups parameters by each unique shape. For each shape, it creates batches of
world_size(equal to 8) parameters and assigns one to each GPU. Any partial batches are filled with padding matrices up to the next multiple of 8. We need four total batches to process 12+12 matrices.The change here is to create both of the MLP matrices with shape (768, 4*768). Now we have 24 matrices of identical shape, evenly divisible into three batches of 8.
During the forward pass, we simply transpose the "wrong" shape before use.
Triton kernels
Multiplying two matrices with shape (m, k) and (k, n) requires 2*m*k*n FLOPS. However, multiplying a matrix (m, k) with its own transpose (k, m) can be done with only m*k*m FLOPS. The result is symmetric, so we only need to compute half of it and copy the result across the diagonal.
The first two steps in Newton-Schulz loop can be optimized this way. The third line is a standard (non-symmetric) matmul. Theoretically, this can give a 1.5x end-to-end speedup for square matrices.
PyTorch does not natively provide a symmetric matrix multiplication function, so I implemented it using Triton. The two applicable Newton-Schulz steps use separate kernels, because the addition operation required for the second line leads to a slight slowdown. The code is roughly based on the Triton tutorial.
Comparing NS with Triton kernels vs. PyTorch implementation with

torch.compilefor a batch of (4, d, d) matrices. Y-axis is "effective TFLOPS" benchmarked on an A100 GPU.Related work
This new speedrun record came out of my work at Microsoft Research on scaling orthonormal optimization to larger models. The Triton kernel was the only thing so far that yields an improvement at the NanoGPT scale. For those who are curious, I would recommend checking out our related work:
Minor changes
Included in the PR are some minor changes made when developing and testing the code.
8 // world_sizeiterations and modified data loading code to produce the equivalent batches of data.DISABLE_FP8=1. Useful for running on older devices that don't support FP8.DATA_PATHenvironment variable.