Skip to content

[draft] New WR(-1.4sec): drop MLP blocks in first layer. Includes WR Changes from PR#109, PR#117 and PR#118#120

Merged
ClassicLarry merged 8 commits intoKellerJordan:masterfrom
EmelyanenkoK:pr-118
Oct 15, 2025
Merged

[draft] New WR(-1.4sec): drop MLP blocks in first layer. Includes WR Changes from PR#109, PR#117 and PR#118#120
ClassicLarry merged 8 commits intoKellerJordan:masterfrom
EmelyanenkoK:pr-118

Conversation

@EmelyanenkoK
Copy link
Copy Markdown
Contributor

@EmelyanenkoK EmelyanenkoK commented Sep 5, 2025

Update: This PR previously and incorrectly claimed that dropping the MLP layers in both the first and last blocks improves training time. While the record itself still stands, the claim about the last block was wrong: due to a bug, that layer was never actually dropped. (Which turns out to be fortunate: removing it causes the training loss to become much worse.)

This submission includes recent WR changes by @varunneal #118, @ClassicLarry (08/23/25) and @byronxu99 #109.

Skipping the MLP in the first and last blocks reduced per-step time by ~3%, but increased the number of steps needed to reach a loss of 3.28 by ~1.5%.

Note: due to recent changes in the data loader (the BOS aligner now drops the tails of long documents and shards in some cases), there isn't enough data for more than 1697 steps. It's recommended to load more data with:

python data/cached_fineweb10B.py 9

Timings and validation

I used 8 × H100 SXM NVLink 80GB on jarvislabs.ai for validation compute. Unfortunately, the exact timing appears to depend on the specific node I’m randomly assigned (and slightly on warming effects). To address this, I validated against the previous record #118 by interleaving runs (one of mine, then one of the previous).

import scipy
import numpy as np

accs_new = [3.2806, 3.2764, 3.2787, 3.2778, 3.2776, 3.2754, 3.2779, 3.2775, 3.2782, 3.2779, ]

accs_prev = [3.2784, 3.2778, 3.2766, 3.2771, 3.2764, 3.2768, 3.276, 3.279, 3.2747, 3.2772]

times_new = [ i/1000 for i in [   163078, 163129, 163029, 163141, 162534, 163100, 163046, 163144, 162926, 163224,]]

times_prev = [ i/1000 for i in [164443, 164403, 164348, 164370, 164817, 164833, 164301, 164459, 164244, 164306, ]]

print('p=%.4f' % scipy.stats.ttest_1samp(accs_new, 3.28, alternative='less').pvalue)
# p=0.0003
print('p=%.4f' % scipy.stats.ttest_1samp(accs_prev, 3.28, alternative='less').pvalue)
# p=0.0000


print(f"{np.mean(times_new):.4f}")
#163.0351


print(f"{np.mean(times_prev):.4f}")
# 164.4524

print(f"Time imporve: {np.mean(times_prev)-np.mean(times_new):.4f} sec")
# Time imporve: 1.4173 sec

So accoring to my timing, this is a 1.4 second mean improvement over #118.

Unsuccessful attempts (I'm not experienced enough to claim these approaches don't work—only that I couldn't get them to work.)

Gated MLP

The initial idea was to replace all MLP blocks with a gated MLP to test the following analysis of GPT-OSS by Sebastian Raschka: "So, overall, using the GLU variants results in fewer parameters, and they perform better as well." With hdim == dim applied to all blocks, per-step time decreased by ~20%, but the higher loss outweighed this gain. Increasing hdim to 2*dim and even 4*dim didn't reduce the loss and made steps slower.

However, the effect wasn't uniform across layers: the loss increase was substantially smaller for the first and last layers. When varying the hidden dimension, smaller hdim lowered the step time (as expected) while the loss stayed the same. This led to dropping the MLP in the first and last blocks (and adjusting the number of steps), which produced the record above.

VE fiddling

Since dropping MLP layers may substantially affect what matters for the attention blocks, I checked whether the need for value embeddings is the same.

It turned out that a ve pattern of
[None] + [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 5) + [ve[0], ve[1], ve[2]]
has the same loss as the initial 012...012. In other words, the first layer doesn't need an additional embedding (my interpretation is that the first layer already has its own embedding).

However,
[None, ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 8) + [ve[0], ve[1], ve[2], None]
has a larger loss: the last layer still needs a custom input embedding.

Other attempts to fiddle with ve, in particular, reducing the number of ve from 3 to 2 to slightly speed up training, didn't succeed. Any speedup was roughly offset by needing more steps (at best). The following combinations were tried:

ve = [None, ve[0], ve[1]] + [None] * (len(self.blocks) - 5) + [ve[0], ve[1]]
ve = [None, ve[0], ve[1]] + [None] * (len(self.blocks) - 7) + [ve[0], ve[1]] * 2
ve = [None] * (len(self.blocks) - 2) + [ve[0], ve[1]]
ve = [None, ve[0], ve[1]] + [None] * (len(self.blocks) - 5) + [ve[0], ve[1]]  # but all ve were normalized
ve = [None, ve[0], ve[1]] + [None] * (len(self.blocks) - 5) + [x0, ve[0], ve[1]]  # using the initial embedding as the value embedding for the -3 layer. I expected that mixing the same vector into different parts of the blocks might severely hamper training since self.embed would get conflicting signals. Instead, the loss increased only very slightly.

Doubled MLP

In block.7 there's no attention, so we have block.6.mlp -> block.7.mlp (+ residual + x0*lambda). This suggested we might be losing information by projecting the internal vector in block.6.mlp from 4*dim to dim only to project it back to 4*dim in block.7.mlp. To test this, I skipped the MLP in block.7 (as with the first and last layers) and replaced block.6.mlp with a two-layer MLP. In this setup, block 7 mainly mixes x0 with lambda.

In block.6.double_mlp we upscale dim*dim -> dim*4*dim, apply ReLU^2, multiply by another 4*dim*4*dim tensor, apply another ReLU^2, and downscale dim*4*dim -> dim*dim. In my understanding, such multi-layer blocks often train poorly due to back-propagation issues (mitigated by residuals), but I guessed that with two layers it might be acceptable.

The naive version (as described above) led to a ~5% higher step time with no change in loss. Part of this overhead seemed related to parameters of different sizes (an issue solved by @byronxu99 in #109). So instead of a single 4*dim*4*dim tensor, I used four dim*4*dim tensors (to match other parameters) and concatenated them just before multiplication. Note that since we dropped the MLPs in layers 1, 7, and 12, we now have only 9*2 = 18 dim*4*dim tensors, so adding another four still allows fitting three batches of dim*4*dim parameters. This change reduced the step-time overhead so it's now only ~3% higher than without the doubled MLP, but again the loss didn't change.

P.S. Unfortunately, I don't yet have enough understanding to fully explain either the successful or unsuccessful findings. I would be very grateful if others with deeper insights could share their interpretations.

@varunneal
Copy link
Copy Markdown
Contributor

Nice job! Did you experiment with taking out other MLPs? How did you decide on removing the first and last layer?

Reproducing this run on the same machine as #118 I got

163.238
163.048
163.039
162.941
162.860

so it looks like sub-163 is nearly here.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Great job working backwards from the observations on the gated mlp.

sharing my speculation on why we can drop MLPs 0 and 11 and attention 7.

As you know, the role of attention is to pass information between positions, and the role of mlp is to process information at a position. At layer 0 information passing is more important than information processing, since we already have clean embedding. However, layer 0 only has 6 heads and 6 queries, which only allows a limited number of data retrieval requests. The initial value embeddings and x0 skip let further queries be ran against the input data in the subsequent 2 layers. Until we have ran sufficient queries, there is limited value in processing the result of those queries.

The middle layers have already collected information at each position, and then need to process it. They can be more agnostic to position, and a position may be collecting and summarizing data that further downstream positions will need.

the final layers need to be acutely aware of position to make an accurate prediction. The Unet design and value embeddings help with this, and this final data movement makes attention layers more important again.

Perhaps an hourglass design for number of attention heads per layer, with a fixed 128 head dim, would accomplish the same objective as the current approach of dropping components. Having 2 attention layers in series has a side effect of extending the effective attention window.

The mlp merge may have had issues with two squaring operations without normalization.

@EmelyanenkoK
Copy link
Copy Markdown
Contributor Author

EmelyanenkoK commented Sep 6, 2025

How did you decide on removing the first and last layer?

It came from non-systematic attempts to decrease number of parameters in MLP block by replacing them with MLPGLU. Tried few combinations, changes in first and last seemed promising, changes in other didn't.

However, following your suggestion I run more systematic study on which MLP blocks we can drop (see Figure below). Blue one is when we drop only one, and red when we always drop first and last but also one other. Single run for each point, so treat data as quite noisy.
mlp_off
(note point with x-coordinate 12 is senseless)

Generally, it supports idea that only last and first worth dropping.

Two more peculiarities about this graph:
I expected to see some deep in red line for 6 and 7 blocks, my intuition was that 2 sequential mlp blocks are less effective, thus dropping one of them will cause less deterioration than any other mlp block. But, either effect is too small or nonexistent.

There seem to be larger effect in dropping mlp in layers 9-11 when 0 is already off (tail of red line).

@EmelyanenkoK
Copy link
Copy Markdown
Contributor Author

The mlp merge may have had issues with two squaring operations without normalization.

I tried to check this with the following code (doubledMLP was substituted into block.6.mlp, with mlp in blocks 0, 7, 12 dropped)

class DoubledMLP Code
INTERMEDIATE_ADD_DIAGONAL, INTERMEDIATE_ZERO_INIT, \
NORM_AFTER_FIRST_NL, NORM_AFTER_INTERMEDIATE, \
ADD_PSEUDO_RESIDUAL = [bool(int(x)) for x in arg.split(",")]

class DoubledMLP(nn.Module):
    # we want to make 2layer perceptron:
    # upscale dim -> 4*dim -> 4*dim -> dim
    # however we want to keep all parameters of nn.Parameter(torch.empty(dim, hdim)) shape
    # so for intermediate layer we will initialize 4 matrices of shape (dim, 4*dim) and manually implement forward
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
        if ADD_PSEUDO_RESIDUAL:
          self.lambda_ = nn.Parameter(torch.tensor([0.5/dist.get_world_size()]*dist.get_world_size()))
        hdim = 4 * dim
        self.c_fc = nn.Parameter(torch.empty(dim, hdim))
        self.c_proj = nn.Parameter(torch.empty(dim, hdim))
        self.c_intermediate1 = nn.Parameter(torch.empty(dim, hdim))
        self.c_intermediate2 = nn.Parameter(torch.empty(dim, hdim))
        self.c_intermediate3 = nn.Parameter(torch.empty(dim, hdim))
        self.c_intermediate4 = nn.Parameter(torch.empty(dim, hdim))
        std = 0.5 * (dim ** -0.5)
        bound = (3 ** 0.5) * std
        with torch.no_grad():
            self.c_fc.uniform_(-bound, bound)
            self.c_proj.zero_()
            if not INTERMEDIATE_ZERO_INIT:
                self.c_intermediate1.uniform_(-bound, bound)
                self.c_intermediate2.uniform_(-bound, bound)
                self.c_intermediate3.uniform_(-bound, bound)
                self.c_intermediate4.uniform_(-bound, bound)
            else:
                self.c_intermediate1.zero_()
                self.c_intermediate2.zero_()
                self.c_intermediate3.zero_()
                self.c_intermediate4.zero_()
            if INTERMEDIATE_ADD_DIAGONAL:
                mids = [self.c_intermediate1, self.c_intermediate2,
                        self.c_intermediate3, self.c_intermediate4]
                for k, W in enumerate(mids):
                    W.diagonal(offset=k*dim, dim1=0, dim2=1).fill_(1.0)

    def forward(self, x: Tensor):
        y = F.linear(x, self.c_fc.T.type_as(x))
        y = F.relu(y).square()
        if NORM_AFTER_FIRST_NL:
            y = norm(y)
        y = F.linear(y, torch.cat([self.c_intermediate1, self.c_intermediate2, self.c_intermediate3, self.c_intermediate4], dim=0).T.type_as(x))
        if NORM_AFTER_INTERMEDIATE:
            y = norm(y)
        if ADD_PSEUDO_RESIDUAL:
            y = y.reshape(*y.shape[:-1], 4, self.dim)
            y = y + x.unsqueeze(-2) * self.lambda_.mean().type_as(x) * self.lambda_.numel()
            y = y.reshape(*y.shape[:-2], 4*self.dim)
        y = F.relu(y).square()
        x = F.linear(y, self.c_proj.type_as(x))
        return x

Beside normalization it also has options to check where to put normalization (prior or after intermediate linear operation), does some residual-like connection help, and some options of intermediate layer initialization.

Results (2 runs per config) are as follows:

All False:
step:1705/1705 val_loss:3.2825 train_time:166176ms step_avg:97.46ms
step:1705/1705 val_loss:3.2828 train_time:166786ms step_avg:97.82ms

NORM_AFTER_FIRST_NL True
step:1705/1705 val_loss:3.2784 train_time:166501ms step_avg:97.65ms
step:1705/1705 val_loss:3.2793 train_time:166072ms step_avg:97.40ms

NORM_AFTER_INTERMEDIATE True
step:1705/1705 val_loss:3.2801 train_time:166600ms step_avg:97.71ms
step:1705/1705 val_loss:3.2785 train_time:165363ms step_avg:96.99ms

NORM_AFTER_FIRST_NL=True ADD_PSEUDO_RESIDUAL=True
step:1705/1705 val_loss:3.2786 train_time:167808ms step_avg:98.42ms
step:1705/1705 val_loss:3.2800 train_time:167727ms step_avg:98.37ms
model.blocks[6].mlp.lambda_ -> 0.93

INTERMEDIATE_ADD_DIAGONAL=True INTERMEDIATE_ZERO_INIT=True
NORM_AFTER_FIRST_NL=True ADD_PSEUDO_RESIDUAL True = True
step:1705/1705 val_loss:3.2880 train_time:168361ms step_avg:98.75ms
model.blocks[6].mlp.lambda_ -> 1.07
step:1705/1705 val_loss:3.2828 train_time:168304ms step_avg:98.71ms
model.blocks[6].mlp.lambda_ -> 0.59

INTERMEDIATE_ADD_DIAGONAL=True
NORM_AFTER_FIRST_NL=True ADD_PSEUDO_RESIDUAL True = True

step:1705/1705 val_loss:3.2811 train_time:167969ms step_avg:98.52ms
model.blocks[6].mlp.lambda_ -> 0.79
step:1705/1705 val_loss:3.2837 train_time:168171ms step_avg:98.63ms
model.blocks[6].mlp.lambda_ -> 0.75

Normalization indeed help a little, but not enough to make DoubleMLP efficient in comparison to just 2 MLPs. Normalization prior or after intermediate doesn't matter (somewhat expected). "Residual" make training slower (???), i checked values of lambda, they slightly increase from initial 0.5 to ~1.

(implementation of lambda as array of 8 elements is due to DistAdam not working with scalar values, that was the easiest way. Apparently whole residual implementation is very inefficient: we lose 2 sec!)

So for now i still don't understand why current structure of 6-7 block is better than doubledMLP.

@ClassicLarry
Copy link
Copy Markdown
Collaborator

Something that would provide some insights is if the consecutive attn and mlp layers are moved to be in parallel. So 12 heads on first and last layer at double total width, and two independent MLPs that read from same residual stream spot. Conventional wisdom is that attn and mlp should alternate. It would be interesting if there is some unique value from consecutive attn and mlp layers, or if this result is really indicating that number of heads should vary by layer. The last mlp layer being less important is still surprising to me. We may need to look at the specific distribution of activations coming from each layer to make more sense of it.

@EmelyanenkoK EmelyanenkoK marked this pull request as ready for review September 6, 2025 18:13
@EmelyanenkoK EmelyanenkoK marked this pull request as draft September 8, 2025 09:21
@EmelyanenkoK
Copy link
Copy Markdown
Contributor Author

The last mlp layer being less important is still surprising to me.

And you are not wrong. It wasn't skipped because [0,12] doesn't skip last layer ([0,11] would).
I apologize for misleading.

@EmelyanenkoK EmelyanenkoK changed the title [draft] New WR(-1.4sec): drop MLP blocks in first and last layers. Includes WR Changes from PR#109, PR#117 and PR#118 [draft] New WR(-1.4sec): drop MLP blocks in first layer. Includes WR Changes from PR#109, PR#117 and PR#118 Sep 8, 2025
@ClassicLarry
Copy link
Copy Markdown
Collaborator

The last mlp layer being less important is still surprising to me.

And you are not wrong. It wasn't skipped because [0,12] doesn't skip last layer ([0,11] would). I apologize for misleading.

No worries. I have had my own moments where I think I’m tuning the lr on a variable that isn’t connected to the computational graph. With compute costs there’s only so much statistical validation you can do for each change.

That also checks out more with the timing drop. Iirc attn and mlp both consume roughly 30% of the runtime when incrementally dropped on a single A100. Dropping an MLP was likely tried previously, especially when people were testing dropping attn 7. Perhaps the addition of the sparse attention gate is reducing data passing in the first 6 heads, making dropping the MLP finally worthwhile since there is not yet enough new data at the position to process.

@EmelyanenkoK
Copy link
Copy Markdown
Contributor Author

EmelyanenkoK commented Sep 9, 2025

Probably it will be of some interest here:
I played with the following VariableBlock code that allows to run a few Attn/MLP blocks of the same size in parallel and than project it to common dimension before passing forward (as well as skipping unwanted layers).

It passes the smoke test of that attn_paths=int(i not in [7]),mlp_paths=int(i not in [0])) has the same timing/loss as code in this PR.
Experiments:

first layer

Increasing number of attention layers in first layer (the layer where we skipped MLP) to 2 gave +5% time, same loss:

step:1705/1705 val_loss:3.2793 train_time:171500ms step_avg:100.59ms
repeat:
step:1705/1705 val_loss:3.2818 train_time:169778ms step_avg:99.58ms

Increasing number of attention layers in first layer (the layer where we skipped MLP) to 4 gave +10% time, slightly lower loss:

step:1705/1705 val_loss:3.2759 train_time:179095ms step_avg:105.04ms

last layer

Increasing number of attention layers in last layer (11th if we count from zero) to 2 gave +5% time, same loss:

step:1705/1705 val_loss:3.2802 train_time:169803ms step_avg:99.59ms

Increasing number of attention layers in last layer (11th if we count from zero) to 4 gave +10% time, not too much but notably lower loss:

step:1705/1705 val_loss:3.2719 train_time:179289ms step_avg:105.16ms

Increasing number of MLP layers in last layer to 2 gave +5% time, slightly lower loss:

step:1705/1705 val_loss:3.2766 train_time:169998ms step_avg:99.71ms

7th layer

Increasing number of MLP layers in layer where we skip attention to 2 gave gave +4% time, slightly lower loss:

step:1705/1705 val_loss:3.2752 train_time:168596ms step_avg:98.88ms

So far not working.

Dropping an MLP was likely tried previously, especially when people were testing dropping attn 7. Perhaps the addition of the sparse attention gate is reducing data passing in the first 6 heads, making dropping the MLP finally worthwhile since there is not yet enough new data at the position to process.

I believe that I started to play with gated MLP (on which I found no loss detriment upon decreasing MLP dimension) on branch forked from current master: 1b51e26. In that case reported gain is unrelated to gated attention. I plan to conduct minor binary search and find after which changes dropping 1st layer MLP starts to work later this week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants