New WR 156s (1.25% better than PR #122): Optimize distributed training, improve skip connection gating, and enhance bfloat16 usage#125
Conversation
Merge PR 118
…aining, improve skip connection gating, and enhance bfloat16 usage
|
Nice work vectorizing muon! I believe that this optimization is unneeded
since Glad to see that keeping the linear weights in bf16 is effective. |
|
Also heads up this is pr125 not 123 haha |
|
Insane turnaround time on this one wow! In principle I think softmax seems reasonable to prevent a hypothetical where the weight oscillates around 0. My hunch is that in this case the main source of improvement is the updated initialization, and perhaps indirect impact on the learning rate. The way the skip gate is applied bounds the relative block contribution of the skip weight from 0 to 0.5, since at most the x:x_skip ratio can be 1:1 before the norm operation. In a prior test on an A100 w/ fp16 lmhead I observed that the skip weights, which start at 1, finished the run at [0.2196, 0.4018, 0.2897, 0.2064, 0.3257, 0.3779]. The initialization of sigmoid(-1.5)=0.18 is much closer to the expected final weights, which all conveniently sit in [0,0.5]. In this range sigmoid is roughly linear anyways. A similar concept of weighting identically shaped inputs applies to the value embeddings and x0 skip. So, if there is something particular about the softmax form that is enabling improvements, perhaps it can be extended to these lambdas. I am curious if the fp16 updates were basically free, or if there was a loss increase that got counteracted by the skip sigmoid gate. |
|
A month ago I tried initializing the skip lambda to 0 instead of 1, but the loss at 125 is substantially higher on average, so I cancelled the run early. Motivated by this PR, I ran some more tests on this. Interestingly, the loss finishes meaningfully lower when the skip lambda is initialized to 0 instead of 1. Sigmoid appears roughly equivalent to initializing to .18 or 0. Neat example where the early loss curve is deceptive. Hypothesizing that initializing skip lambdas to 1 was encouraging a shallower architecture, which performed better only for the first portion of training. at step 125: at step 1670: I also tested initializing to zero while freezing the skip lambda for the first 50 steps, gave slightly worse results. |
Thanks, yes, there is a lot of activity lately. I've just changed it now. |
|
do you guys know if gpu provider matters much? i tried to reproduce your result here with primeintellect's servers, but i'm only getting about 157.5s with this code. |
That is expected. Bernard24's setup with 8 × H100 SXM NVLink 80GB node on RunPod is about 1.3 seconds faster than what I get from lambda labs. For comparing runs, I would recommend checking relative time, similar to what Bernard24 did. |
|
This PR was actually the product of our AI system, which we have used to explore the nanoGPT benchmark. It automatically searched for algorithmic improvements and surfaced and implemented these three optimization ideas. |
That is extremely impressive. I would have expected an algorithmic approach to brute force grid search a large number of arbitrary tweaks to fit the validation set. The changes here, in particular to the optimizer, are quite principled. |
|
This is awesome @bernard24 ! I also built an RL env for it here. Would love to hear your thoughts about it! |
I love it! Been experimenting in this direction too, happy to see your result! |
|
Because many WRs are not merged yet, it seems that the diffs shown by GitHub in later PRs are too large (diffs are not working well, cuz these PR includes changes in previous WRs), maybe you can manually provide a diff only for the proposed changes in as a comment in the PR (not code) for easier review. |
The main benefit here is a larger message size in communication (higher achieved bandwidth!). In addition, I have a suggestion that |
|
@YouJiacheng In the "Files Changed" section you may find it helpful to filter by commits, e.g. you can see diffs from just the latest commit via this page |
|
thx |
This PR takes all recent improvements including PR #122 from today, and adds on top of that the following three ideas:
Replacing in the Muon optimizer Python for-loops with vectorized tensor operations using PyTorch. This is done for improved gradient sharding, padding, and parameter synchronization.
Cast more tensors and buffers (embeddings, linear layers, optimizer state, positional encodings) to torch.bfloat16. This allows us to get faster experiments with minimal changes in model accuracy.
Apply sigmoid gating to U-Net skip connections; initialize skip weights to -1.5 for better learnability. Instead of directly multiplying skip connections by a raw trainable parameter (which could be unbounded and unstable), the code now passes the skip weight through a sigmoid function. This constrains the gate value to the range (0, 1), making the effect of each skip connection smoothly adjustable and numerically stable.
This improves the runtime by 2 seconds, i.e. 1.25%, see below.
Validation
I’ve used a 8 × H100 SXM NVLink 80GB node on RunPod. The results I’ve been getting when benchmarking PR #122 are a bit better than the ones reported there. So here I present the statistics of both PR #122 and this PR when using that node:
Validation for PR #122
Validation for the current PR: