[draft] New WR(-1.4sec): drop MLP blocks in first layer. Includes WR Changes from PR#109, PR#117 and PR#118#120
Conversation
Merge PR 118
See README
Merge pr118 update
|
Nice job! Did you experiment with taking out other MLPs? How did you decide on removing the first and last layer? Reproducing this run on the same machine as #118 I got so it looks like sub-163 is nearly here. |
|
Great job working backwards from the observations on the gated mlp. sharing my speculation on why we can drop MLPs 0 and 11 and attention 7. As you know, the role of attention is to pass information between positions, and the role of mlp is to process information at a position. At layer 0 information passing is more important than information processing, since we already have clean embedding. However, layer 0 only has 6 heads and 6 queries, which only allows a limited number of data retrieval requests. The initial value embeddings and x0 skip let further queries be ran against the input data in the subsequent 2 layers. Until we have ran sufficient queries, there is limited value in processing the result of those queries. The middle layers have already collected information at each position, and then need to process it. They can be more agnostic to position, and a position may be collecting and summarizing data that further downstream positions will need. the final layers need to be acutely aware of position to make an accurate prediction. The Unet design and value embeddings help with this, and this final data movement makes attention layers more important again. Perhaps an hourglass design for number of attention heads per layer, with a fixed 128 head dim, would accomplish the same objective as the current approach of dropping components. Having 2 attention layers in series has a side effect of extending the effective attention window. The mlp merge may have had issues with two squaring operations without normalization. |
I tried to check this with the following code (doubledMLP was substituted into class DoubledMLPCodeINTERMEDIATE_ADD_DIAGONAL, INTERMEDIATE_ZERO_INIT, \
NORM_AFTER_FIRST_NL, NORM_AFTER_INTERMEDIATE, \
ADD_PSEUDO_RESIDUAL = [bool(int(x)) for x in arg.split(",")]
class DoubledMLP(nn.Module):
# we want to make 2layer perceptron:
# upscale dim -> 4*dim -> 4*dim -> dim
# however we want to keep all parameters of nn.Parameter(torch.empty(dim, hdim)) shape
# so for intermediate layer we will initialize 4 matrices of shape (dim, 4*dim) and manually implement forward
def __init__(self, dim: int):
super().__init__()
self.dim = dim
if ADD_PSEUDO_RESIDUAL:
self.lambda_ = nn.Parameter(torch.tensor([0.5/dist.get_world_size()]*dist.get_world_size()))
hdim = 4 * dim
self.c_fc = nn.Parameter(torch.empty(dim, hdim))
self.c_proj = nn.Parameter(torch.empty(dim, hdim))
self.c_intermediate1 = nn.Parameter(torch.empty(dim, hdim))
self.c_intermediate2 = nn.Parameter(torch.empty(dim, hdim))
self.c_intermediate3 = nn.Parameter(torch.empty(dim, hdim))
self.c_intermediate4 = nn.Parameter(torch.empty(dim, hdim))
std = 0.5 * (dim ** -0.5)
bound = (3 ** 0.5) * std
with torch.no_grad():
self.c_fc.uniform_(-bound, bound)
self.c_proj.zero_()
if not INTERMEDIATE_ZERO_INIT:
self.c_intermediate1.uniform_(-bound, bound)
self.c_intermediate2.uniform_(-bound, bound)
self.c_intermediate3.uniform_(-bound, bound)
self.c_intermediate4.uniform_(-bound, bound)
else:
self.c_intermediate1.zero_()
self.c_intermediate2.zero_()
self.c_intermediate3.zero_()
self.c_intermediate4.zero_()
if INTERMEDIATE_ADD_DIAGONAL:
mids = [self.c_intermediate1, self.c_intermediate2,
self.c_intermediate3, self.c_intermediate4]
for k, W in enumerate(mids):
W.diagonal(offset=k*dim, dim1=0, dim2=1).fill_(1.0)
def forward(self, x: Tensor):
y = F.linear(x, self.c_fc.T.type_as(x))
y = F.relu(y).square()
if NORM_AFTER_FIRST_NL:
y = norm(y)
y = F.linear(y, torch.cat([self.c_intermediate1, self.c_intermediate2, self.c_intermediate3, self.c_intermediate4], dim=0).T.type_as(x))
if NORM_AFTER_INTERMEDIATE:
y = norm(y)
if ADD_PSEUDO_RESIDUAL:
y = y.reshape(*y.shape[:-1], 4, self.dim)
y = y + x.unsqueeze(-2) * self.lambda_.mean().type_as(x) * self.lambda_.numel()
y = y.reshape(*y.shape[:-2], 4*self.dim)
y = F.relu(y).square()
x = F.linear(y, self.c_proj.type_as(x))
return xBeside normalization it also has options to check where to put normalization (prior or after intermediate linear operation), does some residual-like connection help, and some options of intermediate layer initialization. Results (2 runs per config) are as follows: Normalization indeed help a little, but not enough to make DoubleMLP efficient in comparison to just 2 MLPs. Normalization prior or after intermediate doesn't matter (somewhat expected). "Residual" make training slower (???), i checked values of lambda, they slightly increase from initial 0.5 to ~1. (implementation of lambda as array of 8 elements is due to DistAdam not working with scalar values, that was the easiest way. Apparently whole residual implementation is very inefficient: we lose 2 sec!) So for now i still don't understand why current structure of 6-7 block is better than doubledMLP. |
|
Something that would provide some insights is if the consecutive attn and mlp layers are moved to be in parallel. So 12 heads on first and last layer at double total width, and two independent MLPs that read from same residual stream spot. Conventional wisdom is that attn and mlp should alternate. It would be interesting if there is some unique value from consecutive attn and mlp layers, or if this result is really indicating that number of heads should vary by layer. The last mlp layer being less important is still surprising to me. We may need to look at the specific distribution of activations coming from each layer to make more sense of it. |
And you are not wrong. It wasn't skipped because |
No worries. I have had my own moments where I think I’m tuning the lr on a variable that isn’t connected to the computational graph. With compute costs there’s only so much statistical validation you can do for each change. That also checks out more with the timing drop. Iirc attn and mlp both consume roughly 30% of the runtime when incrementally dropped on a single A100. Dropping an MLP was likely tried previously, especially when people were testing dropping attn 7. Perhaps the addition of the sparse attention gate is reducing data passing in the first 6 heads, making dropping the MLP finally worthwhile since there is not yet enough new data at the position to process. |
|
Probably it will be of some interest here: It passes the smoke test of that first layerIncreasing number of attention layers in first layer (the layer where we skipped MLP) to Increasing number of attention layers in first layer (the layer where we skipped MLP) to last layerIncreasing number of attention layers in last layer (11th if we count from zero) to Increasing number of attention layers in last layer (11th if we count from zero) to Increasing number of MLP layers in last layer to 7th layerIncreasing number of MLP layers in layer where we skip attention to So far not working.
I believe that I started to play with gated MLP (on which I found no loss detriment upon decreasing MLP dimension) on branch forked from current master: 1b51e26. In that case reported gain is unrelated to gated attention. I plan to conduct minor binary search and find after which changes dropping 1st layer MLP starts to work later this week. |

Update: This PR previously and incorrectly claimed that dropping the MLP layers in both the first and last blocks improves training time. While the record itself still stands, the claim about the last block was wrong: due to a bug, that layer was never actually dropped. (Which turns out to be fortunate: removing it causes the training loss to become much worse.)
This submission includes recent WR changes by @varunneal #118, @ClassicLarry (08/23/25) and @byronxu99 #109.
Skipping the MLP in the first
and lastblocks reduced per-step time by ~3%, but increased the number of steps needed to reach a loss of 3.28 by ~1.5%.Note: due to recent changes in the data loader (the
BOSaligner now drops the tails of long documents and shards in some cases), there isn't enough data for more than 1697 steps. It's recommended to load more data with:Timings and validation
I used 8 × H100 SXM NVLink 80GB on jarvislabs.ai for validation compute. Unfortunately, the exact timing appears to depend on the specific node I’m randomly assigned (and slightly on warming effects). To address this, I validated against the previous record #118 by interleaving runs (one of mine, then one of the previous).
So accoring to my timing, this is a 1.4 second mean improvement over #118.
Unsuccessful attempts (I'm not experienced enough to claim these approaches don't work—only that I couldn't get them to work.)
Gated MLP
The initial idea was to replace all MLP blocks with a gated MLP to test the following analysis of GPT-OSS by Sebastian Raschka: "So, overall, using the GLU variants results in fewer parameters, and they perform better as well." With
hdim == dimapplied to all blocks, per-step time decreased by ~20%, but the higher loss outweighed this gain. Increasinghdimto2*dimand even4*dimdidn't reduce the loss and made steps slower.However, the effect wasn't uniform across layers: the loss increase was substantially smaller for the first
and lastlayers. When varying the hidden dimension, smallerhdimlowered the step time (as expected) while the loss stayed the same. This led to dropping the MLP in the first andlast blocks(and adjusting the number of steps), which produced the record above.VE fiddling
Since dropping MLP layers may substantially affect what matters for the attention blocks, I checked whether the need for value embeddings is the same.
It turned out that a
vepattern of[None] + [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 5) + [ve[0], ve[1], ve[2]]has the same loss as the initial
012...012. In other words, the first layer doesn't need an additional embedding (my interpretation is that the first layer already has its own embedding).However,
[None, ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 8) + [ve[0], ve[1], ve[2], None]has a larger loss: the last layer still needs a custom input embedding.
Other attempts to fiddle with
ve, in particular, reducing the number ofvefrom 3 to 2 to slightly speed up training, didn't succeed. Any speedup was roughly offset by needing more steps (at best). The following combinations were tried:Doubled MLP
In
block.7there's no attention, so we haveblock.6.mlp -> block.7.mlp(+ residual +x0*lambda). This suggested we might be losing information by projecting the internal vector inblock.6.mlpfrom4*dimtodimonly to project it back to4*diminblock.7.mlp. To test this, I skipped the MLP inblock.7(as with the first and last layers) and replacedblock.6.mlpwith a two-layer MLP. In this setup, block 7 mainly mixesx0withlambda.In
block.6.double_mlpwe upscaledim*dim -> dim*4*dim, applyReLU^2, multiply by another4*dim*4*dimtensor, apply anotherReLU^2, and downscaledim*4*dim -> dim*dim. In my understanding, such multi-layer blocks often train poorly due to back-propagation issues (mitigated by residuals), but I guessed that with two layers it might be acceptable.The naive version (as described above) led to a ~5% higher step time with no change in loss. Part of this overhead seemed related to parameters of different sizes (an issue solved by @byronxu99 in #109). So instead of a single
4*dim*4*dimtensor, I used fourdim*4*dimtensors (to match other parameters) and concatenated them just before multiplication. Note that since we dropped the MLPs in layers 1, 7, and 12, we now have only9*2 = 18dim*4*dimtensors, so adding another four still allows fitting three batches ofdim*4*dimparameters. This change reduced the step-time overhead so it's now only ~3% higher than without the doubled MLP, but again the loss didn't change.P.S. Unfortunately, I don't yet have enough understanding to fully explain either the successful or unsuccessful findings. I would be very grateful if others with deeper insights could share their interpretations.