New WR 151.5s: Drop first attn layer, extend all long windows for validation, update schedule#131
Conversation
Merge PR 118
…aining, improve skip connection gating, and enhance bfloat16 usage
|
Does it mean we just dropped 0th block entirely (no attention and no mlp) and have only 11 blocks stacked? Have you tried to drop mlp in 1th block in this case? If it was profitable for 12 blocks maybe it still will be for 11? |
Yes, no layer 0. I did not try dropping mlp layer 1, or shifting the value embeds back a layer. I also did not revisit the skip connection design (layer 11 now has no skip connection input). These recent architecture shifts probably mean the model is no longer in a hyperparameter local minimum and there are some easier to achieve gains from trying things out. |
This PR builds on all recent WR improvements including PR #130. Updates:
Several factors led to dropping the first attention layer:
Reason for iteration_extension:
Future Opportunities
This change bring the total number of [4,768,768] attention variables to 10. There are 22 MLP variables of size [768x4,768]. In Muon attention is getting batched such that 6/16ths on the gradient calcs are on padding tokens. There may be a way to move 2 of the attention variables into the MLP batch, such that MLP is 24/24 and attn is 8/8, instead of MLP being 22/24 and attn being 10/16.
Investigating Muon for 1D variables
Currently the attention gates and smear gate are passed into Muon. From light inspection, the implementation of newton schulz appears to roughly apply F.normalize(x, p=2, dim=-1) for 1d variables. This normalization makes all steps cover roughly the same distance, regardless of the gradient. So for 1d variables Muon turns into an exponential smoothing over prior gradients, where each step is normalized to be roughly the same size. This seems somewhat reasonable. Swapping these variables over to Adam gave roughly a 0.5s runtime increase and no improvement in loss. Directly replacing newton schulz with F.normalize(x, p=2, dim=-1) for these variables showed slightly worse performance. I do not understand the theory here yet, but empirically the performance is good.
Validation:
Code syntax/naming was lightly refactored after performing validation runs. Loss is roughly 0.001 lower than prior record, which is roughly equal to 1s.