New WR Sparse Attention Gate (-55 steps). Including WR changes of @byronxu99#117
Conversation
|
Interestingly, I tried the same attention gate (inspired by https://arxiv.org/abs/2505.06708) on the 2.92 track but it didn't work. |
|
In addition, I kept failing at preventing massive activation on modded-nanogpt. Did you successfully prevent MA? |
|
I tried it again on 2.92 track, with first 16 dims as gate input. |
|
I have a hypothesis that fp8 lm_head and bos_align might be correlated to sparse gate. (by "correlated" I mean sparse gate's effectiveness might depend on them) |
|
hmmm I also wonder if this weight needs a separate lr. (ideally, cuz we follow muP, we don't need a separate lr). |
|
I'll add some ablation plots from some fresh tests on 8 H100 GPUs. I primarily test on a single A100 without fp8, and both the bos_align and sparse attention gate showed strong correlation there. Typically by step 500 I know if a change is good or not. I haven't followed the 2.92 track. I am curious if there are other substantial differences in its setup. I'd expect a training batch size of 64*1024 on the 2.92 track to cause a reduced benefit of bos_align, since fewer samples will get truncated during training. I can say off the bat that non-sparse attention gate hurts performance, which is the same thing I saw when I was exploring attention mechanisms a month ago. The concept of non-sparse attention gates has likely been independently discovered and tested many times over, eg here's another paper on it from 2023 that I found after testing the idea: https://papers.nips.cc/paper_files/paper/2023/file/edbcb7583fd8921dad78adecfe06a99b-Paper-Conference.pdf. |
|
uh by correlation I mean, does sparse gate work without bos_align? |
|
large batch NS is indeed faster, but based on previous profiling, NS is almost hidden by all gather communication. |
|
Ablations added. Main conclusions from them:
|
Based on the ablations above on H100, and prior testing on an A100 w/ no fp8, my hunch is that the lack of improvement on the 2.92 track has a different cause. One confounder here is that the gate is initialized to 0.5, which not easily counter-acted by the value weights since the value gets normed. The value embedding can scale up and compensate, but the 2.92 track has twice as many layers without value embeddings. The 2.92 track also has roughly twice as many heads. It may be that this initialization has a different impact on how the network heads wake up. But I have never ran the 2.92 track, so my intuition there is quite limited. I am at my budget cap for now, but I think there is still a lot to explore here. |
|
thanks a lot! |
this probably is caused by the gate is optimized by muon. another factor is that the 2.92 track has weight decay. |
|
update: I made it work on 2.92 track by a simple change.
this hypothesis is correct. without the multiplier, it doesn't work on 2.92 track with y = y * torch.sigmoid(F.linear(x[..., :self.gate_w.size(-1)], self.gate_w)[..., None])with the multiplier, it works on 2.92 track with y = y * torch.sigmoid(0.1 * F.linear(x[..., :self.gate_w.size(-1)], self.gate_w)[..., None])
I think |
Neat! I'm curious on the runtime and implementation details of your gpt-oss code. I avoided softmax denominator manipulation approaches because I wasn't sure how to make it performant with flex_attention. |
|
uh I use an extra sigmoid (plus flex_attention return_lse=True). it's about ~1.5% slower per step. y = y * torch.sigmoid(lse[..., None] - sinks[None, :, None, None]) |
|
Congrats @ClassicLarry! Very nice work.
If attention gating is a very difficult and important operation, we might expect that more compute-intensive approaches would give better per-step performance (but might be too slow for wall-clock time). @ClassicLarry and @YouJiacheng - I wonder how much you have experimented with other approaches, and if you have any interesting findings.
@YouJiacheng - That's a very efficient and elegant way to implement GPT-OSS attention sink. If I'm understanding it correctly: The softmax denominator used in flex_attention is
I also suspect that optimizing the gate matrix with Muon is affecting things.
|
I have not tried other attention approaches yet. I think there's a chance that gpt_oss and sparse attn gate approaches show improved results when combined, since they are doing different things. sparse attn gate lets a query decide if it wants to look for a key, independent of other tokens. gpt_oss assumes that a query does want to look for a key, and then modulates based on if it finds one. Ideally, a conceptually complete pairing mechanism should do both.
I had done one run at 6 and it performed slightly worse than 12 (basically the same, couldn't tell statistically from 1 run). Interesting insights on Muon. I agree that swapping it with Adam here seems worth trying. |
|
I'm still tuning gate lr. |
8af6981 to
29f3978
Compare
29f3978 to
0832b8c
Compare
New record 08/23/25
Validated results (p=0.0059) with 14 runs:
Ablations
The following comparisons were made against this candidate WR run:
Runtimes below. I advice caution on interpreting the run times too heavily- a model that is 0.01 loss off when its lr has decayed to 0.1 is still very far from achieving 3.28 loss.

Negative and neutral test results during this process:
z = relu(log(p(y|x)/p(y))); embed = norm(rand_linear(z)), where p(y|x) is the bigram prob of token y given x. This initialization makes it so that tokens with similar bigram statistics will have similar embeddings. If I froze the embedding layer, this initialization performed better than random initialization. However, for non frozen embeddings, the impact was not statistically significant.