Skip to content

New WR 149.6s (-0.7s): MuonCustomSizing, perform mlp and attn reduce scatter in shared call#132

Merged
ClassicLarry merged 17 commits intoKellerJordan:masterfrom
ClassicLarry:MuonCustomSizing
Oct 15, 2025
Merged

New WR 149.6s (-0.7s): MuonCustomSizing, perform mlp and attn reduce scatter in shared call#132
ClassicLarry merged 17 commits intoKellerJordan:masterfrom
ClassicLarry:MuonCustomSizing

Conversation

@ClassicLarry
Copy link
Copy Markdown
Collaborator

This PR builds on all recent WR improvements including PR #131. Updates:

  • Add Muon Custom Sizing
class Muon(torch.optim.Optimizer):
    """

    ...

    Custom distributed sizing:
    The model stores all attn and mlp weights in the same shape, and then updates the view as 
    needed on the forward pass. This enables attn and mlp weights to be contained within the same 
    dist.reduce_scatter_tensor() call. The model architecture has been customized to enable 
    (n_attn_layers+n_mlp_layers*2)%4==0 for batching across 8 GPUs with zero padding. The scheduling is:
        1. reduce scatter smear_gate (1 param 7 padding params)
        2. reduce scatter attn_gate (10 params 6 padding params)
        3. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params)
        4. reduce scatter attn/mlp round 2 (16 mlp params)
        5. wait on step 1, then compute NS of 1 and schedule all gather
        6. wait on step 2, then compute NS of 2 and schedule all gather
        7. wait on step 3, then compute NS of 3 and schedule all gather
            GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP]
            GPUs that receive params of type attn reshape before NS
        8. wait on 4, then compute NS of 4 and schedule all gather
        9. wait for each all gather to complete and update params
    Empirically, leading with small params provides an additional 0.2s improvement.
    """

    def generate_custom_param_groups(self, params):
        # implementation requires that a single GPU does not recieve both attn 
        # and mlp params when a param group is split across GPUs
        module_ranks = {
            'smear_gate': 1, # 1 param
            'attn_gate': 2, # 10 params
            'attn': 3, # 10 params
            'mlp': 4, # 22 params
        }
        params = list(params)
        params.sort(key=lambda x: module_ranks.get(x.module))
        idx = 0
        group_sizes = [1,10,16,16]
        assert len(params)==sum(group_sizes)
        param_groups = []
        for size in group_sizes:
            group_params = params[idx:idx+size]
            param_groups.append(dict(params=group_params))
            idx += size
        return param_groups

    if getattr(params[module_idx],'module','none')=='attn':
        batch = 4 * original_shape[0]
        d1 = original_shape[1] 
        d2 = original_shape[2] // 4
        batched = batched_update_grads.view(batch, d1, d2)
        v_chunk = newton_schulz_triton(batched)
        v_chunk = v_chunk.view(original_shape)

Reshaping attn on forward pass:

self.qkvo_w = nn.Parameter(torch.empty(self.hdim, self.dim*4))
q, k, v = F.linear(x, self.qkvo_w.view(4,self.hdim, self.dim)[:3].flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
y = F.linear(y, self.qkvo_w.view(4,self.hdim, self.dim)[3].type_as(y))

I also tested a single param group of all 32 MLP and attn params, where the GPU that received 2 attn and 2 mlp params would then perform 2 iterations of NS. That version was roughly 0.4s slower.

Skipping ML validation since ML is identical.

(looks like this time I was allocated slightly faster GPUs compared to prior PR)
Rerunning prior record: 150.3843 [150.393,150.347,150.413]

New runtime: 149.6905 [149.686,149.678,149.775,149.623]

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

ClassicLarry commented Sep 24, 2025

Adding comment on future promising areas for those looking to test:

  • Adam and Muon could interleave their communication and data updates.
  • iteration_extension = 40 was the only value tested for this hyperparameter. Though caution on this one as skewing the training schedule towards lower lr or higher cooldown will add variance into the outcome, so need to test to statistical significance if increasing this. (a single run of 3.274 loss can be misleading)
  • cpu to gpu passing of input data can probably be made asynchronous, as mentioned in an earlier PR. Could also explicitly schedule this to interleave with forward pass, backward pass, or optimizer step depending on where the communication bottleneck is.
  • May be worth running specific ablations on rotary to bfloat16. No specific ablation was ran on this change, and earlier Github Issues give conflicting info on it.
  • Gate params executing faster in Muon than Adam is suspicious, indicating something may be suboptimal in Adam implementation for small params, which is where the scalar params currently get updated. A custom optimizer for small params may perform faster.

Gusarich added a commit to Gusarich/modded-nanogpt that referenced this pull request Sep 27, 2025
@ClassicLarry ClassicLarry merged commit c16617b into KellerJordan:master Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants