Skip to content

Commit

Permalink
fixed position encoding schemes pos_embeds as per OpenBioML#65
Browse files Browse the repository at this point in the history
  • Loading branch information
rich authored and rich committed Mar 15, 2024
1 parent e04681c commit 64def51
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion protein_lm/configs/train/full_colabfold.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ tokenizer:
model:
nn_model_type: "APT"
nn_model_config_args:
position_embedding: "alibi"
position_embedding: "rope"
rope_scaling_factor: 1.0
rope_theta: 10000
max_sequence_length: 128
Expand Down
4 changes: 2 additions & 2 deletions protein_lm/configs/train/toy_colabfold.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dataset:
training_arguments:
output_dir: "checkpoints/toy"
max_steps: 10
num_train_epochs: 1
num_train_epochs: 10
learning_rate: 1
lr_scheduler_type: "linear"
warmup_steps: 4
Expand Down Expand Up @@ -47,7 +47,7 @@ tokenizer:
model:
nn_model_type: "APT"
nn_model_config_args:
position_embedding: "alibi"
position_embedding: "rope"
rope_scaling_factor: 1.0
rope_theta: 10000
max_sequence_length: 10
Expand Down
12 changes: 8 additions & 4 deletions protein_lm/modeling/models/apt/model_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,12 @@ def __init__(self, config):
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
if self.position_embedding=="learned":
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.alibi = None
elif self.position_embedding in ['rope','rerope','linear_rope_scaling','dynamic_rope_scaling']:
self.wpe = None
self.alibi = None
elif self.position_embedding=="alibi":
maxpos = config.n_positions
attn_heads = config.n_head
Expand Down Expand Up @@ -566,12 +569,13 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
if self.position_embedding=="learned":
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
elif self.position_embedding in ['rope','rerope','linear_rope_scaling','dynamic_rope_scaling','alibi']:
hidden_states = inputs_embeds

else:
raise Exception(f'invalid {self.position_embedding} provided')

if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
Expand Down

0 comments on commit 64def51

Please sign in to comment.