From fc992a7f933bbd3d60511f233d6dfd0c34c97b03 Mon Sep 17 00:00:00 2001 From: Xi Bai Date: Mon, 4 Dec 2023 16:25:09 +0000 Subject: [PATCH 1/2] CU-86938vf30 add trainer callbacks for Transformer NER --- medcat/ner/transformers_ner.py | 16 +++++++++-- tests/ner/test_transformers_ner.py | 46 ++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 tests/ner/test_transformers_ner.py diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 9623b1b93..227ccc083 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -1,6 +1,7 @@ import os import json import logging +import datasets from spacy.tokens import Doc from datetime import datetime from typing import Iterable, Iterator, Optional, Dict, List, cast, Union @@ -18,7 +19,7 @@ from transformers import Trainer, AutoModelForTokenClassification, AutoTokenizer from transformers import pipeline, TrainingArguments -import datasets +from transformers.trainer_callback import TrainerCallback # It should be safe to do this always, as all other multiprocessing #will be finished before data comes to meta_cat @@ -137,7 +138,12 @@ def merge_data_loaded(base, other): return out_path - def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=False, dataset=None, meta_requirements=None): + def train(self, + json_path: Union[str, list, None]=None, + ignore_extra_labels=False, + dataset=None, + meta_requirements=None, + trainer_callbacks: Optional[List[TrainerCallback]]=None): """Train or continue training a model give a json_path containing a MedCATtrainer export. It will continue training if an existing model is loaded or start new training if the model is blank/new. @@ -149,6 +155,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals ignore_extra_labels: Makes only sense when an existing deid model was loaded and from the new data we want to ignore labels that did not exist in the old model. + trainer_callbacks (List[TrainerCallback]): + A list of trainer callbacks for collecting metrics during the training at the client side. The + transformers Trainer object will be passed in when each callback is called. """ if dataset is None and json_path is not None: @@ -193,6 +202,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals compute_metrics=lambda p: metrics(p, tokenizer=self.tokenizer, dataset=encoded_dataset['test'], verbose=self.config.general['verbose_metrics']), data_collator=data_collator, # type: ignore tokenizer=None) + if trainer_callbacks: + for callback in trainer_callbacks: + trainer.add_callback(callback(trainer)) trainer.train() # type: ignore diff --git a/tests/ner/test_transformers_ner.py b/tests/ner/test_transformers_ner.py new file mode 100644 index 000000000..019785395 --- /dev/null +++ b/tests/ner/test_transformers_ner.py @@ -0,0 +1,46 @@ +import os +import unittest +from spacy.lang.en import English +from spacy.tokens import Doc, Span +from transformers import TrainerCallback +from medcat.ner.transformers_ner import TransformersNER +from medcat.config import Config +from medcat.cdb_maker import CDBMaker + + +class TransformerNERTest(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + config = Config() + config.general["spacy_model"] = "en_core_web_md" + cdb_maker = CDBMaker(config) + cdb_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.csv") + cdb = cdb_maker.prepare_csvs([cdb_csv], full_build=True) + Doc.set_extension("ents", default=[], force=True) + Span.set_extension("confidence", default=-1, force=True) + Span.set_extension("id", default=0, force=True) + Span.set_extension("detected_name", default=None, force=True) + Span.set_extension("link_candidates", default=None, force=True) + Span.set_extension("cui", default=-1, force=True) + Span.set_extension("context_similarity", default=-1, force=True) + cls.undertest = TransformersNER(cdb) + cls.undertest.create_eval_pipeline() + + def test_pipe(self): + doc = English().make_doc("Intracerebral hemorrhage is not Movar Virus") + doc = next(self.undertest.pipe([doc])) + assert len(doc.ents) > 0, "No entities were recognised" + + def test_train_with_callbacks(self): + tracker = unittest.mock.Mock() + class _DummyCallback(TrainerCallback): + def __init__(self, trainer) -> None: + self._trainer = trainer + def on_epoch_end(self, *args, **kwargs) -> None: + tracker.call() + + train_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_train_data.json") + self.undertest.training_arguments.num_train_epochs = 1 + self.undertest.train(train_data, trainer_callbacks=[_DummyCallback, _DummyCallback]) + self.assertEqual(tracker.call.call_count, 2) From d2baeaa6cc8e49157dc176719a87d6bd844fee6f Mon Sep 17 00:00:00 2001 From: Xi Bai Date: Tue, 5 Dec 2023 17:02:01 +0000 Subject: [PATCH 2/2] CU-86938vf30 improve tests --- tests/ner/test_transformers_ner.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/ner/test_transformers_ner.py b/tests/ner/test_transformers_ner.py index 019785395..de9eae32c 100644 --- a/tests/ner/test_transformers_ner.py +++ b/tests/ner/test_transformers_ner.py @@ -28,11 +28,11 @@ def setUpClass(cls) -> None: cls.undertest.create_eval_pipeline() def test_pipe(self): - doc = English().make_doc("Intracerebral hemorrhage is not Movar Virus") + doc = English().make_doc("\nPatient Name: John Smith\nAddress: 15 Maple Avenue\nCity: New York\nCC: Chronic back pain\n\nHX: Mr. Smith") doc = next(self.undertest.pipe([doc])) assert len(doc.ents) > 0, "No entities were recognised" - def test_train_with_callbacks(self): + def test_train(self): tracker = unittest.mock.Mock() class _DummyCallback(TrainerCallback): def __init__(self, trainer) -> None: @@ -42,5 +42,9 @@ def on_epoch_end(self, *args, **kwargs) -> None: train_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_train_data.json") self.undertest.training_arguments.num_train_epochs = 1 - self.undertest.train(train_data, trainer_callbacks=[_DummyCallback, _DummyCallback]) + df, examples, dataset = self.undertest.train(train_data, trainer_callbacks=[_DummyCallback, _DummyCallback]) + assert "fp" in examples + assert "fn" in examples + assert dataset["train"].num_rows == 48 + assert dataset["test"].num_rows == 12 self.assertEqual(tracker.call.call_count, 2)