Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'fc1' outgoing_weights are getting reset https://github.com/timoklein/redo/blob/main/src/redo.py#L120 #3

Closed
SaminYeasar opened this issue Mar 8, 2024 · 4 comments

Comments

@SaminYeasar
Copy link

SaminYeasar commented Mar 8, 2024

fc1 has (3136,512) params and it seems like current implementation always resets and sets 0 to the dead-neuron for out-going layer-512. The implementation is supposed to reset the dead-neurons of the incoming-layer and set 0 the dead-neuron of the outgoing-layers. https://github.com/timoklein/redo/blob/main/src/redo.py#L120

@timoklein
Copy link
Owner

I'm not quite sure I got you correctly but let me try to explain:

  1. I have a mask for the activations of fc1 of shape $(512, )$.
  2. The ingoing weights of my neuron are thus defined by fc1 and I can select and re-initialize them with fc1.weight.data[:, mask].
  3. The outgoing weights of my neuron are in q, which is of shape $(512, 6)$. Here, I can select the outgoing weights with q.weight.data[mask, :] and set them to zero.

The code on the new branch does that, please see here

redo/src/redo.py

Lines 109 to 137 in 6fbd0c3

for i in range(len(layers[:-1])):
mask = redo_masks[i]
layer = layers[i]
next_layer = layers[i + 1]
# Skip if there are no dead neurons
if torch.all(~mask):
# No dormant neurons in this layer
continue
# The initialization scheme is the same for conv2d and linear
# 1. Reset the ingoing weights using the initialization distribution
if use_lecun_init:
_lecun_normal_reinit(layer, mask)
else:
_kaiming_uniform_reinit(layer, mask)
# 2. Reset the outgoing weights to 0
# NOTE: Don't reset the bias for the following layer or else you will create new dormant neurons
if isinstance(layer, nn.Conv2d) and isinstance(next_layer, nn.Linear):
# Special case: Transition from conv to linear layer
# Reset the outgoing weights to 0 with a mask created from the conv filters
num_repeatition = next_layer.weight.data.shape[0] // mask.shape[0]
linear_mask = torch.repeat_interleave(mask, num_repeatition)
next_layer.weight.data[linear_mask, :] = 0.0
else:
# Standard case: layer and next_layer are both conv or both linear
# Reset the outgoing weights to 0
next_layer.weight.data[:, mask, ...] = 0.0

I have fixed everything except the moment-resets on that branch, please refer to it.

@SaminYeasar
Copy link
Author

I just rechecked. I believe the reset of the incoming-weights is fine. There's an error when the code tries to set zero at the outgoing layer at this step https://github.com/timoklein/redo/blob/main/src/redo.py#L135 conv3 (64,64,3,3) to fc1 (3136,512). Here activation-dim is 64 that needs to match the dimension of out-going layer: 3136. next_layer.weight.data basically transposes the matrix to (512, 3136), hence, the current implementation sets 0 to the incoming layer instead. The following code should fix this. Let me know if you think this is wrong. Thanks.

num_repeatition = next_layer.weight.data.shape[1] // mask.shape[0]
linear_mask = torch.repeat_interleave(mask, num_repeatition)
next_layer.weight.data[ : , linear_mask] = 0.0

@timoklein
Copy link
Owner

timoklein commented Mar 11, 2024

The following code should fix this. Let me know if you think this is wrong. Thanks.

No, this is actually correct. Good catch. I'll fix this and run some experiments to verify it. It must also be fixed for the outgoing weight bias resets here:

redo/src/redo.py

Lines 163 to 172 in 7069e9f

if (
len(optimizer.state_dict()["state"][i * 2]["exp_avg"].shape) == 4
and len(optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"].shape) == 2
):
# Catch transition from conv to linear layer through moment shapes
num_repeatition = optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"].shape[0] // mask.shape[0]
linear_mask = torch.repeat_interleave(mask, num_repeatition)
optimizer.state_dict()["state"][i * 2 + 2]["exp_avg"][linear_mask, ...] = 0.0
optimizer.state_dict()["state"][i * 2 + 2]["exp_avg_sq"][linear_mask, ...] = 0.0
optimizer.state_dict()["state"][i * 2 + 2]["step"].zero_()

It's quite interesting that the wrong resets are already improving performance substantially

@timoklein
Copy link
Owner

Closing this as it's implemented in #2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants