From 01e9de18bc3743c6fce1655b235b99ac2c2d4a31 Mon Sep 17 00:00:00 2001 From: Kevin Maik Jablonka Date: Mon, 12 Aug 2024 14:07:40 -0700 Subject: [PATCH] refactor and test sampling engine --- .flake8 | 17 -- notebooks/example.ipynb | 40 --- notebooks/explore_inference.ipynb | 135 --------- notebooks/explore_training_loss.ipynb | 221 -------------- notebooks/huggingface_tokenizer.ipynb | 147 --------- src/chemnlp/data/constants.py | 272 +++++++++++++++++ src/chemnlp/data/container.py | 37 +++ src/chemnlp/data/random_variable.py | 29 ++ src/chemnlp/data/sampler.py | 417 ++++++++++++++++++++++++++ src/chemnlp/data/utils.py | 19 ++ src/chemnlp/trainer.py | 77 ----- tests/data/__init__.py | 0 tests/data/test_sampler.py | 90 ++++++ 13 files changed, 864 insertions(+), 637 deletions(-) delete mode 100644 .flake8 delete mode 100644 notebooks/example.ipynb delete mode 100644 notebooks/explore_inference.ipynb delete mode 100644 notebooks/explore_training_loss.ipynb delete mode 100644 notebooks/huggingface_tokenizer.ipynb create mode 100644 src/chemnlp/data/constants.py create mode 100644 src/chemnlp/data/container.py create mode 100644 src/chemnlp/data/random_variable.py create mode 100644 src/chemnlp/data/sampler.py delete mode 100644 src/chemnlp/trainer.py create mode 100644 tests/data/__init__.py create mode 100644 tests/data/test_sampler.py diff --git a/.flake8 b/.flake8 deleted file mode 100644 index eaf1fe2d1..000000000 --- a/.flake8 +++ /dev/null @@ -1,17 +0,0 @@ -[flake8] -ignore = - S301 - S403 - S404 - S603 - W503 - E203 - S101 - D101 - D102 - D103 - D104 - D400 -max-line-length = 120 -max-complexity = 20 -import-order-style = isort diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb deleted file mode 100644 index 4be8cea09..000000000 --- a/notebooks/example.ipynb +++ /dev/null @@ -1,40 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(\"Testing 123\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.8.16 ('chemnlp')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.16" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "7414a981235d2f3cf02118fe0f4e0887d77148dddbf4f0d3c9ba26e586868a93" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/explore_inference.ipynb b/notebooks/explore_inference.ipynb deleted file mode 100644 index a77de02ed..000000000 --- a/notebooks/explore_inference.ipynb +++ /dev/null @@ -1,135 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import AutoTokenizer, GPTNeoXForCausalLM\n", - "import os\n", - "import json" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def load_json(name: str):\n", - " with open(name, 'r') as f:\n", - " return json.load(f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ROOT_300M_MODELS = '/fsx/proj-chemnlp/experiments/checkpoints/finetuned/300M-tokenised-gridsearch-v1'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "models = {}\n", - "NAME = '1B_fine_tune_1' # xyz model\n", - "COLLECT_ALL = True # can be expensive (30 seconds per 1B param model)\n", - "\n", - "if COLLECT_ALL:\n", - " # traverse the directory for all models\n", - " for name in os.listdir(ROOT_300M_MODELS):\n", - " model_path = f\"{ROOT_300M_MODELS}/{name}\"\n", - " if not name.endswith('.json') and 'checkpoint-final' in os.listdir(model_path):\n", - " models[name] = {\n", - " 'model': GPTNeoXForCausalLM.from_pretrained(pretrained_model_name_or_path=f\"{model_path}/checkpoint-final\"),\n", - " 'configs': load_json(f\"{model_path}_global_0_local_0_rank_overrides.json\")\n", - " }\n", - "else:\n", - " model_path = f\"{ROOT_300M_MODELS}/{NAME}\"\n", - " models[NAME] = {\n", - " 'model': GPTNeoXForCausalLM.from_pretrained(pretrained_model_name_or_path=f\"{model_path}/checkpoint-final\"),\n", - " 'configs': load_json(f\"{model_path}_global_0_local_0_rank_overrides.json\")\n", - " }" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tokeniser = AutoTokenizer.from_pretrained(\"EleutherAI/pythia-1b\")\n", - "tokeniser.add_special_tokens({\"pad_token\": \"<|padding|>\"})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "TEST_CASES = [\n", - " 'I enjoy walking with my cute dog',\n", - " 'The heaviest element in the periodic table is',\n", - " 'C 6.39 2.84 -1.46 O 6.12 1.57 -0.86 P 5.14 1.10 0.31',\n", - " 'The element carbon is denoted with the following symbol',\n", - "]\n", - "OUTPUT_LEN = 20" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i, test_case in enumerate(TEST_CASES): \n", - " print(f'TEST CASE {i+1} -> {test_case}')\n", - " input_ids = tokeniser.encode(test_case, return_tensors='pt')\n", - "\n", - " for model_configs in models.values():\n", - " greedy_output = model_configs['model'].generate(\n", - " input_ids, \n", - " max_length=input_ids.shape[-1]+OUTPUT_LEN, \n", - " pad_token_id=tokeniser.eos_token_id\n", - " )\n", - "\n", - " checkpoint_dir = list(model_configs['configs'].keys())[0]\n", - " dataset_name = model_configs['configs'][checkpoint_dir]['data']['path'].split('/')[-1]\n", - "\n", - " print(f\"\\nOutput for model trained on 300M {dataset_name}\\n\" + 100 * '-')\n", - " print(tokeniser.decode(greedy_output[0], skip_special_tokens=True))\n", - " print( )\n", - " print( )" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.16" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/explore_training_loss.ipynb b/notebooks/explore_training_loss.ipynb deleted file mode 100644 index ee4ecd7f6..000000000 --- a/notebooks/explore_training_loss.ipynb +++ /dev/null @@ -1,221 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from peft import PromptTuningConfig, PromptTuningInit, TaskType, get_peft_model\n", - "from transformers import (\n", - " AutoTokenizer,\n", - " DataCollatorForLanguageModeling,\n", - " GPTNeoXForCausalLM,\n", - " Trainer,\n", - " TrainingArguments\n", - ")\n", - "from torch.utils.data import DataLoader\n", - "import datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = datasets.load_from_disk(\"/fsx/proj-chemnlp/data/EleutherAI/pythia-160m/marianna13/chemrxiv\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tokenizer = AutoTokenizer.from_pretrained(\n", - " pretrained_model_name_or_path=\"EleutherAI/pythia-160m\",\n", - " revision=\"main\",\n", - ")\n", - "tokenizer.add_special_tokens({\"pad_token\": \"<|padding|>\"})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", - "dl = DataLoader(dataset, batch_size=2, collate_fn=data_collator)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = GPTNeoXForCausalLM.from_pretrained(\n", - " pretrained_model_name_or_path=\"EleutherAI/pythia-160m\",\n", - " revision=\"main\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "peft_config = PromptTuningConfig(\n", - " task_type=TaskType.CAUSAL_LM,\n", - " prompt_tuning_init=PromptTuningInit.TEXT,\n", - " num_virtual_tokens=10,\n", - " prompt_tuning_init_text=\" \",\n", - " tokenizer_name_or_path=\"EleutherAI/pythia-160m\",\n", - ")\n", - "peft_model = get_peft_model(model, peft_config)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for i, (p1,p2) in enumerate(zip(model.parameters(), peft_model.parameters())):\n", - " import torch\n", - " if not torch.equal(p1,p2):\n", - " print(i, False)\n", - " else:\n", - " print(i)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for b in dl:\n", - " token_types = b.pop('token_type_ids')\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "out = model(**b)\n", - "out['loss']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "out = peft_model(**b)\n", - "out['loss']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all([p.requires_grad for p in model.parameters()]), all([p.requires_grad for p in peft_model.parameters()])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "split_dataset = dataset.train_test_split(test_size=0.025)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "training_args = TrainingArguments(\n", - " output_dir='./',\n", - " num_train_epochs=1,\n", - " learning_rate=3e-4,\n", - " evaluation_strategy='steps',\n", - " logging_steps=1,\n", - " eval_steps=50,\n", - " dataloader_num_workers=4,\n", - " bf16=True,\n", - " fp16=False,\n", - " per_device_train_batch_size=28,\n", - " per_device_eval_batch_size=28,\n", - " report_to=\"none\",\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer = Trainer(\n", - " model=peft_model,\n", - " args=training_args,\n", - " train_dataset=split_dataset[\"train\"],\n", - " eval_dataset=split_dataset[\"test\"],\n", - " tokenizer=tokenizer,\n", - " data_collator=data_collator,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.train()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.16" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/huggingface_tokenizer.ipynb b/notebooks/huggingface_tokenizer.ipynb deleted file mode 100644 index 5885a7365..000000000 --- a/notebooks/huggingface_tokenizer.ipynb +++ /dev/null @@ -1,147 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "7d4ff501-0ba7-41b6-8126-b29155ca7b54", - "metadata": {}, - "outputs": [], - "source": [ - "!pip install tokenizers" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "daffceb5-7bf6-40ce-b1e8-d6ef7a5ebce5", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from tokenizers import Tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "c688695b-e219-4588-a134-d63f3f6f5435", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "tokenizer = Tokenizer.from_pretrained(\"EleutherAI/pythia-12b\")" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "41560898-c9f4-402f-8726-583a3216f9f4", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "examples = [\n", - " \"This is a carbohydrate with sum formula C6H12O6 and a molecular weight of 180.16 g mol^-1\",\n", - " \"The studied metal complex has a composition of [Ru(py)(bpy)(terpy)](PF6)2 and is soluble in CD3CN\",\n", - " \"Nitrobenzene with a sum formula of C6H5NO2 has an InCHI-IDENtifier InChI=1S/C6H5NO2/c8-7(9)6-4-2-1-3-5-6/h1-5H/i1+1,2+1,3+1,4+1,5+1,6+1\",\n", - " \"Ibuprofen with a formula of (H3C)2CCH2-C6H4-CH(CH3)COOH has a SMILES string CC(C)Cc1cCc(cc1)C(C)C(O)=O\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "1b8af9b0-1a17-499d-95af-57cf10a15731", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "outputs = tokenizer.encode_batch(examples)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "f75b39db-fb58-4996-9fbc-bd321b313447", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['This', 'Ġis', 'Ġa', 'Ġcarbohydrate', 'Ġwith', 'Ġsum', 'Ġformula', 'ĠC', '6', 'H', '12', 'O', '6', 'Ġand', 'Ġa', 'Ġmolecular', 'Ġweight', 'Ġof', 'Ġ180', '.', '16', 'Ġg', 'Ġmol', '^-', '1']\n", - "['The', 'Ġstudied', 'Ġmetal', 'Ġcomplex', 'Ġhas', 'Ġa', 'Ġcomposition', 'Ġof', 'Ġ[', 'Ru', '(', 'py', ')(', 'b', 'py', ')(', 'ter', 'py', ')](', 'PF', '6', ')', '2', 'Ġand', 'Ġis', 'Ġsoluble', 'Ġin', 'ĠCD', '3', 'CN']\n", - "['N', 'itro', 'benz', 'ene', 'Ġwith', 'Ġa', 'Ġsum', 'Ġformula', 'Ġof', 'ĠC', '6', 'H', '5', 'NO', '2', 'Ġhas', 'Ġan', 'ĠIn', 'CH', 'I', '-', 'ID', 'EN', 't', 'ifier', 'ĠIn', 'Ch', 'I', '=', '1', 'S', '/', 'C', '6', 'H', '5', 'NO', '2', '/', 'c', '8', '-', '7', '(', '9', ')', '6', '-', '4', '-', '2', '-', '1', '-', '3', '-', '5', '-', '6', '/', 'h', '1', '-', '5', 'H', '/', 'i', '1', '+', '1', ',', '2', '+', '1', ',', '3', '+', '1', ',', '4', '+', '1', ',', '5', '+', '1', ',', '6', '+', '1']\n", - "['I', 'b', 'up', 'ro', 'fen', 'Ġwith', 'Ġa', 'Ġformula', 'Ġof', 'Ġ(', 'H', '3', 'C', ')', '2', 'CC', 'H', '2', '-', 'C', '6', 'H', '4', '-', 'CH', '(', 'CH', '3', ')', 'C', 'OO', 'H', 'Ġhas', 'Ġa', 'ĠSM', 'IL', 'ES', 'Ġstring', ' ', 'CC', '(', 'C', ')', 'C', 'c', '1', 'c', 'C', 'c', '(', 'cc', '1', ')', 'C', '(', 'C', ')', 'C', '(', 'O', ')=', 'O']\n" - ] - } - ], - "source": [ - "for output in outputs:\n", - " print(output.tokens)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "bebd084e-1ad1-4c1f-a4ad-ea6cfa3b48d6", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "This is a carbohydrate with sum formula C6H12O6 and a molecular weight of 180.16 g mol^-1\n", - "['This', 'Ġis', 'Ġa', 'Ġcarbohydrate', 'Ġwith', 'Ġsum', 'Ġformula', 'ĠC', '6', 'H', '12', 'O', '6', 'Ġand', 'Ġa', 'Ġmolecular', 'Ġweight', 'Ġof', 'Ġ180', '.', '16', 'Ġg', 'Ġmol', '^-', '1']\n", - "\n", - "The studied metal complex has a composition of [Ru(py)(bpy)(terpy)](PF6)2 and is soluble in CD3CN\n", - "['The', 'Ġstudied', 'Ġmetal', 'Ġcomplex', 'Ġhas', 'Ġa', 'Ġcomposition', 'Ġof', 'Ġ[', 'Ru', '(', 'py', ')(', 'b', 'py', ')(', 'ter', 'py', ')](', 'PF', '6', ')', '2', 'Ġand', 'Ġis', 'Ġsoluble', 'Ġin', 'ĠCD', '3', 'CN']\n", - "\n", - "Nitrobenzene with a sum formula of C6H5NO2 has an InCHI-IDENtifier InChI=1S/C6H5NO2/c8-7(9)6-4-2-1-3-5-6/h1-5H/i1+1,2+1,3+1,4+1,5+1,6+1\n", - "['N', 'itro', 'benz', 'ene', 'Ġwith', 'Ġa', 'Ġsum', 'Ġformula', 'Ġof', 'ĠC', '6', 'H', '5', 'NO', '2', 'Ġhas', 'Ġan', 'ĠIn', 'CH', 'I', '-', 'ID', 'EN', 't', 'ifier', 'ĠIn', 'Ch', 'I', '=', '1', 'S', '/', 'C', '6', 'H', '5', 'NO', '2', '/', 'c', '8', '-', '7', '(', '9', ')', '6', '-', '4', '-', '2', '-', '1', '-', '3', '-', '5', '-', '6', '/', 'h', '1', '-', '5', 'H', '/', 'i', '1', '+', '1', ',', '2', '+', '1', ',', '3', '+', '1', ',', '4', '+', '1', ',', '5', '+', '1', ',', '6', '+', '1']\n", - "\n", - "Ibuprofen with a formula of (H3C)2CCH2-C6H4-CH(CH3)COOH has a SMILES string CC(C)Cc1cCc(cc1)C(C)C(O)=O\n", - "['I', 'b', 'up', 'ro', 'fen', 'Ġwith', 'Ġa', 'Ġformula', 'Ġof', 'Ġ(', 'H', '3', 'C', ')', '2', 'CC', 'H', '2', '-', 'C', '6', 'H', '4', '-', 'CH', '(', 'CH', '3', ')', 'C', 'OO', 'H', 'Ġhas', 'Ġa', 'ĠSM', 'IL', 'ES', 'Ġstring', ' ', 'CC', '(', 'C', ')', 'C', 'c', '1', 'c', 'C', 'c', '(', 'cc', '1', ')', 'C', '(', 'C', ')', 'C', '(', 'O', ')=', 'O']\n", - "\n" - ] - } - ], - "source": [ - "for e in examples:\n", - " print(e)\n", - " enc = tokenizer.encode(e)\n", - " print(enc.tokens)\n", - " print()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "chemnlp", - "language": "python", - "name": "chemnlp" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.16" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/src/chemnlp/data/constants.py b/src/chemnlp/data/constants.py new file mode 100644 index 000000000..716d41275 --- /dev/null +++ b/src/chemnlp/data/constants.py @@ -0,0 +1,272 @@ +DEFAULT_SIGNIFICANT_DIGITS = 3 + + + +STANDARD_TABULAR_TEXT_TEMPLATES = [ + "The molecule with the {SMILES__description} {#representation of |!}{SMILES#} has a {TARGET__names__noun} of {TARGET#} {TARGET__units}.", # noqa: E501 + "Based on the {SMILES__description} {#representation of |!}{SMILES#}, the molecule has a {TARGET__names__noun} of {TARGET#} {TARGET__units}.", # noqa: E501 + "The {SMILES__description} {SMILES#} {#represents|is representing!} a molecule {#that has a|with a!} {TARGET__names__noun} of {TARGET#} {TARGET__units}.", # noqa: E501 + "The molecule with the {SMILES__description} {SMILES#} has a {TARGET__names__noun} of {TARGET#} {TARGET__units}.", + # Instruction tuning text templates + """Task: Please predict a molecule feature based on the description. +Description: Predict the {TARGET__names__noun} in {TARGET__units}. +{#Molecule |!}{SMILES__description}: {SMILES#} +Constraint: Even if you are {#uncertain|not sure!}, you must answer with a numeric value in {TARGET__units} without using any {#other|additional!} words. +Result: {TARGET#} {TARGET__units}""", # noqa: E501 + """Task: Please predict a molecule feature based on the description. +Description: Predict the {TARGET__names__noun} in {TARGET__units}. +{SMILES__description}: {SMILES#} +Constraint: Even if you are {#uncertain|not sure!}, you must answer with a numeric value in {TARGET__units} without the unit and without using any {#other|additional!} words. +Result: {TARGET#}""", # noqa: E501 + """Task: Please {#give me|create|generate!} a {#molecule|chemical|compound!} with {SMILES__description} based on the {#text |!}description{# below|!}. +Description: A molecule that has a {TARGET__names__noun} of {TARGET#} {TARGET__units}. +Result: {SMILES#}""", # noqa: E501 + # Conversational text templates + """User: Can you {#tell me|derive|estimate!} the {TARGET__names__noun} in {TARGET__units} of the molecule with the {SMILES__description} {SMILES#}? +Assistant: {#Yes|Of course|Sure|Yes, I'm happy to help!}, this molecule has a {TARGET__names__noun} of {TARGET#} {TARGET__units}.""", # noqa: E501 + """User: Can you {#give me|create|generate!} the {SMILES__description} of a molecule that has a {TARGET__names__noun} of {TARGET#} {TARGET__units}? +Assistant: {#Yes|Of course|Sure|Yes, I'm happy to help!}, here you go: {SMILES#}""", # noqa: E501 + """User: I'm {#searching|looking!} for the {SMILES__description} of a molecule that has a {TARGET__names__noun} of {TARGET#} {TARGET__units}. +Assistant: This is a molecule that has a {TARGET__names__noun} of {TARGET#} {TARGET__units}: {SMILES#}""", # noqa: E501 + """User: I want to {#come up with|create|generate!} the {SMILES__description} of a {#molecule|chemical|chemical compound!}. +Assistant: {#This sounds very exciting. |This sounds very interesting. !}Should I consider any {#constraints|specific points!} for the {#generation|creation!}? +User: Yes, please. The molecule should have a {TARGET__names__noun} of {TARGET#} {TARGET__units}. +Assistant: {#Ok|Got it!},{# here you go,|!} this {SMILES__description} represents a molecule that has a {TARGET__names__noun} of {TARGET#} {TARGET__units}: {SMILES#}""", # noqa: E501 + """User: I want to {#come up with|create|generate!} a {SMILES__description} of a {#molecule|chemical|chemical structure!}. +Assistant: {#This sounds very exciting. |This sounds very interesting. !}Should it be a special {#molecule|one!}? +User: Yes, the molecule should have a {TARGET__names__noun} of {TARGET#} {TARGET__units}. +Assistant: {#Understood|Got it|Ok!}, this {SMILES__description} represents a molecule that has a {TARGET__names__noun} of {TARGET#} {TARGET__units}: {SMILES#}""", # noqa: E501 + # Benchmarking text templates + "The {TARGET__names__noun} of the molecule with the {SMILES__description} {SMILES#} is:{TARGET#} {TARGET__units}", # noqa: E501 + "The {TARGET__names__noun} of the {SMILES__description} {SMILES#} is:{TARGET#} {TARGET__units}", # noqa: E501 + "The {TARGET__names__noun} of the molecule {SMILES__description} {SMILES#} is:{TARGET#} {TARGET__units}", # noqa: E501 + """Task: Please predict a molecule feature based on the description. +Description: Predict the {TARGET__names__noun} in {TARGET__units} of a molecule. +{#Molecule |!}{SMILES__description}: {SMILES#} +Constraint: Even if you are {#uncertain|not sure!}, you must answer with a numeric value in {TARGET__units} without using any {#other|additional!} words. +Result:{TARGET#} {TARGET__units}""", # noqa: E501 + """Task: Please predict a molecule feature based on the description. +Description: Predict the {TARGET__names__noun} in {TARGET__units} of a molecule. +{#Molecule |!}{SMILES__description}: {SMILES#} +Constraint: Even if you are {#uncertain|not sure!}, you must answer with a numeric value in {TARGET__units} without the unit and without using any {#other|additional!} words. +Result:{TARGET#}""", # noqa: E501 + """Task: Please {#give me|create|generate!} a {#molecule |!}{SMILES__description} based on the {#text |!}description{# below|!}. +Description: A molecule that has a {TARGET__names__noun} of {TARGET#} {TARGET__units}. +Result:{SMILES#}""", # noqa: E501 +] + + +EXCLUDE_FROM_STANDARD_TABULAR_TEXT_TEMPLATES = [ + "BACE", + "BBBP", # because it is boolean target data + "MUV_466", # boolean target data + "MUV_548", # boolean target data + "MUV_600", # boolean target data + "MUV_644", # boolean target data + "MUV_652", # boolean target data + "MUV_689", # boolean target data + "MUV_692", # boolean target data + "MUV_712", # boolean target data + "MUV_713", # boolean target data + "MUV_733", # boolean target data + "MUV_737", # boolean target data + "MUV_810", # boolean target data + "MUV_832", # boolean target data + "MUV_846", # boolean target data + "MUV_852", # boolean target data + "MUV_858", # boolean target data + "MUV_859", # boolean target data + "RedDB", + "SIDER", + "ames_mutagenicity", # because it is boolean target data + "aminoacids", + "bc5chem", + "bc5disease", + "bicerano_dataset", + "bio_ner", + "bioavailability_ma_et_al", # because it is boolean target data + "block_polymers_morphology", + "blood_brain_barrier_martins_et_al", # because it is boolean target data + "buchwald_hartwig", + "carcinogens", # because it is boolean target data + "cav3_t-type_calcium_channels_butkiewicz", # because it is boolean target data + "chebi_20", # target is text description + "chem_caption_smarts", + "chembl_v29", # text only, no SMILES + "chemcaption_fragments", + "chemcaption_rdkit", # text only, no SMILES + "chemdner", + "chemistry_stackexchange", + "choline_transporter_butkiewicz", # because it is boolean target data + "clintox", # because it is boolean target data + "compound_chebi_chebi_chebi_1", + "compound_chebi_chebi_chebi_2", + "core_mof_no_topo", + "cyp2c9_substrate_carbonmangels", # boolean target data + "cyp2d6_substrate_carbonmangels", # boolean target data + "cyp3a4_substrate_carbonmangels", # boolean target data + "cyp_p450_1a2_inhibition_veith_et_al", # boolean target data + "cyp_p450_2c19_inhibition_veith_et_al", # boolean target data + "cyp_p450_2c9_inhibition_veith_et_al", # boolean target data + "cyp_p450_2d6_inhibition_veith_et_al", # boolean target data + "cyp_p450_3a4_inhibition_veith_et_al", # boolean target data + "drug_chebi_chebi_chebi", + "drug_induced_liver_injury", # boolean target data + "drugchat_liang_zhang_et_al", # text + "fda_adverse_reactions", + "formation_energies", + "freesolv", # more than one target + "h2_storage_materials", + "herg_blockers", # more than one target + "herg_central_inhib", # boolean target data + "herg_karim_et_al", # boolean target data + "hiv", # boolean target data + "human_intestinal_absorption", # boolean target data + "iupac_goldbook", # text only, no SMILES + "iupac_smiles", # translation from IUPAC name to SMILES + "kcnq2_potassium_channel_butkiewicz", # boolean target data + "m1_muscarinic_receptor_agonists_butkiewicz", # boolean target data + "m1_muscarinic_receptor_antagonists_butkiewicz", # boolean target data + "mattermodeling_stackexchange", + "melting_points", + "mofdscribe", + "mol2svg", + "mol_repr_transl_canonical_inchi", + "mol_repr_transl_canonical_iupac_name", + "mol_repr_transl_deepsmiles_canonical", + "mol_repr_transl_deepsmiles_inchi", + "mol_repr_transl_deepsmiles_iupac_name", + "mol_repr_transl_inchi_iupac_name", + "mol_repr_transl_selfies_canonical", + "mol_repr_transl_selfies_deepsmiles", + "mol_repr_transl_selfies_inchi", + "mol_repr_transl_selfies_iupac_name", + "mol_repr_transl_smiles_canonical", + "mol_repr_transl_smiles_deepsmiles", + "mol_repr_transl_smiles_inchi", + "mol_repr_transl_smiles_iupac_name", + "mol_repr_transl_smiles_selfies", + "mona", # more than one target + "moses", + "moses", # SMILES only, has no target + "mp_anisotropy", + "mp_bulk_modulus", + "mp_descriptions", + "mp_self_supervised", + "mp_shear_modulus", + "ncbi_disease", + "nlmchem", # text only, no SMILES + "nomad_structure", + "nr_ahr_tox21", # boolean target data + "nr_ar_lbd_tox21", # boolean target data + "nr_ar_tox21", # boolean target data + "nr_aromatase_tox21", # boolean target data + "nr_er_lbd_tox21", # boolean target data + "nr_er_tox21", # boolean target data + "nr_ppar_gamma_tox21", # boolean target data + "ocp", + "odd_one_out", + "opv", + "oqmd", + "orbnet_denali", # only makes sense for the structure files + "ord_masked", + "ord_predictions", + "ord_procedure_steps", + "ord_rxn_smiles_procedure", + "ord_rxn_smiles_yield_pred", + "ord_steps_yield", + "orexin1_receptor_butkiewicz", # boolean target data + "p_glycoprotein_inhibition_broccatelli_et_al", # boolean target data + "pampa_ncats", # boolean target data + "peptides_hemolytic", # boolean target data + "peptides_nonfouling", # boolean target data + "peptides_soluble", # boolean target data + "perovskite_db", + "physics_stackexchange", + "potassium_ion_channel_kir2_1_butkiewicz", # boolean target data + "qm8", + "qm9", + "qmof_gcmc", + "qmof_quantum", + "rhea_db_masked", + "rhea_db_predictions", + "sarscov2_3clpro_diamond", # boolean target data + "sarscov2_vitro_touret", # boolean target data + "serine_threonine_kinase_33_butkiewicz", # boolean target data + "skin_reaction", # boolean target data + "smiles_to_3d", + "sr_are_tox21", # boolean target data + "sr_atad5_tox21", # boolean target data + "sr_hse_tox21", # boolean target data + "sr_mmp_tox21", # boolean target data + "sr_p53_tox21", # boolean target data + "suzuki_miyaura_sach", + "tyrosyl-dna_phosphodiesterase_butkiewicz", # boolean target data + "uniprot_binding_single", + "uniprot_binding_sites_multiple", + "uniprot_organisms", + "uniprot_reactions", + "uniprot_sentences", + "uspto", + "uspto_yield", + "zinc", # SMILES only, has no target + # "h2_storage_materials", # only IUPAC identifier, more than one target, LOW PRIO: has only 30 samples +] + + +LM_EVAL_YAML_TEAMPLTE_LOGLIKELIHOOD = { + "group": [ + "chemnlp", + "loglikelihood", + ], + "task": None, + "dataset_path": None, + "dataset_name": None, + "output_type": "loglikelihood", + "doc_to_text": "input", + "doc_to_target": "output", + "metric_list": [ + { + "metric": "perplexity", + "aggregation": "perplexity", + "higher_is_better": False, + }, + { + "metric": "acc", + "aggregation": "mean", + "higher_is_better": True, + }, + ], +} + +LM_EVAL_YAML_TEMPLATE_MULTIPLE_CHOICE = { + "group": [ + "chemnlp", + "multiple_choice", + ], + "task": None, + "dataset_path": None, + "dataset_name": None, + "output_type": "multiple_choice", + "doc_to_text": "input", + "doc_to_target": "output", + "doc_to_choice": "{{answer_choices}}", + "metric_list": [ + { + "metric": "acc", + "aggregation": "mean", + "higher_is_better": True, + }, + { + "metric": "acc_norm", + "aggregation": "mean", + "higher_is_better": True, + }, + # todo: check acc_mutual_info because it breaks + # { + # "metric": "acc_mutual_info", + # "aggregation": "mean", + # "higher_is_better": True, + # }, + ], +} diff --git a/src/chemnlp/data/container.py b/src/chemnlp/data/container.py new file mode 100644 index 000000000..4dba0c6a1 --- /dev/null +++ b/src/chemnlp/data/container.py @@ -0,0 +1,37 @@ +from dependency_injector import containers, providers +from chemnlp.data_processing.template_sampler import TemplateSampler +from chemnlp.data.utils import load_yaml +import pandas as pd +import logging + +class Container(containers.DeclarativeContainer): + + config = providers.Configuration(yaml_files=["config/default_config.yaml"]) + + # Configure logging + logging = providers.Resource( + logging.basicConfig, + level=config.logging.level, + format=config.logging.format + ) + + # Provide the logger + logger = providers.Factory(logging.getLogger, name="chemnlp") + + # Provide the YAML loader + yaml_loader = providers.Factory(load_yaml) + + # Provide the DataFrame loader + df_loader = providers.Factory( + pd.read_csv, + low_memory=False + ) + + # Provide the TemplateSampler + template_sampler = providers.Factory( + TemplateSampler, + logger=logger, + yaml_loader=yaml_loader, + df_loader=df_loader, + config=config.template_sampler + ) diff --git a/src/chemnlp/data/random_variable.py b/src/chemnlp/data/random_variable.py new file mode 100644 index 000000000..9b61636cf --- /dev/null +++ b/src/chemnlp/data/random_variable.py @@ -0,0 +1,29 @@ +import random +from functools import partial +from typing import Callable, Optional + + +def unwrap_list_length_1(list_input: list): + """Unwraps lists of length 1 and returns the first = single element.""" + if isinstance(list_input, list): + assert len(list_input) == 1 + return list_input[0] + else: + raise NotImplementedError() + + +class RandomVariable: + """Simple random variable class that takes in a name, data, and a sampler. + The sampler needs to return a single element.""" + + def __init__(self, name: str, data: list, sampler: Optional[Callable] = None): + self.name = name + self.data = data + self.sampler = partial(random.sample, k=1) if sampler is None else sampler + + def __repr__(self): + return f"RandomVariable: {self.name}, {self.data}, {self.sampler}" + + def __call__(self) -> str: + """Carries out sampling and returns a single element.""" + return unwrap_list_length_1(self.sampler(self.data)) diff --git a/src/chemnlp/data/sampler.py b/src/chemnlp/data/sampler.py new file mode 100644 index 000000000..dd91ca4f1 --- /dev/null +++ b/src/chemnlp/data/sampler.py @@ -0,0 +1,417 @@ +from chemnlp.data.constants import DEFAULT_SIGNIFICANT_DIGITS +import pandas as pd +import random +import math +from typing import List, Dict, Union, Callable, Optional, Tuple +import re +from string import ascii_lowercase, ascii_uppercase +from chemnlp.data.random_variable import RandomVariable +from functools import partial +from functools import lru_cache + +class TemplateSampler: + """ + A class for sampling and generating text based on templates and data. + + This class handles the creation of text samples from templates, managing both + standard variable substitution and multiple-choice question generation. It supports + various data types and sampling methods, including class-balanced sampling and + benchmarking templates. + + Attributes: + df (pd.DataFrame): The dataset used for sampling. + meta (Dict): Metadata about the dataset, including identifiers and targets. + config (Dict): Configuration parameters for the sampler. + column_datafield_sampler (Callable): A function for sampling from multiple options. + + Examples: + >>> config = { + ... 'DEFAULT_SIGNIFICANT_DIGITS': 3, + ... 'multiple_choice_rnd_symbols': ["", ".", ".)", ")", ":", "()", "[]"], + ... 'multiple_choice_benchmarking_templates': False, + ... 'multiple_choice_benchmarking_format': None + ... } + >>> sampler = TemplateSampler(df, meta, config) + >>> template = "The molecule with SMILES {SMILES#} has a {property#} of {value#}." + >>> result = sampler.sample(df.iloc[0], template) + >>> print(result) + The molecule with SMILES CC(=O)OC1=CC=CC=C1C(=O)O has a solubility of 3.142. + """ + def __init__( + self, + df: pd.DataFrame, + meta: Dict, + config: Dict, + column_datafield_sampler: Optional[Callable] = None + ): + self.df_orig = df + self.df = df + self.meta = meta + self.config = config + self.column_datafield_sampler = column_datafield_sampler or (lambda x: random.sample(x, k=1)) + self.class_balanced = False + self.balance_column = None + + def _balance_classes(self, column: str) -> pd.DataFrame: + """ + Create a class-balanced version of the dataset. + + Args: + column (str): The column to use for balancing. + + Returns: + pd.DataFrame: A new dataframe with balanced classes. + """ + value_counts = self.df_orig[column].value_counts() + min_count = value_counts.min() + balanced_dfs = [] + + for value in value_counts.index: + class_df = self.df_orig[self.df_orig[column] == value] + if len(class_df) > min_count: + class_df = class_df.sample(min_count) + balanced_dfs.append(class_df) + + return pd.concat(balanced_dfs, ignore_index=True) + + def enable_class_balancing(self, column: str): + """ + Enable class-balanced sampling. + + Args: + column (str): The column to use for balancing. + """ + self.class_balanced = True + self.balance_column = column + self.df = self._balance_classes(column) + + def disable_class_balancing(self): + """ + Disable class-balanced sampling and revert to the original dataset. + """ + self.class_balanced = False + self.balance_column = None + self.df = self.df_orig + + def _get_target_from_row(self, sample: pd.Series, var: str) -> str: + """ + Extract and process a target value from a sample row based on a variable string. + + This method handles various formats of the variable string to extract and process + data from the sample row. It supports multiple text string sampling, recoding, + multiple column selection, and special case handling for NaN values. + + The method also processes the extracted value based on its data type (continuous or not) + and handles cases where the extracted value itself contains multiple options. + + Args: + sample (pd.Series): A row from the dataset. + var (str): A string specifying how to extract the target value. This can include + special characters like '#', '!', '|', and '&' for different behaviors. + + Returns: + str: The extracted and processed target value. + + Raises: + ValueError: If a continuous value is not a number (float or int). + + Note: + - The behavior changes based on the format of the 'var' string: + - '#', '!', '|': Treats as synonm options. One is randomly selected. + - '#', '&': Treats as recoding information. + - '#', '|': Treats as multiple column selection. + - Only '#': Simple column value retrieval. + - Special handling is included for NaN values in certain column types. + - The method rounds continuous values to a specified number of significant digits. + - If the final value contains '|', it's split and a random option is chosen. + """ + if ("#" in var) and ("!" in var) and ("|" in var): + choices = var.replace("#", "").replace("!", "").split("|") + return self.column_datafield_sampler(choices)[0] + + elif ("#" in var) and ("&" in var): + var, choices = var.split("#") + choices = choices.split("&") + choice = choices[sample[var]] + return "" if choice == "NULL" else choice + + elif ("#" in var) and ("|" in var): + var = var.replace("#", "") + columns = var.split("|") + var = self.column_datafield_sampler(columns)[0] + out = sample[var] + + elif "#" in var: + var = var.replace("#", "") + out = sample[var] + if not isinstance(out, str) and math.isnan(out): + if "_smiles" in var: + out = sample[var.replace("_smiles", "_name")] + elif "_protein_names" in var: + out = sample[var.replace("_protein_names", "_name")] + + var_dict = next(x for x in self.meta["identifiers"] + self.meta["targets"] if x["id"] == var) + if var_dict["type"] == "continuous": + if not isinstance(out, (float, int)): + raise ValueError(f"out is not a number (int or float): {out}") + significant_digits = var_dict.get("significant_digits", self.config.get("DEFAULT_SIGNIFICANT_DIGITS", DEFAULT_SIGNIFICANT_DIGITS)) + out = f"{round(out, significant_digits):.{significant_digits}f}" + else: + out = str(out) + + if "|" in out: + choices = [c for c in out.split("|") if isinstance(c, str) or not math.isnan(c)] + out = self.column_datafield_sampler(choices)[0] + + return out + + def get_sample_dict(self, sample: pd.Series, template: str) -> Dict[str, str]: + """ + Extract and process all target values from a sample row based on a template. + """ + input_variables = self._get_input_variables_from_template(template) + sample_dict = {} + + if any("%" in x for x in input_variables): + sample_dict.update(self._handle_multiple_choice(sample, input_variables)) + + for var in input_variables: + if "#" in var: + sample_dict[var] = self._get_target_from_row(sample, var) + elif "%" not in var: + sample_dict[var] = self._get_target_from_string(var)() + + return sample_dict + + def _get_symbols_from_multiple_choice_enum(self, enum_str: str) -> List[str]: + _, choice_count, symbol = enum_str.split('%')[1:] + if '-' in choice_count: + min_count, max_count = map(int, choice_count.split('-')) + count = random.randint(min_count, max_count) + else: + count = int(choice_count) + + if 'a' in symbol: + return list(ascii_lowercase[:count]) + elif 'A' in symbol: + return list(ascii_uppercase[:count]) + elif '1' in symbol: + return [str(i) for i in range(1, count + 1)] + + def _format_enum_string(self, symbols: List[str]) -> str: + """ + Format a list of symbols into a string representation for multiple-choice questions. + + This method takes a list of symbols (e.g., ['a', 'b', 'c']) and formats them + into a string like "a, b, or c" for use in multiple-choice question prompts. + + Args: + symbols (List[str]): A list of symbols representing multiple-choice options. + + Returns: + str: A formatted string of the symbols. + + Examples: + ['a', 'b'] -> "a or b" + ['a', 'b', 'c'] -> "a, b, or c" + ['1', '2', '3', '4'] -> "1, 2, 3, or 4" + """ + if len(symbols) == 0: + return "" + elif len(symbols) == 1: + return symbols[0] + elif len(symbols) == 2: + return f"{symbols[0]} or {symbols[1]}" + else: + return ", ".join(symbols[:-1]) + f", or {symbols[-1]}" + + def _handle_multiple_choice(self, sample: pd.Series, input_variables: List[str]) -> Dict[str, Union[str, List[str]]]: + multiple_choice_dict = {} + + # get multiple_choice_enum + multiple_choice_enum_idx = [i for i, x in enumerate(input_variables) if x.startswith("%multiple_choice_enum")] + assert len(multiple_choice_enum_idx) == 1 + multiple_choice_enum_idx = multiple_choice_enum_idx[0] + multiple_choice_enum = input_variables[multiple_choice_enum_idx] + + # get multiple_choice_var + multiple_choice_var_idx = [i for i, x in enumerate(input_variables) if x.endswith("%")] + assert len(multiple_choice_var_idx) == 1 + multiple_choice_var_idx = multiple_choice_var_idx[0] + multiple_choice_input = input_variables[multiple_choice_var_idx] + + if multiple_choice_input.count("%") > 1: + multiple_choice_var, multiple_choice_indicator, _ = multiple_choice_input.split("%") + else: + multiple_choice_var, multiple_choice_indicator = multiple_choice_input.split("%") + multiple_choice_indicator = "" # multiple_choice_indicator is here an empty string + + symbols = self._get_symbols_from_multiple_choice_enum(multiple_choice_enum) + + # get all and correct choices incl. index + correct_choice = self._get_target_from_row(sample, multiple_choice_var + "#") + + if multiple_choice_indicator == "": + multiple_choices, correct_choice_idx = self._get_choices_without_indicator(multiple_choice_var, symbols, correct_choice) + else: + multiple_choices, correct_choice_idx = self._get_choices_with_indicator(sample, multiple_choice_var, multiple_choice_indicator, symbols, correct_choice) + + multiple_choice_dict[multiple_choice_enum] = self._format_enum_string(symbols) + multiple_choice_dict[multiple_choice_input] = self._format_choices(symbols, multiple_choices) + multiple_choice_dict["%multiple_choice_result"] = self._format_result(symbols, correct_choice_idx) + multiple_choice_dict["%multiple_choice_symbols"] = symbols + multiple_choice_dict["%multiple_choice_result_idx"] = correct_choice_idx + + return multiple_choice_dict + + def _get_choices_without_indicator(self, multiple_choice_var: str, symbols: List[str], correct_choice: str) -> Tuple[List[str], int]: + cutoff_full_unique = 100 + if len(self.df[multiple_choice_var].unique()) < cutoff_full_unique: + all_choices = sorted([str(x) for x in self.df[multiple_choice_var].unique()]) + else: + all_choices = sorted([str(x) for x in self.df[multiple_choice_var].sample(cutoff_full_unique).unique()]) + + if all_choices == ["0", "1"]: + all_choices = ["False", "True"] + correct_choice = all_choices[int(correct_choice)] + + multiple_choices = random.sample(all_choices, k=len(symbols)) + if correct_choice not in multiple_choices: + multiple_choices = multiple_choices[:-1] + [correct_choice] + random.shuffle(multiple_choices) + + correct_choice_idx = multiple_choices.index(correct_choice) + return multiple_choices, correct_choice_idx + + def _get_choices_with_indicator(self, sample: pd.Series, multiple_choice_var: str, multiple_choice_indicator: str, symbols: List[str], correct_choice: str) -> Tuple[List[str], List[int]]: + correct_choice_indicator = self._get_target_from_row(sample, multiple_choice_indicator + "#") + df_sample = self.df.sample(len(symbols) - 1)[[multiple_choice_var, multiple_choice_indicator]] + + multiple_choices = df_sample[multiple_choice_var].astype(str).tolist() + [correct_choice] + multiple_choices_indicators = df_sample[multiple_choice_indicator].astype(str).tolist() + [correct_choice_indicator] + + multiple_choices_combined = list(zip(multiple_choices, multiple_choices_indicators)) + random.shuffle(multiple_choices_combined) + multiple_choices, multiple_choices_indicators = zip(*multiple_choices_combined) + + correct_choice_idx = [i for i, (choice, indicator) in enumerate(zip(multiple_choices, multiple_choices_indicators)) + if indicator == correct_choice_indicator] + + return list(multiple_choices), correct_choice_idx + + def _format_choices(self, symbols: List[str], choices: List[str]) -> str: + rnd_symbol = self._get_random_symbol() + rnd_symbol_prefix, rnd_symbol_suffix = self._get_symbol_affixes(rnd_symbol) + + return "\n".join([f"{rnd_symbol_prefix}{s}{rnd_symbol_suffix} {c}" for s, c in zip(symbols, choices)]) + + def _format_result(self, symbols: List[str], correct_choice_idx: Union[int, List[int]]) -> str: + if isinstance(correct_choice_idx, list): + return ", ".join([symbols[i] for i in correct_choice_idx]) + else: + return symbols[correct_choice_idx] + + def _get_random_symbol(self) -> str: + if self.config.get('multiple_choice_benchmarking_templates') and self.config.get('multiple_choice_benchmarking_format') is not None: + if len(self.config['multiple_choice_rnd_symbols']) > 1: + return self.config['multiple_choice_rnd_symbols'][self.config['multiple_choice_benchmarking_format']] + else: + return self.config['multiple_choice_rnd_symbols'][0] + else: + return random.choice(self.config['multiple_choice_rnd_symbols']) + + def _get_symbol_affixes(self, symbol: str) -> Tuple[str, str]: + if symbol in ["()", "[]"]: + return symbol[0], symbol[1] + else: + return "", symbol + + def _get_input_variables_from_template(self, template: str) -> List[str]: + return re.findall(r'\{([^}]+)\}', template) + + @lru_cache(maxsize=None) + def _get_random_text_identifiers_and_targets(self) -> dict: + """Cached version of get_random_text_identifiers_and_targets""" + rnd_texts = {} + for e in self.meta["identifiers"] + self.meta["targets"]: + rnd_texts[e["id"]] = {} + if "names" in e: + rnd_texts[e["id"]]["names"] = {} + name_types = set([list(x.keys())[0] for x in e["names"]]) + for name in name_types: + rnd_text = RandomVariable( + f"{e['id']}__names__{name}", + [x[name] for x in e["names"] if name in x], + ) + rnd_texts[e["id"]]["names"][name] = rnd_text + + if "description" in e: + rnd_texts[e["id"]]["description"] = partial(lambda x: x, e["description"]) + + if "units" in e: + rnd_texts[e["id"]]["units"] = partial(lambda x: x, e["units"]) + + return rnd_texts + + def _get_target_from_string(self, var: str) -> str: + """ + Retrieve a target value from the meta information based on a string key. + + This method navigates through the nested structure of the meta dictionary + to find the appropriate value. + + Args: + var (str): A string key representing the path to the target value + in the meta dictionary, with levels separated by '__'. + + Returns: + str: The target value. + + Raises: + KeyError: If the specified path doesn't exist in the meta dictionary. + + Example: + If var is "SMILES__names__noun", it will look for + self.meta["SMILES"]["names"]["noun"] and return a RandomVariable if it's a list. + """ + keys = var.split("__") + + def get_with_nested_keys(d: dict, keys: list) -> Union[str, Callable]: + t = d + for k in keys: + if k not in t: + raise KeyError(f"Key '{k}' not found in nested dictionary.") + t = t[k] + return t + + if len(keys) == 1 and keys[0] in self.meta: + return self.meta[keys[0]] + elif keys[0] in [x["id"] for x in self.meta["identifiers"] + self.meta["targets"]]: + rnd_texts = self._get_random_text_identifiers_and_targets() + return get_with_nested_keys(rnd_texts, keys) + else: + raise KeyError(f"Unable to find key '{var}' in meta information.") + + def sample(self, sample: pd.Series, template: str) -> str: + """ + Generate a text sample based on a template and a data sample. + + If no sample is provided, a random sample is chosen from the current dataset + (which may be class-balanced if enabled). + + Args: + sample (Optional[pd.Series]): A row from the dataset. If None, a random sample is chosen. + template (str): The template string to be filled. + + Returns: + str: The completed text sample with all variables replaced by their values. + """ + if sample is None: + sample = self.df.sample(1).iloc[0] + sample_dict = self.get_sample_dict(sample, template) + return self._fill_template(template, sample_dict) + + def _fill_template(self, template: str, sample_dict: Dict[str, str]) -> str: + for key, value in sample_dict.items(): + template = template.replace('{' + key + '}', value) + return template diff --git a/src/chemnlp/data/utils.py b/src/chemnlp/data/utils.py index 69956e99c..21345d6db 100644 --- a/src/chemnlp/data/utils.py +++ b/src/chemnlp/data/utils.py @@ -8,6 +8,8 @@ import chemnlp.data.hf_datasets as hf_datasets +import yaml +from typing import Any def sample_dataset(dataset, num_samples): n = len(dataset) @@ -169,3 +171,20 @@ def oxford_comma_join(items: List[str]) -> str: return f"{items[0]} and {items[1]}" else: return ", ".join(items[:-1]) + f", and {items[-1]}" + + +def load_yaml(file_path: str) -> Any: + with open(file_path, 'r') as file: + return yaml.safe_load(file) + +def save_yaml(data: Any, file_path: str) -> None: + with open(file_path, 'w') as file: + yaml.dump(data, file, sort_keys=False) + +def str_presenter(dumper, data): + if len(data.splitlines()) > 1: # check for multiline string + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + +yaml.add_representer(str, str_presenter) +yaml.representer.SafeRepresenter.add_representer(str, str_presenter) diff --git a/src/chemnlp/trainer.py b/src/chemnlp/trainer.py deleted file mode 100644 index 20fa0a05a..000000000 --- a/src/chemnlp/trainer.py +++ /dev/null @@ -1,77 +0,0 @@ -"""A custom trainer for modifying data sampling behaviour""" - -from typing import Optional - -import datasets -import torch -from torch.utils.data import DataLoader, sampler -from transformers import Trainer -from transformers.trainer_pt_utils import IterableDatasetShard -from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available - - -class LLcheMTrainer(Trainer): - def __init__( - self, - sampler: Optional[sampler.Sampler] = None, - **kwargs, - ): - """ - Rewritten over from transformers 4.30.2 - * custom sampler - * all other kwargs get passed as normal - """ - super().__init__(**kwargs) - self.sampler = sampler - - def get_train_dataloader(self) -> DataLoader: - """ - Returns the training [`~torch.utils.data.DataLoader`]. - Uses default transformers behaviour unless a custom sampler is provided. - """ - if self.train_dataset is None: - raise ValueError("Trainer: training requires a train_dataset.") - - train_dataset = self.train_dataset - data_collator = self.data_collator - if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): - train_dataset = self._remove_unused_columns( - train_dataset, description="training" - ) - else: - data_collator = self._get_collator_with_removed_columns( - data_collator, description="training" - ) - - if isinstance(train_dataset, torch.utils.data.IterableDataset): - if self.args.world_size > 1: - train_dataset = IterableDatasetShard( - train_dataset, - batch_size=self._train_batch_size, - drop_last=self.args.dataloader_drop_last, - num_processes=self.args.world_size, - process_index=self.args.process_index, - ) - - return DataLoader( - train_dataset, - batch_size=self._train_batch_size, - collate_fn=data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) - - # NOTE change from original code - train_sampler = self.sampler if self.sampler else self._get_train_sampler() - - return DataLoader( - train_dataset, - batch_size=self._train_batch_size, - sampler=train_sampler, - collate_fn=data_collator, - drop_last=self.args.dataloader_drop_last, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - worker_init_fn=seed_worker, - ) diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py new file mode 100644 index 000000000..71d615386 --- /dev/null +++ b/tests/data/test_sampler.py @@ -0,0 +1,90 @@ +import pytest +import pandas as pd +from chemnlp.data.sampler import TemplateSampler + +@pytest.fixture +def sample_df(): + return pd.DataFrame({ + 'SMILES': ['CC(C)NCC(O)c1ccc(O)c(O)c1', 'CC1=C(C(=O)NC2=C1C=CC=C2)C3=CC=CC=C3'], + 'CYP2D6_Substrate': [1, 0], + 'compound_name': ['Isoproterenol', 'Phenytoin'], + 'split': ['train', 'test'] + }) + +@pytest.fixture +def sample_meta(): + return { + "identifiers": [ + {"id": "SMILES", "type": "SMILES", "description": "SMILES"}, + {"id": "compound_name", "type": "Other", "description": "drug name", + "names": [{"noun": "compound name"}, {"noun": "drug name"}, {"noun": "generic drug name"}]} + ], + "targets": [ + { + "id": "CYP2D6_Substrate", + "type": "boolean", + "description": "drugs that are metabolized by the CYP P450 2D6 (1) or not (0)", + "names": [ + {"noun": "CYP P450 2D6 substrate"}, + {"noun": "CYP2D6 substrate"}, + {"noun": "substrate for CYP2D6"}, + {"noun": "substrate for CYP P450 2D6"}, + {"verb": "metabolized by CYP2D6"}, + {"verb": "metabolized by CYP P450 2D6"} + ] + } + ] + } + +@pytest.fixture +def sample_config(): + return { + 'DEFAULT_SIGNIFICANT_DIGITS': 2, + 'multiple_choice_rnd_symbols': ["", ".)", ")"], + 'multiple_choice_benchmarking_templates': False, + 'multiple_choice_benchmarking_format': None + } + +def test_get_target_from_row(sample_df, sample_meta, sample_config): + sampler = TemplateSampler(sample_df, sample_meta, sample_config) + assert sampler._get_target_from_row(sample_df.iloc[0], "SMILES#") == "CC(C)NCC(O)c1ccc(O)c(O)c1" + assert sampler._get_target_from_row(sample_df.iloc[0], "CYP2D6_Substrate#") == "1" + assert sampler._get_target_from_row(sample_df.iloc[0], "compound_name#") == "Isoproterenol" + +def test_get_target_from_string(sample_df, sample_meta, sample_config): + sampler = TemplateSampler(sample_df, sample_meta, sample_config) + assert sampler._get_target_from_string("CYP2D6_Substrate__names__noun")() in [ + "CYP P450 2D6 substrate", "CYP2D6 substrate", "substrate for CYP2D6", "substrate for CYP P450 2D6" + ] + assert sampler._get_target_from_string("CYP2D6_Substrate__names__verb")() in [ + "metabolized by CYP2D6", "metabolized by CYP P450 2D6" + ] + +def test_sample_with_template(sample_df, sample_meta, sample_config): + sampler = TemplateSampler(sample_df, sample_meta, sample_config) + template = "The molecule with the {SMILES__description} {SMILES#} is {CYP2D6_Substrate#not &NULL}a {CYP2D6_Substrate__names__noun}." + result = sampler.sample(sample_df.iloc[0], template) + assert "CC(C)NCC(O)c1ccc(O)c(O)c1" in result + assert "is a" in result + assert "CYP P450 2D6 substrate" in result or "CYP2D6 substrate" in result or "substrate for CYP2D6" in result or "substrate for CYP P450 2D6" in result + +def test_multiple_choice_template(sample_df, sample_meta, sample_config): + sampler = TemplateSampler(sample_df, sample_meta, sample_config) + template = """ + Task: Please answer the multiple choice question. + Question: Is the molecule with the {SMILES__description} {SMILES#} {CYP2D6_Substrate__names__verb}? + Constraint: Even if you are uncertain, you must pick either {%multiple_choice_enum%2%aA1} without using any other words. + Options: + {CYP2D6_Substrate%} + Answer: {%multiple_choice_result} + """ + result = sampler.sample(sample_df.iloc[0], template) + assert "CC(C)NCC(O)c1ccc(O)c(O)c1" in result + assert "A or B" in result or "a or b" in result or "1 or 2" in result + assert result.strip().endswith("A") or result.strip().endswith("a") or result.strip().endswith("1") + +def test_class_balancing(sample_df, sample_meta, sample_config): + sampler = TemplateSampler(sample_df, sample_meta, sample_config) + sampler.enable_class_balancing("CYP2D6_Substrate") + balanced_df = sampler.df + assert len(balanced_df[balanced_df['CYP2D6_Substrate'] == 0]) == len(balanced_df[balanced_df['CYP2D6_Substrate'] == 1])