Skip to content

New WR Sparse Attention Gate (-55 steps). Including WR changes of @byronxu99#117

Merged
ClassicLarry merged 5 commits intoKellerJordan:masterfrom
ClassicLarry:master
Oct 15, 2025
Merged

New WR Sparse Attention Gate (-55 steps). Including WR changes of @byronxu99#117
ClassicLarry merged 5 commits intoKellerJordan:masterfrom
ClassicLarry:master

Conversation

@ClassicLarry
Copy link
Copy Markdown
Collaborator

@ClassicLarry ClassicLarry commented Aug 23, 2025

New record 08/23/25

  1. Included WR improvements on Triton and grad batching from Transpose one of the MLP matrices + add Triton kernel for symmetric matmul (new WR) #109 by @byronxu99
  2. Added a sparse attention gate on the attention output to enable a context based no-op. Found the mechanism was performant with 12 active dimensions from the residual stream. If curious, here is a related blog post from an earlier investigation into non-sparse attention gate with detailed plots: https://medium.com/@larry36d/modulating-attention-scores-cc0bcd853f06. The blog demonstrates how the attention gate reduces the need for the bos_token to function as an attention sink. This is particularly relevant in a sliding window attention context because the bos_token is not always in the context window. ROPE embeddings cause the bos_token attention sink to change based on relative distance, whereas a sparse attention gate is indifferent to distance from start of sample. Estimate of impact: 50 steps fewer, with slight increase in time per step.
  3. As a follow-on from 2: Reduced number of iterations from 1750 to 1695.
  4. Reverted the lm head scaling changes made on Feb 10th: 85a0a52. When tested on a single A100, reverting this change drops the L2 norm of the LM head weights from 250 down to 10. The logits need to express values roughly from -10 to 10 in order to capture the range of token probabilities. Dividing by 27.5 (x.size(-1)**0.5) was causing the weights to grow substantially to accomplish this, since the residual stream was being normed prior to the lm_head. The second moment estimate of Adam depends on the parameter scale, and the Adam learning rates were likely heavily tuned prior to the Feb 10th update. If curious, more details near end of this blog post: https://medium.com/@larry36d/exploration-log-exploring-initializing-transformers-with-bigram-distribution-70f9c8800b21. Estimate of impact: 5-10 steps. (in this case just a cleaner cut below 3.28)
  5. Chose to keep the minimum lr at 0.1. The bos_align record decreased the minimum lr to 0.05, and a later refactor, perhaps unintentionally, moved it back to 0.1. On further testing, the impact of this value on mean loss is marginal, but lower minimum lr appear to increase the variance of the final loss, making testing more challenging. Lower minimum lr may have higher variance because its committing to diving deep in the local space earlier, and is somewhat rolling the dice on if its a promising region or not. On reflection, I likely originally picked 0.05 because taking the min loss over a grid search will naturally bias to higher variance configurations, which is the opposite of what we want.

Validated results (p=0.0059) with 14 runs:

import scipy.stats
import torch

accs = [3.2774, 3.2782, 3.2796, 3.2815, 3.276 , 3.2777, 3.2784, 3.2795,
       3.281 , 3.2802, 3.2767, 3.2772, 3.28  , 3.2786
    ]
times = [
    168.627, 169.037, 169.003, 168.727, 168.647, 169.024, 168.917,
       168.999, 168.728, 169.07 , 168.981, 168.938, 168.718, 167.122]

print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue)
# p=0.0059

print('acc:',torch.std_mean(torch.tensor(accs)))
# acc: (tensor(0.0016), tensor(3.2787))

print('time:',torch.std_mean(torch.tensor(times)))
# time: (tensor(0.4946), tensor(168.7527)) 
# Running on fresh cluster gave 167.695. actively working in jupyter notebooks on same machines during these runs may be adding variance to timing

Ablations

The following comparisons were made against this candidate WR run:

  1. 96_attn_gate. Feed the attention gate with 96 dimensions of the residual stream instead of 12. This showed decrease in accuracy. Performance roughly linearly decayed from d=12 up to d=768, d=(6,24,48) were also tested for 1 run. I don't have an explanation. Very hand wavy speculation: Masking out attention is a very sensitive operation. Perhaps if the full residual stream is attending to this task, it loses focus on the main objective of building a representation of the next token.
  2. no_attn_gate. Apply standard attention without a gate.
  3. full_attn_gate. Feed the attention gate with all 768 dimensions instead of 12. Higher loss than no attention gate, and the highest time per step.
  4. no_bos_align. Meaningful impact on loss.
  5. no_bos_align_or_gate. Highest loss of all ablations, as the model now has no good mechanism for handling attention sinks.
  6. no_bos_token. Updating document mask to treat bos_token as an isolated document. Interestingly, this still performs decently well as long as the attention gate exists and can take on the role of attention sink.
  7. no_logit_scale. Minor impact, point 4 from implemented changes.
image

Runtimes below. I advice caution on interpreting the run times too heavily- a model that is 0.01 loss off when its lr has decayed to 0.1 is still very far from achieving 3.28 loss.
image

Negative and neutral test results during this process:

  1. Initialize embedding tokens using bigram distribution. Bigram statistics can be calculated for 100 million tokens in ~1 second or less. I tested initializing the embedding layer using z = relu(log(p(y|x)/p(y))); embed = norm(rand_linear(z)), where p(y|x) is the bigram prob of token y given x. This initialization makes it so that tokens with similar bigram statistics will have similar embeddings. If I froze the embedding layer, this initialization performed better than random initialization. However, for non frozen embeddings, the impact was not statistically significant.
  2. Weight freezing during training. Since the majority of the time on each step is spent computing the gradient, freezing a subset of weights can substantially decrease time per step. Unfortunately, all combinations tested of this failed to yield an improvement. Typical matmul ops require N FLOPS on the forward pass and 2N FLOPs on the backwards pass. The 2N is to compute the gradient with respect to the weights to update the weights, and the gradient with respect to the data to pass the gradient onwards. The torch compiler is smart enough to compute only N FLOPS on the backwards pass for leaf operations. To leverage this, I tested updating the first 3 layers to run in parallel, and then froze the embedding after a portion of training, such that 3 layers became leaf operations. The change was not kept as the performance drop outweighed the speedup.
  3. Logit shift parameter. The residual space activations for all positions are heavily aligned away (>120 degrees) from the lm_head vectors of tokens that never appear in the training set. In other words, the ~400 tokens that never appear in the 50348 vocab size (including the 91 padding vocab entries) may be skewing the topology of the activations in the residual stream. Adding a simple logits += logit_shift enables the model to learn the unigram distribution directly (or even just a static variable that is -inf on padding tokens), without disrupting the residual space. Unfortunately, my implementation of this change was giving memory issues on an A100. On the H100 setup, the change dropped the loss by 0.01 but was slightly edged out by the increase in time per step. I don't have the budget to fiddle substantially with params I can't test on an A100. If a more compute optimized version can be found, this is an easy improvement to the loss, likely equivalent to 50+ steps.
  4. Removing torch.compile on zeropower_via_newtonschulz5(). Surprisingly, the torch compiler makes the output of newtonschulz() vary based on the batch dimension size, with a 2% change depending on the batch size. This is relevant when we are batching kqv in one op. This appears to occur because of rounding issues with bfloat16 and some internal accumulations the compiler is altering, as the percent diff drops to less then 0.1% for float32. On an A100 removing the compile gave an improvement when I was testing different batch sizes, but the change was not statistically significant on H100 w/ fp8 lm_head. Unclear exactly what is going on here, but noting that bfloat16 can lead to very unintuitive consequences.
  5. Megabatch NetwonSchulz. Inspired by @byronxu99, I tested further impacts of batching for zeropower_via_newtonschulz5(). The results were quite surprising on an A100. The run time was heavily dependent on the batch size, with larger batch sizes running up to twice as fast, based on initial testing (honestly need to sanity check this, seemed too crazy). As a result, I experimented with setting all MLP params in 1 contiguous variable and doing a single iteration of zeropower_via_newtonschulz5(), with [3,4*768,768] input to each GPU as a single pass, and [6,768,768] for Attn to each GPU as a second pass. This gave a total of only 2 iterations of zeropower_via_newtonschulz5() on each GPU per step. I was running into memory errors on the 8H100 setup, and need to get a cheaper distributed setup before I test further.
  6. 0.5 init weighting for x0 stream instead of 0. At the end of training on a A100, the x0 weight for many layers is 50x higher than the x weight. Updating the weighting to 0.5 gave a statistically significant improvement on A100, but this was not replicated on the 8H100 setup with fp8 lmhead.
  7. Normalize value embedding inputs during forward pass. Seemed like a natural thing to do given norms on the input embedding and the existing lambda to scale value weights. However, this yielded worse performance, perhaps because the value embeddings need to have much high weight than the values and the lambda scaling parameter was not tuned to handle this itself.
  8. Renormalize embedding in place between each forward pass. The L2 norm of the embedding layer is climbing from 27 to 500 over the course of training, leading to a different effective learning rate depending on the stage of training. Normalizing this parameter may enable the lr to be tuned more precisely. However, I found norm() still needed to be included in the forward pass for an accurate grad calc, at which point the compute penalty for a second norm outside the forward pass became not worthwhile.
  9. Removing value entirely (only use value embedding) for first and last 3 layers. The trained weights indicate that the value embedding is dominating the calculated attention value, and I can save some matmul ops if I can drop 6 layers of value calcs. The change cost roughly 0.015 loss, which unfortunately was worth more than the speedup achieved based on the parameters used.
  10. Bigram full initialization. Similar to 1, I tested initializing the lm_head and embedding layer to approximate the bigram distribution. (Bigram could in theory cause learning to start around 5.7 loss, with potentially better generalization during training). Unfortunately, it is not analytically simple to set embed and lm_head to achieve a known bigram distribution, because of the nonlinearity of the softmax. Attempting to approximate this yielded worse performance than random initialization.
  11. Dual loss on bigram distribution. I tested having the first X iterations minimize a combination of the next token prediction loss, along with the bigram distribution for that token. Intuition was that since I can compute the bigram distribution of 100 million tokens in 1s, the bigram distribution encodes a higher density of information than a single high variance loss signal of a 500,000 token batch. However, the 50,000x50,000 bigram matrix proved too bulky for compute efficient steps.

@YouJiacheng
Copy link
Copy Markdown
Contributor

YouJiacheng commented Aug 23, 2025

Interestingly, I tried the same attention gate (inspired by https://arxiv.org/abs/2505.06708) on the 2.92 track but it didn't work.
Is using the first 12 dims necessary to see improvement or it's just for speedup?
I tried per-head gate like your submission (but use full x instead of [..., :12] as input) but it makes a negligible difference.
I also tried per-channel gate, but it isn't worth the extra compute.
In contrast, I saw a significant improvement by using gpt-oss sinks on the 2.92 track (2.9196→2.9164).

@YouJiacheng
Copy link
Copy Markdown
Contributor

In addition, I kept failing at preventing massive activation on modded-nanogpt. Did you successfully prevent MA?
And is bos_align necessary?

@YouJiacheng
Copy link
Copy Markdown
Contributor

YouJiacheng commented Aug 23, 2025

I tried it again on 2.92 track, with first 16 dims as gate input.
but it still didn't work...
hmmmmmm

@YouJiacheng
Copy link
Copy Markdown
Contributor

YouJiacheng commented Aug 23, 2025

I have a hypothesis that fp8 lm_head and bos_align might be correlated to sparse gate. (by "correlated" I mean sparse gate's effectiveness might depend on them)
Should you have the availability, your insights on testing this would be invaluable.

@YouJiacheng
Copy link
Copy Markdown
Contributor

YouJiacheng commented Aug 23, 2025

hmmm I also wonder if this weight needs a separate lr. (ideally, cuz we follow muP, we don't need a separate lr).
I even observed a negative effect with [..., :16] → [..., num_heads] gate.

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

I'll add some ablation plots from some fresh tests on 8 H100 GPUs. I primarily test on a single A100 without fp8, and both the bos_align and sparse attention gate showed strong correlation there. Typically by step 500 I know if a change is good or not. I haven't followed the 2.92 track. I am curious if there are other substantial differences in its setup.

I'd expect a training batch size of 64*1024 on the 2.92 track to cause a reduced benefit of bos_align, since fewer samples will get truncated during training. I can say off the bat that non-sparse attention gate hurts performance, which is the same thing I saw when I was exploring attention mechanisms a month ago. The concept of non-sparse attention gates has likely been independently discovered and tested many times over, eg here's another paper on it from 2023 that I found after testing the idea: https://papers.nips.cc/paper_files/paper/2023/file/edbcb7583fd8921dad78adecfe06a99b-Paper-Conference.pdf.
At least for this speed run architecture, it needs to be sparse or loss increases.

@YouJiacheng
Copy link
Copy Markdown
Contributor

uh by correlation I mean, does sparse gate work without bos_align?

@YouJiacheng
Copy link
Copy Markdown
Contributor

large batch NS is indeed faster, but based on previous profiling, NS is almost hidden by all gather communication.
but on another hand, batched AG is also faster… (but harder to overlap)

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

Ablations added. Main conclusions from them:

  1. bos_align and sparse attention gate both independently improve loss.
  2. The attention gate needs to be sparse. dim=12 performs substantially better than dim=96, which performs better than dim=768 (in fact dim=768 performs worse than no attention gate at all)
  3. Removing the bos_token entirely from the context window still performs reasonably well if attention gate is present.
  4. Logit scaling helps bring loss more cleanly under 3.28.

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

I have a hypothesis that fp8 lm_head and bos_align might be correlated to sparse gate. (by "correlated" I mean sparse gate's effectiveness might depend on them) Should you have the availability, your insights on testing this would be invaluable.

Based on the ablations above on H100, and prior testing on an A100 w/ no fp8, my hunch is that the lack of improvement on the 2.92 track has a different cause. One confounder here is that the gate is initialized to 0.5, which not easily counter-acted by the value weights since the value gets normed. The value embedding can scale up and compensate, but the 2.92 track has twice as many layers without value embeddings. The 2.92 track also has roughly twice as many heads. It may be that this initialization has a different impact on how the network heads wake up. But I have never ran the 2.92 track, so my intuition there is quite limited. I am at my budget cap for now, but I think there is still a lot to explore here.

@YouJiacheng
Copy link
Copy Markdown
Contributor

thanks a lot!
normalized value should not be a big problem because we have sa_lambdas which can easily scales the value and compensate the 0.5 init gate.

@YouJiacheng
Copy link
Copy Markdown
Contributor

The attention gate needs to be sparse. dim=12 performs substantially better than dim=96, which performs better than dim=768 (in fact dim=768 performs worse than no attention gate at all)

this probably is caused by the gate is optimized by muon.

another factor is that the 2.92 track has weight decay.

@YouJiacheng
Copy link
Copy Markdown
Contributor

YouJiacheng commented Aug 24, 2025

update: I made it work on 2.92 track by a simple change.
this change might be also useful for 3.28 track (needs extra tuning, but the first value I try just worked on 2.92 track).

hmmm I also wonder if this weight needs a separate lr. (ideally, cuz we follow muP, we don't need a separate lr).

this hypothesis is correct.

without the multiplier, it doesn't work on 2.92 track with (num_heads, 16) shape self.gate_w.

y = y * torch.sigmoid(F.linear(x[..., :self.gate_w.size(-1)], self.gate_w)[..., None])

with the multiplier, it works on 2.92 track with (num_heads, 16) shape gate_w.

y = y * torch.sigmoid(0.1 * F.linear(x[..., :self.gate_w.size(-1)], self.gate_w)[..., None])

0.1 is the first value I tried, so it's unlikely to be optimal. A few runs show it's on-par with gpt-oss's sinks.
(but gpt-oss's sinks can mitigate MA better)
Since Muon is scale-invariant and gate_w is zero init, 0.1 multiplier here is equivalent to a separate lr of 0.1×.

I think multiplier might not be optimal on 3.28 track either.

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

0.1 is the first value I tried, so it's unlikely to be optimal. A few runs show it's on-par with gpt-oss's sinks.
(but gpt-oss's sinks can mitigate MA better)

Neat! I'm curious on the runtime and implementation details of your gpt-oss code. I avoided softmax denominator manipulation approaches because I wasn't sure how to make it performant with flex_attention.

@YouJiacheng
Copy link
Copy Markdown
Contributor

uh I use an extra sigmoid (plus flex_attention return_lse=True). it's about ~1.5% slower per step.
slightly slower than (num_heads, 16) gate.

y = y * torch.sigmoid(lse[..., None] - sinks[None, :, None, None])

@byronxu99
Copy link
Copy Markdown
Contributor

Congrats @ClassicLarry! Very nice work.

Very hand wavy speculation: Masking out attention is a very sensitive operation. Perhaps if the full residual stream is attending to this task, it loses focus on the main objective of building a representation of the next token.

If attention gating is a very difficult and important operation, we might expect that more compute-intensive approaches would give better per-step performance (but might be too slow for wall-clock time).

@ClassicLarry and @YouJiacheng - I wonder how much you have experimented with other approaches, and if you have any interesting findings.

  • Learnable scalar attention bias like GPT-OSS: This would be a very lightweight approach with one scalar per attention head, and it also eliminates any potential effect on the residual stream. But maybe it has the drawback that the gating effect weakens with sequence length (scalar value stays the same while softmax denominator grows), whereas your approach gates independently of the number of previous tokens in the sequence.
  • Vector attention bias consisting of learnable K and zero V: As suggested by this paper, each attention head gets a learnable K vector with corresponding all-zero V. The dot product of each Q with this K vector produces zero output and thus functions as an attention sink. Being query-dependent, it would be more expressive than the GPT-OSS approach and can adjust the amount of sinking for each token. It still isn't truly independent of sequence length.
  • "QKVG" attention: Each attention head gets another matrix that outputs a head_dim vector G per token. The head's output is gated by the scalar value sigmoid(dot_product(Q, G)). This adds many more learnable parameters to the model - an entire additional (d_model, n_head * head_dim) matrix. If you have the intuitive sense that deciding how much to gate the attention head's output is as hard as deciding how much to attend to any previous token, this approach might make more sense. Maybe it would work better for larger models with GQA, since a single G could be shared for multiple Q's just like K and V are (and G doesn't need to be KV-cached for inference).

I use an extra sigmoid (plus flex_attention return_lse=True)

@YouJiacheng - That's a very efficient and elegant way to implement GPT-OSS attention sink. If I'm understanding it correctly:

The softmax denominator used in flex_attention is e^LSE, where LSE is the returned value. We would like it to be e^LSE + e^sink. Because sigmoid(LSE - sink) = e^(LSE-sink) / (1 + e^(LSE-sink)) = e^LSE / (e^sink + e^LSE), we can simply multiply the attention output by the sigmoid expression to change the denominator from e^LSE to e^sink + e^LSE. So it's identical to modifying the softmax denominator, except it's done efficiently without needing a custom attention kernel.

The attention gate needs to be sparse. dim=12 performs substantially better than dim=96, which performs better than dim=768 (in fact dim=768 performs worse than no attention gate at all)

this probably is caused by the gate is optimized by muon.

I also suspect that optimizing the gate matrix with Muon is affecting things.

  • The gate matrix doesn't behave the same as regular weight matrices (vector to vector linear transform). Instead it maps a vector to independent scalars, which you use to modulate each attention head's output. This seems closer to how the LM head behaves, so maybe it's better to use Adam for the gate.
  • Muon seems to work the least well with aggressive down-projection matrices. NanoGPT has only 6 attention heads. If you use the whole 768-dimension hidden state, the matrix has a very lopsided shape (d_in, d_out) = (768, 6). By truncating the input to only 12 dimensions, the (12, 6) matrix is much closer to square. I wonder how well a (6, 6) matrix would work, or if it becomes too restrictive.
  • This paper suggests that weight initialization for down-projection matrices should be scaled down by sqrt(d_out / d_in) (see Parametrization 1). NanoGPT uses zero init, and I saw that you also did the same for the gate matrix. Zero init certainly suffices to avoid large outputs at the start of training, but I do wonder if there are potential gains here.

@ClassicLarry
Copy link
Copy Markdown
Collaborator Author

ClassicLarry commented Aug 24, 2025

I wonder how much you have experimented with other approaches, and if you have any interesting findings.

I have not tried other attention approaches yet. I think there's a chance that gpt_oss and sparse attn gate approaches show improved results when combined, since they are doing different things. sparse attn gate lets a query decide if it wants to look for a key, independent of other tokens. gpt_oss assumes that a query does want to look for a key, and then modulates based on if it finds one. Ideally, a conceptually complete pairing mechanism should do both.

Muon seems to work the least well with aggressive down-projection matrices. NanoGPT has only 6 attention heads. If you use the whole 768-dimension hidden state, the matrix has a very lopsided shape (d_in, d_out) = (768, 6). By truncating the input to only 12 dimensions, the (12, 6) matrix is much closer to square. I wonder how well a (6, 6) matrix would work, or if it becomes too restrictive.

I had done one run at 6 and it performed slightly worse than 12 (basically the same, couldn't tell statistically from 1 run). Interesting insights on Muon. I agree that swapping it with Adam here seems worth trying.

@YouJiacheng
Copy link
Copy Markdown
Contributor

I'm still tuning gate lr.
with a lr_mul = 0.01, it's significantly better than 0.1 on 2.92 track. So I guess the optimal lr_mul on the 3.28 track is not 1.

@ClassicLarry ClassicLarry merged commit d7e16ad into KellerJordan:master Oct 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants