Skip to content

Commit

Permalink
QM9 trainer up and running
Browse files Browse the repository at this point in the history
  • Loading branch information
szaman19 committed Mar 5, 2024
1 parent 8d91f8b commit 6d31004
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 20 deletions.
20 changes: 16 additions & 4 deletions applications/FLASK/Transformer/datasets/QM9.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from pretokenize.SMILES_tokenizer import MolTokenizer
from pretokenize.data_utils import random_zero_array

sequence_length = int(os.getenv("QM9_SEQUENCE_LENGTH", default="32"))

Expand All @@ -20,10 +21,12 @@
tokenizer = MolTokenizer(os.path.join(data_dir, "QM9_vocab.json"))
tokenizer.load_vocab_file()

dataset_train = np.load(os.path.join(data_dir, "QM9_Pretokenized.npy"), allow_pickle=True)
dataset_train = np.load(os.path.join(data_dir, "QM9_Pretokenized.npy"))

# dataset_train = np.zeros((140000, 32), dtype=np.float32)
_vocab_size = 46

pad_index = tokenizer.token_to_id("<pad>")
sep_index = tokenizer.token_to_id("<eos>")

# ----------------------------------------------
# Sample access functions
Expand All @@ -36,7 +39,16 @@ def num_train_samples():

def get_train_sample(i):
data = dataset_train[i]
return data

boundary = np.where(data == sep_index)[0][0]
masked_data = random_zero_array(
data[:boundary], 0.15, tokenizer.token_to_id(tokenizer.mask_token)
)
output = np.zeros((2 * sequence_length), dtype=np.int32)
output[0:boundary] = masked_data
output[boundary] = sep_index
output[sequence_length:] = data
return output


def sample_dims():
Expand All @@ -50,4 +62,4 @@ def vocab_size():
if __name__ == "__main__":
print("Training samples:", num_train_samples())
print("Training sample 101:")
print(get_train_sample(101))
print(get_train_sample(0))
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import numpy as np
from SMILES_tokenizer import MolTokenizer
from data_utils import random_zero_array
import os
import os.path


def main():
tokenizer = MolTokenizer("SMILES_vocab.json")
data_dir = os.getenv("QM9_DATA_DIR", "/p/vast1/lbann/datasets/FLASK/QM9")

tokenizer = MolTokenizer(os.path.join(data_dir, "QM9_vocab.json"))
tokenizer.load_vocab_file()
with open("QM9_smiles.txt", 'r') as smiles_data:
smiles_data = smiles_data.readlines()
num_samples = len(smiles_data)
max_length = 32

tokenized_data = np.ones((num_samples, max_length)) * tokenizer.encode(tokenizer.pad_token)
tokenized_data[:, 0] = tokenizer.encode(tokenizer.sep_token)
data_file = os.path.join(data_dir, "QM9_smiles.txt")
with open(data_file, "r") as smiles_data:
smiles_data = smiles_data.readlines()
num_samples = len(smiles_data)
max_length = 32

tokenized_data = np.ones((num_samples, max_length)) * tokenizer.encode(
tokenizer.pad_token
)
tokenized_data[:, 0] = tokenizer.encode(tokenizer.sep_token)

for i, smiles in enumerate(smiles_data, start=1):
tokens = tokenizer.tokenize(smiles)
tokens = random_zero_array(tokens, 0.15, tokenizer.encode(tokenizer.mask_token))
tokenized_data[i, :len(tokens)] = tokens
tokenized_data[i, len(tokens)] = tokenizer.encode(tokenizer.cls_token)
for i, smiles in enumerate(smiles_data, start=0):
tokens = tokenizer.tokenize(smiles)
tokenized_data[i, : len(tokens)] = tokens
tokenized_data[i, len(tokens)] = tokenizer.encode(tokenizer.sep_token)
save_file_loc = os.path.join(data_dir, "QM9_Pretokenized.npy")
np.save(save_file_loc, tokenized_data)

np.save('QM9_Pretokenized.npy', tokenized_data)

if __name__ == '__main__':
if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions applications/FLASK/Transformer/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ def _add_input_encoding(

# Apply encoder
if encoder_input is not None:
encoder_input = positional_encoder(
encoder_input = positional_encoder.apply_input(
encoder_input, encoder_sequence_length, **kwargs
)
if decoder_input is not None:
decoder_input = positional_encoder(
decoder_input = positional_encoder.apply_input(
decoder_input, decoder_sequence_length, **kwargs
)

Expand Down

0 comments on commit 6d31004

Please sign in to comment.