diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index ae3e82ef8..1731cf610 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -37,6 +37,7 @@ class General(MixingConfig, BaseModel): a deployment.""" pipe_batch_size_in_chars: int = 20000000 """How many characters are piped at once into the meta_cat class""" + span_group: Optional[str] = None class Config: extra = Extra.allow diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index d92e6ea61..7f9615b56 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -5,7 +5,7 @@ import numpy from multiprocessing import Lock from torch import nn, Tensor -from spacy.tokens import Doc +from spacy.tokens import Doc, Span from datetime import datetime from typing import Iterable, Iterator, Optional, Dict, List, Tuple, cast, Union from medcat.utils.hasher import Hasher @@ -356,6 +356,17 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA meta_cat.model.load_state_dict(torch.load(model_save_path, map_location=device)) return meta_cat + + def get_ents(self, doc: Doc) -> List[Span]: + span_group_name = self.config.general.span_group + if span_group_name: + return doc.spans[span_group_name] + + # Should we annotate overlapping entities + if self.config.general['annotate_overlapping']: + return doc._.ents + + return doc.ents def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowercase: bool) -> Tuple: """Prepares document. @@ -381,11 +392,7 @@ def prepare_document(self, doc: Doc, input_ids: List, offset_mapping: List, lowe cntx_right = config.general['cntx_right'] replace_center = config.general['replace_center'] - # Should we annotate overlapping entities - if config.general['annotate_overlapping']: - ents = doc._.ents - else: - ents = doc.ents + ents = self.get_ents(doc) samples = [] last_ind = 0 @@ -522,10 +529,7 @@ def _set_meta_anns(self, predictions = all_predictions[start_ind:end_ind] confidences = all_confidences[start_ind:end_ind] - if config.general['annotate_overlapping']: - ents = doc._.ents - else: - ents = doc.ents + ents = self.get_ents(doc) for ent in ents: ent_ind = ent_id2ind[ent._.id] diff --git a/tests/test_meta_cat.py b/tests/test_meta_cat.py index df5be9f77..ac332a81a 100644 --- a/tests/test_meta_cat.py +++ b/tests/test_meta_cat.py @@ -7,7 +7,8 @@ from medcat.meta_cat import MetaCAT from medcat.config_meta_cat import ConfigMetaCAT from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT - +import spacy +from spacy.tokens import Span class MetaCATTests(unittest.TestCase): @@ -19,7 +20,7 @@ def setUpClass(cls) -> None: config.train['nepochs'] = 1 config.model['input_size'] = 100 - cls.meta_cat = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config) + cls.meta_cat: MetaCAT = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config) cls.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") os.makedirs(cls.tmp_dir, exist_ok=True) @@ -44,6 +45,33 @@ def test_save_load(self): self.assertEqual(f1, n_f1) + def test_predict_spangroup(self): + Span.set_extension('id', default=0, force=True) + Span.set_extension('meta_anns', default=None, force=True) + + + json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'mct_export_for_meta_cat_test.json') + self.meta_cat.train(json_path, save_dir_path=self.tmp_dir) + self.meta_cat.save(self.tmp_dir) + n_meta_cat = MetaCAT.load(self.tmp_dir) + assert n_meta_cat.config.general.span_group is None + + spangroup_name = 'predict_spangroup' + n_meta_cat.config.general.span_group = spangroup_name + nlp = spacy.blank("en") + doc = nlp("No history of diabetes.") + span = doc.char_span(14, 22, label="foo_spantype") + assert span.text == 'diabetes' + doc.spans[spangroup_name] = [span] + doc = n_meta_cat(doc) + + # set back to None + n_meta_cat.config.general.span_group = None + assert doc.spans[spangroup_name][0]._.meta_anns['Status']['value'] == 'Affirmed' + + + + if __name__ == '__main__': unittest.main()