From 6d3100417c06b006d3299951cb4150d0e146fcc9 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 4 Mar 2024 16:09:09 -0800 Subject: [PATCH] QM9 trainer up and running --- .../FLASK/Transformer/datasets/QM9.py | 20 ++++++++--- .../datasets/pretokenize/QM9_Pretokenize.py | 36 +++++++++++-------- applications/FLASK/Transformer/network.py | 4 +-- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/applications/FLASK/Transformer/datasets/QM9.py b/applications/FLASK/Transformer/datasets/QM9.py index 34b292d925b..cf7c8c90e35 100644 --- a/applications/FLASK/Transformer/datasets/QM9.py +++ b/applications/FLASK/Transformer/datasets/QM9.py @@ -7,6 +7,7 @@ import numpy as np from pretokenize.SMILES_tokenizer import MolTokenizer +from pretokenize.data_utils import random_zero_array sequence_length = int(os.getenv("QM9_SEQUENCE_LENGTH", default="32")) @@ -20,10 +21,12 @@ tokenizer = MolTokenizer(os.path.join(data_dir, "QM9_vocab.json")) tokenizer.load_vocab_file() -dataset_train = np.load(os.path.join(data_dir, "QM9_Pretokenized.npy"), allow_pickle=True) +dataset_train = np.load(os.path.join(data_dir, "QM9_Pretokenized.npy")) +# dataset_train = np.zeros((140000, 32), dtype=np.float32) _vocab_size = 46 - +pad_index = tokenizer.token_to_id("") +sep_index = tokenizer.token_to_id("") # ---------------------------------------------- # Sample access functions @@ -36,7 +39,16 @@ def num_train_samples(): def get_train_sample(i): data = dataset_train[i] - return data + + boundary = np.where(data == sep_index)[0][0] + masked_data = random_zero_array( + data[:boundary], 0.15, tokenizer.token_to_id(tokenizer.mask_token) + ) + output = np.zeros((2 * sequence_length), dtype=np.int32) + output[0:boundary] = masked_data + output[boundary] = sep_index + output[sequence_length:] = data + return output def sample_dims(): @@ -50,4 +62,4 @@ def vocab_size(): if __name__ == "__main__": print("Training samples:", num_train_samples()) print("Training sample 101:") - print(get_train_sample(101)) + print(get_train_sample(0)) diff --git a/applications/FLASK/Transformer/datasets/pretokenize/QM9_Pretokenize.py b/applications/FLASK/Transformer/datasets/pretokenize/QM9_Pretokenize.py index 6fe886d9b61..d7a696cfffe 100644 --- a/applications/FLASK/Transformer/datasets/pretokenize/QM9_Pretokenize.py +++ b/applications/FLASK/Transformer/datasets/pretokenize/QM9_Pretokenize.py @@ -1,26 +1,34 @@ import numpy as np from SMILES_tokenizer import MolTokenizer from data_utils import random_zero_array +import os +import os.path def main(): - tokenizer = MolTokenizer("SMILES_vocab.json") + data_dir = os.getenv("QM9_DATA_DIR", "/p/vast1/lbann/datasets/FLASK/QM9") + + tokenizer = MolTokenizer(os.path.join(data_dir, "QM9_vocab.json")) tokenizer.load_vocab_file() - with open("QM9_smiles.txt", 'r') as smiles_data: - smiles_data = smiles_data.readlines() - num_samples = len(smiles_data) - max_length = 32 - tokenized_data = np.ones((num_samples, max_length)) * tokenizer.encode(tokenizer.pad_token) - tokenized_data[:, 0] = tokenizer.encode(tokenizer.sep_token) + data_file = os.path.join(data_dir, "QM9_smiles.txt") + with open(data_file, "r") as smiles_data: + smiles_data = smiles_data.readlines() + num_samples = len(smiles_data) + max_length = 32 + + tokenized_data = np.ones((num_samples, max_length)) * tokenizer.encode( + tokenizer.pad_token + ) + tokenized_data[:, 0] = tokenizer.encode(tokenizer.sep_token) - for i, smiles in enumerate(smiles_data, start=1): - tokens = tokenizer.tokenize(smiles) - tokens = random_zero_array(tokens, 0.15, tokenizer.encode(tokenizer.mask_token)) - tokenized_data[i, :len(tokens)] = tokens - tokenized_data[i, len(tokens)] = tokenizer.encode(tokenizer.cls_token) + for i, smiles in enumerate(smiles_data, start=0): + tokens = tokenizer.tokenize(smiles) + tokenized_data[i, : len(tokens)] = tokens + tokenized_data[i, len(tokens)] = tokenizer.encode(tokenizer.sep_token) + save_file_loc = os.path.join(data_dir, "QM9_Pretokenized.npy") + np.save(save_file_loc, tokenized_data) - np.save('QM9_Pretokenized.npy', tokenized_data) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/applications/FLASK/Transformer/network.py b/applications/FLASK/Transformer/network.py index f3b08c91c5f..362cc81f0d8 100644 --- a/applications/FLASK/Transformer/network.py +++ b/applications/FLASK/Transformer/network.py @@ -144,11 +144,11 @@ def _add_input_encoding( # Apply encoder if encoder_input is not None: - encoder_input = positional_encoder( + encoder_input = positional_encoder.apply_input( encoder_input, encoder_sequence_length, **kwargs ) if decoder_input is not None: - decoder_input = positional_encoder( + decoder_input = positional_encoder.apply_input( decoder_input, decoder_sequence_length, **kwargs )