Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,848 changes: 2,848 additions & 0 deletions records/071825_TritonMuon/record.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/020630eb-2191-4ba2-9ee4-4cdc94316943.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/21e732fb-4c4b-4db9-94bc-9fcd5d59b080.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/4518e917-cec2-4c81-9c1a-53b0644c2326.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/48b19604-5049-48c9-956c-8ddc4d0781fb.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/50524dcb-cf95-4b75-bf89-ba8ff3c5e1af.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/53ecb4ef-77ed-4af6-b776-47cd4006614b.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/6701af06-6c40-4553-bb04-f501fdd56284.txt

Large diffs are not rendered by default.

2,802 changes: 2,802 additions & 0 deletions records/082325_SparseAttnGate/6df384bb-9c24-46b3-826b-f7c07168c27a.txt

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions records/082325_SparseAttnGate/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
## New record 08/23/25

1. Included WR improvements on Triton and grad batching from https://github.com/KellerJordan/modded-nanogpt/pull/109 by @byronxu99
2. Added a sparse attention gate on the attention output to enable a context based no-op. Found the mechanism was performant with 12 active dimensions from the residual stream. If curious, here is a related blog post from an earlier investigation into non-sparse attention gate with detailed plots: https://medium.com/@larry36d/modulating-attention-scores-cc0bcd853f06. The blog demonstrates how the attention gate reduces the need for the bos_token to function as an attention sink. This is particularly relevant in a sliding window attention context because the bos_token is not always in the context window. ROPE embeddings cause the bos_token attention sink to change based on relative distance, whereas a sparse attention gate is indifferent to distance from start of sample. Estimate of impact: 50 steps fewer, with slight increase in time per step.
3. As a follow-on from 2: Reduced number of iterations from 1750 to 1695.
4. Reverted the lm head scaling changes made on Feb 10th: https://github.com/KellerJordan/modded-nanogpt/commit/85a0a5201f08c4d6bb288ef348bb252d9c33e132. When tested on a single A100, reverting this change drops the L2 norm of the LM head weights from 250 down to 10. The logits need to express values roughly from -10 to 10 in order to capture the range of token probabilities. Dividing by 27.5 (x.size(-1)**0.5) was causing the weights to grow substantially to accomplish this, since the residual stream was being normed prior to the lm_head. The second moment estimate of Adam depends on the parameter scale, and the Adam learning rates were likely heavily tuned prior to the Feb 10th update. If curious, more details near end of this blog post: https://medium.com/@larry36d/exploration-log-exploring-initializing-transformers-with-bigram-distribution-70f9c8800b21. Estimate of impact: 5-10 steps. (in this case just a cleaner cut below 3.28)
5. Chose to keep the minimum lr at 0.1. The bos_align record decreased the minimum lr to 0.05, and a later refactor, perhaps unintentionally, moved it back to 0.1. On further testing, the impact of this value on mean loss is marginal, but lower minimum lr appear to increase the variance of the final loss, making testing more challenging. Lower minimum lr may have higher variance because its committing to diving deep in the local space earlier, and is somewhat rolling the dice on if its a promising region or not. On reflection, I likely originally picked 0.05 because taking the min loss over a grid search will naturally bias to higher variance configurations, which is the opposite of what we want.


Validated results (p=0.0059) with 14 runs:
```
import scipy.stats
import torch

accs = [3.2774, 3.2782, 3.2796, 3.2815, 3.276 , 3.2777, 3.2784, 3.2795,
3.281 , 3.2802, 3.2767, 3.2772, 3.28 , 3.2786
]
times = [
168.627, 169.037, 169.003, 168.727, 168.647, 169.024, 168.917,
168.999, 168.728, 169.07 , 168.981, 168.938, 168.718, 167.122]

print('p=%.4f' % scipy.stats.ttest_1samp(accs, 3.28, alternative='less').pvalue)
# p=0.0059

print('acc:',torch.std_mean(torch.tensor(accs)))
# acc: (tensor(0.0016), tensor(3.2787))

print('time:',torch.std_mean(torch.tensor(times)))
# time: (tensor(0.4946), tensor(168.7527))
# Running on fresh cluster gave 167.695. actively working in jupyter notebooks on same machines during these runs may be adding variance to timing
```

###Negative and neutral test results during this process:

1. Initialize embedding tokens using bigram distribution. Bigram statistics can be calculated for 100 million tokens in ~1 second or less. I tested initializing the embedding layer using `z = relu(log(p(y|x)/p(y))); embed = norm(rand_linear(z))`, where p(y|x) is the bigram prob of token y given x. This initialization makes it so that tokens with similar bigram statistics will have similar embeddings. If I froze the embedding layer, this initialization performed better than random initialization. However, for non frozen embeddings, the impact was not statistically significant.
2. Weight freezing during training. Since the majority of the time on each step is spent computing the gradient, freezing a subset of weights can substantially decrease time per step. Unfortunately, all combinations tested of this failed to yield an improvement. Typical matmul ops require N FLOPS on the forward pass and 2N FLOPs on the backwards pass. The 2N is to compute the gradient with respect to the weights to update the weights, and the gradient with respect to the data to pass the gradient onwards. The torch compiler is smart enough to compute only N FLOPS on the backwards pass for leaf operations. To leverage this, I tested updating the first 3 layers to run in parallel, and then froze the embedding after a portion of training, such that 3 layers became leaf operations. The change was not kept as the performance drop outweighed the speedup.
3. Logit shift parameter. The residual space activations for all positions are heavily aligned away (>120 degrees) from the lm_head vectors of tokens that never appear in the training set. In other words, the ~400 tokens that never appear in the 50348 vocab size (including the 91 padding vocab entries) may be skewing the topology of the activations in the residual stream. Adding a simple logits += logit_shift enables the model to learn the unigram distribution directly (or even just a static variable that is -inf on padding tokens), without disrupting the residual space. Unfortunately, my implementation of this change was giving memory issues on an A100. On the H100 setup, the change dropped the loss by 0.01 but was slightly edged out by the increase in time per step. I don't have the budget to fiddle substantially with params I can't test on an A100. If a more compute optimized version can be found, this is an easy improvement to the loss, likely equivalent to 50+ steps.
4. Removing torch.compile on zeropower_via_newtonschulz5(). Surprisingly, the torch compiler makes the output of newtonschulz() vary based on the batch dimension size, with a 2% change depending on the batch size. This is relevant when we are batching kqv in one op. This appears to occur because of rounding issues with bfloat16 and some internal accumulations the compiler is altering, as the percent diff drops to less then 0.1% for float32. On an A100 removing the compile gave an improvement when I was testing different batch sizes, but the change was not statistically significant on H100 w/ fp8 lm_head. Unclear exactly what is going on here, but noting that bfloat16 can lead to very unintuitive consequences.
5. Megabatch NetwonSchulz. Inspired by @byronxu99, I tested further impacts of batching for zeropower_via_newtonschulz5(). The results were quite surprising on an A100. The run time was heavily dependent on the batch size, with larger batch sizes running up to twice as fast, based on initial testing (honestly need to sanity check this, seemed too crazy). As a result, I experimented with setting all MLP params in 1 contiguous variable and doing a single iteration of zeropower_via_newtonschulz5(), with [3,4*768,768] input to each GPU as a single pass, and [6,768,768] for Attn to each GPU as a second pass. This gave a total of only 2 iterations of zeropower_via_newtonschulz5() on each GPU per step. I was running into memory errors on the 8H100 setup, and need to get a cheaper distributed setup before I test further.
6. 0.5 init weighting for x0 stream instead of 0. At the end of training on a A100, the x0 weight for many layers is 50x higher than the x weight. Updating the weighting to 0.5 gave a statistically significant improvement on A100, but this was not replicated on the 8H100 setup with fp8 lmhead.
7. Normalize value embedding inputs during forward pass. Seemed like a natural thing to do given norms on the input embedding and the existing lambda to scale value weights. However, this yielded worse performance, perhaps because the value embeddings need to have much high weight than the values and the lambda scaling parameter was not tuned to handle this itself.
8. Renormalize embedding in place between each forward pass. The L2 norm of the embedding layer is climbing from 27 to 500 over the course of training, leading to a different effective learning rate depending on the stage of training. Normalizing this parameter may enable the lr to be tuned more precisely. However, I found norm() still needed to be included in the forward pass for an accurate grad calc, at which point the compute penalty for a second norm outside the forward pass became not worthwhile.
9. Removing value entirely (only use value embedding) for first and last 3 layers. The trained weights indicate that the value embedding is dominating the calculated attention value, and I can save some matmul ops if I can drop 6 layers of value calcs. The change cost roughly 0.015 loss, which unfortunately was worth more than the speedup achieved based on the parameters used.
10. Bigram full initialization. Similar to 1, I tested initializing the lm_head and embedding layer to approximate the bigram distribution. (Bigram could in theory cause learning to start around 5.7 loss, with potentially better generalization during training). Unfortunately, it is not analytically simple to set embed and lm_head to achieve a known bigram distribution, because of the nonlinearity of the softmax. Attempting to approximate this yielded worse performance than random initialization.
11. Dual loss on bigram distribution. I tested having the first X iterations minimize a combination of the next token prediction loss, along with the bigram distribution for that token. Intuition was that since I can compute the bigram distribution of 100 million tokens in 1s, the bigram distribution encodes a higher density of information than a single high variance loss signal of a 500,000 token batch. However, the 50,000x50,000 bigram matrix proved too bulky for compute efficient steps.
Loading