Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MuParametrization and MuTransfer #64

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions protein_lm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ dependencies:
- evaluate
- pytest
- fair-esm
- mup
1 change: 1 addition & 0 deletions protein_lm/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from models import *
Copy link
Collaborator

@othertea othertea Feb 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from models import *
from .models import *

I think you need these to be relative imports? They didn't work as is for me.
Alternatively, instead of changing all of these to relative imports we can remove these lines and import them by specifying the full module paths in test_coord_check.py

1 change: 1 addition & 0 deletions protein_lm/modeling/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from apt import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from apt import *
from .apt import *

2 changes: 2 additions & 0 deletions protein_lm/modeling/models/apt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from config import APTConfig
from model_pytorch import *
NZ99 marked this conversation as resolved.
Show resolved Hide resolved
25 changes: 25 additions & 0 deletions protein_lm/modeling/models/apt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,36 @@ def __init__(
position_embedding="learned",
tokenizer=None,
max_sequence_length = 1024,
query_zero_init = True,
n_layer = None,
contact_prediction_head = False,
initializer_range = 0.02,
# whether to use MuParametrization
use_mup = False,
# whether to initialize the output (readout) layer with zero-initialization
readout_zero_init = True,
# the output layer multiplier if mup is used, see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L56
mup_output_mult = 1.0,
width_mult_for_weights = 2.0,
# rope
rope_theta = 0.0,
rope_scaling_factor=1,
**kwargs
):
super().__init__(**kwargs)
self.nn_model_type = "APT"
self.position_embedding = position_embedding
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length

self.use_mup = use_mup
self.query_zero_init = query_zero_init,
self.n_layer = n_layer
self.contact_prediction_head = contact_prediction_head
self.initializer_range = initializer_range
self.readout_zero_init = readout_zero_init
self.mup_output_mult = mup_output_mult
self.width_mult_for_weights = width_mult_for_weights
self.rope_theta = rope_theta
self.rope_scaling_factor = rope_scaling_factor

153 changes: 131 additions & 22 deletions protein_lm/modeling/models/apt/model_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple, Union
import math
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
Expand All @@ -8,6 +9,7 @@
from transformers.pytorch_utils import Conv1D
from transformers.activations import ACT2FN
from transformers.utils import logging
from mup import MuReadout, MuSharedReadout, normal_
from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding
from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding
from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor
Expand All @@ -16,6 +18,7 @@

logger = logging.get_logger(__name__)


class APTAttention(GPT2Attention):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
Expand All @@ -42,6 +45,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
f" {self.num_heads})."
)

# muP
self.use_mup = config.use_mup
self.attn_score = nn.Identity() # just for coordcheck
self.query = nn.Identity() # just for coordcheck
self.key = nn.Identity() # just for coordcheck
self.value = nn.Identity() # just for coordcheck

self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention

Expand All @@ -55,13 +65,20 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)

self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
if self.use_mup:
self.attn_dropout = nn.Identity()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider asserting that the dropout probabilities are set to 0 in this case (in configs)?

self.resid_dropout = nn.Identity()
else:
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()



self.rot_emb=None
if self.position_embedding == "rope":
self.rot_emb=RotaryEmbedding(dim=self.head_dim)
Expand All @@ -72,15 +89,23 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
elif self.position_embedding=="dynamic_rope_scaling":
self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta)



def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

#muP
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
if self.use_mup:
attn_weights = attn_weights / torch.full(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be multiplying by some attn_mult here that we add as config option? (as in Mu-Scaling or mutransformers )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch yes, thank you! Will update accordingly

[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device
)
else:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)

attn_weights = self.attn_score(attn_weights)

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
Expand All @@ -97,7 +122,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if alibi_bias is not None:
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
Expand Down Expand Up @@ -150,7 +175,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea

if alibi_bias is not None:
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
Expand All @@ -171,7 +196,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea

return attn_output, attn_weights


def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
Expand Down Expand Up @@ -202,11 +227,15 @@ def forward(
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)


query = self.query(query)
key = self.key(key)
value = self.value(value)

kv_seq_len=key.shape[-2]
if layer_past is not None:
kv_seq_len+=layer_past[0].shape[-2]

# Apply rope embedding to query and key
if self.rot_emb:
bsz, q_len, _ = hidden_states.size()
Expand All @@ -225,7 +254,6 @@ def forward(
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)


if use_cache is True:
present = (key, value)
else:
Expand All @@ -251,10 +279,20 @@ class APTMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size

#muP
use_mup = config.use_mup

self.c_fc = Conv1D(intermediate_size, embed_dim)

self.c_proj = Conv1D(embed_dim, intermediate_size)

self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)

if use_mup:
self.dropout = nn.Identity()
else:
self.dropout = nn.Dropout(config.resid_pdrop)

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
Expand All @@ -270,6 +308,9 @@ def __init__(self, config, layer_idx=None):
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

#muP
self.use_mup = config.use_mup

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = APTAttention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
Expand Down Expand Up @@ -354,23 +395,32 @@ def __init__(self, config):
super().__init__(config)

self.embed_dim = config.hidden_size
use_mup = config.use_mup

self.wte = nn.Embedding(config.vocab_size, self.embed_dim)

self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.alibi = None
elif self.position_embedding=="alibi":
#muP TO DO: check proper behavior in alibi case
maxpos = config.n_positions
attn_heads = config.n_head
alibi = create_alibi_tensor(attn_heads,maxpos)
self.register_buffer('alibi',alibi)
else:
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi')

self.drop = nn.Dropout(config.embd_pdrop)

#muP
if use_mup:
self.drop = nn.Identity()
else:
self.drop = nn.Dropout(config.embd_pdrop)

self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])

self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

# Model parallel
Expand Down Expand Up @@ -477,7 +527,7 @@ def forward(
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds


if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
Expand Down Expand Up @@ -593,19 +643,78 @@ class APTLMHeadModel(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = APTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

# muP
# TO DO: look into weight tying
# TO DO: if weight tying is used, APTMuSharedReadout with the proper tied weight should be used instead
self.lm_head = MuReadout(config.n_embd,
config.vocab_size,
bias=False,
readout_zero_init=config.readout_zero_init,
output_mult=config.output_mult)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
output_mult=config.output_mult)
output_mult=config.mup_output_mult)

I believe you have a typo here.


# mup
# note that this has to be run after mup.set_base_shape for it to work
# see https://github.com/microsoft/mup#basic-usage
# not sure if this is required here
self.apply(self._init_weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it seems to me like we shouldn't call this here? As in your coordinate check example, you will have to call it again anyway (and only if you're using mup?)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I think this might have been the result of some earlier testing and of forgetting to remove. Indeed this shouldn't have an effect so no reason to keep. Thanks!


# Model parallel
self.model_parallel = False
self.device_map = None

self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads,
prepend_bos=True,
append_eos=True,
eos_idx=2)
# mup implementation does not currently support this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the dropout case, should we consider adding an assertion that we are not using mup with this in the configs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this is a good idea!

if config.contact_prediction_head:
self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads,
prepend_bos=True,
append_eos=True,
eos_idx=2)

# Initialize weights and apply final processing
self.post_init()

# mup
# general function for mup-specific weight initialization
def _init_weights(self, module):
if isinstance(module, (MuReadout, MuSharedReadout)) and self.config.readout_zero_init:
module.weight.data.zero_()
elif isinstance(module, (nn.Linear, Conv1D)):
if hasattr(module.weight, 'infshape'):
normal_(module.weight, mean=0.0, std=self.config.initializer_range)
else:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

if isinstance(module, APTAttention):
if hasattr(module, "q_attn"):
# cross attention case
if self.config.query_zero_init:
# q_attn same as last third of c_attn in no cross attention case -- zero initialization
NZ99 marked this conversation as resolved.
Show resolved Hide resolved
self.q_attn.weight.data = 0
else:
if self.config.query_zero_init:
_, fanout = module.c_attn.weight.shape
assert fanout % 3 == 0
module.c_attn.weight.data[:, :fanout//3] = 0

depth_std = self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
for name, p in module.named_parameters():
if "c_proj" in name and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
if hasattr(p, 'infshape'):
normal_(p, mean=0.0, std=depth_std)
else:
p.data.normal_(mean=0.0, std=depth_std)


def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
2 changes: 2 additions & 0 deletions protein_lm/modeling/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def train(
config_dict["wandb"],
)

# TO DO: add support for mup's optimizers in case use_mup is used, see e.g. https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/optim.py
# available via mup.optim
trainer = Trainer(
model=model,
args=training_args,
Expand Down
Loading