Skip to content

Transpose one of the MLP matrices + add Triton kernel for symmetric matmul (new WR)#109

Merged
ClassicLarry merged 1 commit intoKellerJordan:masterfrom
byronxu99:master
Oct 15, 2025
Merged

Transpose one of the MLP matrices + add Triton kernel for symmetric matmul (new WR)#109
ClassicLarry merged 1 commit intoKellerJordan:masterfrom
byronxu99:master

Conversation

@byronxu99
Copy link
Copy Markdown
Contributor

New world record (169042 ms, 2.817 min)

Transposing a MLP matrix

This is an extremely simple change that results in processing one fewer batch of MLP parameters per optimizer step.

Our model has 12 transformer blocks. Each block has two MLP matrices of shape (768, 4*768) and (4*768, 768), which gives 12 matrices of each shape. The optimizer groups parameters by each unique shape. For each shape, it creates batches of world_size (equal to 8) parameters and assigns one to each GPU. Any partial batches are filled with padding matrices up to the next multiple of 8. We need four total batches to process 12+12 matrices.

The change here is to create both of the MLP matrices with shape (768, 4*768). Now we have 24 matrices of identical shape, evenly divisible into three batches of 8.

During the forward pass, we simply transpose the "wrong" shape before use.

Triton kernels

Multiplying two matrices with shape (m, k) and (k, n) requires 2*m*k*n FLOPS. However, multiplying a matrix (m, k) with its own transpose (k, m) can be done with only m*k*m FLOPS. The result is symmetric, so we only need to compute half of it and copy the result across the diagonal.

The first two steps in Newton-Schulz loop can be optimized this way. The third line is a standard (non-symmetric) matmul. Theoretically, this can give a 1.5x end-to-end speedup for square matrices.

PyTorch does not natively provide a symmetric matrix multiplication function, so I implemented it using Triton. The two applicable Newton-Schulz steps use separate kernels, because the addition operation required for the second line leads to a slight slowdown. The code is roughly based on the Triton tutorial.

Comparing NS with Triton kernels vs. PyTorch implementation with torch.compile for a batch of (4, d, d) matrices. Y-axis is "effective TFLOPS" benchmarked on an A100 GPU.
image

Related work

This new speedrun record came out of my work at Microsoft Research on scaling orthonormal optimization to larger models. The Triton kernel was the only thing so far that yields an improvement at the NanoGPT scale. For those who are curious, I would recommend checking out our related work:

  • My team has released a paper on Dion, an alternative to Newton-Schulz iteration that can be directly applied to sharded matrices (and could potentially give slightly better orthonormalization accuracy)
  • In the paper linked above, we have some promising results with using Lion instead of AdamW for non-matrix parameters. There's a systematic method for sharing the same base learning rate between two different optimizer algorithms using per-parameter LR scale factors.
  • We have optimized implementations of Muon and Dion for distributed training. Our Muon supports PyTorch FSDP2 (using all-to-all to efficiently unshard) and Dion supports combined FSDP2+TP.
  • We are currently awaiting approval for open-sourcing the code, but a public release is on the project roadmap.

Minor changes

Included in the PR are some minor changes made when developing and testing the code.

  • Support training on machines with fewer than 8 GPUs. Added a gradient accumulation loop for 8 // world_size iterations and modified data loading code to produce the equivalent batches of data.
  • Allow disabling FP8 by setting environment variable DISABLE_FP8=1. Useful for running on older devices that don't support FP8.
  • Allow specifying a custom path to data directory with DATA_PATH environment variable.

@KellerJordan
Copy link
Copy Markdown
Owner

This looks awesome. I expect to retime it on the same hardware I used to time the previous record, and then announce the new record, roughly at the end of week.

@linux-leo
Copy link
Copy Markdown

Just linking relevant literature regarding symmetric matmul: https://arxiv.org/pdf/2505.09814

Not sure if it provides a speedup at this scale though.

@kiankyars
Copy link
Copy Markdown
Contributor

kiankyars commented Jul 30, 2025

Just linking relevant literature regarding symmetric matmul: https://arxiv.org/pdf/2505.09814

Not sure if it provides a speedup at this scale though.

"Note that the accelerations not only holds asymptotically for large matrices
with n →∞, but also for small matrices including n = 4."

I think we should definitely try this, can't find any implementations of out online though and the benchmarking for the paper was done on CPU

@kiankyars
Copy link
Copy Markdown
Contributor

This looks awesome. I expect to retime it on the same hardware I used to time the previous record, and then announce the new record, roughly at the end of week.

you said EOW, just checking

@MarktHart
Copy link
Copy Markdown

We are currently awaiting approval for open-sourcing the code, but a public release is on the project roadmap.

For people looking: https://github.com/microsoft/dion

@YouJiacheng
Copy link
Copy Markdown
Contributor

YouJiacheng commented Aug 23, 2025

did you test NS triton kernel's end-to-end speedup?
iirc, AllGather time is dominant here and we overlap AG with NS, so NS speedup might not translate to end-to-end speedup.

putting 24 FFN matrices into 3 batches instead of 4 batches is effective cuz it significantly reduces the communication volume.
but faster NS might not be that effective.

@byronxu99
Copy link
Copy Markdown
Contributor Author

I tried ablations with only batching / only triton kernel. The MLP matrix batching is responsible for most of the speedup here (maybe around 75% of total) but the triton kernel does help, even at this very small scale with aggressive overlapping.

By "AllGather time is dominant" are you saying that Muon is communication bound here? That would surprise me, since NVlink is very fast, and NS has fairly high compute/comms ratio with 15 matmuls to 1 all-gather.

@YouJiacheng
Copy link
Copy Markdown
Contributor

YouJiacheng commented Aug 24, 2025

My memory might be wrong. (I can't exactly remember the AG vs. NS time in the trace)

Here is my reasoning:

Cuz the matrices are small, I observed (IIRC) like ~200GB/s AG bandwidth (nominal unidirectional bandwidth is 450GB/s) on H100s.

If NS can achieve 600TFLOPS, we need 3000 FLOPs/byte arithmetic intensity (relative to network) to be compute bound.

An important factor is that, each GPU only do NS for 1/8 of matrices but it need to gather 7/8 of matrices.

For a (3072, 768) matrix, NS consumes 2*5*(3072*2+768)*768^2=4.076863e+010 FLOPs, AG needs to gather other 7 matrices, i.e. 2*3072*768*7=33030144 bytes, so the intensity is 4.076863*10^10/33030144=1234.285567 FLOPs/byte.

Plz correct me if I'm wrong.

Ofc, 600TFLOPS might be too optimistic...

@byronxu99
Copy link
Copy Markdown
Contributor Author

You may be right that we are communication bottlenecked. Here's another estimate. For matrix shape (m, n) with m >= n we have

  • NS FLOPS default: 20 m n^2 + 10 n^3
  • NS FLOPS with Triton: 15 m n^2 + 5 n^3
  • Triton FLOPS reduction: 5 m n^2 + 5 n^3

Before the batching change, we have 2 batches of (4, 768, 768) for attention + 4 batches of (3072, 768) for MLP. After, the MLP part is reduced to 3 batches of (3072, 768).

  • The FLOPS saved by batching alone (no triton) would be 20*3072*768^2 + 10*768^3 = 4.08e10.
  • If we are compute bound, the FLOPs saved by Triton matmul alone would be 2*4*(5*768^3 + 5*768^3) + 4*(5*3072*768^2 + 5*768^3) = 8.15e10.

This is 2x the FLOPS from batching. However, I observed that most speedup comes from batching alone. Either the Triton kernel is extremely inefficient, or we are communication bound.

If we are communication bound, Triton only helps with the first batch of NS, which doesn't have an all-gather to overlap with. (but maybe it overlaps with previous reduce-scatter? not entirely sure here)

  • If the first batch is MLP, we save 5*3072*768^2 + 5*768^3 = 1.13e10.
  • If it's attention, we save 4*(5*768^3 + 5*768^3) = 1.81e10.
  • The communication time savings from one fewer batch of MLP matrices would be greater than the 4.08e10 FLOPS savings.

This is much more proportional to the observed ~75% of speedup coming from one fewer batch.

I think the first batch is attention since it's defined first in the model (but haven't verified this). Then the comms time saving would be equal to roughly the time it takes to perform (75% / 25% * 1.81e10) = 5.43e10 FLOPS.

  • (2*3072*768*7 = 33030144 bytes) / (200 GB/s bandwidth) = 1.65e-4 seconds
  • 5.43e10 FLOPs / 1.65e-4 seconds = 329 TFLOP/s

This seems like a reasonable estimate of H100 TFLOPS for Newton Schulz on small matrices.

I included a plot above with benchmark results on A100. Triton NS gets about 150 TFLOPS and Torch NS is a bit over 100 TFLOPS for a batch of (4, 768, 768) matrices. The H100 has about 3x theoretical FLOPS and 1.7x memory bandwidth as A100. I think your estimate of 600 TFLOPS does seem rather high. But 329 is quite realistic and in line with extrapolating from A100 benchmark results.

@YouJiacheng
Copy link
Copy Markdown
Contributor

Your analysis looks great!
It should overlap with previous reduce-scatter, so I didn't fully understand why the speedup is significant.
But in an ideal implementation, this previous RS can be overlapped with the backward pass, so the exposed computation indeed matters.
I want to add a point: the reduced number of batches also reduces the communication volume of the previous reduce-scatter.

@Gusarich
Copy link
Copy Markdown
Contributor

@KellerJordan any news? i'm wondering why all those WRs aren't being merged

@YouJiacheng
Copy link
Copy Markdown
Contributor

Because eff_lr = group["lr"] * max(1, p.size(-2) / p.size(-1)) ** 0.5 * getattr(p, "lr_mul", 1.0), it seems that the transposition will change the effective lr. (But it's true that this lr change won't have a significant impact).

@varunneal
Copy link
Copy Markdown
Contributor

@byronxu99 You may find it interesting that Microsoft appears to have fully used your kernels in their repo dion: https://github.com/microsoft/dion/blob/main/dion/newton_schulz_triton.py

@YouJiacheng
Copy link
Copy Markdown
Contributor

@varunneal lol you can check the author list of dion, Bryon Xu is one of the authors of dion!

@varunneal
Copy link
Copy Markdown
Contributor

varunneal commented Oct 8, 2025

@YouJiacheng 🤦‍♂️ My mistake

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants