Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-2e77a31 improve print stats #366

Merged
merged 12 commits into from
Dec 18, 2023
218 changes: 13 additions & 205 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import glob
import shutil
import pickle
import traceback
import json
import logging
import math
Expand All @@ -24,23 +23,23 @@
from medcat.pipe import Pipe
from medcat.preprocessing.taggers import tag_skip_and_punct
from medcat.cdb import CDB
from medcat.utils.matutils import intersect_nonempty_set
from medcat.utils.data_utils import make_mc_train_test, get_false_positives
from medcat.utils.normalizers import BasicSpellChecker
from medcat.utils.checkpoint import Checkpoint, CheckpointConfig, CheckpointManager
from medcat.utils.helpers import tkns_from_doc, get_important_config_parameters, has_new_spacy
from medcat.utils.hasher import Hasher
from medcat.ner.vocab_based_ner import NER
from medcat.linking.context_based_linker import Linker
from medcat.utils.filters import get_project_filters
from medcat.preprocessing.cleaners import prepare_name
from medcat.meta_cat import MetaCAT
from medcat.utils.meta_cat.data_utils import json_to_fake_spacy
from medcat.config import Config, LinkingFilters
from medcat.config import Config
from medcat.vocab import Vocab
from medcat.utils.decorators import deprecated
from medcat.ner.transformers_ner import TransformersNER
from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY
from medcat.stats.stats import get_stats
from medcat.utils.filters import set_project_filters


logger = logging.getLogger(__name__) # separate logger from the package-level one
Expand Down Expand Up @@ -434,7 +433,8 @@ def _print_stats(self,
use_overlaps: bool = False,
use_cui_doc_limit: bool = False,
use_groups: bool = False,
extra_cui_filter: Optional[Set] = None) -> Tuple:
extra_cui_filter: Optional[Set] = None,
do_print: bool = True) -> Tuple:
"""TODO: Refactor and make nice
Print metrics on a dataset (F1, P, R), it will also print the concepts that have the most FP,FN,TP.

Expand Down Expand Up @@ -474,204 +474,12 @@ def _print_stats(self,
Number of occurrence for each CUI.
examples (dict):
Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][<list_of_examples>].
do_print (bool):
Whether to print stats out. Defaults to True.
"""
tp = 0
fp = 0
fn = 0
fps: Dict = {}
fns: Dict = {}
tps: Dict = {}
cui_prec: Dict = {}
cui_rec: Dict = {}
cui_f1: Dict = {}
cui_counts: Dict = {}
examples: Dict = {'fp': {}, 'fn': {}, 'tp': {}}

fp_docs: Set = set()
fn_docs: Set = set()

orig_filters = self.config.linking.filters.copy_of()
local_filters = self.config.linking.filters
for pind, project in tqdm(enumerate(data['projects']), desc="Stats project", total=len(data['projects']), leave=False):
local_filters.cuis = set()

# Add extra filter if set
self._set_project_filters(local_filters, project, extra_cui_filter, use_project_filters)

for dind, doc in tqdm(
enumerate(project["documents"]),
desc="Stats document",
total=len(project["documents"]),
leave=False,
):
anns = self._get_doc_annotations(doc)

# Apply document level filtering, in this case project_filter is ignored while the extra_cui_filter is respected still
if use_cui_doc_limit:
_cuis = set([ann['cui'] for ann in anns])
if _cuis:
local_filters.cuis = intersect_nonempty_set(_cuis, extra_cui_filter)
else:
local_filters.cuis = {'empty'}

spacy_doc: Doc = self(doc['text']) # type: ignore

if use_overlaps:
p_anns = spacy_doc._.ents
else:
p_anns = spacy_doc.ents

anns_norm = []
anns_norm_neg = []
anns_examples = []
anns_norm_cui = []
for ann in anns:
cui = ann['cui']
if local_filters.check_filters(cui):
if use_groups:
cui = self.cdb.addl_info['cui2group'].get(cui, cui)

if ann.get('validated', True) and (not ann.get('killed', False) and not ann.get('deleted', False)):
anns_norm.append((ann['start'], cui))
anns_examples.append({"text": doc['text'][max(0, ann['start']-60):ann['end']+60],
"cui": cui,
"start": ann['start'],
"end": ann['end'],
"source value": ann['value'],
"acc": 1,
"project name": project.get('name'),
"document name": doc.get('name'),
"project id": project.get('id'),
"document id": doc.get('id')})
elif ann.get('validated', True) and (ann.get('killed', False) or ann.get('deleted', False)):
anns_norm_neg.append((ann['start'], cui))

if ann.get("validated", True):
# This is used to test was someone annotating for this CUI in this document
anns_norm_cui.append(cui)
cui_counts[cui] = cui_counts.get(cui, 0) + 1

p_anns_norm = []
p_anns_examples = []
for ann in p_anns:
cui = ann._.cui
if use_groups:
cui = self.cdb.addl_info['cui2group'].get(cui, cui)

p_anns_norm.append((ann.start_char, cui))
p_anns_examples.append({"text": doc['text'][max(0, ann.start_char-60):ann.end_char+60],
"cui": cui,
"start": ann.start_char,
"end": ann.end_char,
"source value": ann.text,
"acc": float(ann._.context_similarity),
"project name": project.get('name'),
"document name": doc.get('name'),
"project id": project.get('id'),
"document id": doc.get('id')})
for iann, ann in enumerate(p_anns_norm):
cui = ann[1]
if ann in anns_norm:
tp += 1
tps[cui] = tps.get(cui, 0) + 1

example = p_anns_examples[iann]
examples['tp'][cui] = examples['tp'].get(cui, []) + [example]
else:
fp += 1
fps[cui] = fps.get(cui, 0) + 1
fp_docs.add(doc.get('name', 'unk'))

# Add example for this FP prediction
example = p_anns_examples[iann]
if ann in anns_norm_neg:
# Means that it really was annotated as negative
example['real_fp'] = True

examples['fp'][cui] = examples['fp'].get(cui, []) + [example]

for iann, ann in enumerate(anns_norm):
if ann not in p_anns_norm:
cui = ann[1]
fn += 1
fn_docs.add(doc.get('name', 'unk'))

fns[cui] = fns.get(cui, 0) + 1
examples['fn'][cui] = examples['fn'].get(cui, []) + [anns_examples[iann]]

try:
prec = tp / (tp + fp)
rec = tp / (tp + fn)
f1 = 2*(prec*rec) / (prec + rec)
print("Epoch: {}, Prec: {}, Rec: {}, F1: {}\n".format(epoch, prec, rec, f1))
print("Docs with false positives: {}\n".format("; ".join([str(x) for x in list(fp_docs)[0:10]])))
print("Docs with false negatives: {}\n".format("; ".join([str(x) for x in list(fn_docs)[0:10]])))

# Sort fns & prec
fps = {k: v for k, v in sorted(fps.items(), key=lambda item: item[1], reverse=True)}
fns = {k: v for k, v in sorted(fns.items(), key=lambda item: item[1], reverse=True)}
tps = {k: v for k, v in sorted(tps.items(), key=lambda item: item[1], reverse=True)}


# F1 per concept
for cui in tps.keys():
prec = tps[cui] / (tps.get(cui, 0) + fps.get(cui, 0))
rec = tps[cui] / (tps.get(cui, 0) + fns.get(cui, 0))
f1 = 2*(prec*rec) / (prec + rec)
cui_prec[cui] = prec
cui_rec[cui] = rec
cui_f1[cui] = f1


# Get top 10
pr_fps = [(self.cdb.cui2preferred_name.get(cui,
list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fps[cui]) for cui in list(fps.keys())[0:10]]
pr_fns = [(self.cdb.cui2preferred_name.get(cui,
list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, fns[cui]) for cui in list(fns.keys())[0:10]]
pr_tps = [(self.cdb.cui2preferred_name.get(cui,
list(self.cdb.cui2names.get(cui, [cui]))[0]), cui, tps[cui]) for cui in list(tps.keys())[0:10]]


print("\n\nFalse Positives\n")
for one in pr_fps:
print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
print("\n\nFalse Negatives\n")
for one in pr_fns:
print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
print("\n\nTrue Positives\n")
for one in pr_tps:
print("{:70} - {:20} - {:10}".format(str(one[0])[0:69], str(one[1])[0:19], one[2]))
print("*"*110 + "\n")

except Exception:
traceback.print_exc()

self.config.linking.filters = orig_filters

return fps, fns, tps, cui_prec, cui_rec, cui_f1, cui_counts, examples

def _set_project_filters(self, local_filters: LinkingFilters, project: dict,
extra_cui_filter: Optional[Set], use_project_filters: bool):
"""Set the project filters to a LinkingFilters object based on
the specified project.

Args:
local_filters (LinkingFilters): The linking filters instance
project (dict): The project
extra_cui_filter (Optional[Set]): Extra CUIs (if specified)
use_project_filters (bool): Whether to use per-project filters
"""
if isinstance(extra_cui_filter, set):
local_filters.cuis = extra_cui_filter

if use_project_filters:
project_filter = get_project_filters(cuis=project.get('cuis', None),
type_ids=project.get('tuis', None),
cdb=self.cdb,
project=project)
# Intersect project filter with existing if it has something
if project_filter:
local_filters.cuis = intersect_nonempty_set(project_filter, local_filters.cuis)
return get_stats(self, data=data, epoch=epoch, use_project_filters=use_project_filters,
use_overlaps=use_overlaps, use_cui_doc_limit=use_cui_doc_limit,
use_groups=use_groups, extra_cui_filter=extra_cui_filter, do_print=do_print)

def _init_ckpts(self, is_resumed, checkpoint):
if self.config.general.checkpoint.steps is not None or checkpoint is not None:
Expand Down Expand Up @@ -1102,15 +910,15 @@ def train_supervised_raw(self,
# then add the extra CUI filters
if retain_filters and extra_cui_filter and not retain_extra_cui_filter:
# adding project filters without extra_cui_filters
self._set_project_filters(local_filters, project, set(), use_filters)
set_project_filters(self.cdb.addl_info, local_filters, project, set(), use_filters)
orig_filters.merge_with(local_filters)
# adding extra_cui_filters, but NOT project filters
self._set_project_filters(local_filters, project, extra_cui_filter, False)
set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, False)
# refrain from doing it again for subsequent epochs
retain_filters = False
else:
# Set filters in case we are using the train_from_fp
self._set_project_filters(local_filters, project, extra_cui_filter, use_filters)
set_project_filters(self.cdb.addl_info, local_filters, project, extra_cui_filter, use_filters)

for idx_doc in trange(current_document, len(project['documents']), initial=current_document, total=len(project['documents']), desc='Document', leave=False):
doc = project['documents'][idx_doc]
Expand Down
Empty file added medcat/stats/__init__.py
Empty file.
Loading
Loading