Skip to content

New WR 156s (1.25% better than PR #122): Optimize distributed training, improve skip connection gating, and enhance bfloat16 usage#125

Merged
ClassicLarry merged 11 commits intoKellerJordan:masterfrom
bernard24:new_wr
Oct 15, 2025
Merged

New WR 156s (1.25% better than PR #122): Optimize distributed training, improve skip connection gating, and enhance bfloat16 usage#125
ClassicLarry merged 11 commits intoKellerJordan:masterfrom
bernard24:new_wr

Conversation

@bernard24
Copy link
Copy Markdown

@bernard24 bernard24 commented Sep 11, 2025

This PR takes all recent improvements including PR #122 from today, and adds on top of that the following three ideas:

  • Replacing in the Muon optimizer Python for-loops with vectorized tensor operations using PyTorch. This is done for improved gradient sharding, padding, and parameter synchronization.

  • Cast more tensors and buffers (embeddings, linear layers, optimizer state, positional encodings) to torch.bfloat16. This allows us to get faster experiments with minimal changes in model accuracy.

  • Apply sigmoid gating to U-Net skip connections; initialize skip weights to -1.5 for better learnability. Instead of directly multiplying skip connections by a raw trainable parameter (which could be unbounded and unstable), the code now passes the skip weight through a sigmoid function. This constrains the gate value to the range (0, 1), making the effect of each skip connection smoothly adjustable and numerically stable.

This improves the runtime by 2 seconds, i.e. 1.25%, see below.

Validation

I’ve used a 8 × H100 SXM NVLink 80GB node on RunPod. The results I’ve been getting when benchmarking PR #122 are a bit better than the ones reported there. So here I present the statistics of both PR #122 and this PR when using that node:

Validation for PR #122

import scipy.stats
import torch

accs = [3.2798, 3.2798, 3.2829, 3.2785, 3.2783, 3.2787, 3.2787, 3.2784, 3.2821, 3.2794, 3.2786, 3.2765, 3.2794, 3.2776, 3.2778, 3.2774, 3.2777]

times = [157.977, 157.889, 158.014, 158.103, 158.093, 158.001, 158.089, 157.981, 158.019, 157.963, 158.043, 157.957, 157.880, 157.687, 158.002, 157.947, 158.097]

print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue)
# p=0.0069

print("acc:", torch.std_mean(torch.tensor(accs)))
# acc: (tensor(0.0016), tensor(3.2789))

print("time:", torch.std_mean(torch.tensor(times)))
# time: (tensor(0.1021), tensor(157.9848))

Validation for the current PR:

import scipy.stats
import torch

accs = [3.277, 3.2772, 3.2778, 3.2767, 3.2805, 3.2781, 3.2797, 3.2802, 3.2774, 3.2767, 3.2769, 3.2783]

times = [155.902, 155.956, 156.043, 155.987, 155.980, 155.717, 156.019, 156.077, 156.064, 156.100, 156.129, 155.799]

print("p=%.4f" % scipy.stats.ttest_1samp(accs, 3.28, alternative="less").pvalue)
# p=0.0002

print("acc:", torch.std_mean(torch.tensor(accs)))
# acc: (tensor(0.0014), tensor(3.2780))

print("time:", torch.std_mean(torch.tensor(times)))
# time: (tensor(0.1233), tensor(155.9811))

@varunneal
Copy link
Copy Markdown
Contributor

Nice work vectorizing muon!

I believe that this optimization is unneeded

x = x0 = norm(self.embed(input_seq)[None]).to(torch.bfloat16)

since x should be in bf16 already due to this cast:

for m in model.modules():
    if isinstance(m, nn.Embedding):
        m.bfloat16()

Glad to see that keeping the linear weights in bf16 is effective.

@varunneal
Copy link
Copy Markdown
Contributor

Also heads up this is pr125 not 123 haha

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Insane turnaround time on this one wow!

In principle I think softmax seems reasonable to prevent a hypothetical where the weight oscillates around 0. My hunch is that in this case the main source of improvement is the updated initialization, and perhaps indirect impact on the learning rate.

The way the skip gate is applied bounds the relative block contribution of the skip weight from 0 to 0.5, since at most the x:x_skip ratio can be 1:1 before the norm operation. In a prior test on an A100 w/ fp16 lmhead I observed that the skip weights, which start at 1, finished the run at [0.2196, 0.4018, 0.2897, 0.2064, 0.3257, 0.3779]. The initialization of sigmoid(-1.5)=0.18 is much closer to the expected final weights, which all conveniently sit in [0,0.5]. In this range sigmoid is roughly linear anyways.

A similar concept of weighting identically shaped inputs applies to the value embeddings and x0 skip. So, if there is something particular about the softmax form that is enabling improvements, perhaps it can be extended to these lambdas.

I am curious if the fp16 updates were basically free, or if there was a loss increase that got counteracted by the skip sigmoid gate.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

A month ago I tried initializing the skip lambda to 0 instead of 1, but the loss at 125 is substantially higher on average, so I cancelled the run early.

Motivated by this PR, I ran some more tests on this. Interestingly, the loss finishes meaningfully lower when the skip lambda is initialized to 0 instead of 1. Sigmoid appears roughly equivalent to initializing to .18 or 0. Neat example where the early loss curve is deceptive. Hypothesizing that initializing skip lambdas to 1 was encouraging a shallower architecture, which performed better only for the first portion of training.

at step 125:
init_1_step_125: 4.2939 [4.2857, 4.2892, 4.2943, 4.3018, 4.2986]
init_.18_step_125: 4.3343 [4.3372 4.3453 4.32 4.3261 4.3428]
init_0_step_125: 4.3298 [4.3509 4.3185 4.3563 4.3176 4.3055]
sigmoid_step_125: 4.3106 [4.3155 4.3007 4.3131 4.3265 4.297 ]

at step 1670:
init_1_step_1670: 3.2787 [3.2785, 3.2806, 3.2779, 3.2789, 3.2778]
init_18_step_1670: 3.2781 [3.2766 3.2774 3.2776 3.2793 3.2794]
init_0_step_1670: 3.2778 [3.2789 3.2773 3.2779 3.2768 3.278 ]
sigmoid_step_1670: 3.2779 [3.2784 3.2787 3.2805 3.2764 3.2755]

I also tested initializing to zero while freezing the skip lambda for the first 50 steps, gave slightly worse results.
freeze_step_125: 4.367 [4.3788 4.3765 4.3854 4.3772 4.3173]
freeze_step_1670: 3.2791 [3.2801 3.2795 3.2805 3.2782 3.2772]

@bernard24
Copy link
Copy Markdown
Author

Also heads up this is pr125 not 123 haha

Thanks, yes, there is a lot of activity lately. I've just changed it now.

@Gusarich
Copy link
Copy Markdown
Contributor

do you guys know if gpu provider matters much? i tried to reproduce your result here with primeintellect's servers, but i'm only getting about 157.5s with this code.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

do you guys know if gpu provider matters much? i tried to reproduce your result here with primeintellect's servers, but i'm only getting about 157.5s with this code.

That is expected. Bernard24's setup with 8 × H100 SXM NVLink 80GB node on RunPod is about 1.3 seconds faster than what I get from lambda labs. For comparing runs, I would recommend checking relative time, similar to what Bernard24 did.

@bernard24
Copy link
Copy Markdown
Author

This PR was actually the product of our AI system, which we have used to explore the nanoGPT benchmark. It automatically searched for algorithmic improvements and surfaced and implemented these three optimization ideas.
More information here.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

This PR was actually the product of our AI system, which we have used to explore the nanoGPT benchmark. It automatically searched for algorithmic improvements and surfaced and implemented these three optimization ideas. More information here.

That is extremely impressive. I would have expected an algorithmic approach to brute force grid search a large number of arbitrary tweaks to fit the validation set. The changes here, in particular to the optimizer, are quite principled.

@leloykun
Copy link
Copy Markdown
Contributor

This is awesome @bernard24 ! I also built an RL env for it here. Would love to hear your thoughts about it!

@Gusarich
Copy link
Copy Markdown
Contributor

This PR was actually the product of our AI system, which we have used to explore the nanoGPT benchmark. It automatically searched for algorithmic improvements and surfaced and implemented these three optimization ideas. More information here.

I love it! Been experimenting in this direction too, happy to see your result!

@YouJiacheng
Copy link
Copy Markdown
Contributor

Because many WRs are not merged yet, it seems that the diffs shown by GitHub in later PRs are too large (diffs are not working well, cuz these PR includes changes in previous WRs), maybe you can manually provide a diff only for the proposed changes in as a comment in the PR (not code) for easier review.

@YouJiacheng
Copy link
Copy Markdown
Contributor

Replacing in the Muon optimizer Python for-loops with vectorized tensor operations using PyTorch. This is done for improved gradient sharding, padding, and parameter synchronization.

The main benefit here is a larger message size in communication (higher achieved bandwidth!).
But it will also hurt overlapping (not important for small track tho).
Ideally there should be a binning strategy, balancing message size and overlapping.

In addition, I have a suggestion that eff_lr_val and eff_weight_decay_val should be computed separately for each parameter: lr_mul and wd_mul can be different.

@varunneal
Copy link
Copy Markdown
Contributor

@YouJiacheng In the "Files Changed" section you may find it helpful to filter by commits, e.g. you can see diffs from just the latest commit via this page

@YouJiacheng
Copy link
Copy Markdown
Contributor

thx

@ClassicLarry ClassicLarry merged commit b2307bd into KellerJordan:master Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants