diff --git a/medcat/cat.py b/medcat/cat.py index f49a25022..d3003b24b 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -2,7 +2,6 @@ import glob import shutil import pickle -import traceback import json import logging import math @@ -24,7 +23,6 @@ from medcat.pipe import Pipe from medcat.preprocessing.taggers import tag_skip_and_punct from medcat.cdb import CDB -from medcat.utils.matutils import intersect_nonempty_set from medcat.utils.data_utils import make_mc_train_test, get_false_positives from medcat.utils.normalizers import BasicSpellChecker from medcat.utils.checkpoint import Checkpoint, CheckpointConfig, CheckpointManager @@ -32,15 +30,16 @@ from medcat.utils.hasher import Hasher from medcat.ner.vocab_based_ner import NER from medcat.linking.context_based_linker import Linker -from medcat.utils.filters import get_project_filters from medcat.preprocessing.cleaners import prepare_name from medcat.meta_cat import MetaCAT from medcat.utils.meta_cat.data_utils import json_to_fake_spacy -from medcat.config import Config, LinkingFilters +from medcat.config import Config from medcat.vocab import Vocab from medcat.utils.decorators import deprecated from medcat.ner.transformers_ner import TransformersNER from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY +from medcat.stats.stats import get_stats +from medcat.utils.filters import set_project_filters logger = logging.getLogger(__name__) # separate logger from the package-level one @@ -442,7 +441,8 @@ def _print_stats(self, use_overlaps: bool = False, use_cui_doc_limit: bool = False, use_groups: bool = False, - extra_cui_filter: Optional[Set] = None) -> Tuple: + extra_cui_filter: Optional[Set] = None, + do_print: bool = True) -> Tuple: """TODO: Refactor and make nice Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP. @@ -482,204 +482,12 @@ def _print_stats(self, Number of occurrence for each CUI. examples (dict): Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][]. + do_print (bool): + Whether to print stats out. Defaults to True. """ - tp = 0 - fp = 0 - fn = 0 - fps: Dict = {} - fns: Dict = {} - tps: Dict = {} - cui_prec: Dict = {} - cui_rec: Dict = {} - cui_f1: Dict = {} - cui_counts: Dict = {} - examples: Dict = {'fp': {}, 'fn': {}, 'tp': {}} - - fp_docs: Set = set() - fn_docs: Set = set() - - orig_filters = self.config.linking.filters.copy_of() - local_filters = self.config.linking.filters - for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False): - local_filters.cuis = set() - - # Add extra filter if set - self._set_project_filters(local_filters, project, extra_cui_filter, use_project_filters) - - for dind, doc in tqdm( - enumerate(project["documents"]), - desc="Stats document", - total=len(project["documents"]), - leave=False, - ): - anns = self._get_doc_annotations(doc) - - # Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still - if use_cui_doc_limit: - _cuis = set([ann['cui'] for ann in anns]) - if _cuis: - local_filters.cuis = intersect_nonempty_set(_cuis, extra_cui_filter) - else: - local_filters.cuis = {'empty'} - - spacy_doc: Doc = self(doc['text']) # type: ignore - - if use_overlaps: - p_anns = spacy_doc._.ents - else: - p_anns = spacy_doc.ents - - anns_norm = [] - anns_norm_neg = [] - anns_examples = [] - anns_norm_cui = [] - for ann in anns: - cui = ann['cui'] - if local_filters.check_filters(cui): - if use_groups: - cui = self.cdb.addl_info['cui2group'].get(cui, cui) - - if ann.get('validated', True) and (not ann.get('killed', False) and not ann.get('deleted', False)): - anns_norm.append((ann['start'], cui)) - anns_examples.append({"text": doc['text'][max(0, ann['start']-60):ann['end']+60], - "cui": cui, - "start": ann['start'], - "end": ann['end'], - "source value": ann['value'], - "acc": 1, - "project name": project.get('name'), - "document name": doc.get('name'), - "project id": project.get('id'), - "document id": doc.get('id')}) - elif ann.get('validated', True) and (ann.get('killed', False) or ann.get('deleted', False)): - anns_norm_neg.append((ann['start'], cui)) - - if ann.get("validated", True): - # This is used to test was someone annotating for this CUI in this document - anns_norm_cui.append(cui) - cui_counts[cui] = cui_counts.get(cui, 0) + 1 - - p_anns_norm = [] - p_anns_examples = [] - for ann in p_anns: - cui = ann._.cui - if use_groups: - cui = self.cdb.addl_info['cui2group'].get(cui, cui) - - p_anns_norm.append((ann.start_char, cui)) - p_anns_examples.append({"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60], - "cui": cui, - "start": ann.start_char, - "end": ann.end_char, - "source value": ann.text, - "acc": float(ann._.context_similarity), - "project name": project.get('name'), - "document name": doc.get('name'), - "project id": project.get('id'), - "document id": doc.get('id')}) - for iann, ann in enumerate(p_anns_norm): - cui = ann[1] - if ann in anns_norm: - tp += 1 - tps[cui] = tps.get(cui, 0) + 1 - - example = p_anns_examples[iann] - examples['tp'][cui] = examples['tp'].get(cui, []) + [example] - else: - fp += 1 - fps[cui] = fps.get(cui, 0) + 1 - fp_docs.add(doc.get('name', 'unk')) - - # Add example for this FP prediction - example = p_anns_examples[iann] - if ann in anns_norm_neg: - # Means that it really was annotated as negative - example['real_fp'] = True - - examples['fp'][cui] = examples['fp'].get(cui, []) + [example] - - for iann, ann in enumerate(anns_norm): - if ann not in p_anns_norm: - cui = ann[1] - fn += 1 - fn_docs.add(doc.get('name', 'unk')) - - fns[cui] = fns.get(cui, 0) + 1 - examples['fn'][cui] = examples['fn'].get(cui, []) + [anns_examples[iann]] - - try: - prec = tp / (tp + fp) - rec = tp / (tp + fn) - f1 = 2*(prec*rec) / (prec + rec) - print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(epoch, prec, rec, f1)) - print("Docs with false positives: {}\n".format("; ".join([str(x) for x in list(fp_docs)[0:10]]))) - print("Docs with false negatives: {}\n".format("; ".join([str(x) for x in list(fn_docs)[0:10]]))) - - # Sort fns & prec - fps = {k: v for k, v in sorted(fps.items(), key=lambda item: item[1], reverse=True)} - fns = {k: v for k, v in sorted(fns.items(), key=lambda item: item[1], reverse=True)} - tps = {k: v for k, v in sorted(tps.items(), key=lambda item: item[1], reverse=True)} - - - # F1 per concept - for cui in tps.keys(): - prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0)) - rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0)) - f1 = 2*(prec*rec) / (prec + rec) - cui_prec[cui] = prec - cui_rec[cui] = rec - cui_f1[cui] = f1 - - - # Get top 10 - pr_fps = [(self.cdb.cui2preferred_name.get(cui, - list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fps[cui]) for cui in list(fps.keys())[0:10]] - pr_fns = [(self.cdb.cui2preferred_name.get(cui, - list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fns[cui]) for cui in list(fns.keys())[0:10]] - pr_tps = [(self.cdb.cui2preferred_name.get(cui, - list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, tps[cui]) for cui in list(tps.keys())[0:10]] - - - print("\n\nFalse Positives\n") - for one in pr_fps: - print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) - print("\n\nFalse Negatives\n") - for one in pr_fns: - print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) - print("\n\nTrue Positives\n") - for one in pr_tps: - print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) - print("*"*110 + "\n") - - except Exception: - traceback.print_exc() - - self.config.linking.filters = orig_filters - - return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples - - def _set_project_filters(self, local_filters: LinkingFilters, project: dict, - extra_cui_filter: Optional[Set], use_project_filters: bool): - """Set the project filters to a LinkingFilters object based on - the specified project. - - Args: - local_filters (LinkingFilters): The linking filters instance - project (dict): The project - extra_cui_filter (Optional[Set]): Extra CUIs (if specified) - use_project_filters (bool): Whether to use per-project filters - """ - if isinstance(extra_cui_filter, set): - local_filters.cuis = extra_cui_filter - - if use_project_filters: - project_filter = get_project_filters(cuis=project.get('cuis', None), - type_ids=project.get('tuis', None), - cdb=self.cdb, - project=project) - # Intersect project filter with existing if it has something - if project_filter: - local_filters.cuis = intersect_nonempty_set(project_filter, local_filters.cuis) + return get_stats(self, data=data, epoch=epoch, use_project_filters=use_project_filters, + use_overlaps=use_overlaps, use_cui_doc_limit=use_cui_doc_limit, + use_groups=use_groups, extra_cui_filter=extra_cui_filter, do_print=do_print) def _init_ckpts(self, is_resumed, checkpoint): if self.config.general.checkpoint.steps is not None or checkpoint is not None: @@ -1114,15 +922,15 @@ def train_supervised_raw(self, # then add the extra CUI filters if retain_filters and extra_cui_filter and not retain_extra_cui_filter: # adding project filters without extra_cui_filters - self._set_project_filters(local_filters, project, set(), use_filters) + set_project_filters(self.cdb.addl_info, local_filters, project, set(), use_filters) orig_filters.merge_with(local_filters) # adding extra_cui_filters, but NOT project filters - self._set_project_filters(local_filters, project, extra_cui_filter, False) + set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, False) # refrain from doing it again for subsequent epochs retain_filters = False else: # Set filters in case we are using the train_from_fp - self._set_project_filters(local_filters, project, extra_cui_filter, use_filters) + set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, use_filters) for idx_doc in trange(current_document, len(project['documents']), initial=current_document, total=len(project['documents']), desc='Document', leave=False): doc = project['documents'][idx_doc] diff --git a/medcat/stats/__init__.py b/medcat/stats/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat/stats/stats.py b/medcat/stats/stats.py new file mode 100644 index 000000000..06b712158 --- /dev/null +++ b/medcat/stats/stats.py @@ -0,0 +1,340 @@ +from typing import Dict, Optional, Set, Tuple, Callable, List, cast + +from tqdm import tqdm +import traceback + +from spacy.tokens import Doc + +from medcat.utils.filters import set_project_filters +from medcat.utils.matutils import intersect_nonempty_set +from medcat.config import LinkingFilters + + +class StatsBuilder: + + def __init__(self, + filters: LinkingFilters, + addl_info: dict, + doc_getter: Callable[[Optional[str], bool], Optional[Doc]], + doc_annotation_getter: Callable[[dict], list], + cui2group: Dict[str, str], + cui2preferred_name: Dict[str, str], + cui2names: Dict[str, Set[str]], + use_project_filters: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + use_groups: bool = False, + extra_cui_filter: Optional[Set] = None) -> None: + self.filters = filters + self.addl_info = addl_info + self.doc_getter = doc_getter + self._get_doc_annotations = doc_annotation_getter + self.cui2group = cui2group + self.cui2preferred_name = cui2preferred_name + self.cui2names = cui2names + self.use_project_filters = use_project_filters + self.use_overlaps = use_overlaps + self.use_cui_doc_limit = use_cui_doc_limit + self.use_groups = use_groups + self.extra_cui_filter = extra_cui_filter + self._reset_stats() + + def _reset_stats(self): + self.tp = 0 + self.fp = 0 + self.fn = 0 + self.fps: Dict = {} + self.fns: Dict = {} + self.tps: Dict = {} + self.cui_prec: Dict = {} + self.cui_rec: Dict = {} + self.cui_f1: Dict = {} + self.cui_counts: Dict = {} + self.examples: Dict = {'fp': {}, 'fn': {}, 'tp': {}} + self.fp_docs: Set = set() + self.fn_docs: Set = set() + + def process_project(self, project: dict) -> None: + self.filters.cuis = set() + + # Add extra filter if set + set_project_filters(self.addl_info, self.filters, project, self.extra_cui_filter, self.use_project_filters) + + documents = project["documents"] + for dind, doc in tqdm( + enumerate(documents), + desc="Stats document", + total=len(documents), + leave=False, + ): + self.process_document(cast(str, project.get('name')), + cast(str, project.get('id')), doc) + + def process_document(self, project_name: str, project_id: str, doc: dict) -> None: + anns = self._get_doc_annotations(doc) + + # Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still + if self.use_cui_doc_limit: + _cuis = set([ann['cui'] for ann in anns]) + if _cuis: + self.filters.cuis = intersect_nonempty_set(_cuis, self.extra_cui_filter) + else: + self.filters.cuis = {'empty'} + + spacy_doc: Doc = self.doc_getter(doc['text']) # type: ignore + + if self.use_overlaps: + p_anns = spacy_doc._.ents + else: + p_anns = spacy_doc.ents + + (anns_norm, anns_norm_neg, + anns_examples, _) = self._preprocess_annotations(project_name, project_id, doc, anns) + + p_anns_norm, p_anns_examples = self._process_p_anns(project_name, project_id, + doc, p_anns) + self._count_p_anns_norm(doc, anns_norm, anns_norm_neg, + p_anns_norm, p_anns_examples) + self._process_anns_norm(doc, anns_norm, p_anns_norm, anns_examples) + + def _process_anns_norm(self, doc: dict, anns_norm: list, p_anns_norm: list, + anns_examples: list) -> None: + for iann, ann in enumerate(anns_norm): + if ann not in p_anns_norm: + cui = ann[1] + self.fn += 1 + self.fn_docs.add(doc.get('name', 'unk')) + + self.fns[cui] = self.fns.get(cui, 0) + 1 + self.examples['fn'][cui] = self.examples['fn'].get(cui, []) + [anns_examples[iann]] + + def _process_p_anns(self, project_name: str, project_id: str, doc: dict, p_anns: list) -> Tuple[list, list]: + p_anns_norm = [] + p_anns_examples = [] + for ann in p_anns: + cui = ann._.cui + if self.use_groups: + cui = self.cui2group.get(cui, cui) + + p_anns_norm.append((ann.start_char, cui)) + p_anns_examples.append(self._create_annoation_2(project_name, project_id, cui, doc, ann)) + return p_anns_norm, p_anns_examples + + def _count_p_anns_norm(self, doc: dict, anns_norm: list, anns_norm_neg: list, + p_anns_norm: list, p_anns_examples: list) -> None: + for iann, ann in enumerate(p_anns_norm): + cui = ann[1] + if ann in anns_norm: + self.tp += 1 + self.tps[cui] = self.tps.get(cui, 0) + 1 + + example = p_anns_examples[iann] + self.examples['tp'][cui] = self.examples['tp'].get(cui, []) + [example] + else: + self.fp += 1 + self.fps[cui] = self.fps.get(cui, 0) + 1 + self.fp_docs.add(doc.get('name', 'unk')) + + # Add example for this FP prediction + example = p_anns_examples[iann] + if ann in anns_norm_neg: + # Means that it really was annotated as negative + example['real_fp'] = True + + self.examples['fp'][cui] = self.examples['fp'].get(cui, []) + [example] + + def _create_annoation(self, project_name: str, project_id: str, cui: str, doc: dict, ann: Dict) -> Dict: + return {"text": doc['text'][max(0, ann['start']-60):ann['end']+60], + "cui": cui, + "start": ann['start'], + "end": ann['end'], + "source value": ann['value'], + "acc": 1, + "project name": project_name, + "document name": doc.get('name'), + "project id": project_id, + "document id": doc.get('id')} + + def _create_annoation_2(self, project_name: str, project_id: str, cui: str, doc: dict, ann) -> Dict: + return {"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60], + "cui": cui, + "start": ann.start_char, + "end": ann.end_char, + "source value": ann.text, + "acc": float(ann._.context_similarity), + "project name": project_name, + "document name": doc.get('name'), + "project id": project_id, + "document id": doc.get('id')} + + def _preprocess_annotations(self, project_name: str, project_id: str, + doc: dict, anns: List[Dict]) -> Tuple[list, list, list, list]: + anns_norm = [] + anns_norm_neg = [] + anns_examples = [] + anns_norm_cui = [] + for ann in anns: + cui = ann['cui'] + if self.filters.check_filters(cui): + if self.use_groups: + cui = self.cui2group.get(cui, cui) + + if ann.get('validated', True) and (not ann.get('killed', False) and not ann.get('deleted', False)): + anns_norm.append((ann['start'], cui)) + anns_examples.append(self._create_annoation(project_name, project_id, cui, doc, ann)) + elif ann.get('validated', True) and (ann.get('killed', False) or ann.get('deleted', False)): + anns_norm_neg.append((ann['start'], cui)) + + if ann.get("validated", True): + # This is used to test was someone annotating for this CUI in this document + anns_norm_cui.append(cui) + self.cui_counts[cui] = self.cui_counts.get(cui, 0) + 1 + return anns_norm, anns_norm_neg, anns_examples, anns_norm_cui + + def finalise_report(self, epoch: int, do_print: bool = True): + try: + prec = self.tp / (self.tp + self.fp) + rec = self.tp / (self.tp + self.fn) + f1 = 2*(prec*rec) / (prec + rec) + if do_print: + print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(epoch, prec, rec, f1)) + print("Docs with false positives: {}\n".format("; ".join([str(x) for x in list(self.fp_docs)[0:10]]))) + print("Docs with false negatives: {}\n".format("; ".join([str(x) for x in list(self.fn_docs)[0:10]]))) + + # Sort fns & prec + fps = {k: v for k, v in sorted(self.fps.items(), key=lambda item: item[1], reverse=True)} + fns = {k: v for k, v in sorted(self.fns.items(), key=lambda item: item[1], reverse=True)} + tps = {k: v for k, v in sorted(self.tps.items(), key=lambda item: item[1], reverse=True)} + + + # F1 per concept + for cui in tps.keys(): + prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0)) + rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0)) + f1 = 2*(prec*rec) / (prec + rec) + self.cui_prec[cui] = prec + self.cui_rec[cui] = rec + self.cui_f1[cui] = f1 + + + # Get top 10 + pr_fps = [(self.cui2preferred_name.get(cui, + list(self.cui2names.get(cui, [cui]))[0]), cui, fps[cui]) for cui in list(fps.keys())[0:10]] + pr_fns = [(self.cui2preferred_name.get(cui, + list(self.cui2names.get(cui, [cui]))[0]), cui, fns[cui]) for cui in list(fns.keys())[0:10]] + pr_tps = [(self.cui2preferred_name.get(cui, + list(self.cui2names.get(cui, [cui]))[0]), cui, tps[cui]) for cui in list(tps.keys())[0:10]] + + if do_print: + print("\n\nFalse Positives\n") + for one in pr_fps: + print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) + print("\n\nFalse Negatives\n") + for one in pr_fns: + print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) + print("\n\nTrue Positives\n") + for one in pr_tps: + print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2])) + print("*"*110 + "\n") + + except Exception: + traceback.print_exc() + + def unwrap(self) -> Tuple: + return (self.fps, self.fns, self.tps, + self.cui_prec, self.cui_rec, self.cui_f1, + self.cui_counts, self.examples) + + @classmethod + def from_cat(cls, cat, + local_filters: LinkingFilters, + use_project_filters: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + use_groups: bool = False, + extra_cui_filter: Optional[Set] = None) -> 'StatsBuilder': + return StatsBuilder(filters=local_filters, + addl_info=cat.cdb.addl_info, + doc_getter=cat.__call__, + doc_annotation_getter=cat._get_doc_annotations, + cui2group=cat.cdb.addl_info['cui2group'], + cui2preferred_name=cat.cdb.cui2preferred_name, + cui2names=cat.cdb.cui2names, + use_project_filters=use_project_filters, + use_overlaps=use_overlaps, + use_cui_doc_limit=use_cui_doc_limit, + use_groups=use_groups, + extra_cui_filter=extra_cui_filter) + + +def get_stats(cat, + data: Dict, + epoch: int = 0, + use_project_filters: bool = False, + use_overlaps: bool = False, + use_cui_doc_limit: bool = False, + use_groups: bool = False, + extra_cui_filter: Optional[Set] = None, + do_print: bool = True) -> Tuple: + """TODO: Refactor and make nice + Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP. + + Args: + cat: (CAT): + The model pack. + data (list of dict): + The json object that we get from MedCATtrainer on export. + epoch (int): + Used during training, so we know what epoch is it. + use_project_filters (boolean): + Each project in MedCATtrainer can have filters, do we want to respect those filters + when calculating metrics. + use_overlaps (boolean): + Allow overlapping entities, nearly always False as it is very difficult to annotate overlapping entites. + use_cui_doc_limit (boolean): + If True the metrics for a CUI will be only calculated if that CUI appears in a document, in other words + if the document was annotated for that CUI. Useful in very specific situations when during the annotation + process the set of CUIs changed. + use_groups (boolean): + If True concepts that have groups will be combined and stats will be reported on groups. + extra_cui_filter(Optional[Set]): + This filter will be intersected with all other filters, or if all others are not set then only this one will be used. + + Returns: + fps (dict): + False positives for each CUI. + fns (dict): + False negatives for each CUI. + tps (dict): + True positives for each CUI. + cui_prec (dict): + Precision for each CUI. + cui_rec (dict): + Recall for each CUI. + cui_f1 (dict): + F1 for each CUI. + cui_counts (dict): + Number of occurrence for each CUI. + examples (dict): + Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][]. + do_print (bool): + Whether to print stats out. Defaults to True. + """ + orig_filters = cat.config.linking.filters.copy_of() + local_filters = cat.config.linking.filters + builder = StatsBuilder.from_cat(cat, + local_filters=local_filters, + use_project_filters=use_project_filters, + use_overlaps=use_overlaps, + use_cui_doc_limit=use_cui_doc_limit, + use_groups=use_groups, + extra_cui_filter=extra_cui_filter) + for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False): + builder.process_project(project) + + # this is the part that prints out the stats + builder.finalise_report(epoch, do_print=do_print) + + cat.config.linking.filters = orig_filters + + return builder.unwrap() diff --git a/medcat/utils/filters.py b/medcat/utils/filters.py index c4803027a..cb85e0e26 100644 --- a/medcat/utils/filters.py +++ b/medcat/utils/filters.py @@ -1,3 +1,9 @@ +from typing import Optional, Set, Dict + +from medcat.config import LinkingFilters +from medcat.utils.matutils import intersect_nonempty_set + + def check_filters(cui, filters): """Checks is a CUI in the filters @@ -15,7 +21,7 @@ def check_filters(cui, filters): return False -def get_all_irrelevant_cuis(project, cdb): +def get_all_irrelevant_cuis(project): i_cuis = set() for d in project['documents']: for a in d['annotations']: @@ -24,7 +30,7 @@ def get_all_irrelevant_cuis(project, cdb): return i_cuis -def get_project_filters(cuis, type_ids, cdb, project=None): +def get_project_filters(cuis, type_ids, addl_info: Dict, project=None): cui_filter = set() if isinstance(cuis, str): if cuis is not None and cuis: @@ -33,10 +39,10 @@ def get_project_filters(cuis, type_ids, cdb, project=None): type_ids = [x.strip().upper() for x in type_ids.split(",")] # Convert type_ids to cuis - if 'type_id2cuis' in cdb.addl_info: + if 'type_id2cuis' in addl_info: for type_id in type_ids: - if type_id in cdb.addl_info['type_id2cuis']: - cui_filter.update(cdb.addl_info['type_id2cuis'][type_id]) + if type_id in addl_info['type_id2cuis']: + cui_filter.update(addl_info['type_id2cuis'][type_id]) else: raise Exception("Impossible to create filters, disable them.") else: @@ -45,8 +51,33 @@ def get_project_filters(cuis, type_ids, cdb, project=None): cui_filter = set(cuis) if project is not None: - i_cuis = get_all_irrelevant_cuis(project, cdb) + i_cuis = get_all_irrelevant_cuis(project) for i_cui in i_cuis: cui_filter.remove(i_cui) return cui_filter + + +def set_project_filters(addl_info: Dict, local_filters: LinkingFilters, project: dict, + extra_cui_filter: Optional[Set], use_project_filters: bool): + """Set the project filters to a LinkingFilters object based on + the specified project. + + Args: + addl_info (Dict): The CDB additional information + local_filters (LinkingFilters): The linking filters instance + project (dict): The project + extra_cui_filter (Optional[Set]): Extra CUIs (if specified) + use_project_filters (bool): Whether to use per-project filters + """ + if isinstance(extra_cui_filter, set): + local_filters.cuis = extra_cui_filter + + if use_project_filters: + project_filter = get_project_filters(cuis=project.get('cuis', None), + type_ids=project.get('tuis', None), + addl_info=addl_info, + project=project) + # Intersect project filter with existing if it has something + if project_filter: + local_filters.cuis = intersect_nonempty_set(project_filter, local_filters.cuis) diff --git a/medcat/utils/regression/targeting.py b/medcat/utils/regression/targeting.py index 19f19bb3f..7a13b2bcc 100644 --- a/medcat/utils/regression/targeting.py +++ b/medcat/utils/regression/targeting.py @@ -25,12 +25,12 @@ class TranslationLayer: Args: cui2names (Dict[str, Set[str]]): The map from CUI to names - name2cuis (Dict[str, Set[str]]): The map from name to CUIs + name2cuis (Dict[str, List[str]]): The map from name to CUIs cui2type_ids (Dict[str, Set[str]]): The map from CUI to type_ids cui2children (Dict[str, Set[str]]): The map from CUI to child CUIs """ - def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, Set[str]], + def __init__(self, cui2names: Dict[str, Set[str]], name2cuis: Dict[str, List[str]], cui2type_ids: Dict[str, Set[str]], cui2children: Dict[str, Set[str]]) -> None: self.cui2names = cui2names self.name2cuis = name2cuis diff --git a/setup.py b/setup.py index ab49eaff1..34963943a 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ url="https://github.com/CogStack/MedCAT", packages=['medcat', 'medcat.utils', 'medcat.preprocessing', 'medcat.ner', 'medcat.linking', 'medcat.datasets', 'medcat.tokenizers', 'medcat.utils.meta_cat', 'medcat.pipeline', 'medcat.utils.ner', - 'medcat.utils.saving', 'medcat.utils.regression'], + 'medcat.utils.saving', 'medcat.utils.regression', 'medcat.stats'], install_requires=[ 'numpy>=1.22.0', # first to support 3.11 'pandas>=1.4.2', # first to support 3.11 diff --git a/tests/test_cat.py b/tests/test_cat.py index acd337e71..368b1e885 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -17,6 +17,7 @@ class CATTests(unittest.TestCase): + SUPERVISED_TRAINING_JSON = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json") @classmethod def setUpClass(cls) -> None: @@ -39,7 +40,8 @@ def setUpClass(cls) -> None: @classmethod def tearDownClass(cls) -> None: cls.undertest.destroy_pipe() - shutil.rmtree(cls.meta_cat_dir) + if os.path.exists(cls.meta_cat_dir): + shutil.rmtree(cls.meta_cat_dir) def tearDown(self) -> None: self.cdb.config.annotation_output.include_text_in_output = False @@ -214,7 +216,7 @@ def test_get_entities_multi_texts_including_text(self): def test_train_supervised(self): nepochs = 3 num_of_documents = 27 - data_path = os.path.join(os.path.dirname(__file__), "resources", "medcat_trainer_export.json") + data_path = self.SUPERVISED_TRAINING_JSON ckpt_dir_path = tempfile.TemporaryDirectory().name checkpoint = Checkpoint(dir_path=ckpt_dir_path, steps=1, max_to_keep=sys.maxsize) fp, fn, tp, p, r, f1, cui_counts, examples = self.undertest.train_supervised(data_path, @@ -391,6 +393,42 @@ def test_hashing(self): cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) self.assertEqual(cat.get_hash(), cat.config.version.id) + def test_print_stats(self): + # based on current JSON + EXP_FALSE_NEGATIVES = {'C0017168': 2, 'C0020538': 43, 'C0038454': 4, 'C0007787': 1, 'C0155626': 4, 'C0011860': 12, + 'C0042029': 6, 'C0010068': 2, 'C0007222': 1, 'C0027051': 6, 'C0878544': 1, 'C0020473': 12, + 'C0037284': 21, 'C0003864': 4, 'C0011849': 12, 'C0005686': 1, 'C0085762': 3, 'C0030920': 2, + 'C0854135': 3, 'C0004096': 4, 'C0010054': 10, 'C0497156': 10, 'C0011334': 2, 'C0018939': 1, + 'C1561826': 2, 'C0276289': 2, 'C0041834': 9, 'C0000833': 2, 'C0238792': 1, 'C0040034': 3, + 'C0035078': 5, 'C0018799': 5, 'C0042109': 1, 'C0035439': 1, 'C0035435': 1, 'C0018099': 1, + 'C1277187': 1, 'C0024117': 7, 'C0004238': 4, 'C0032227': 6, 'C0008679': 1, 'C0013146': 6, + 'C0032285': 1, 'C0002871': 7, 'C0149871': 4, 'C0442886': 1, 'C0022104': 1, 'C0034065': 5, + 'C0011854': 6, 'C1398668': 1, 'C0020676': 2, 'C1301700': 1, 'C0021167': 1, 'C0029456': 2, + 'C0011570': 10, 'C0009324': 1, 'C0011882': 1, 'C0020615': 1, 'C0242510': 2, 'C0033581': 2, + 'C0011168': 3, 'C0039082': 2, 'C0009241': 2, 'C1404970': 1, 'C0018524': 3, 'C0150063': 1, + 'C0917799': 1, 'C0178417': 1, 'C0033975': 1, 'C0011253': 1, 'C0018802': 8, 'C0022661': 4, + 'C0017658': 1, 'C0023895': 2, 'C0003123': 1, 'C0041582': 4, 'C0085096': 4, 'C0403447': 2, + 'C2363741': 2, 'C0457949': 1, 'C0040336': 1, 'C0037315': 2, 'C0024236': 3, 'C0442874': 1, + 'C0028754': 4, 'C0520679': 5, 'C0028756': 2, 'C0029408': 5, 'C0409959': 2, 'C0018801': 1, + 'C3844825': 1, 'C0022660': 2, 'C0005779': 4, 'C0011175': 1, 'C0018965': 4, 'C0018889': 1, + 'C0022354': 2, 'C0033377': 1, 'C0042769': 1, 'C0035222': 1, 'C1456868': 2, 'C1145670': 1, + 'C0018790': 1, 'C0263746': 1, 'C0206172': 1, 'C0021400': 1, 'C0243026': 1, 'C0020443': 1, + 'C0001883': 1, 'C0031350': 1, 'C0010709': 4, 'C1565489': 7, 'C3489393': 1, 'C0005586': 2, + 'C0158288': 5, 'C0700594': 4, 'C0158266': 3, 'C0006444': 2, 'C0024003': 1} + with open(self.SUPERVISED_TRAINING_JSON) as f: + data = json.load(f) + (fps, fns, tps, + cui_prec, cui_rec, cui_f1, + cui_counts, examples) = self.undertest._print_stats(data) + self.assertEqual(fps, {}) + self.assertEqual(fns, EXP_FALSE_NEGATIVES) + self.assertEqual(tps, {}) + self.assertEqual(cui_prec, {}) + self.assertEqual(cui_rec, {}) + self.assertEqual(cui_f1, {}) + self.assertEqual(len(cui_counts), 136) + self.assertEqual(len(examples), 3) + def _assertNoLogs(self, logger: logging.Logger, level: int): if hasattr(self, 'assertNoLogs'): return self.assertNoLogs(logger=logger, level=level)