From 4e618aa2f918540c49bdb7cded8f46c21474fc46 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Mon, 8 Jan 2024 16:59:40 +0200 Subject: [PATCH] v1.10.0 (#388) * Bump urllib3 from 1.26.5 to 1.26.17 in /webapp/webapp Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.5 to 1.26.17. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/1.26.5...1.26.17) --- updated-dependencies: - dependency-name: urllib3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Cu 8692wbcq5 docs builds (#359) * CU-8692wbcq5: Pin max version of numpy * CU-8692wbcq5: Pin max version of numpy in setup.py * CU-8692wbcq5: Bump python version for readthedocs workflow * CU-8692wbcq5: Pin all requirement versions in docs requirements * CU-8692wbcq5: Move docs requirements before setuptools * CU-8692wbcq5: Fix typo in docs requirements * CU-8692wbcq5: Remove some less relevant stuff from docs requirements * CU-8692wbcq5: Add back sphinx-based requirements to docs requirements * CU-8692wbcq5: Move back to python 3.9 on docs build workflow * CU-8692wbcq5: Bump sphinx-autoapi version * CU-8692wbcq5: Bump sphinx version * CU-8692wbcq5: Bump python version back to 3.10 for future-proofing * CU-8692wbcq5: Undo pinning numpy to max version in setup.py * CU-8692wbcq5: Remove docs-build specific dependencies in setup.py * Bump urllib3 from 1.26.17 to 1.26.18 in /webapp/webapp Bumps [urllib3](https://github.com/urllib3/urllib3) from 1.26.17 to 1.26.18. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/1.26.17...1.26.18) --- updated-dependencies: - dependency-name: urllib3 dependency-type: direct:production ... Signed-off-by: dependabot[bot] * CU-8692uznvd: Allow empty-dict config.linking.filters.cuis and convert to set in memory (#352) * CU-8692uznvd: Allow empty-dict config.linking.filters.cuis and convert to set in memory * CU-8692uznvd: Move the empty-set detection and conversion from validator to init * CU-8692uznvd: Remove unused import * CU-8692t3fdf separate config on save (#350) * CU-8692t3fdf Move saving config outside of the cdb.dat; Add test to make sure the config does not get saved with the CDB; patch a few existing tests * CU-8692t3fdf Use class methods on class instead of instance in a few tests * CU-8692t3fdf Fix typing issue * CU-8692t3fdf Add additional tests for 2 configs and zero configs when loading model pack * CU-8692t3fdf: Make sure CDB is linked to the correct config; Treat incorrect configs as dirty CDBs and force a recalc of the hash * CU-2cdpd4t: Unify default addl_info in different methdos. (#363) * Bump django from 3.2.20 to 3.2.23 in /webapp/webapp Bumps [django](https://github.com/django/django) from 3.2.20 to 3.2.23. - [Commits](https://github.com/django/django/compare/3.2.20...3.2.23) --- updated-dependencies: - dependency-name: django dependency-type: direct:production ... Signed-off-by: dependabot[bot] * Changing cdb.add_concept to a protected method * Re-added deprecated method with deprecated flag and addtional comments * Initial commit for merge_cdb method * Added indentation to make merge_cdb a class method * fixed syntax issues * more lint fixes * more lint fixes * bug fixes of merge_cdb * removed print statements * CU-86931prq4: Update GHA versions (checkout and setup-python) to v4 (#368) * Cu 1yn0v9e duplicate multiprocessing methods (#364) * CU-1yn0v9e: Rename and deprecate one of the multiprocessing methods; Add docstring. Trying to be more explicit regarding usage and differences between different methods * CU-1yn0v9e: Rename and deprecate the multiprocessing_pipe method; Add docstring. Trying to be more explicit regarding usage and differences between different methods * CU-1yn0v9e: Fix typo in docstring; more consistent naming * 869377m3u: Add comment regarding demo link load times to README (#376) * intermediate changes of merge_cdb and testing function * Added README.md documentation for CPU only installations (#365) * changed README.md to reflect installation options. * added setup script to demonstrate how wrapper could look for CPU installations * removed setup.sh as unnessescary for cpu only builds * Initial commit for merge_cdb method * Added indentation to make merge_cdb a class method * fixed syntax issues * more lint fixes * more lint fixes * bug fixes of merge_cdb * removed print statements * Added commentary on disk space usage of pytorch-gpu * removed merge_cdb from branch --------- Co-authored-by: adam-sutton-1992 * Cu 8692zguyq no preferred name (#367) * CU-8692zguyq: Slight simplification of minimum-name-length logic * CU-8692zguyq: Add some tests for prepare_name preprocessor * CU-8692zguyq: Add warning if no preferred name was added along a new CUI * CU-8692zguyq: Add additional warning messages when adding/training a new CUI with no preferred name * CU-8692zguyq: Make no preferred name warnings only run if name status is preferred * CU-8692zguyq: Add tests for no-preferred name warnings * CU-8692zguyq: Add Vocab.make_unigram_table to CAT tests * CU-8692zguyq: Move to built in asserting for logging instead of patching the method * CU-8692zguyq: Add workaround for assertNoLogs on python 3.8 and 3.9 * Add trainer callbacks for Transformer NER (#377) CU-86938vf30 add trainer callbacks for Transformer NER * changes to merge_cdb and adding unit tests for method * fixing lint issues * fixing flake8 linting * bug fixes, additional tests, and more documentation * moved set up of cdbs to be merged to tests.helper * moved merge_cdb to utils and created test_cdb_utils * removed class wrapper in cdb utils and fixed class set up in tests * changed test object setup to class setup * removed erroneous new line * CU-2e77a31 improve print stats (#366) * Add base class for CAT * Add CDB base class * Some whitespace fixes for base modules * CU-2e77a31: Move print stats to their own module and class * CU-2e77a31: Fix issues introduced by moving print stats * CU-2e77a31: Rename print_stats to get_stats and add option to avoid printed output * CU-2e77a31: Add test for print_stats * CU-2e77a31: Remove unused import * CU-2e77a31: Add new package to setup.py * CU-2e77a31: Fix a bunch of typing issues * CU-2e77a31: Revert CAT and CDB abstraction * Load stopwords in Defaults before spacy model * CU-8693az82g Remove cdb tests side effects (#380) * 8693az82g: Add method to CDBMaker to reset the CDB * 8693az82g: Add test in CDB tests to ensure a new CDB is used for each test * 8693az82g: Reset CDB in CDB tests before each test to avoid side effects * Added tests * CU-8693bpq82 fallback spacy model (#384) * CU-8693bpq82: Add fallback spacy model along with test * CU-8693bpq82: Remove debug output * CU-8693bpq82: Add exception info to warning upon spacy model load failure and fallback * Remove tests of internals where possible * Add test for skipping of stopwords * Avoid supporting only English for stopwords * Remove debug output * Make sure stopwords language getter works for file-path spacy models * CU-8693cv3w0 Fix fallback spacy model existance on pip installs (#386) * CU-8693cv3w0: Add method to ensure spacy model and use it when falling back to default model * CU-8693cv3w0: Add logged output when installing/downloading spacy model * CU-8693b0a61 Add method to get spacy model version (#381) * CU-8693b0a61: Add method to find spacy folder in model pack along with some tests * CU-8693b0a61: Add test for spacy folder finding (full path) * CU-8693b0a61: Add method for finding spacy model in model pack along with tests * CU-8693b0a61: Add method for finding current spacy version * CU-8693b0a61: Add method for getting spacy model version installed * CU-8693b0a61: Fix getting spacy model folder return path * CU-8693b0a61: Add method to get name and meta of spacy model based on model pack * CU-8693b0a61: Add missing fake spacy model meta * CU-8693b0a61: Add missing docstrings * CU-8693b0a61: Change name of method for clarity * CU-8693b0a61: Add method to get spacy model name and version from model pack path * CU-8693b0a61: Fix a few typing issues * CU-8693b0a61: Add a missing docstring * CU-8693b0a61: Match folder name of fake spacy model to its name * CU-8693b0a61: Make the final method return true name of spacy model instead of folder name * Add additional output to method for getting spacy model version - the compatible spacy versions * CU-8693b0a61: Add method for querying whether the spacy version is compatible with a range * CU-8693b0a61: Add better abstraction for spacy version mocking in tests * CU-8693b0a61: Add some more abstraction for fake model pack in tests * CU-8693b0a61: Add method for checking whethera model pack has a spacy model compatible with installed spacy version * CU-8693b0a61: Improve abstraction within tests * CU-8693b0a61: Add method to check which of two versions is older * CU-8693b0a61: Fix fake spacy model versioning * CU-8693b0a61: Add method for determining whether a model pack has semi-compatible spacy model * CU-8693b0a61: Add missing word in docstring. * CU-8693b0a61: Change some method to protected ones --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: tomolopolis Co-authored-by: adam-sutton-1992 Co-authored-by: adam-sutton-1992 <60137864+adam-sutton-1992@users.noreply.github.com> Co-authored-by: Xi Bai <82581439+baixiac@users.noreply.github.com> Co-authored-by: jenniferajiang Co-authored-by: Jennifer Jiang <37081323+jenniferajiang@users.noreply.github.com> --- .github/workflows/main.yml | 8 +- .github/workflows/production.yml | 4 +- .readthedocs.yaml | 6 +- README.md | 12 + docs/requirements.txt | 106 +++++- medcat/cat.py | 283 ++++----------- medcat/cdb.py | 132 ++++++- medcat/cdb_maker.py | 16 +- medcat/config.py | 21 +- medcat/ner/transformers_ner.py | 16 +- medcat/pipe.py | 30 +- medcat/preprocessing/cleaners.py | 3 +- medcat/stats/__init__.py | 0 medcat/stats/stats.py | 340 ++++++++++++++++++ medcat/utils/cdb_utils.py | 117 ++++++ medcat/utils/filters.py | 43 ++- medcat/utils/helpers.py | 32 ++ medcat/utils/regression/targeting.py | 4 +- medcat/utils/saving/serializer.py | 15 +- medcat/utils/spacy_compatibility.py | 211 +++++++++++ setup.py | 8 +- tests/archive_tests/test_cdb_maker_archive.py | 2 +- tests/helper.py | 35 ++ tests/ner/test_transformers_ner.py | 50 +++ tests/preprocessing/__init__.py | 0 tests/preprocessing/test_cleaners.py | 104 ++++++ tests/resources/ff_core_fake_dr/meta.json | 8 + tests/test_cat.py | 228 +++++++++++- tests/test_cdb.py | 18 + tests/test_config.py | 50 ++- tests/test_pipe.py | 9 +- tests/utils/saving/test_serialization.py | 2 +- tests/utils/test_cdb_utils.py | 43 +++ tests/utils/test_hashing.py | 53 ++- tests/utils/test_helpers.py | 24 ++ tests/utils/test_spacy_compatibility.py | 302 ++++++++++++++++ webapp/webapp/requirements.txt | 4 +- 37 files changed, 2064 insertions(+), 275 deletions(-) create mode 100644 medcat/stats/__init__.py create mode 100644 medcat/stats/stats.py create mode 100644 medcat/utils/cdb_utils.py create mode 100644 medcat/utils/spacy_compatibility.py create mode 100644 tests/ner/test_transformers_ner.py create mode 100644 tests/preprocessing/__init__.py create mode 100644 tests/preprocessing/test_cleaners.py create mode 100644 tests/resources/ff_core_fake_dr/meta.json create mode 100644 tests/utils/test_cdb_utils.py create mode 100644 tests/utils/test_helpers.py create mode 100644 tests/utils/test_spacy_compatibility.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c769dfc2e..a5468fb9b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -16,9 +16,9 @@ jobs: max-parallel: 4 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -48,13 +48,13 @@ jobs: steps: - name: Checkout master - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: ref: 'master' fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.9 diff --git a/.github/workflows/production.yml b/.github/workflows/production.yml index 5088c1000..9ad9a5d90 100644 --- a/.github/workflows/production.yml +++ b/.github/workflows/production.yml @@ -14,13 +14,13 @@ jobs: steps: - name: Checkout production - uses: actions/checkout@v2 + uses: actions/checkout@v4 with: ref: ${{ github.event.release.target_commitish }} fetch-depth: 0 - name: Set up Python 3.9 - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: 3.9 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 8c4e65615..5cc0d97f0 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,13 +7,13 @@ version: 2 build: os: ubuntu-20.04 tools: - python: "3.9" + python: "3.10" sphinx: configuration: docs/conf.py python: install: + - requirements: docs/requirements.txt - method: setuptools - path: . - - requirements: docs/requirements.txt \ No newline at end of file + path: . \ No newline at end of file diff --git a/README.md b/README.md index 395aecf69..bf34f00c6 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,20 @@ To download any of these models, please [follow this link](https://uts.nlm.nih.g - **Paper**: [What’s in a Summary? Laying the Groundwork for Advances in Hospital-Course Summarization](https://www.aclweb.org/anthology/2021.naacl-main.382.pdf) - ([more...](https://github.com/CogStack/MedCAT/blob/master/media/news.md)) +## Installation +To install the latest version of MedCAT run the following command: +``` +pip install medcat +``` +Normal installations of MedCAT will install torch-gpu and all relevant dependancies (such as CUDA). This can require as much as 10 GB more disk space, which isn't required for CPU only usage. + +To install the latest version of MedCAT without torch GPU support run the following command: +``` +pip install medcat --extra_index_url https://download.pytorch.org/whl/cpu/ +``` ## Demo A demo application is available at [MedCAT](https://medcat.rosalind.kcl.ac.uk). This was trained on MIMIC-III and all of SNOMED-CT. +PS: This link can take a long time to load the first time around. The machine spins up as needed and spins down when inactive. ## Tutorials A guide on how to use MedCAT is available at [MedCAT Tutorials](https://github.com/CogStack/MedCATtutorials). Read more about MedCAT on [Towards Data Science](https://towardsdatascience.com/medcat-introduction-analyzing-electronic-health-records-e1c420afa13a). diff --git a/docs/requirements.txt b/docs/requirements.txt index be517876f..7e7df6e01 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,6 +1,104 @@ -Sphinx~=4.0 +sphinx==6.2.1 sphinx-rtd-theme~=1.0 myst-parser~=0.17 -sphinx-autoapi~=1.8 -setuptools>=60.0 -aiohttp==3.8.5 \ No newline at end of file +sphinx-autoapi~=3.0.0 +MarkupSafe==2.1.3 +accelerate==0.23.0 +aiofiles==23.2.1 +aiohttp==3.8.5 +aiosignal==1.3.1 +asttokens==2.4.0 +async-timeout==4.0.3 +attrs==23.1.0 +backcall==0.2.0 +blis==0.7.11 +catalogue==2.0.10 +certifi==2023.7.22 +charset-normalizer==3.3.0 +click==8.1.7 +comm==0.1.4 +confection==0.1.3 +cymem==2.0.8 +datasets==2.14.5 +decorator==5.1.1 +dill==0.3.7 +exceptiongroup==1.1.3 +executing==2.0.0 +filelock==3.12.4 +flake8==4.0.1 +frozenlist==1.4.0 +fsspec==2023.6.0 +gensim==4.3.2 +huggingface-hub==0.17.3 +idna==3.4 +ipython==8.16.1 +ipywidgets==8.1.1 +jedi==0.19.1 +jinja2==3.1.2 +joblib==1.3.2 +jsonpickle==3.0.2 +jupyterlab-widgets==3.0.9 +langcodes==3.3.0 +matplotlib-inline==0.1.6 +mccabe==0.6.1 +mpmath==1.3.0 +multidict==6.0.4 +multiprocess==0.70.15 +murmurhash==1.0.10 +mypy==1.0.0 +mypy-extensions==0.4.3 +networkx==3.1 +numpy==1.25.2 +packaging==23.2 +pandas==2.1.1 +parso==0.8.3 +pathy==0.10.2 +pexpect==4.8.0 +pickleshare==0.7.5 +preshed==3.0.9 +prompt-toolkit==3.0.39 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyarrow==13.0.0 +pycodestyle==2.8.0 +pydantic==1.10.13 +pyflakes==2.4.0 +pygments==2.16.1 +python-dateutil==2.8.2 +pytz==2023.3.post1 +pyyaml==6.0.1 +regex==2023.10.3 +requests==2.31.0 +safetensors==0.4.0 +scikit-learn==1.3.1 +scipy==1.9.3 +six==1.16.0 +smart-open==6.4.0 +spacy==3.4.4 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +srsly==2.4.8 +stack-data==0.6.3 +sympy==1.12 +thinc==8.1.12 +threadpoolctl==3.2.0 +tokenizers==0.14.1 +tomli==2.0.1 +torch==2.1.0 +tqdm==4.66.1 +traitlets==5.11.2 +transformers==4.34.0 +triton==2.1.0 +typer==0.7.0 +types-PyYAML==6.0.3 +types-aiofiles==0.8.3 +types-setuptools==57.4.10 +typing-extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.6 +wasabi==0.10.1 +wcwidth==0.2.8 +widgetsnbextension==4.0.9 +xxhash==3.4.1 +yarl==1.9.2 \ No newline at end of file diff --git a/medcat/cat.py b/medcat/cat.py index 2323cd737..d3003b24b 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -2,7 +2,6 @@ import glob import shutil import pickle -import traceback import json import logging import math @@ -24,7 +23,6 @@ from medcat.pipe import Pipe from medcat.preprocessing.taggers import tag_skip_and_punct from medcat.cdb import CDB -from medcat.utils.matutils import intersect_nonempty_set from medcat.utils.data_utils import make_mc_train_test, get_false_positives from medcat.utils.normalizers import BasicSpellChecker from medcat.utils.checkpoint import Checkpoint, CheckpointConfig, CheckpointManager @@ -32,15 +30,16 @@ from medcat.utils.hasher import Hasher from medcat.ner.vocab_based_ner import NER from medcat.linking.context_based_linker import Linker -from medcat.utils.filters import get_project_filters from medcat.preprocessing.cleaners import prepare_name from medcat.meta_cat import MetaCAT from medcat.utils.meta_cat.data_utils import json_to_fake_spacy -from medcat.config import Config, LinkingFilters +from medcat.config import Config from medcat.vocab import Vocab from medcat.utils.decorators import deprecated from medcat.ner.transformers_ner import TransformersNER from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY +from medcat.stats.stats import get_stats +from medcat.utils.filters import set_project_filters logger = logging.getLogger(__name__) # separate logger from the package-level one @@ -271,6 +270,10 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M cdb_path = os.path.join(save_dir_path, "cdb.dat") self.cdb.save(cdb_path, json_path) + # Save the config + config_path = os.path.join(save_dir_path, "config.json") + self.cdb.config.save(config_path) + # Save the Vocab vocab_path = os.path.join(save_dir_path, "vocab.dat") if self.vocab is not None: @@ -362,6 +365,10 @@ def load_model_pack(cls, logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format') cdb = CDB.load(cdb_path, json_path) + # load config + config_path = os.path.join(model_pack_path, "config.json") + cdb.load_config(config_path) + # TODO load addl_ner # Modify the config to contain full path to spacy model @@ -434,7 +441,8 @@ def _print_stats(self, use_overlaps: bool = False, use_cui_doc_limit: bool = False, use_groups: bool = False, - extra_cui_filter: Optional[Set] = None) -> Tuple: + extra_cui_filter: Optional[Set] = None, + do_print: bool = True) -> Tuple: """TODO: Refactor and make nice Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP. @@ -474,204 +482,12 @@ def _print_stats(self, Number of occurrence for each CUI. examples (dict): Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][]. + do_print (bool): + Whether to print stats out. Defaults to True. """ - tp = 0 - fp = 0 - fn = 0 - fps: Dict = {} - fns: Dict = {} - tps: Dict = {} - cui_prec: Dict = {} - cui_rec: Dict = {} - cui_f1: Dict = {} - cui_counts: Dict = {} - examples: Dict = {'fp': {}, 'fn': {}, 'tp': {}} - - fp_docs: Set = set() - fn_docs: Set = set() - - orig_filters = self.config.linking.filters.copy_of() - local_filters = self.config.linking.filters - for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False): - local_filters.cuis = set() - - # Add extra filter if set - self._set_project_filters(local_filters, project, extra_cui_filter, use_project_filters) - - for dind, doc in tqdm( - enumerate(project["documents"]), - desc="Stats document", - total=len(project["documents"]), - leave=False, - ): - anns = self._get_doc_annotations(doc) - - # Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still - if use_cui_doc_limit: - _cuis = set([ann['cui'] for ann in anns]) - if _cuis: - local_filters.cuis = intersect_nonempty_set(_cuis, extra_cui_filter) - else: - local_filters.cuis = {'empty'} - - spacy_doc: Doc = self(doc['text']) # type: ignore - - if use_overlaps: - p_anns = spacy_doc._.ents - else: - p_anns = spacy_doc.ents - - anns_norm = [] - anns_norm_neg = [] - anns_examples = [] - anns_norm_cui = [] - for ann in anns: - cui = ann['cui'] - if local_filters.check_filters(cui): - if use_groups: - cui = self.cdb.addl_info['cui2group'].get(cui, cui) - - if ann.get('validated', True) and (not ann.get('killed', False) and not ann.get('deleted', False)): - anns_norm.append((ann['start'], cui)) - anns_examples.append({"text": doc['text'][max(0, ann['start']-60):ann['end']+60], - "cui": cui, - "start": ann['start'], - "end": ann['end'], - "source value": ann['value'], - "acc": 1, - "project name": project.get('name'), - "document name": doc.get('name'), - "project id": project.get('id'), - "document id": doc.get('id')}) - elif ann.get('validated', True) and (ann.get('killed', False) or ann.get('deleted', False)): - anns_norm_neg.append((ann['start'], cui)) - - if ann.get("validated", True): - # This is used to test was someone annotating for this CUI in this document - anns_norm_cui.append(cui) - cui_counts[cui] = cui_counts.get(cui, 0) + 1 - - p_anns_norm = [] - p_anns_examples = [] - for ann in p_anns: - cui = ann._.cui - if use_groups: - cui = self.cdb.addl_info['cui2group'].get(cui, cui) - - p_anns_norm.append((ann.start_char, cui)) - p_anns_examples.append({"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60], - "cui": cui, - "start": ann.start_char, - "end": ann.end_char, - "source value": ann.text, - "acc": float(ann._.context_similarity), - "project name": project.get('name'), - "document name": doc.get('name'), - "project id": project.get('id'), - "document id": doc.get('id')}) - for iann, ann in enumerate(p_anns_norm): - cui = ann[1] - if ann in anns_norm: - tp += 1 - tps[cui] = tps.get(cui, 0) + 1 - - example = p_anns_examples[iann] - examples['tp'][cui] = examples['tp'].get(cui, []) + [example] - else: - fp += 1 - fps[cui] = fps.get(cui, 0) + 1 - fp_docs.add(doc.get('name', 'unk')) - - # Add example for this FP prediction - example = p_anns_examples[iann] - if ann in anns_norm_neg: - # Means that it really was annotated as negative - example['real_fp'] = True - - examples['fp'][cui] = examples['fp'].get(cui, []) + [example] - - for iann, ann in enumerate(anns_norm): - if ann not in p_anns_norm: - cui = ann[1] - fn += 1 - fn_docs.add(doc.get('name', 'unk')) - - fns[cui] = fns.get(cui, 0) + 1 - examples['fn'][cui] = examples['fn'].get(cui, []) + [anns_examples[iann]] - - try: - prec = tp / (tp + fp) - rec = tp / (tp + fn) - f1 = 2*(prec*rec) / (prec + rec) - print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(epoch, prec, rec, f1)) - print("Docs with false positives: {}\n".format("; ".join([str(x) for x in list(fp_docs)[0:10]]))) - print("Docs with false negatives: {}\n".format("; ".join([str(x) for x in list(fn_docs)[0:10]]))) - - # Sort fns & prec - fps = {k: v for k, v in sorted(fps.items(), key=lambda item: item[1], reverse=True)} - fns = {k: v for k, v in sorted(fns.items(), key=lambda item: item[1], reverse=True)} - tps = {k: v for k, v in sorted(tps.items(), key=lambda item: item[1], reverse=True)} - - - # F1 per concept - for cui in tps.keys(): - prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0)) - rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0)) - f1 = 2*(prec*rec) / (prec + rec) - cui_prec[cui] = prec - cui_rec[cui] = rec - cui_f1[cui] = f1 - - - # Get top 10 - pr_fps = [(self.cdb.cui2preferred_name.get(cui, - list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fps[cui]) for cui in list(fps.keys())[0:10]] - pr_fns = [(self.cdb.cui2preferred_name.get(cui, - list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fns[cui]) for cui in list(fns.keys())[0:10]] - pr_tps = [(self.cdb.cui2preferred_name.get(cui, - list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, tps[cui]) for cui in list(tps.keys())[0:10]] - - - print("\n\nFalse Positives\n") - for one in pr_fps: - print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) - print("\n\nFalse Negatives\n") - for one in pr_fns: - print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) - print("\n\nTrue Positives\n") - for one in pr_tps: - print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) - print("*"*110 + "\n") - - except Exception: - traceback.print_exc() - - self.config.linking.filters = orig_filters - - return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples - - def _set_project_filters(self, local_filters: LinkingFilters, project: dict, - extra_cui_filter: Optional[Set], use_project_filters: bool): - """Set the project filters to a LinkingFilters object based on - the specified project. - - Args: - local_filters (LinkingFilters): The linking filters instance - project (dict): The project - extra_cui_filter (Optional[Set]): Extra CUIs (if specified) - use_project_filters (bool): Whether to use per-project filters - """ - if isinstance(extra_cui_filter, set): - local_filters.cuis = extra_cui_filter - - if use_project_filters: - project_filter = get_project_filters(cuis=project.get('cuis', None), - type_ids=project.get('tuis', None), - cdb=self.cdb, - project=project) - # Intersect project filter with existing if it has something - if project_filter: - local_filters.cuis = intersect_nonempty_set(project_filter, local_filters.cuis) + return get_stats(self, data=data, epoch=epoch, use_project_filters=use_project_filters, + use_overlaps=use_overlaps, use_cui_doc_limit=use_cui_doc_limit, + use_groups=use_groups, extra_cui_filter=extra_cui_filter, do_print=do_print) def _init_ckpts(self, is_resumed, checkpoint): if self.config.general.checkpoint.steps is not None or checkpoint is not None: @@ -832,9 +648,13 @@ def add_and_train_concept(self, Refer to medcat.cat.cdb.CDB.add_concept """ names = prepare_name(name, self.pipe.spacy_nlp, {}, self.config) + if not names and cui not in self.cdb.cui2preferred_name and name_status == 'P': + logger.warning("No names were able to be prepared in CAT.add_and_train_concept " + "method. As such no preferred name will be able to be specifeid. " + "The CUI: '%s' and raw name: '%s'", cui, name) # Only if not negative, otherwise do not add the new name if in fact it should not be detected if do_add_concept and not negative: - self.cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, description=description, + self.cdb._add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, description=description, full_build=full_build) if spacy_entity is not None and spacy_doc is not None: @@ -1102,15 +922,15 @@ def train_supervised_raw(self, # then add the extra CUI filters if retain_filters and extra_cui_filter and not retain_extra_cui_filter: # adding project filters without extra_cui_filters - self._set_project_filters(local_filters, project, set(), use_filters) + set_project_filters(self.cdb.addl_info, local_filters, project, set(), use_filters) orig_filters.merge_with(local_filters) # adding extra_cui_filters, but NOT project filters - self._set_project_filters(local_filters, project, extra_cui_filter, False) + set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, False) # refrain from doing it again for subsequent epochs retain_filters = False else: # Set filters in case we are using the train_from_fp - self._set_project_filters(local_filters, project, extra_cui_filter, use_filters) + set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, use_filters) for idx_doc in trange(current_document, len(project['documents']), initial=current_document, total=len(project['documents']), desc='Document', leave=False): doc = project['documents'][idx_doc] @@ -1327,19 +1147,42 @@ def _save_docs_to_file(self, docs: Iterable, annotated_ids: List[str], save_dir_ pickle.dump((annotated_ids, part_counter), open(annotated_ids_path, 'wb')) return part_counter + @deprecated(message="Use `multiprocessing_batch_char_size` instead") def multiprocessing(self, data: Union[List[Tuple], Iterable[Tuple]], nproc: int = 2, batch_size_chars: int = 5000 * 1000, only_cui: bool = False, - addl_info: List[str] = [], + addl_info: List[str] = ['cui2icd10', 'cui2ontologies', 'cui2snomed'], separate_nn_components: bool = True, out_split_size_chars: Optional[int] = None, save_dir_path: str = os.path.abspath(os.getcwd()), min_free_memory=0.1) -> Dict: + return self.multiprocessing_batch_char_size(data=data, nproc=nproc, + batch_size_chars=batch_size_chars, + only_cui=only_cui, addl_info=addl_info, + separate_nn_components=separate_nn_components, + out_split_size_chars=out_split_size_chars, + save_dir_path=save_dir_path, + min_free_memory=min_free_memory) + + def multiprocessing_batch_char_size(self, + data: Union[List[Tuple], Iterable[Tuple]], + nproc: int = 2, + batch_size_chars: int = 5000 * 1000, + only_cui: bool = False, + addl_info: List[str] = [], + separate_nn_components: bool = True, + out_split_size_chars: Optional[int] = None, + save_dir_path: str = os.path.abspath(os.getcwd()), + min_free_memory=0.1) -> Dict: r"""Run multiprocessing for inference, if out_save_path and out_split_size_chars is used this will also continue annotating documents if something is saved in that directory. + This method batches the data based on the number of characters as specified by user. + + PS: This method is unlikely to work on a Windows machine. + Args: data: Iterator or array with format: [(id, text), (id, text), ...] @@ -1523,15 +1366,35 @@ def _multiprocessing_batch(self, return docs - def multiprocessing_pipe(self, - in_data: Union[List[Tuple], Iterable[Tuple]], + @deprecated(message="Use `multiprocessing_batch_docs_size` instead") + def multiprocessing_pipe(self, in_data: Union[List[Tuple], Iterable[Tuple]], nproc: Optional[int] = None, batch_size: Optional[int] = None, only_cui: bool = False, addl_info: List[str] = [], return_dict: bool = True, batch_factor: int = 2) -> Union[List[Tuple], Dict]: - """Run multiprocessing NOT FOR TRAINING + return self.multiprocessing_batch_docs_size(in_data=in_data, nproc=nproc, + batch_size=batch_size, + only_cui=only_cui, + addl_info=addl_info, + return_dict=return_dict, + batch_factor=batch_factor) + + def multiprocessing_batch_docs_size(self, + in_data: Union[List[Tuple], Iterable[Tuple]], + nproc: Optional[int] = None, + batch_size: Optional[int] = None, + only_cui: bool = False, + addl_info: List[str] = ['cui2icd10', 'cui2ontologies', 'cui2snomed'], + return_dict: bool = True, + batch_factor: int = 2) -> Union[List[Tuple], Dict]: + """Run multiprocessing NOT FOR TRAINING. + + This method batches the data based on the number of documents as specified by the user. + + PS: + This method supports Windows. Args: in_data (Union[List[Tuple], Iterable[Tuple]]): List with format: [(id, text), (id, text), ...] diff --git a/medcat/cdb.py b/medcat/cdb.py index 44d4fd9dd..76cb7327e 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -5,13 +5,15 @@ import logging import aiofiles import numpy as np -from typing import Dict, Set, Optional, List, Union +from typing import Dict, Set, Optional, List, Union, cast from functools import partial +import os from medcat import __version__ from medcat.utils.hasher import Hasher from medcat.utils.matutils import unitvec from medcat.utils.ml_utils import get_lr_linking +from medcat.utils.decorators import deprecated from medcat.config import Config, weighted_average, workers from medcat.utils.saving.serializer import CDBSerializer @@ -61,8 +63,10 @@ class CDB(object): def __init__(self, config: Union[Config, None] = None) -> None: if config is None: self.config = Config() + self._config_from_file = False else: self.config = config + self._config_from_file = True self.name2cuis: Dict = {} self.name2cuis2status: Dict = {} @@ -95,6 +99,12 @@ def __init__(self, config: Union[Config, None] = None) -> None: self._optim_params = None self.is_dirty = False self._hash: Optional[str] = None + # the config hash is kept track of here so that + # the CDB hash can be re-calculated when the config changes + # it can also be used to make sure the config loaded with + # a CDB matches the config it was saved with + # since the config is now saved separately + self._config_hash: Optional[str] = None self._memory_optimised_parts: Set[str] = set() def get_name(self, cui: str) -> str: @@ -213,8 +223,9 @@ def add_names(self, cui: str, names: Dict, name_status: str = 'A', full_build: b # Name status must be one of the three name_status = 'A' - self.add_concept(cui=cui, names=names, ontologies=set(), name_status=name_status, type_ids=set(), description='', full_build=full_build) + self._add_concept(cui=cui, names=names, ontologies=set(), name_status=name_status, type_ids=set(), description='', full_build=full_build) + @deprecated("Use `cdb._add_concept` as this will be removed in a future release.") def add_concept(self, cui: str, names: Dict, @@ -223,6 +234,43 @@ def add_concept(self, type_ids: Set[str], description: str, full_build: bool = False) -> None: + """ + Deprecated: Use `cdb._add_concept` as this will be removed in a future release. + + Add a concept to internal Concept Database (CDB). Depending on what you are providing + this will add a large number of properties for each concept. + + Args: + cui (str): + Concept ID or unique identifier in this database, all concepts that have + the same CUI will be merged internally. + names (Dict[str, Dict]): + Names for this concept, or the value that if found in free text can be linked to this concept. + Names is a dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}` + Names should be generated by helper function 'medcat.preprocessing.cleaners.prepare_name' + ontologies (Set[str]): + ontologies in which the concept exists (e.g. SNOMEDCT, HPO) + name_status (str): + One of `P`, `N`, `A` + type_ids (Set[str]): + Semantic type identifier (have a look at TUIs in UMLS or SNOMED-CT) + description (str): + Description of this concept. + full_build (bool): + If True the dictionary self.addl_info will also be populated, contains a lot of extra information + about concepts, but can be very memory consuming. This is not necessary + for normal functioning of MedCAT (Default Value `False`). + """ + self._add_concept(cui, names, ontologies, name_status, type_ids, description, full_build) + + def _add_concept(self, + cui: str, + names: Dict, + ontologies: set, + name_status: str, + type_ids: Set[str], + description: str, + full_build: bool = False) -> None: """Add a concept to internal Concept Database (CDB). Depending on what you are providing this will add a large number of properties for each concept. @@ -232,7 +280,8 @@ def add_concept(self, the same CUI will be merged internally. names (Dict[str, Dict]): Names for this concept, or the value that if found in free text can be linked to this concept. - Names is an dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}` + Names is a dict like: `{name: {'tokens': tokens, 'snames': snames, 'raw_name': raw_name}, ...}` + Names should be generated by helper function 'medcat.preprocessing.cleaners.prepare_name' ontologies (Set[str]): ontologies in which the concept exists (e.g. SNOMEDCT, HPO) name_status (str): @@ -309,6 +358,21 @@ def add_concept(self, if name_status == 'P' and cui not in self.cui2preferred_name: # Do not overwrite old preferred names self.cui2preferred_name[cui] = name_info['raw_name'] + elif names: + # if no name_info and names is NOT an empty dict + # this shouldn't really happen in the current setup + raise ValueError("Unknown state where there is no name_info, " + "yet the `names` dict is not empty (%s)", names) + elif name_status == 'P' and cui not in self.cui2preferred_name: + # this means names is an empty `names` dict + logger.warning("Did not manage to add a preferred name in `add_cui`. " + "Was trying to do so for cui: '%s'" + "This means that the `names` dict passed was empty. " + "This is _usually_ caused by either no name or too short " + "a name passed to the `prepare_name` method. " + "The minimum length is defined in: " + "'config.cdb_maker.min_letters_required' and " + "is currently set at %s", cui, self.config.cdb_maker['min_letters_required']) # Add other fields if full_build if full_build: @@ -458,6 +522,35 @@ async def save_async(self, path: str) -> None: } await f.write(dill.dumps(to_save)) + def load_config(self, config_path: str) -> None: + if not os.path.exists(config_path): + if not self._config_from_file: + # if there's no config defined anywhere + raise ValueError("Could not find a config in the CDB nor " + "in the config.json for this model " + f"({os.path.dirname(config_path)})", + ) + # if there is a config, but it's defined in the cdb.dat file + logger.warning("Could not find config.json in model pack folder " + f"({os.path.dirname(config_path)}). " + "This is probably an older model. Please save the model " + "again in the new format to avoid potential issues.") + else: + if self._config_from_file: + # if there's a config.json and one defined in the cbd.dat file + raise ValueError("Found a config in the CDB and in the config.json " + f"for model ({os.path.dirname(config_path)}) - " + "this is ambiguous. Please either remove the " + "config.json or load the CDB without the config.json " + "in the folder and re-save in the newer format " + "(the default save in this version)") + # if the only config is in the separate config.json file + # this should be the behaviour for all newer models + self.config = cast(Config, Config.load(config_path)) + logger.debug("Loaded config from CDB from %s", config_path) + # mark config read from file + self._config_from_file = True + @classmethod def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[Dict] = None) -> "CDB": """Load and return a CDB. This allows partial loads in probably not the right way at all. @@ -777,8 +870,34 @@ def _check_medcat_version(cls, config_data: Dict) -> None: or download the compatible model.""" ) + def _should_recalc_hash(self, force_recalc: bool) -> bool: + if force_recalc: + return True + if self.config.hash is None: + # TODO - perhaps this is not the best? + # as this is a side effect + # get and save result in config + self.config.get_hash() + if not self._hash or self.is_dirty: + # if no hash saved or is dirty + # need to calculate + logger.debug("Recalculating hash due to %s", + "no hash saved" if not self._hash else "CDB is dirty") + return True + # recalc config hash in case it changed + self.config.get_hash() + if self._config_hash is None or self._config_hash != self.config.hash: + # if no config hash saved + # or if the config hash is different from one saved in here + logger.debug("Recalculating hash due to %s", + "no config hash saved" if not self._config_hash + else "config hash has changed") + return True + return False + def get_hash(self, force_recalc: bool = False): - if not force_recalc and self._hash and not self.is_dirty: + should_recalc = self._should_recalc_hash(force_recalc) + if not should_recalc: logger.info("Reusing old hash of CDB since the CDB has not changed: %s", self._hash) return self._hash self.is_dirty = False @@ -791,7 +910,7 @@ def calculate_hash(self): for k,v in self.__dict__.items(): if k in ['cui2countext_vectors', 'name2cuis']: hasher.update(v, length=False) - elif k in ['_hash', 'is_dirty']: + elif k in ['_hash', 'is_dirty', '_config_hash']: # ignore _hash since if it previously didn't exist, the # new hash would be different when the value does exist # and ignore is_dirty so that we get the same hash as previously @@ -799,6 +918,9 @@ def calculate_hash(self): elif k != 'config': hasher.update(v, length=True) + # set cached config hash + self._config_hash = self.config.hash + self._hash = hasher.hexdigest() logger.info("Found new CDB hash: %s", self._hash) return self._hash diff --git a/medcat/cdb_maker.py b/medcat/cdb_maker.py index e9c72d12e..a4dd7dd27 100644 --- a/medcat/cdb_maker.py +++ b/medcat/cdb_maker.py @@ -49,6 +49,14 @@ def __init__(self, config: Config, cdb: Optional[CDB] = None) -> None: name='skip_and_punct', additional_fields=['is_punct']) + def reset_cdb(self) -> None: + """This will re-create a new internal CDB based on the same config. + + This will be necessary if/when you're wishing to call `prepare_csvs` + multiple times on the same object `CDBMaker` instance. + """ + self.cdb = CDB(config=self.config) + def prepare_csvs(self, csv_paths: Union[pd.DataFrame, List[str]], sep: str = ',', @@ -59,6 +67,12 @@ def prepare_csvs(self, only_existing_cuis: bool = False, **kwargs) -> CDB: r"""Compile one or multiple CSVs into a CDB. + Note: This class/method generally uses the same instance of the CDB. + So if you're using the same CDBMaker and calling `prepare_csvs` + multiple times, you are likely to get leakage from prior calls + into new ones. + To reset the CDB, call `reset_cdb`. + Args: csv_paths (Union[pd.DataFrame, List[str]]): An array of paths to the csv files that should be processed. Can also be an array of pd.DataFrames @@ -173,7 +187,7 @@ def prepare_csvs(self, if len(raw_name) >= self.config.cdb_maker['remove_parenthesis']: prepare_name(raw_name, self.pipe.spacy_nlp, names, self.config) - self.cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, + self.cdb._add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, type_ids=type_ids, description=description, full_build=full_build) # DEBUG logger.debug("\n\n**** Added\n CUI: %s\n Names: %s\n Ontologies: %s\n Name status: %s\n Type IDs: %s\n Description: %s\n Is full build: %s", diff --git a/medcat/config.py b/medcat/config.py index b2e324deb..e60c2eafc 100644 --- a/medcat/config.py +++ b/medcat/config.py @@ -433,6 +433,19 @@ class LinkingFilters(MixingConfig, BaseModel): cuis: Set[str] = set() cuis_exclude: Set[str] = set() + def __init__(self, **data): + if 'cuis' in data: + cuis = data['cuis'] + if isinstance(cuis, dict) and len(cuis) == 0: + logger.warning("Loading an old model where " + "config.linking.filters.cuis has been " + "dict to an empty dict instead of an empty " + "set. Converting the dict to a set in memory " + "as that is what is expected. Please consider " + "saving the model again.") + data['cuis'] = set(cuis.keys()) + super().__init__(**data) + def check_filters(self, cui: str) -> bool: """Checks is a CUI in the filters @@ -535,6 +548,7 @@ class Config(MixingConfig, BaseModel): linking: Linking = Linking() word_skipper: re.Pattern = re.compile('') # empty pattern gets replaced upon init punct_checker: re.Pattern = re.compile('') # empty pattern gets replaced upon init + hash: Optional[str] = None class Config: # this if for word_skipper and punct_checker which would otherwise @@ -559,6 +573,9 @@ def rebuild_re(self) -> None: def get_hash(self): hasher = Hasher() for k, v in self.dict().items(): + if k in ['hash', ]: + # ignore hash + continue if k not in ['version', 'general', 'linking']: hasher.update(v, length=True) elif k == 'general': @@ -574,5 +591,5 @@ def get_hash(self): hasher.update(v2, length=False) else: hasher.update(v2, length=True) - - return hasher.hexdigest() + self.hash = hasher.hexdigest() + return self.hash diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 9623b1b93..227ccc083 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -1,6 +1,7 @@ import os import json import logging +import datasets from spacy.tokens import Doc from datetime import datetime from typing import Iterable, Iterator, Optional, Dict, List, cast, Union @@ -18,7 +19,7 @@ from transformers import Trainer, AutoModelForTokenClassification, AutoTokenizer from transformers import pipeline, TrainingArguments -import datasets +from transformers.trainer_callback import TrainerCallback # It should be safe to do this always, as all other multiprocessing #will be finished before data comes to meta_cat @@ -137,7 +138,12 @@ def merge_data_loaded(base, other): return out_path - def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=False, dataset=None, meta_requirements=None): + def train(self, + json_path: Union[str, list, None]=None, + ignore_extra_labels=False, + dataset=None, + meta_requirements=None, + trainer_callbacks: Optional[List[TrainerCallback]]=None): """Train or continue training a model give a json_path containing a MedCATtrainer export. It will continue training if an existing model is loaded or start new training if the model is blank/new. @@ -149,6 +155,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals ignore_extra_labels: Makes only sense when an existing deid model was loaded and from the new data we want to ignore labels that did not exist in the old model. + trainer_callbacks (List[TrainerCallback]): + A list of trainer callbacks for collecting metrics during the training at the client side. The + transformers Trainer object will be passed in when each callback is called. """ if dataset is None and json_path is not None: @@ -193,6 +202,9 @@ def train(self, json_path: Union[str, list, None]=None, ignore_extra_labels=Fals compute_metrics=lambda p: metrics(p, tokenizer=self.tokenizer, dataset=encoded_dataset['test'], verbose=self.config.general['verbose_metrics']), data_collator=data_collator, # type: ignore tokenizer=None) + if trainer_callbacks: + for callback in trainer_callbacks: + trainer.add_callback(callback(trainer)) trainer.train() # type: ignore diff --git a/medcat/pipe.py b/medcat/pipe.py index 3861267df..7bf06364b 100644 --- a/medcat/pipe.py +++ b/medcat/pipe.py @@ -1,4 +1,5 @@ import types +import os import spacy import gc import logging @@ -17,11 +18,15 @@ from medcat.pipeline.pipe_runner import PipeRunner from medcat.preprocessing.taggers import tag_skip_and_punct from medcat.ner.transformers_ner import TransformersNER +from medcat.utils.helpers import ensure_spacy_model logger = logging.getLogger(__name__) # different logger from the package-level one +DEFAULT_SPACY_MODEL = 'en_core_web_md' + + class Pipe(object): """A wrapper around the standard spacy pipeline. @@ -38,9 +43,27 @@ class Pipe(object): """ def __init__(self, tokenizer: Tokenizer, config: Config) -> None: - self._nlp = spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components) if config.preprocessing.stopwords is not None: - self._nlp.Defaults.stop_words = set(config.preprocessing.stopwords) + lang = os.path.basename(config.general.spacy_model).split('_', 1)[0] + cls = spacy.util.get_lang_class(lang) + cls.Defaults.stop_words = set(config.preprocessing.stopwords) + try: + self._nlp = self._init_nlp(config) + except Exception as e: + if config.general.spacy_model == DEFAULT_SPACY_MODEL: + raise e + logger.warning("Could not load spacy model from '%s'. " + "Falling back to installed en_core_web_md. " + "For best compatibility, we'd recommend " + "packaging and using your model pack with " + "the spacy model it was designed for", + config.general.spacy_model, exc_info=e) + # we're changing the config value so that this propages + # to other places that try to load the model. E.g: + # medcat.utils.normalizers.TokenNormalizer.__init__ + ensure_spacy_model(DEFAULT_SPACY_MODEL) + config.general.spacy_model = DEFAULT_SPACY_MODEL + self._nlp = self._init_nlp(config) self._nlp.tokenizer = tokenizer(self._nlp, config) # Set max document length self._nlp.max_length = config.preprocessing.max_document_length @@ -48,6 +71,9 @@ def __init__(self, tokenizer: Tokenizer, config: Config) -> None: # Set log level logger.setLevel(self.config.general.log_level) + def _init_nlp(selef, config: Config) -> Language: + return spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components) + def add_tagger(self, tagger: Callable, name: Optional[str] = None, additional_fields: List[str] = []) -> None: """Add any kind of a tagger for tokens. diff --git a/medcat/preprocessing/cleaners.py b/medcat/preprocessing/cleaners.py index 18314d562..43e8098e2 100644 --- a/medcat/preprocessing/cleaners.py +++ b/medcat/preprocessing/cleaners.py @@ -48,7 +48,8 @@ def prepare_name(raw_name: str, nlp: Language, names: Dict, config: Config) -> D snames = set() name = config.general['separator'].join(tokens) - if not config.cdb_maker.get('min_letters_required', 0) or len(re.sub("[^A-Za-z]*", '', name)) >= config.cdb_maker.get('min_letters_required', 0): + min_letters = config.cdb_maker.get('min_letters_required', 0) + if not min_letters or len(re.sub("[^A-Za-z]*", '', name)) >= min_letters: if name not in names: sname = "" for token in tokens: diff --git a/medcat/stats/__init__.py b/medcat/stats/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat/stats/stats.py b/medcat/stats/stats.py new file mode 100644 index 000000000..06b712158 --- /dev/null +++ b/medcat/stats/stats.py @@ -0,0 +1,340 @@ +from typing import Dict, Optional, Set, Tuple, Callable, List, cast + +from tqdm import tqdm +import traceback + +from spacy.tokens import Doc + +from medcat.utils.filters import set_project_filters +from medcat.utils.matutils import intersect_nonempty_set +from medcat.config import LinkingFilters + + +class StatsBuilder: + + def __init__(self, + filters: LinkingFilters, + addl_info: dict, + doc_getter: Callable[[Optional[str], bool], Optional[Doc]], + doc_annotation_getter: Callable[[dict], list], + cui2group: Dict[str, str], + cui2preferred_name: Dict[str, str], + cui2names: Dict[str, Set[str]], + use_project_filters: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + use_groups: bool = False, + extra_cui_filter: Optional[Set] = None) -> None: + self.filters = filters + self.addl_info = addl_info + self.doc_getter = doc_getter + self._get_doc_annotations = doc_annotation_getter + self.cui2group = cui2group + self.cui2preferred_name = cui2preferred_name + self.cui2names = cui2names + self.use_project_filters = use_project_filters + self.use_overlaps = use_overlaps + self.use_cui_doc_limit = use_cui_doc_limit + self.use_groups = use_groups + self.extra_cui_filter = extra_cui_filter + self._reset_stats() + + def _reset_stats(self): + self.tp = 0 + self.fp = 0 + self.fn = 0 + self.fps: Dict = {} + self.fns: Dict = {} + self.tps: Dict = {} + self.cui_prec: Dict = {} + self.cui_rec: Dict = {} + self.cui_f1: Dict = {} + self.cui_counts: Dict = {} + self.examples: Dict = {'fp': {}, 'fn': {}, 'tp': {}} + self.fp_docs: Set = set() + self.fn_docs: Set = set() + + def process_project(self, project: dict) -> None: + self.filters.cuis = set() + + # Add extra filter if set + set_project_filters(self.addl_info, self.filters, project, self.extra_cui_filter, self.use_project_filters) + + documents = project["documents"] + for dind, doc in tqdm( + enumerate(documents), + desc="Stats document", + total=len(documents), + leave=False, + ): + self.process_document(cast(str, project.get('name')), + cast(str, project.get('id')), doc) + + def process_document(self, project_name: str, project_id: str, doc: dict) -> None: + anns = self._get_doc_annotations(doc) + + # Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still + if self.use_cui_doc_limit: + _cuis = set([ann['cui'] for ann in anns]) + if _cuis: + self.filters.cuis = intersect_nonempty_set(_cuis, self.extra_cui_filter) + else: + self.filters.cuis = {'empty'} + + spacy_doc: Doc = self.doc_getter(doc['text']) # type: ignore + + if self.use_overlaps: + p_anns = spacy_doc._.ents + else: + p_anns = spacy_doc.ents + + (anns_norm, anns_norm_neg, + anns_examples, _) = self._preprocess_annotations(project_name, project_id, doc, anns) + + p_anns_norm, p_anns_examples = self._process_p_anns(project_name, project_id, + doc, p_anns) + self._count_p_anns_norm(doc, anns_norm, anns_norm_neg, + p_anns_norm, p_anns_examples) + self._process_anns_norm(doc, anns_norm, p_anns_norm, anns_examples) + + def _process_anns_norm(self, doc: dict, anns_norm: list, p_anns_norm: list, + anns_examples: list) -> None: + for iann, ann in enumerate(anns_norm): + if ann not in p_anns_norm: + cui = ann[1] + self.fn += 1 + self.fn_docs.add(doc.get('name', 'unk')) + + self.fns[cui] = self.fns.get(cui, 0) + 1 + self.examples['fn'][cui] = self.examples['fn'].get(cui, []) + [anns_examples[iann]] + + def _process_p_anns(self, project_name: str, project_id: str, doc: dict, p_anns: list) -> Tuple[list, list]: + p_anns_norm = [] + p_anns_examples = [] + for ann in p_anns: + cui = ann._.cui + if self.use_groups: + cui = self.cui2group.get(cui, cui) + + p_anns_norm.append((ann.start_char, cui)) + p_anns_examples.append(self._create_annoation_2(project_name, project_id, cui, doc, ann)) + return p_anns_norm, p_anns_examples + + def _count_p_anns_norm(self, doc: dict, anns_norm: list, anns_norm_neg: list, + p_anns_norm: list, p_anns_examples: list) -> None: + for iann, ann in enumerate(p_anns_norm): + cui = ann[1] + if ann in anns_norm: + self.tp += 1 + self.tps[cui] = self.tps.get(cui, 0) + 1 + + example = p_anns_examples[iann] + self.examples['tp'][cui] = self.examples['tp'].get(cui, []) + [example] + else: + self.fp += 1 + self.fps[cui] = self.fps.get(cui, 0) + 1 + self.fp_docs.add(doc.get('name', 'unk')) + + # Add example for this FP prediction + example = p_anns_examples[iann] + if ann in anns_norm_neg: + # Means that it really was annotated as negative + example['real_fp'] = True + + self.examples['fp'][cui] = self.examples['fp'].get(cui, []) + [example] + + def _create_annoation(self, project_name: str, project_id: str, cui: str, doc: dict, ann: Dict) -> Dict: + return {"text": doc['text'][max(0, ann['start']-60):ann['end']+60], + "cui": cui, + "start": ann['start'], + "end": ann['end'], + "source value": ann['value'], + "acc": 1, + "project name": project_name, + "document name": doc.get('name'), + "project id": project_id, + "document id": doc.get('id')} + + def _create_annoation_2(self, project_name: str, project_id: str, cui: str, doc: dict, ann) -> Dict: + return {"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60], + "cui": cui, + "start": ann.start_char, + "end": ann.end_char, + "source value": ann.text, + "acc": float(ann._.context_similarity), + "project name": project_name, + "document name": doc.get('name'), + "project id": project_id, + "document id": doc.get('id')} + + def _preprocess_annotations(self, project_name: str, project_id: str, + doc: dict, anns: List[Dict]) -> Tuple[list, list, list, list]: + anns_norm = [] + anns_norm_neg = [] + anns_examples = [] + anns_norm_cui = [] + for ann in anns: + cui = ann['cui'] + if self.filters.check_filters(cui): + if self.use_groups: + cui = self.cui2group.get(cui, cui) + + if ann.get('validated', True) and (not ann.get('killed', False) and not ann.get('deleted', False)): + anns_norm.append((ann['start'], cui)) + anns_examples.append(self._create_annoation(project_name, project_id, cui, doc, ann)) + elif ann.get('validated', True) and (ann.get('killed', False) or ann.get('deleted', False)): + anns_norm_neg.append((ann['start'], cui)) + + if ann.get("validated", True): + # This is used to test was someone annotating for this CUI in this document + anns_norm_cui.append(cui) + self.cui_counts[cui] = self.cui_counts.get(cui, 0) + 1 + return anns_norm, anns_norm_neg, anns_examples, anns_norm_cui + + def finalise_report(self, epoch: int, do_print: bool = True): + try: + prec = self.tp / (self.tp + self.fp) + rec = self.tp / (self.tp + self.fn) + f1 = 2*(prec*rec) / (prec + rec) + if do_print: + print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(epoch, prec, rec, f1)) + print("Docs with false positives: {}\n".format("; ".join([str(x) for x in list(self.fp_docs)[0:10]]))) + print("Docs with false negatives: {}\n".format("; ".join([str(x) for x in list(self.fn_docs)[0:10]]))) + + # Sort fns & prec + fps = {k: v for k, v in sorted(self.fps.items(), key=lambda item: item[1], reverse=True)} + fns = {k: v for k, v in sorted(self.fns.items(), key=lambda item: item[1], reverse=True)} + tps = {k: v for k, v in sorted(self.tps.items(), key=lambda item: item[1], reverse=True)} + + + # F1 per concept + for cui in tps.keys(): + prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0)) + rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0)) + f1 = 2*(prec*rec) / (prec + rec) + self.cui_prec[cui] = prec + self.cui_rec[cui] = rec + self.cui_f1[cui] = f1 + + + # Get top 10 + pr_fps = [(self.cui2preferred_name.get(cui, + list(self.cui2names.get(cui, [cui]))[0]), cui, fps[cui]) for cui in list(fps.keys())[0:10]] + pr_fns = [(self.cui2preferred_name.get(cui, + list(self.cui2names.get(cui, [cui]))[0]), cui, fns[cui]) for cui in list(fns.keys())[0:10]] + pr_tps = [(self.cui2preferred_name.get(cui, + list(self.cui2names.get(cui, [cui]))[0]), cui, tps[cui]) for cui in list(tps.keys())[0:10]] + + if do_print: + print("\n\nFalse Positives\n") + for one in pr_fps: + print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) + print("\n\nFalse Negatives\n") + for one in pr_fns: + print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) + print("\n\nTrue Positives\n") + for one in pr_tps: + print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) + print("*"*110 + "\n") + + except Exception: + traceback.print_exc() + + def unwrap(self) -> Tuple: + return (self.fps, self.fns, self.tps, + self.cui_prec, self.cui_rec, self.cui_f1, + self.cui_counts, self.examples) + + @classmethod + def from_cat(cls, cat, + local_filters: LinkingFilters, + use_project_filters: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + use_groups: bool = False, + extra_cui_filter: Optional[Set] = None) -> 'StatsBuilder': + return StatsBuilder(filters=local_filters, + addl_info=cat.cdb.addl_info, + doc_getter=cat.__call__, + doc_annotation_getter=cat._get_doc_annotations, + cui2group=cat.cdb.addl_info['cui2group'], + cui2preferred_name=cat.cdb.cui2preferred_name, + cui2names=cat.cdb.cui2names, + use_project_filters=use_project_filters, + use_overlaps=use_overlaps, + use_cui_doc_limit=use_cui_doc_limit, + use_groups=use_groups, + extra_cui_filter=extra_cui_filter) + + +def get_stats(cat, + data: Dict, + epoch: int = 0, + use_project_filters: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + use_groups: bool = False, + extra_cui_filter: Optional[Set] = None, + do_print: bool = True) -> Tuple: + """TODO: Refactor and make nice + Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP. + + Args: + cat: (CAT): + The model pack. + data (list of dict): + The json object that we get from MedCATtrainer on export. + epoch (int): + Used during training, so we know what epoch is it. + use_project_filters (boolean): + Each project in MedCATtrainer can have filters, do we want to respect those filters + when calculating metrics. + use_overlaps (boolean): + Allow overlapping entities, nearly always False as it is very difficult to annotate overlapping entites. + use_cui_doc_limit (boolean): + If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words + if the document was annotated for that CUI. Useful in very specific situations when during the annotation + process the set of CUIs changed. + use_groups (boolean): + If True concepts that have groups will be combined and stats will be reported on groups. + extra_cui_filter(Optional[Set]): + This filter will be intersected with all other filters, or if all others are not set then only this one will be used. + + Returns: + fps (dict): + False positives for each CUI. + fns (dict): + False negatives for each CUI. + tps (dict): + True positives for each CUI. + cui_prec (dict): + Precision for each CUI. + cui_rec (dict): + Recall for each CUI. + cui_f1 (dict): + F1 for each CUI. + cui_counts (dict): + Number of occurrence for each CUI. + examples (dict): + Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][]. + do_print (bool): + Whether to print stats out. Defaults to True. + """ + orig_filters = cat.config.linking.filters.copy_of() + local_filters = cat.config.linking.filters + builder = StatsBuilder.from_cat(cat, + local_filters=local_filters, + use_project_filters=use_project_filters, + use_overlaps=use_overlaps, + use_cui_doc_limit=use_cui_doc_limit, + use_groups=use_groups, + extra_cui_filter=extra_cui_filter) + for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False): + builder.process_project(project) + + # this is the part that prints out the stats + builder.finalise_report(epoch, do_print=do_print) + + cat.config.linking.filters = orig_filters + + return builder.unwrap() diff --git a/medcat/utils/cdb_utils.py b/medcat/utils/cdb_utils.py new file mode 100644 index 000000000..445fb7d6f --- /dev/null +++ b/medcat/utils/cdb_utils.py @@ -0,0 +1,117 @@ +import logging +import numpy as np + +from copy import deepcopy +from medcat.cdb import CDB + +logger = logging.getLogger(__name__) # separate logger from the package-level one + + +def merge_cdb(cdb1: "CDB", + cdb2: "CDB", + overwrite_training: int = 0, + full_build: bool = False): + """Merge two CDB's together to produce a new, single CDB. The contents of inputs CDBs will not be changed. + `addl_info` can not be perfectly merged, and will prioritise cdb1. see `full_build` + + Args: + cdb1 (medcat.cdb.CDB): + The first medcat cdb to merge. In cases where merging isn't suitable isn't ideal (such as + cui2preferred_name), this cdb values will be prioritised over cdb2. + cdb2 (medcat.cdb.CDB): + The second medcat cdb to merge. + overwrite_training (int): + Choose to prioritise a CDB's context vectors values over merging gracefully. 0 - no prio, 1 - CDB1, 2 - CDB2 + full_build (bool): + Add additional information from "addl_info" dicts "cui2ontologies" and "cui2description" + """ + config = deepcopy(cdb1.config) + cdb = CDB(config) + + # Copy CDB 1 - as all settings from CDB 1 will be carried over + cdb.cui2names = deepcopy(cdb1.cui2names) + cdb.cui2snames = deepcopy(cdb1.cui2snames) + cdb.cui2count_train = deepcopy(cdb1.cui2count_train) + cdb.cui2info = deepcopy(cdb1.cui2info) + cdb.cui2context_vectors = deepcopy(cdb1.cui2context_vectors) + cdb.cui2tags = deepcopy(cdb1.cui2tags) + cdb.cui2type_ids = deepcopy(cdb1.cui2type_ids) + cdb.cui2preferred_name = deepcopy(cdb1.cui2preferred_name) + cdb.name2cuis = deepcopy(cdb1.name2cuis) + cdb.name2cuis2status = deepcopy(cdb1.name2cuis2status) + cdb.name2count_train = deepcopy(cdb1.name2count_train) + cdb.name_isupper = deepcopy(cdb1.name_isupper) + if full_build: + cdb.addl_info = deepcopy(cdb1.addl_info) + + # handles cui2names, cui2snames, name_isupper, name2cuis, name2cuis2status, cui2preferred_name + for cui in cdb2.cui2names: + names = dict() + for name in cdb2.cui2names[cui]: + names[name] = {'snames': cdb2.cui2snames.get(cui, set()), 'is_upper': cdb2.name_isupper.get(name, False), 'tokens': {}, 'raw_name': cdb2.get_name(cui)} + name_status = cdb2.name2cuis2status.get(name, 'A').get(cui, 'A') # get the name status if it exists, default to 'A' + # For addl_info check cui2original_names as they MUST be added + ontologies = set() + description = '' + to_build = False + if full_build and (cui in cdb2.addl_info['cui2original_names'] or cui in cdb2.addl_info['cui2description']): + to_build = True + if 'cui2ontologies' in cdb2.addl_info: + ontologies.update(cdb2.addl_info['cui2ontologies'][cui]) + if 'cui2description' in cdb2.addl_info: + description = cdb2.addl_info['cui2description'][cui] + cdb.add_concept(cui=cui, names=names, ontologies=ontologies, name_status=name_status, + type_ids=cdb2.cui2type_ids[cui], description=description, full_build=to_build) + if cui in cdb1.cui2names: + if (cui in cdb1.cui2count_train or cui in cdb2.cui2count_train) and not (overwrite_training == 1 and cui in cdb1.cui2count_train): + if overwrite_training == 2 and cui in cdb2.cui2count_train: + cdb.cui2count_train[cui] = cdb2.cui2count_train[cui] + else: + cdb.cui2count_train[cui] = cdb1.cui2count_train.get(cui, 0) + cdb2.cui2count_train.get(cui, 0) + if cui in cdb1.cui2context_vectors and not (overwrite_training == 1 and cui in cdb1.cui2context_vectors[cui]): + if overwrite_training == 2 and cui in cdb2.cui2context_vectors: + weights = [0, 1] + else: + norm = cdb.cui2count_train[cui] + weights = [np.divide(cdb1.cui2count_train.get(cui, 0), norm), np.divide(cdb2.cui2count_train.get(cui, 0), norm)] + contexts = set(list(cdb1.cui2context_vectors.get(cui, {}).keys()) + list(cdb2.cui2context_vectors.get(cui, {}).keys())) # xlong, long, medium, short + for s in contexts: + cdb.cui2context_vectors[cui][s] = (weights[0] * cdb1.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + (weights[1] * cdb2.cui2context_vectors[cui].get(s, np.zeros(shape=(300)))) + if cui in cdb1.cui2tags: + cdb.cui2tags[cui].append(cdb2.cui2tags[cui]) + if cui in cdb1.cui2type_ids: + cdb.cui2type_ids[cui] = cdb1.cui2type_ids[cui].union(cdb2.cui2type_ids[cui]) + else: + if cui in cdb2.cui2count_train: + cdb.cui2count_train[cui] = cdb2.cui2names[cui] + if cui in cdb2.cui2info: + cdb.cui2info[cui] = cdb2.cui2info[cui] + if cui in cdb2.cui2context_vectors: + cdb.cui2context_vectors[cui] = cdb2.cui2context_vectors[cui] + if cui in cdb2.cui2tags: + cdb.cui2tags[cui] = cdb2.cui2tags[cui] + if cui in cdb2.cui2type_ids: + cdb.cui2type_ids[cui] = cdb2.cui2type_ids[cui] + + if overwrite_training != 1: + for name in cdb2.name2cuis: + if name in cdb1.name2cuis and overwrite_training == 0: # if they exist in both cdbs + if name in cdb1.name2count_train and name in cdb2.name2count_train: + cdb.name2count_train[name] = str(int(cdb1.name2count_train[name]) + int(cdb2.name2count_train[name])) # these are strings for some reason + else: + if name in cdb2.name2count_train: + cdb.name2count_train[name] = cdb2.name2count_train[name] + + # snames + cdb.snames = cdb1.snames.union(cdb2.snames) + + # vocab, adding counts if they occur in both + cdb.vocab = deepcopy(cdb1.vocab) + if overwrite_training != 1: + for word in cdb2.vocab: + if word in cdb.vocab and overwrite_training == 0: + cdb.vocab[word] += cdb2.vocab[word] + else: + cdb.vocab[word] = cdb2.vocab[word] + + return cdb diff --git a/medcat/utils/filters.py b/medcat/utils/filters.py index c4803027a..cb85e0e26 100644 --- a/medcat/utils/filters.py +++ b/medcat/utils/filters.py @@ -1,3 +1,9 @@ +from typing import Optional, Set, Dict + +from medcat.config import LinkingFilters +from medcat.utils.matutils import intersect_nonempty_set + + def check_filters(cui, filters): """Checks is a CUI in the filters @@ -15,7 +21,7 @@ def check_filters(cui, filters): return False -def get_all_irrelevant_cuis(project, cdb): +def get_all_irrelevant_cuis(project): i_cuis = set() for d in project['documents']: for a in d['annotations']: @@ -24,7 +30,7 @@ def get_all_irrelevant_cuis(project, cdb): return i_cuis -def get_project_filters(cuis, type_ids, cdb, project=None): +def get_project_filters(cuis, type_ids, addl_info: Dict, project=None): cui_filter = set() if isinstance(cuis, str): if cuis is not None and cuis: @@ -33,10 +39,10 @@ def get_project_filters(cuis, type_ids, cdb, project=None): type_ids = [x.strip().upper() for x in type_ids.split(",")] # Convert type_ids to cuis - if 'type_id2cuis' in cdb.addl_info: + if 'type_id2cuis' in addl_info: for type_id in type_ids: - if type_id in cdb.addl_info['type_id2cuis']: - cui_filter.update(cdb.addl_info['type_id2cuis'][type_id]) + if type_id in addl_info['type_id2cuis']: + cui_filter.update(addl_info['type_id2cuis'][type_id]) else: raise Exception("Impossible to create filters, disable them.") else: @@ -45,8 +51,33 @@ def get_project_filters(cuis, type_ids, cdb, project=None): cui_filter = set(cuis) if project is not None: - i_cuis = get_all_irrelevant_cuis(project, cdb) + i_cuis = get_all_irrelevant_cuis(project) for i_cui in i_cuis: cui_filter.remove(i_cui) return cui_filter + + +def set_project_filters(addl_info: Dict, local_filters: LinkingFilters, project: dict, + extra_cui_filter: Optional[Set], use_project_filters: bool): + """Set the project filters to a LinkingFilters object based on + the specified project. + + Args: + addl_info (Dict): The CDB additional information + local_filters (LinkingFilters): The linking filters instance + project (dict): The project + extra_cui_filter (Optional[Set]): Extra CUIs (if specified) + use_project_filters (bool): Whether to use per-project filters + """ + if isinstance(extra_cui_filter, set): + local_filters.cuis = extra_cui_filter + + if use_project_filters: + project_filter = get_project_filters(cuis=project.get('cuis', None), + type_ids=project.get('tuis', None), + addl_info=addl_info, + project=project) + # Intersect project filter with existing if it has something + if project_filter: + local_filters.cuis = intersect_nonempty_set(project_filter, local_filters.cuis) diff --git a/medcat/utils/helpers.py b/medcat/utils/helpers.py index f783a9b06..816b316ce 100644 --- a/medcat/utils/helpers.py +++ b/medcat/utils/helpers.py @@ -537,3 +537,35 @@ def has_new_spacy() -> bool: return (major > 3 or (major == 3 and minor > 3) or (major == 3 and minor == 3 and patch >= 1)) + + +def has_spacy_model(model_name: str) -> bool: + """Checks if the spacy model is available. + + Args: + model_name (str): The model name. + + Returns: + bool: True if the model is available, False otherwise. + """ + import spacy.util + return model_name in spacy.util.get_installed_models() + + +def ensure_spacy_model(model_name: str) -> None: + """Ensure the specified spacy model exists. + + If the model does not currently exist, it will attempt downloading it. + + Args: + model_name (str): The spacy model name. + """ + import subprocess + if has_spacy_model(model_name): + return + # running in subprocess so that we can catch the exception + # if the model name is unknown. Otherwise we'd just be bumped + # out of python (sys.exit). + logger.info("Installing the spacy model %s using the CLI command " + "'python -m spacy download %s'", model_name, model_name) + subprocess.run(["python", "-m", "spacy", "download", model_name], check=True) diff --git a/medcat/utils/regression/targeting.py b/medcat/utils/regression/targeting.py index 19f19bb3f..7a13b2bcc 100644 --- a/medcat/utils/regression/targeting.py +++ b/medcat/utils/regression/targeting.py @@ -25,12 +25,12 @@ class TranslationLayer: Args: cui2names (Dict[str, Set[str]]): The map from CUI to names - name2cuis (Dict[str, Set[str]]): The map from name to CUIs + name2cuis (Dict[str, List[str]]): The map from name to CUIs cui2type_ids (Dict[str, Set[str]]): The map from CUI to type_ids cui2children (Dict[str, Set[str]]): The map from CUI to child CUIs """ - def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, Set[str]], + def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, List[str]], cui2type_ids: Dict[str, Set[str]], cui2children: Dict[str, Set[str]]) -> None: self.cui2names = cui2names self.name2cuis = name2cuis diff --git a/medcat/utils/saving/serializer.py b/medcat/utils/saving/serializer.py index d82df751c..25529c778 100644 --- a/medcat/utils/saving/serializer.py +++ b/medcat/utils/saving/serializer.py @@ -135,13 +135,12 @@ def serialize(self, cdb, overwrite: bool = False) -> None: raise ValueError(f'Unable to overwrite shelf path "{self.json_path}"' ' - specify overrwrite=True if you wish to overwrite') to_save = {} - to_save['config'] = cdb.config.asdict() # This uses different names so as to not be ambiguous # when looking at files whether the json parts should # exist separately or not to_save['cdb_main' if self.jsons is not None else 'cdb'] = dict( ((key, val) for key, val in cdb.__dict__.items() if - key != 'config' and + key not in ('config', '_config_from_file') and (self.jsons is None or key not in SPECIALITY_NAMES))) logger.info('Dumping CDB to %s', self.main_path) with open(self.main_path, 'wb') as f: @@ -165,7 +164,17 @@ def deserialize(self, cdb_cls): logger.info('Reading CDB data from %s', self.main_path) with open(self.main_path, 'rb') as f: data = dill.load(f) - config = cast(Config, Config.from_dict(data['config'])) + if 'config' in data: + logger.warning("Found config in CDB for model (%s). " + "This is an old format. Please re-save the " + "model in the new format to avoid potential issues", + os.path.dirname(self.main_path)) + config = cast(Config, Config.from_dict(data['config'])) + else: + # by passing None as config to constructor + # the CDB should identify that there has been + # no config loaded + config = None cdb = cdb_cls(config=config) if self.jsons is None: cdb_main = data['cdb'] diff --git a/medcat/utils/spacy_compatibility.py b/medcat/utils/spacy_compatibility.py new file mode 100644 index 000000000..a64737f21 --- /dev/null +++ b/medcat/utils/spacy_compatibility.py @@ -0,0 +1,211 @@ +"""This module attempts to read the spacy compatibilty of +a model pack and (if necessary) compare it to the installed +spacy version. +""" +from typing import Tuple, List, cast +import os +import re +from packaging import version +from packaging.specifiers import SpecifierSet + +import spacy + + +SPACY_MODEL_REGEX = re.compile(r"(\w{2}_core_(\w{3,4})_(sm|md|lg|trf|xxl|\w+))|(spacy_model)") + + +def _is_spacy_model_folder(folder_name: str) -> bool: + """Check if a folder within a model pack contains a spacy model. + + The idea is to do this without loading the model. That is because + the version of the model may be incompatible with what we've got. + And as such, loading may not be possible. + + Args: + folder_name (str): The folder to check. + + Returns: + bool: Whether the folder contains a spacy model. + """ + # since we're trying to identify this solely from the + # folder name, we only care about the base name. + folder_name = os.path.basename(folder_name) + if folder_name.startswith("meta_"): + # these are MetaCat stuff (or should be) + return False + return bool(SPACY_MODEL_REGEX.match(folder_name)) + + +def _find_spacy_model_folder(model_pack_folder: str) -> str: + """Find the spacy model folder in a model pack folder. + + Args: + model_pack_folder (str): The model pack folder + + Raises: + ValueError: If it's ambiguous or there's no model folder. + + Returns: + str: The full path to the model folder. + """ + options: List[str] = [] + for folder_name in os.listdir(model_pack_folder): + full_folder_path = os.path.join(model_pack_folder, folder_name) + if not os.path.isdir(full_folder_path): + continue + if _is_spacy_model_folder(folder_name): + options.append(full_folder_path) + if len(options) != 1: + raise ValueError("Unable to determine spacy folder name from " + f"{len(options)} ambiguous folders: {options}") + return options[0] + + +def get_installed_spacy_version() -> str: + """Get the spacy version installed currently. + + Returns: + str: The currently installed spacy verison. + """ + return spacy.__version__ + + +def get_installed_model_version(model_name: str) -> str: + """Get the version of a model installed in spacy. + + Args: + model_name (str): The model name. + + Returns: + str: The version of the installed model. + """ + if model_name not in spacy.util.get_installed_models(): + return 'N/A' + # NOTE: I don't really know when spacy.info + # might return a str instead + return cast(dict, spacy.info(model_name))['version'] + + +def _get_name_and_meta_of_spacy_model_in_medcat_modelpack(model_pack_path: str) -> Tuple[str, dict]: + """Gets the name and meta information about a spacy model within a medcat model pack. + + PS: This gets the raw (folder) name of the spacy model. + While this is usually (in models created after v1.2.4) + identical to the spacy model version, that may not always + be the case. + + Args: + model_pack_path (str): The model pack path. + + Returns: + Tuple[str, dict]: The name of the spacy model, and the meta information. + """ + spacy_model_folder = _find_spacy_model_folder(model_pack_path) + # NOTE: I don't really know when spacy.info + # might return a str instead + info = cast(dict, spacy.info(spacy_model_folder)) + return os.path.basename(spacy_model_folder), info + + +def get_name_and_version_of_spacy_model_in_medcat_modelpack(model_pack_path: str) -> Tuple[str, str, str]: + """Get the name, version, and compatible spacy versions of a spacy model within a medcat model pack. + + PS: This gets the real name of the spacy model. + While this is usually (in models created after v1.2.4) + identical to the folder name, that may not always + be the case. + + Args: + model_pack_path (str): The model pack path. + + Returns: + Tuple[str, str, str]: The name of the spacy model, its version, and supported spacy version. + """ + _, info = _get_name_and_meta_of_spacy_model_in_medcat_modelpack(model_pack_path) + true_name = info["lang"] + "_" + info['name'] + return true_name, info['version'], info["spacy_version"] + + +def _is_spacy_version_within_range(spacy_version_range: str) -> bool: + """Checks whether the spacy version is within the specified range. + + The expected format of the version range is similar to that used + in requirements and/or pip installs. E.g: + - >=3.1.0,<3.2.0 + - ==3.1.0 + - >=3.1.0 + - <3.20 + + Args: + spacy_version_range (str): The requires spacy version range. + + Returns: + bool: Whether the specified range is compatible. + """ + spacy_version = version.parse(get_installed_spacy_version()) + range = SpecifierSet(spacy_version_range) + return range.contains(spacy_version) + + +def medcat_model_pack_has_compatible_spacy_model(model_pack_path: str) -> bool: + """Checks whether a medcat model pack has a spacy model compatible with installed spacy version. + + Args: + model_pack_path (str): The model pack path. + + Returns: + bool: Whether the spacy model in the model pack is compatible. + """ + _, _, spacy_range = get_name_and_version_of_spacy_model_in_medcat_modelpack(model_pack_path) + return _is_spacy_version_within_range(spacy_range) + + +def is_older_spacy_version(model_version: str) -> bool: + """Checks if the specified version is older than the installed version. + + Args: + model_version (str): The specified spacy version. + + Returns: + bool: Whether the specified version is older. + """ + installed_version = version.parse(get_installed_spacy_version()) + model_version = version.parse(model_version) + return model_version <= installed_version + + +def medcat_model_pack_has_semi_compatible_spacy_model(model_pack_path: str) -> bool: + """Checks whether the spacy model within a medcat model pack is + compatible or older than the installed spacy version. + + This method returns `True` if the spacy model is compatible or + released with a lower version number compared to the spacy + version currently installed. + + We've found that most of the time older models will work with + a newer version of spacy. Though there is a warning on spacy's + side and they do not guarantee 100% compatibility, we've not + seen issues so far. + + E.g for installed spacy 3.4.4 all the following will be suiable: + - en_core_web_md-3.1.0 + - en_core_web_md-3.2.0 + - en_core_web_md-3.3.0 + - en_core_web_md-3.4.1 + However, for the same version, the following would not be suitable: + - en_core_web_md-3.5.0 + - en_core_web_md-3.6.0 + - en_core_web_md-3.7.1 + + Args: + model_pack_path (str): The model pack path. + + Returns: + bool: Whether the spacy model in the model pack is compatible. + """ + (_, + model_version, + spacy_range) = get_name_and_version_of_spacy_model_in_medcat_modelpack(model_pack_path) + if _is_spacy_version_within_range(spacy_range): + return True + return is_older_spacy_version(model_version) diff --git a/setup.py b/setup.py index 8b152cb77..34963943a 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ url="https://github.com/CogStack/MedCAT", packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets', 'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', - 'medcat.utils.saving', 'medcat.utils.regression'], + 'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'], install_requires=[ 'numpy>=1.22.0', # first to support 3.11 'pandas>=1.4.2', # first to support 3.11 @@ -40,12 +40,6 @@ 'blis>=0.7.5', # allow later versions, tested with 0.7.9 'click>=8.0.4', # allow later versions, tested with 8.1.3 'pydantic>=1.10.0,<2.0', # for spacy compatibility; avoid 2.0 due to breaking changes - # the following are not direct dependencies of MedCAT but needed for docs/building - # hopefully will no longer need the transitive dependencies - 'aiohttp==3.8.5', # 3.8.3 is needed for compatibility with fsspec <- datasets <- medcat - 'blis<0.8.0,>=0.7.8', # as required by thinc <- spacy <- medcat - # 'smart-open==5.2.1', # 5.2.1 is needed for compatibility with pathy - # 'joblib~=1.2', ], classifiers=[ "Programming Language :: Python :: 3", diff --git a/tests/archive_tests/test_cdb_maker_archive.py b/tests/archive_tests/test_cdb_maker_archive.py index 329408999..9e2fc2d72 100644 --- a/tests/archive_tests/test_cdb_maker_archive.py +++ b/tests/archive_tests/test_cdb_maker_archive.py @@ -108,7 +108,7 @@ def test_concept_similarity(self): for i in range(500): cui = "C" + str(i) type_ids = {'T-' + str(i%10)} - cdb.add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.pipe.get_spacy_nlp(), {}, self.config), ontologies=set(), + cdb._add_concept(cui=cui, names=prepare_name('Name: ' + str(i), self.maker.pipe.get_spacy_nlp(), {}, self.config), ontologies=set(), name_status='P', type_ids=type_ids, description='', full_build=True) vectors = {} diff --git a/tests/helper.py b/tests/helper.py index 9fb66589b..52943c3cd 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -6,6 +6,8 @@ import numpy as np from medcat.vocab import Vocab +from medcat.cdb_maker import CDBMaker +from medcat.config import Config class AsyncMock(unittest.mock.MagicMock): @@ -86,3 +88,36 @@ def check_or_download(self): return with open(self.vocab_path, 'wb') as f: f.write(tmp.content) + + +class ForCDBMerging: + + def __init__(self) -> None: + # generating cdbs - two maker are requested as they point to the same created CDB. + config = Config() + config.general["spacy_model"] = "en_core_web_md" + maker1 = CDBMaker(config) + maker2 = CDBMaker(config) # second maker is required as it will otherwise point to same object + path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "model_creator", "umls_sample.csv") + self.cdb1 = maker1.prepare_csvs(csv_paths=[path]) + self.cdb2 = maker2.prepare_csvs(csv_paths=[path]) + + # generating context vectors here for for testing the weighted average function (based off cui2count_train) + zeroes = np.zeros(shape=(1,300)) + ones = np.ones(shape=(1,300)) + for i, cui in enumerate(self.cdb1.cui2names): + self.cdb1.cui2context_vectors[cui] = {"short": ones} + self.cdb2.cui2context_vectors[cui] = {"short": zeroes} + self.cdb1.cui2count_train[cui] = 1 + self.cdb2.cui2count_train[cui] = i + 1 + # adding new names and cuis to each cdb to test after merging + test_add = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}} + self.cdb1.add_names("C0006826", test_add) + unique_test = {"test": {'tokens': "test_token", 'snames': ["test_name"], 'raw_name': "test_raw_name", "is_upper": "P"}} + self.cdb2.add_names("UniqueTest", unique_test) + self.cdb2.cui2context_vectors["UniqueTest"] = {"short": zeroes} + self.cdb2.addl_info["cui2ontologies"] = {} + self.cdb2.addl_info["cui2description"] = {} + for cui in self.cdb2.cui2names: + self.cdb2.addl_info["cui2ontologies"][cui] = {"test_ontology"} + self.cdb2.addl_info["cui2description"][cui] = "test_description" diff --git a/tests/ner/test_transformers_ner.py b/tests/ner/test_transformers_ner.py new file mode 100644 index 000000000..de9eae32c --- /dev/null +++ b/tests/ner/test_transformers_ner.py @@ -0,0 +1,50 @@ +import os +import unittest +from spacy.lang.en import English +from spacy.tokens import Doc, Span +from transformers import TrainerCallback +from medcat.ner.transformers_ner import TransformersNER +from medcat.config import Config +from medcat.cdb_maker import CDBMaker + + +class TransformerNERTest(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + config = Config() + config.general["spacy_model"] = "en_core_web_md" + cdb_maker = CDBMaker(config) + cdb_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "examples", "cdb.csv") + cdb = cdb_maker.prepare_csvs([cdb_csv], full_build=True) + Doc.set_extension("ents", default=[], force=True) + Span.set_extension("confidence", default=-1, force=True) + Span.set_extension("id", default=0, force=True) + Span.set_extension("detected_name", default=None, force=True) + Span.set_extension("link_candidates", default=None, force=True) + Span.set_extension("cui", default=-1, force=True) + Span.set_extension("context_similarity", default=-1, force=True) + cls.undertest = TransformersNER(cdb) + cls.undertest.create_eval_pipeline() + + def test_pipe(self): + doc = English().make_doc("\nPatient Name: John Smith\nAddress: 15 Maple Avenue\nCity: New York\nCC: Chronic back pain\n\nHX: Mr. Smith") + doc = next(self.undertest.pipe([doc])) + assert len(doc.ents) > 0, "No entities were recognised" + + def test_train(self): + tracker = unittest.mock.Mock() + class _DummyCallback(TrainerCallback): + def __init__(self, trainer) -> None: + self._trainer = trainer + def on_epoch_end(self, *args, **kwargs) -> None: + tracker.call() + + train_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_train_data.json") + self.undertest.training_arguments.num_train_epochs = 1 + df, examples, dataset = self.undertest.train(train_data, trainer_callbacks=[_DummyCallback, _DummyCallback]) + assert "fp" in examples + assert "fn" in examples + assert dataset["train"].num_rows == 48 + assert dataset["test"].num_rows == 12 + self.assertEqual(tracker.call.call_count, 2) diff --git a/tests/preprocessing/__init__.py b/tests/preprocessing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/preprocessing/test_cleaners.py b/tests/preprocessing/test_cleaners.py new file mode 100644 index 000000000..b879d9ee6 --- /dev/null +++ b/tests/preprocessing/test_cleaners.py @@ -0,0 +1,104 @@ +from medcat.preprocessing.cleaners import prepare_name +from medcat.config import Config +from medcat.cdb_maker import CDBMaker + +import logging, os + +import unittest + + +class BaseCDBMakerTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + config = Config() + config.general['log_level'] = logging.DEBUG + config.general["spacy_model"] = "en_core_web_md" + cls.maker = CDBMaker(config) + csvs = [ + os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'examples', 'cdb.csv'), + os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..', 'examples', 'cdb_2.csv') + ] + cls.cdb = cls.maker.prepare_csvs(csvs, full_build=True) + + +class BasePrepareNameTest(BaseCDBMakerTests): + raw_name = 'raw' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.do_prepare_name() + + # method called after setup, when raw_name has been specified + @classmethod + def do_prepare_name(cls) -> None: + cls.name = cls.cdb.config.general.separator.join(cls.raw_name.split()) + cls.names = prepare_name(cls.raw_name, cls.maker.pipe.spacy_nlp, {}, cls.cdb.config) + + def _dict_has_key_val_type(self, d: dict, key, val_type): + self.assertIn(key, d) + self.assertIsInstance(d[key], val_type) + + def _names_has_key_val_type(self, key, val_type): + self._dict_has_key_val_type(self.names, key, val_type) + + def test_result_has_name(self): + self._names_has_key_val_type(self.name, dict) + + def test_name_info_has_tokens(self): + self._dict_has_key_val_type(self.names[self.name], 'tokens', list) + + def test_name_info_has_words_as_tokens(self): + name_info = self.names[self.name] + tokens = name_info['tokens'] + for word in self.raw_name.split(): + with self.subTest(word): + self.assertIn(word, tokens) + + +class NamePreparationTests_OneLetter(BasePrepareNameTest): + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.raw_name = "a" + # the minimum name length is defined by the following config option + # if I don't set this to 1 here, I would see the tests fail + # that would be because the result from `prepare_names` would be empty + cls.cdb.config.cdb_maker.min_letters_required = 1 + super().do_prepare_name() + + +class NamePreparationTests_TwoLetters(BasePrepareNameTest): + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.raw_name = "an" + super().do_prepare_name() + + +class NamePreparationTests_MultiToken(BasePrepareNameTest): + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.raw_name = "this raw name" + super().do_prepare_name() + + +class NamePreparationTests_Empty(BaseCDBMakerTests): + """In case of an empty name, I would expect the names dict + returned by `prepare_name` to be empty. + """ + empty_raw_name = '' + + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + cls.names = prepare_name(cls.empty_raw_name, cls.maker.pipe.spacy_nlp, {}, cls.cdb.config) + + def test_names_dict_is_empty(self): + self.assertEqual(len(self.names), 0) + self.assertEqual(self.names, {}) diff --git a/tests/resources/ff_core_fake_dr/meta.json b/tests/resources/ff_core_fake_dr/meta.json new file mode 100644 index 000000000..fe9825db7 --- /dev/null +++ b/tests/resources/ff_core_fake_dr/meta.json @@ -0,0 +1,8 @@ +{ + "lang":"ff", + "name":"core_fake_dr", + "version":"3.1.0", + "description":"This is a FAKE model", + "author":"Fakio Martimus", + "spacy_version":">=3.1.0,<3.2.0" + } \ No newline at end of file diff --git a/tests/test_cat.py b/tests/test_cat.py index 0baa0d35d..bc49a2808 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -4,10 +4,14 @@ import unittest import tempfile import shutil +import logging +import contextlib from transformers import AutoTokenizer from medcat.vocab import Vocab -from medcat.cdb import CDB -from medcat.cat import CAT +from medcat.cdb import CDB, logger as cdb_logger +from medcat.cat import CAT, logger as cat_logger +from medcat.config import Config +from medcat.pipe import logger as pipe_logger from medcat.utils.checkpoint import Checkpoint from medcat.meta_cat import MetaCAT from medcat.config_meta_cat import ConfigMetaCAT @@ -15,11 +19,13 @@ class CATTests(unittest.TestCase): + SUPERVISED_TRAINING_JSON = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json") @classmethod def setUpClass(cls) -> None: cls.cdb = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat")) cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat")) + cls.vocab.make_unigram_table() cls.cdb.config.general.spacy_model = "en_core_web_md" cls.cdb.config.ner.min_name_len = 2 cls.cdb.config.ner.upper_case_limit_len = 3 @@ -36,7 +42,8 @@ def setUpClass(cls) -> None: @classmethod def tearDownClass(cls) -> None: cls.undertest.destroy_pipe() - shutil.rmtree(cls.meta_cat_dir) + if os.path.exists(cls.meta_cat_dir): + shutil.rmtree(cls.meta_cat_dir) def tearDown(self) -> None: self.cdb.config.annotation_output.include_text_in_output = False @@ -60,7 +67,7 @@ def test_multiprocessing(self): (2, ""), (3, None) ] - out = self.undertest.multiprocessing(in_data, nproc=1) + out = self.undertest.multiprocessing_batch_char_size(in_data, nproc=1) self.assertEqual(3, len(out)) self.assertEqual(1, len(out[1]['entities'])) @@ -73,7 +80,7 @@ def test_multiprocessing_pipe(self): (2, "The dog is sitting outside the house."), (3, "The dog is sitting outside the house."), ] - out = self.undertest.multiprocessing_pipe(in_data, nproc=2, return_dict=False) + out = self.undertest.multiprocessing_batch_docs_size(in_data, nproc=2, return_dict=False) self.assertTrue(type(out) == list) self.assertEqual(3, len(out)) self.assertEqual(1, out[0][0]) @@ -89,7 +96,7 @@ def test_multiprocessing_pipe_with_malformed_texts(self): (2, ""), (3, None), ] - out = self.undertest.multiprocessing_pipe(in_data, nproc=1, batch_size=1, return_dict=False) + out = self.undertest.multiprocessing_batch_docs_size(in_data, nproc=1, batch_size=1, return_dict=False) self.assertTrue(type(out) == list) self.assertEqual(3, len(out)) self.assertEqual(1, out[0][0]) @@ -105,7 +112,7 @@ def test_multiprocessing_pipe_return_dict(self): (2, "The dog is sitting outside the house."), (3, "The dog is sitting outside the house.") ] - out = self.undertest.multiprocessing_pipe(in_data, nproc=2, return_dict=True) + out = self.undertest.multiprocessing_batch_docs_size(in_data, nproc=2, return_dict=True) self.assertTrue(type(out) == dict) self.assertEqual(3, len(out)) self.assertEqual({'entities': {}, 'tokens': []}, out[1]) @@ -211,7 +218,7 @@ def test_get_entities_multi_texts_including_text(self): def test_train_supervised(self): nepochs = 3 num_of_documents = 27 - data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json") + data_path = self.SUPERVISED_TRAINING_JSON ckpt_dir_path = tempfile.TemporaryDirectory().name checkpoint = Checkpoint(dir_path=ckpt_dir_path, steps=1, max_to_keep=sys.maxsize) fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised(data_path, @@ -367,7 +374,7 @@ def test_load_model_pack(self): meta_cat = _get_meta_cat(self.meta_cat_dir) cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat]) full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name") - cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) + cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) self.assertTrue(isinstance(cat, CAT)) self.assertIsNotNone(cat.config.version.medcat_version) self.assertEqual(repr(cat._meta_cats), repr([meta_cat])) @@ -377,7 +384,7 @@ def test_load_model_pack_without_meta_cat(self): meta_cat = _get_meta_cat(self.meta_cat_dir) cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat]) full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name") - cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False) + cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False) self.assertTrue(isinstance(cat, CAT)) self.assertIsNotNone(cat.config.version.medcat_version) self.assertEqual(cat._meta_cats, []) @@ -385,9 +392,208 @@ def test_load_model_pack_without_meta_cat(self): def test_hashing(self): save_dir_path = tempfile.TemporaryDirectory() full_model_pack_name = self.undertest.create_model_pack(save_dir_path.name, model_pack_name="mp_name") - cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) + cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) self.assertEqual(cat.get_hash(), cat.config.version.id) + def test_print_stats(self): + # based on current JSON + EXP_FALSE_NEGATIVES = {'C0017168': 2, 'C0020538': 43, 'C0038454': 4, 'C0007787': 1, 'C0155626': 4, 'C0011860': 12, + 'C0042029': 6, 'C0010068': 2, 'C0007222': 1, 'C0027051': 6, 'C0878544': 1, 'C0020473': 12, + 'C0037284': 21, 'C0003864': 4, 'C0011849': 12, 'C0005686': 1, 'C0085762': 3, 'C0030920': 2, + 'C0854135': 3, 'C0004096': 4, 'C0010054': 10, 'C0497156': 10, 'C0011334': 2, 'C0018939': 1, + 'C1561826': 2, 'C0276289': 2, 'C0041834': 9, 'C0000833': 2, 'C0238792': 1, 'C0040034': 3, + 'C0035078': 5, 'C0018799': 5, 'C0042109': 1, 'C0035439': 1, 'C0035435': 1, 'C0018099': 1, + 'C1277187': 1, 'C0024117': 7, 'C0004238': 4, 'C0032227': 6, 'C0008679': 1, 'C0013146': 6, + 'C0032285': 1, 'C0002871': 7, 'C0149871': 4, 'C0442886': 1, 'C0022104': 1, 'C0034065': 5, + 'C0011854': 6, 'C1398668': 1, 'C0020676': 2, 'C1301700': 1, 'C0021167': 1, 'C0029456': 2, + 'C0011570': 10, 'C0009324': 1, 'C0011882': 1, 'C0020615': 1, 'C0242510': 2, 'C0033581': 2, + 'C0011168': 3, 'C0039082': 2, 'C0009241': 2, 'C1404970': 1, 'C0018524': 3, 'C0150063': 1, + 'C0917799': 1, 'C0178417': 1, 'C0033975': 1, 'C0011253': 1, 'C0018802': 8, 'C0022661': 4, + 'C0017658': 1, 'C0023895': 2, 'C0003123': 1, 'C0041582': 4, 'C0085096': 4, 'C0403447': 2, + 'C2363741': 2, 'C0457949': 1, 'C0040336': 1, 'C0037315': 2, 'C0024236': 3, 'C0442874': 1, + 'C0028754': 4, 'C0520679': 5, 'C0028756': 2, 'C0029408': 5, 'C0409959': 2, 'C0018801': 1, + 'C3844825': 1, 'C0022660': 2, 'C0005779': 4, 'C0011175': 1, 'C0018965': 4, 'C0018889': 1, + 'C0022354': 2, 'C0033377': 1, 'C0042769': 1, 'C0035222': 1, 'C1456868': 2, 'C1145670': 1, + 'C0018790': 1, 'C0263746': 1, 'C0206172': 1, 'C0021400': 1, 'C0243026': 1, 'C0020443': 1, + 'C0001883': 1, 'C0031350': 1, 'C0010709': 4, 'C1565489': 7, 'C3489393': 1, 'C0005586': 2, + 'C0158288': 5, 'C0700594': 4, 'C0158266': 3, 'C0006444': 2, 'C0024003': 1} + with open(self.SUPERVISED_TRAINING_JSON) as f: + data = json.load(f) + (fps, fns, tps, + cui_prec, cui_rec, cui_f1, + cui_counts, examples) = self.undertest._print_stats(data) + self.assertEqual(fps, {}) + self.assertEqual(fns, EXP_FALSE_NEGATIVES) + self.assertEqual(tps, {}) + self.assertEqual(cui_prec, {}) + self.assertEqual(cui_rec, {}) + self.assertEqual(cui_f1, {}) + self.assertEqual(len(cui_counts), 136) + self.assertEqual(len(examples), 3) + + def _assertNoLogs(self, logger: logging.Logger, level: int): + if hasattr(self, 'assertNoLogs'): + return self.assertNoLogs(logger=logger, level=level) + else: + return self.__assertNoLogs(logger=logger, level=level) + + @contextlib.contextmanager + def __assertNoLogs(self, logger: logging.Logger, level: int): + try: + with self.assertLogs(logger, level) as captured_logs: + yield + except AssertionError: + return + if captured_logs: + raise AssertionError("Logs were found: {}".format(captured_logs)) + + def assertLogsDuringAddAndTrainConcept(self, logger: logging.Logger, log_level, + name: str, name_status: str, nr_of_calls: int): + cui = 'CUI-%d'%(hash(name) + id(name)) + with (self.assertLogs(logger=logger, level=log_level) + if nr_of_calls == 1 + else self._assertNoLogs(logger=logger, level=log_level)): + self.undertest.add_and_train_concept(cui, name, name_status=name_status) + + def test_add_and_train_concept_cat_nowarn_long_name(self): + long_name = 'a very long name' + self.assertLogsDuringAddAndTrainConcept(cat_logger, logging.WARNING, name=long_name, name_status='', nr_of_calls=0) + + def test_add_and_train_concept_cdb_nowarn_long_name(self): + long_name = 'a very long name' + self.assertLogsDuringAddAndTrainConcept(cdb_logger, logging.WARNING, name=long_name, name_status='', nr_of_calls=0) + + def test_add_and_train_concept_cat_nowarn_short_name_not_pref(self): + short_name = 'a' + self.assertLogsDuringAddAndTrainConcept(cat_logger, logging.WARNING, name=short_name, name_status='', nr_of_calls=0) + + def test_add_and_train_concept_cdb_nowarn_short_name_not_pref(self): + short_name = 'a' + self.assertLogsDuringAddAndTrainConcept(cdb_logger, logging.WARNING, name=short_name, name_status='', nr_of_calls=0) + + def test_add_and_train_concept_cat_warns_short_name(self): + short_name = 'a' + self.assertLogsDuringAddAndTrainConcept(cat_logger, logging.WARNING, name=short_name, name_status='P', nr_of_calls=1) + + def test_add_and_train_concept_cdb_warns_short_name(self): + short_name = 'a' + self.assertLogsDuringAddAndTrainConcept(cdb_logger, logging.WARNING, name=short_name, name_status='P', nr_of_calls=1) + + +class GetEntitiesWithStopWords(unittest.TestCase): + # NB! The order in which the different CDBs are created + # is important here since the way that the stop words are + # set is class-based, it creates the side effect of having + # the same stop words the next time around + # regardless of whether or not they should've been set + + @classmethod + def setUpClass(cls) -> None: + cls.cdb1 = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat")) + cls.cdb2 = CDB.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat")) + cls.vocab = Vocab.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat")) + cls.vocab.make_unigram_table() + cls.cdb1.config.general.spacy_model = "en_core_web_md" + cls.cdb1.config.ner.min_name_len = 2 + cls.cdb1.config.ner.upper_case_limit_len = 3 + cls.cdb1.config.general.spell_check = True + cls.cdb1.config.linking.train_count_threshold = 10 + cls.cdb1.config.linking.similarity_threshold = 0.3 + cls.cdb1.config.linking.train = True + cls.cdb1.config.linking.disamb_length_limit = 5 + cls.cdb1.config.general.full_unlink = True + cls.cdb2.config = Config.from_dict(cls.cdb1.config.asdict()) + # the regular CAT without stopwords + cls.no_stopwords = CAT(cdb=cls.cdb1, config=cls.cdb1.config, vocab=cls.vocab, meta_cats=[]) + # this (the following two lines) + # needs to be done before initialising the CAT + # since that initialises the pipe + cls.cdb2.config.preprocessing.stopwords = {"stop", "words"} + cls.cdb2.config.preprocessing.skip_stopwords = True + # the CAT that skips the stopwords + cls.w_stopwords = CAT(cdb=cls.cdb2, config=cls.cdb2.config, vocab=cls.vocab, meta_cats=[]) + + def test_stopwords_are_skipped(self, text: str = "second words csv"): + # without stopwords no entities are captured + # with stopwords, the `second words csv` entity is captured + doc_no_stopwords = self.no_stopwords(text) + doc_w_stopwords = self.w_stopwords(text) + self.assertGreater(len(doc_w_stopwords._.ents), len(doc_no_stopwords._.ents)) + + +class ModelWithTwoConfigsLoadTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples") + cdb = CDB.load(os.path.join(cls.model_path, "cdb.dat")) + # save config next to the CDB + cls.config_path = os.path.join(cls.model_path, 'config.json') + cdb.config.save(cls.config_path) + + + @classmethod + def tearDownClass(cls) -> None: + # REMOVE config next to the CDB + os.remove(cls.config_path) + + def test_loading_model_pack_with_cdb_config_and_config_json_raises_exception(self): + with self.assertRaises(ValueError): + CAT.load_model_pack(self.model_path) + + +class ModelLoadsUnreadableSpacy(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.temp_dir = tempfile.TemporaryDirectory() + model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples") + cdb = CDB.load(os.path.join(model_path, 'cdb.dat')) + cdb.config.general.spacy_model = os.path.join(cls.temp_dir.name, "en_core_web_md") + # save CDB in new location + cdb.save(os.path.join(cls.temp_dir.name, 'cdb.dat')) + # save config in new location + cdb.config.save(os.path.join(cls.temp_dir.name, 'config.json')) + # copy vocab into new location + vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat") + cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat') + shutil.copyfile(vocab_path, cls.vocab_path) + + @classmethod + def tearDownClass(cls) -> None: + # REMOVE temp dir + cls.temp_dir.cleanup() + + def test_loads_without_specified_spacy_model(self): + with self.assertLogs(logger=pipe_logger, level=logging.WARNING): + cat = CAT.load_model_pack(self.temp_dir.name) + self.assertTrue(isinstance(cat, CAT)) + + +class ModelWithZeroConfigsLoadTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat") + cdb = CDB.load(cdb_path) + vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat") + # copy the CDB and vocab to a temp dir + cls.temp_dir = tempfile.TemporaryDirectory() + cls.cdb_path = os.path.join(cls.temp_dir.name, 'cdb.dat') + cdb.save(cls.cdb_path) # save without internal config + cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat') + shutil.copyfile(vocab_path, cls.vocab_path) + + + @classmethod + def tearDownClass(cls) -> None: + # REMOVE temp dir + cls.temp_dir.cleanup() + + def test_loading_model_pack_without_any_config_raises_exception(self): + with self.assertRaises(ValueError): + CAT.load_model_pack(self.temp_dir.name) + def _get_meta_cat(meta_cat_dir): config = ConfigMetaCAT() diff --git a/tests/test_cdb.py b/tests/test_cdb.py index 96425bc8c..1be74edfe 100644 --- a/tests/test_cdb.py +++ b/tests/test_cdb.py @@ -6,6 +6,7 @@ import numpy as np from medcat.config import Config from medcat.cdb_maker import CDBMaker +from medcat.cdb import CDB class CDBTests(unittest.TestCase): @@ -21,11 +22,21 @@ def setUp(self) -> None: cdb_2_csv = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_2.csv") self.tmp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "tmp") os.makedirs(self.tmp_dir, exist_ok=True) + # resetting the CDB because otherwise the CDBMaker + # will refer to and modify the same instance of the CDB + # and this can (and does!) create side effects + CDBTests.cdb_maker.reset_cdb() self.undertest = CDBTests.cdb_maker.prepare_csvs([cdb_csv, cdb_2_csv], full_build=True) def tearDown(self) -> None: shutil.rmtree(self.tmp_dir) + def test_setup_changes_cdb(self): + id1 = id(self.undertest) + self.setUp() + id2 = id(self.undertest) + self.assertNotEqual(id1, id2) + def test_name2cuis(self): self.assertEqual({ 'second~csv': ['C0000239'], @@ -53,6 +64,13 @@ def test_save_and_load(self): self.undertest.save(f.name) self.undertest.load(f.name) + def test_load_has_no_config(self): + with tempfile.NamedTemporaryFile() as f: + self.undertest.save(f.name) + cdb = CDB.load(f.name) + self.assertFalse(cdb._config_from_file) + + def test_save_async_and_load(self): with tempfile.NamedTemporaryFile() as f: asyncio.run(self.undertest.save_async(f.name)) diff --git a/tests/test_config.py b/tests/test_config.py index 2f9cd5a84..ce6ed76eb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,7 +1,7 @@ import unittest import pickle import tempfile -from medcat.config import Config, MixingConfig, VersionInfo, General +from medcat.config import Config, MixingConfig, VersionInfo, General, LinkingFilters from pydantic import ValidationError import os @@ -179,6 +179,54 @@ def test_from_dict(self): config = Config.from_dict({"key": "value"}) self.assertEqual("value", config.key) + def test_config_no_hash_before_get(self): + config = Config() + self.assertIsNone(config.hash) + + def test_config_has_hash_after_get(self): + config = Config() + config.get_hash() + self.assertIsNotNone(config.hash) + + def test_config_hash_recalc_same_def(self): + config = Config() + h1 = config.get_hash() + h2 = config.get_hash() + self.assertEqual(h1, h2) + + def test_config_hash_changes_after_change(self): + config = Config() + h1 = config.get_hash() + config.linking.filters.cuis = {"a", "b"} + h2 = config.get_hash() + self.assertNotEqual(h1, h2) + + def test_config_hash_recalc_same_changed(self): + config = Config() + config.linking.filters.cuis = {"a", "b"} + h1 = config.get_hash() + h2 = config.get_hash() + self.assertEqual(h1, h2) + + +class ConfigLinkingFiltersTests(unittest.TestCase): + + def test_allows_empty_dict_for_cuis(self): + lf = LinkingFilters(cuis={}) + self.assertIsNotNone(lf) + + def test_empty_dict_converted_to_empty_set(self): + lf = LinkingFilters(cuis={}) + self.assertEqual(lf.cuis, set()) + + def test_not_allow_nonempty_dict_for_cuis(self): + with self.assertRaises(ValidationError): + LinkingFilters(cuis={"KEY": "VALUE"}) + + def test_not_allow_empty_dict_for_cuis_exclude(self): + with self.assertRaises(ValidationError): + LinkingFilters(cuis_exclude={}) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_pipe.py b/tests/test_pipe.py index e6da42898..8ce47cfb5 100644 --- a/tests/test_pipe.py +++ b/tests/test_pipe.py @@ -28,6 +28,7 @@ def setUpClass(cls) -> None: cls.config.ner['max_skip_tokens'] = 1 cls.config.ner['upper_case_limit_len'] = 4 cls.config.linking['disamb_length_limit'] = 2 + cls.config.preprocessing.stopwords = {'stop', 'words'} cls.cdb = CDB(config=cls.config) downloader = VocabDownloader() @@ -42,7 +43,7 @@ def setUpClass(cls) -> None: _tokenizer = TokenizerWrapperBERT(hf_tokenizers=AutoTokenizer.from_pretrained("bert-base-uncased")) cls.meta_cat = MetaCAT(tokenizer=_tokenizer) - cls.text = "CDB - I was running and then Movar Virus attacked and CDb" + cls.text = "stop of CDB - I was running and then Movar Virus attacked and CDb" cls.undertest = Pipe(tokenizer=spacy_split_all, config=cls.config) @classmethod @@ -81,6 +82,12 @@ def test_add_meta_cat(self): PipeTests.undertest.add_meta_cat(PipeTests.meta_cat) self.assertEqual(PipeTests.meta_cat.name, Language.get_factory_meta(PipeTests.meta_cat.name).factory) + + def test_stopwords_loading(self): + self.assertEqual(PipeTests.undertest._nlp.Defaults.stop_words, PipeTests.config.preprocessing.stopwords) + doc = PipeTests.undertest(PipeTests.text) + self.assertEqual(doc[0].is_stop, True) + self.assertEqual(doc[1].is_stop, False) def test_batch_multi_process(self): PipeTests.undertest.add_tagger(tagger=tag_skip_and_punct, additional_fields=["is_punct"]) diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py index f0cc75de1..c2c44da16 100644 --- a/tests/utils/saving/test_serialization.py +++ b/tests/utils/saving/test_serialization.py @@ -87,7 +87,7 @@ def test_dill_to_json(self): model_pack_folder = os.path.join( self.json_model_pack.name, model_pack_path) json_path = os.path.join(model_pack_folder, "*.json") - jsons = glob.glob(json_path) + jsons = [fn for fn in glob.glob(json_path) if not fn.endswith("config.json")] # there is also a model_card.json # but nothing for cui2many or name2many # so can remove the length of ONE2MANY diff --git a/tests/utils/test_cdb_utils.py b/tests/utils/test_cdb_utils.py new file mode 100644 index 000000000..777a2506b --- /dev/null +++ b/tests/utils/test_cdb_utils.py @@ -0,0 +1,43 @@ +import unittest +import numpy as np +from tests.helper import ForCDBMerging +from medcat.utils.cdb_utils import merge_cdb + + +class CDBMergeTests(unittest.TestCase): + + @classmethod + def setUpClass(cls): + to_merge = ForCDBMerging() + cls.cdb1 = to_merge.cdb1 + cls.cdb2 = to_merge.cdb2 + cls.merged_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2) + cls.overwrite_cdb = merge_cdb(cdb1=cls.cdb1, cdb2=cls.cdb2, overwrite_training=2, full_build=True) + cls.zeroes = np.zeros(shape=(1,300)) + cls.ones = np.ones(shape=(1,300)) + + def test_merge_inserts(self): + self.assertIn("test", self.merged_cdb.cui2names["C0006826"]) + self.assertIn("test_name", self.merged_cdb.cui2snames["C0006826"]) + self.assertEqual("Cancer", self.merged_cdb.cui2preferred_name["C0006826"]) + + def test_no_full_build(self): + self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict()) + self.assertEqual(self.merged_cdb.addl_info["cui2ontologies"], dict()) + + def test_full_build(self): + for cui in self.cdb2.cui2names: + self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"}) + self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description") + + def test_vector_merge(self): + self.assertTrue(np.array_equal(self.zeroes, self.merged_cdb.cui2context_vectors["UniqueTest"]["short"])) + for i, cui in enumerate(self.cdb1.cui2names): + self.assertTrue(np.array_equal(self.merged_cdb.cui2context_vectors[cui]["short"], np.divide(self.ones, i+2))) + + + def test_overwrite_parameter(self): + for cui in self.cdb2.cui2names: + self.assertTrue(np.array_equal(self.overwrite_cdb.cui2context_vectors[cui]["short"], self.zeroes)) + self.assertEqual(self.overwrite_cdb.addl_info["cui2ontologies"][cui], {"test_ontology"}) + self.assertEqual(self.overwrite_cdb.addl_info["cui2description"][cui], "test_description") diff --git a/tests/utils/test_hashing.py b/tests/utils/test_hashing.py index 99c10b153..0fd6b5891 100644 --- a/tests/utils/test_hashing.py +++ b/tests/utils/test_hashing.py @@ -1,4 +1,5 @@ import os +from typing import Optional import tempfile import unittest import unittest.mock @@ -6,6 +7,7 @@ from medcat.cat import CAT from medcat.cdb import CDB from medcat.vocab import Vocab +from medcat.config import Config class CDBHashingTests(unittest.TestCase): @@ -30,6 +32,43 @@ def test_CDB_hash_saves_on_disk(self): self.assertEqual(h, cdb._hash) +class CDBHashingWithConfigTests(unittest.TestCase): + temp_dir = tempfile.TemporaryDirectory() + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + # ensure config has hash + h = cls.cdb.get_hash() + cls.config = cls.config_copy(cls.cdb.config) + cls._config_hash = cls.cdb.config.hash + + @classmethod + def config_copy(cls, config: Optional[Config] = None) -> Config: + if config is None: + config = cls.config + return Config(**config.asdict()) + + def setUp(self) -> None: + # reset config + self.cdb.config = self.config_copy() + # reset config hash + self.cdb._config_hash = self._config_hash + self.cdb.config.hash = self._config_hash + + def test_CDB_same_hash_no_need_recalc(self): + self.assertFalse(self.cdb._should_recalc_hash(force_recalc=False)) + + def test_CDB_hash_recalc_if_no_config_hash(self): + self.cdb._config_hash = None + self.assertTrue(self.cdb._should_recalc_hash(force_recalc=False)) + + def test_CDB_hash_recalc_after_config_change(self): + self.cdb.config.linking.filters.cuis = {"a", "b", "c"} + self.assertTrue(self.cdb._should_recalc_hash(force_recalc=False)) + + class BaseCATHashingTests(unittest.TestCase): @classmethod @@ -75,8 +114,14 @@ def test_no_changes_recalc_same(self): class CATHashingTestsWithoutChange(CATHashingTestsWithFakeHash): - def test_no_changes_no_calc(self): + def setUp(self) -> None: + self._calculate_hash = self.undertest.cdb.calculate_hash + # make sure the hash exists + self.undertest.cdb._config_hash = self.undertest.cdb.config.get_hash() + self.undertest.cdb.get_hash() self.undertest.cdb.calculate_hash = unittest.mock.Mock() + + def test_no_changes_no_calc(self): hash = self.undertest.get_hash() self.assertIsInstance(hash, str) self.undertest.cdb.calculate_hash.assert_not_called() @@ -90,7 +135,7 @@ class CATHashingTestsWithChange(CATHashingTestsWithFakeHash): def test_when_changes_do_calc(self): with unittest.mock.patch.object(CDB, 'calculate_hash', return_value='abcd1234') as patch_method: - self.undertest.cdb.add_concept(**self.concept_kwargs) + self.undertest.cdb._add_concept(**self.concept_kwargs) hash = self.undertest.get_hash() self.assertIsInstance(hash, str) patch_method.assert_called() @@ -106,10 +151,10 @@ def test_default_cdb_not_dirty(self): self.assertFalse(self.undertest.cdb.is_dirty) def test_after_add_concept_is_dirty(self): - self.undertest.cdb.add_concept(**self.concept_kwargs) + self.undertest.cdb._add_concept(**self.concept_kwargs) self.assertTrue(self.undertest.cdb.is_dirty) def test_after_recalc_not_dirty(self): - self.undertest.cdb.add_concept(**self.concept_kwargs) + self.undertest.cdb._add_concept(**self.concept_kwargs) self.undertest.get_hash() self.assertFalse(self.undertest.cdb.is_dirty) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py new file mode 100644 index 000000000..6703ce91a --- /dev/null +++ b/tests/utils/test_helpers.py @@ -0,0 +1,24 @@ +from medcat.utils.helpers import has_spacy_model, ensure_spacy_model +from medcat.pipe import DEFAULT_SPACY_MODEL + +import unittest +import subprocess + + +class HasSpacyModelTests(unittest.TestCase): + + def test_no_rubbish_model(self, model_name='rubbish_model'): + self.assertFalse(has_spacy_model(model_name)) + + def test_has_def_model(self, model_name=DEFAULT_SPACY_MODEL): + self.assertTrue(has_spacy_model(model_name)) + + +class EnsureSpacyModelTests(unittest.TestCase): + + def test_fails_rubbish_model(self, model_name='rubbish_model'): + with self.assertRaises(subprocess.CalledProcessError): + ensure_spacy_model(model_name) + + def test_success_def_model(self, model_name=DEFAULT_SPACY_MODEL): + ensure_spacy_model(model_name) diff --git a/tests/utils/test_spacy_compatibility.py b/tests/utils/test_spacy_compatibility.py new file mode 100644 index 000000000..5cf0dd03e --- /dev/null +++ b/tests/utils/test_spacy_compatibility.py @@ -0,0 +1,302 @@ +import medcat.utils.spacy_compatibility as module_under_test +from medcat.utils.spacy_compatibility import _is_spacy_model_folder, _find_spacy_model_folder +from medcat.utils.spacy_compatibility import get_installed_spacy_version, get_installed_model_version +from medcat.utils.spacy_compatibility import _get_name_and_meta_of_spacy_model_in_medcat_modelpack +from medcat.utils.spacy_compatibility import get_name_and_version_of_spacy_model_in_medcat_modelpack +from medcat.utils.spacy_compatibility import _is_spacy_version_within_range +from medcat.utils.spacy_compatibility import medcat_model_pack_has_compatible_spacy_model +from medcat.utils.spacy_compatibility import is_older_spacy_version +from medcat.utils.spacy_compatibility import medcat_model_pack_has_semi_compatible_spacy_model + +import unittest + +from typing import Callable +import random +import string +import tempfile +import os +from contextlib import contextmanager + + +FAKE_SPACY_MODEL_NAME = "ff_core_fake_dr" +FAKE_SPACY_MODEL_DIR = os.path.join("tests", "resources", FAKE_SPACY_MODEL_NAME) +FAKE_MODELPACK_MODEL_DIR = os.path.join(FAKE_SPACY_MODEL_DIR, '..') + + +class SpacyModelFolderIdentifierTests(unittest.TestCase): + expected_working_spacy_models = [ + "en_core_sci_sm", + "en_core_web_sm", + "en_core_web_md", + "en_core_web_lg", + "en_core_web_trf", + "nl_core_news_sm", + "nl_core_news_md", + "nl_core_news_lg", + ] + # the following were used in medcat models created prior + # to v1.2.4 + expected_working_legacy_names = [ + "spacy_model" + ] + + def test_works_expected_models(self): + for model_name in self.expected_working_spacy_models: + with self.subTest(model_name): + self.assertTrue(_is_spacy_model_folder(model_name)) + + def test_works_legacy_models(self): + for model_name in self.expected_working_legacy_names: + with self.subTest(model_name): + self.assertTrue(_is_spacy_model_folder(model_name)) + + def test_works_fill_path(self): + for model_name in self.expected_working_legacy_names: + full_folder_path = os.path.join("some", "folder", "structure", model_name) + with self.subTest(full_folder_path): + self.assertTrue(_is_spacy_model_folder(model_name)) + + def get_all_garbage(self) -> list: + """Generate garbage "spacy names". + + Returns: + List[str]: Some random strings that shouldn't be spacy models. + """ + my_examples = ["garbage_in_and_out", "meta_Presence", "something"] + true_randoms_N10 = [''.join(random.choices(string.ascii_uppercase + string.digits, k=10)) for _ in range(10)] + true_randoms_N20 = [''.join(random.choices(string.ascii_uppercase + string.digits, k=20)) for _ in range(10)] + return my_examples + true_randoms_N10 + true_randoms_N20 + + def test_does_not_work_grabage(self): + for garbage in self.get_all_garbage(): + with self.subTest(garbage): + self.assertFalse(_is_spacy_model_folder(garbage)) + + +class FindSpacyFolderJustOneFolderEmptyFilesTests(unittest.TestCase): + + @classmethod + def setUpClass(cls, spacy_folder_name='en_core_web_md') -> None: + # setup temp folder + cls.temp_folder = tempfile.TemporaryDirectory() + cls.fake_modelpack_folder_name = cls.temp_folder.name + # create spacy folder + cls.spacy_folder = os.path.join(cls.fake_modelpack_folder_name, spacy_folder_name) + os.makedirs(cls.spacy_folder) + # create 2 empty files + filenames = ["file1.dat", "file2.json"] + filenames = [os.path.join(cls.fake_modelpack_folder_name, fn) for fn in filenames] + for fn in filenames: + with open(fn, 'w'): + pass # open and write empty file + + @classmethod + def tearDownClass(cls) -> None: + cls.temp_folder.cleanup() + + def test_finds(self): + found_folder_path = _find_spacy_model_folder(self.fake_modelpack_folder_name) + self.assertEqual(found_folder_path, self.spacy_folder) + + +class FindSpacyFolderMoreFoldersEmptyFilesTests(FindSpacyFolderJustOneFolderEmptyFilesTests): + + @classmethod + def setUpClass(cls, spacy_folder_name='en_core_web_md') -> None: + super().setUpClass(spacy_folder_name) + # add a few folders + folder_names = ["meta_Presence", "garbage_in_garbage_out"] + folder_names = [os.path.join(cls.fake_modelpack_folder_name, fn) for fn in folder_names] + for folder in folder_names: + os.makedirs(folder) + + +class SpacyVersionTests(unittest.TestCase): + + def test_version_received(self): + installed = get_installed_spacy_version() + import spacy + expected = spacy.__version__ + self.assertEqual(installed, expected) + + +class InstalledVersionChecker(unittest.TestCase): + + def test_existing(self, model_name: str = 'en_core_web_md'): + version = get_installed_model_version(model_name) + self.assertIsInstance(version, str) + self.assertNotEqual(version, "N/A") + + def test_non_existing(self, model_name: str = 'en_core_web_lg'): + version = get_installed_model_version(model_name) + self.assertIsInstance(version, str) + self.assertEqual(version, "N/A") + + +class GetSpacyModelInfoTests(unittest.TestCase): + expected_version = "3.1.0" + + @classmethod + def setUpClass(cls) -> None: + cls.name, cls.info = _get_name_and_meta_of_spacy_model_in_medcat_modelpack(FAKE_MODELPACK_MODEL_DIR) + + def test_reads_name(self): + self.assertEqual(self.name, FAKE_SPACY_MODEL_NAME) + + def test_reads_info(self): + self.assertIsInstance(self.info, dict) + self.assertTrue(self.info) # not empty + + +class GetSpacyModelVersionTests(GetSpacyModelInfoTests): + expected_spacy_version = ">=3.1.0,<3.2.0" + + @classmethod + def setUpClass(cls) -> None: + (cls.name, + cls.version, + cls.spacy_version) = get_name_and_version_of_spacy_model_in_medcat_modelpack(FAKE_MODELPACK_MODEL_DIR) + + def test_name_correct(self): + self.assertEqual(self.name, FAKE_SPACY_MODEL_NAME) + + def test_version_correct(self): + self.assertEqual(self.version, self.expected_version) + + def test_spacy_version_correct(self): + self.assertEqual(self.spacy_version, self.expected_spacy_version) + + +@contextmanager +def custom_spacy_version(mock_version: str): + """Changes the apparently installed spacy version. + """ + print(f"Mocking spacy version to: {mock_version}") + _old_method = module_under_test.get_installed_spacy_version + module_under_test.get_installed_spacy_version = lambda: mock_version + yield mock_version + print("Returning regular spacy version getter") + module_under_test.get_installed_spacy_version = _old_method + + +class VersionMockBaseTests(unittest.TestCase): + + def base_subtest_for(self, target_fun: Callable[[str], bool], + spacy_model_range: str, spacy_version: str, should_work: bool) -> None: + with self.subTest(spacy_version): + if should_work: + self.assertTrue(target_fun(spacy_model_range)) + else: + self.assertFalse(target_fun(spacy_model_range)) + + def base_check_version(self, target_fun: Callable[[str], bool], + spacy_model_range: str, spacy_version: str, should_work: bool = True) -> None: + with custom_spacy_version(spacy_version): + self.base_subtest_for(target_fun, spacy_model_range, spacy_version, should_work) + +class SpacyVersionMockBaseTests(VersionMockBaseTests): + + def _subtest_for(self, spacy_model_range: str, spacy_version: str, should_work: bool) -> None: + return self.base_subtest_for(_is_spacy_version_within_range, + spacy_model_range, spacy_version, should_work) + + def _check_version(self, spacy_model_range: str, spacy_version: str, should_work: bool = True) -> None: + return self.base_check_version(_is_spacy_version_within_range, + spacy_model_range, spacy_version, should_work) + + +class SpacyVersionInRangeOldRangeTests(SpacyVersionMockBaseTests): + """This is for versions before 1.7.0. + Those versions used to have spacy constraints of 'spacy<3.1.4,>=3.1.0' + and as such, they used v3.1.0 of en_core_web_md. + """ + spacy_model_range = ">=3.1.0,<3.2.0" # model range for en_core_web_md-3.1.0 + useful_spacy_versions = ["3.1.0", "3.1.2", "3.1.3"] + unsupported_spacy_versions = ["3.2.0", "3.5.3", "3.6.0"] + + def test_works_in_range(self): + for spacy_version in self.useful_spacy_versions: + self._check_version(self.spacy_model_range, spacy_version, should_work=True) + + def test_not_suitable_outside_range(self): + for spacy_version in self.unsupported_spacy_versions: + self._check_version(self.spacy_model_range, spacy_version, should_work=False) + + +class SpacyVersionInRangeNewRangeTests(SpacyVersionInRangeOldRangeTests): + """This is for versions AFTER (and includring) 1.7.0. + Those versions used to have spacy constraints of 'spacy>=3.1.0' + and as such, we use v3.4.0 of en_core_web_md. + + In this setup, generally (in GHA at 14.12.2023) + the spacy version for python version: + 3.8 -> spacy-3.7.2 + 3.9 -> spacy-3.7.2 + 3.10 -> spacy-3.7.2 + 3.11 -> spacy-3.7.2 + Alongside the `en_core_web_md-3.4.0` is installed. + It technically has the compatibility of >=3.4.0,<3.5.0. + But practically, I've seen no issues with spacy==3.7.2. + """ + spacy_model_range = ">=3.1.0" # model range for medcat>=1.7.0 + useful_spacy_versions = ["3.1.0", "3.1.2", "3.1.3", + "3.7.2", "3.6.3"] + unsupported_spacy_versions = ["3.0.0"] + + +class ModelPackHasCompatibleSpacyRangeTests(unittest.TestCase): + test_spacy_version = "3.1.0" + + def test_is_in_range(self): + with custom_spacy_version(self.test_spacy_version): + b = medcat_model_pack_has_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR) + self.assertTrue(b) + +class ModelPackHasInCompatibleSpacyRangeTests(unittest.TestCase): + test_spacy_version = "3.2.0" + + def test_is_in_range(self): + with custom_spacy_version(self.test_spacy_version): + b = medcat_model_pack_has_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR) + self.assertFalse(b) + + +class IsOlderSpacyVersionTests(VersionMockBaseTests): + test_spacy_version = "3.4.4" + expected_older = ["3.1.0", "3.2.0", "3.3.0", "3.4.0"] + expected_newer = ["3.5.0", "3.6.0", "3.7.1"] + + def _check_version(self, model_version: str, should_work: bool = True) -> None: + self.base_check_version(is_older_spacy_version, model_version, self.test_spacy_version, should_work) + + def test_older_works(self): + for model_version in self.expected_older: + self._check_version(model_version, should_work=True) + + def test_newer_fails(self): + for model_version in self.expected_newer: + self._check_version(model_version, should_work=False) + + +class HasSemiCompatibleSpacyModelTests(unittest.TestCase): + # model version on file is 3.1.0, + # and spacy_version range >=3.1.0,<3.2.0" + good_spacy_version = "3.1.3" + semi_good_spacy_version = "3.4.4" # newer than the model + bad_spacy_version = "3.0.0" # older than the model + + def run_subtest(self, spacy_version: str, should_work: bool) -> None: + with custom_spacy_version(spacy_version): + if should_work: + self.assertTrue(medcat_model_pack_has_semi_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR)) + else: + self.assertFalse(medcat_model_pack_has_semi_compatible_spacy_model(FAKE_MODELPACK_MODEL_DIR)) + + def test_works_compatible_spacy_version(self): + self.run_subtest(self.good_spacy_version, should_work=True) + + def test_works_semi_compatible_spacy_version(self): + self.run_subtest(self.semi_good_spacy_version, should_work=True) + + def test_fails_incompatible_spacy_version(self): + self.run_subtest(self.bad_spacy_version, should_work=False) diff --git a/webapp/webapp/requirements.txt b/webapp/webapp/requirements.txt index a4b7827ad..ce68f853d 100644 --- a/webapp/webapp/requirements.txt +++ b/webapp/webapp/requirements.txt @@ -1,6 +1,6 @@ -Django==3.2.20 +Django==3.2.23 django-dbbackup==4.0.0b0 django-storages[boto3]==1.12.3 django-cron==0.5.1 medcat==1.2.7 -urllib3==1.26.5 +urllib3==1.26.18