Skip to content

Commit

Permalink
Merge branch 'master' into CU-8694wh3d5-track-usage
Browse files Browse the repository at this point in the history
  • Loading branch information
mart-r committed Jul 18, 2024
2 parents 6dcaf70 + 018fd7a commit 5a409f2
Show file tree
Hide file tree
Showing 5 changed files with 287 additions and 63 deletions.
4 changes: 4 additions & 0 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def get_hash(self, force_recalc: bool = False) -> str:
str: The resulting hash
"""
hasher = Hasher()
if self.config.general.simple_hash:
logger.info("Using simplified hashing that only takes into account the model card")
hasher.update(self.get_model_card())
return hasher.hexdigest()
hasher.update(self.cdb.get_hash(force_recalc))

hasher.update(self.config.get_hash())
Expand Down
5 changes: 5 additions & 0 deletions medcat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,11 @@ class General(MixingConfig, BaseModel):
if `long` it will be CUI | Name | Confidence"""
map_cui_to_group: bool = False
"""If the cdb.addl_info['cui2group'] is provided and this option enabled, each CUI will be maped to the group"""
simple_hash: bool = False
"""Whether to use a simple hash.
NOTE: While using a simple hash is faster at save time, it is less
reliable due to not taking into account all the details of the changes."""

class Config:
extra = Extra.allow
Expand Down
153 changes: 98 additions & 55 deletions medcat/stats/kfold.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterable, Callable, cast, Any
from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterable, Callable, cast, Any, Union

from abc import ABC, abstractmethod
from enum import Enum, auto
from copy import deepcopy
from pydantic import BaseModel

import numpy as np

Expand Down Expand Up @@ -299,52 +300,72 @@ def get_per_fold_metrics(cat: CATLike, folds: List[MedCATTrainerExport],
return metrics


def _update_all_weighted_average(joined: List[Dict[str, Tuple[int, float]]],
single: List[Dict[str, float]], cui2count: Dict[str, int]) -> None:
if len(joined) != len(single):
raise ValueError(f"Incompatible lists. Joined {len(joined)} and single {len(single)}")
for j, s in zip(joined, single):
_update_one_weighted_average(j, s, cui2count)
def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None:
for ex_type, ex_dict in cur_examples.items():
if ex_type not in all_examples:
all_examples[ex_type] = {}
per_type_examples = all_examples[ex_type]
for ex_cui, cui_examples_list in ex_dict.items():
if ex_cui not in per_type_examples:
per_type_examples[ex_cui] = []
per_type_examples[ex_cui].extend(cui_examples_list)


# helper types
IntValuedMetric = Union[
Dict[str, int],
Dict[str, Tuple[int, float]]
]
FloatValuedMetric = Union[
Dict[str, float],
Dict[str, Tuple[float, float]]
]


class PerCUIMetrics(BaseModel):
weights: List[Union[int, float]] = []
vals: List[Union[int, float]] = []

def _update_one_weighted_average(joined: Dict[str, Tuple[int, float]],
one: Dict[str, float],
cui2count: Dict[str, int]) -> None:
for k in one:
if k not in joined:
joined[k] = (0, 0)
prev_w, prev_val = joined[k]
new_w, new_val = cui2count[k], one[k]
total_w = prev_w + new_w
total_val = (prev_w * prev_val + new_w * new_val) / total_w
joined[k] = (total_w, total_val)
def add(self, val, weight: int = 1):
self.weights.append(weight)
self.vals.append(val)

def get_mean(self):
return sum(w * v for w, v in zip(self.weights, self.vals)) / sum(self.weights)

def _update_all_add(joined: List[Dict[str, int]], single: List[Dict[str, int]]) -> None:
def get_std(self):
mean = self.get_mean()
return (sum(w * (v - mean)**2 for w, v in zip(self.weights, self.vals)) / sum(self.weights))**.5


def _add_helper(joined: List[Dict[str, PerCUIMetrics]],
single: List[Dict[str, int]]) -> None:
if len(joined) != len(single):
raise ValueError(f"Incompatible number of stuff: {len(joined)} vs {len(single)}")
for j, s in zip(joined, single):
for k, v in s.items():
j[k] = j.get(k, 0) + v
if k not in j:
j[k] = PerCUIMetrics()
j[k].add(v)


def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None:
for ex_type, ex_dict in cur_examples.items():
if ex_type not in all_examples:
all_examples[ex_type] = {}
per_type_examples = all_examples[ex_type]
for ex_cui, cui_examples_list in ex_dict.items():
if ex_cui not in per_type_examples:
per_type_examples[ex_cui] = []
per_type_examples[ex_cui].extend(cui_examples_list)
def _add_weighted_helper(joined: List[Dict[str, PerCUIMetrics]],
single: List[Dict[str, float]],
cui2count: Dict[str, int]) -> None:
for j, s in zip(joined, single):
for k, v in s.items():
if k not in j:
j[k] = PerCUIMetrics()
j[k].add(v, cui2count[k])


def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]]
) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]:
def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]],
include_std: bool) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]:
"""The the mean of the provided metrics.
Args:
metrics (List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]): The metrics.
include_std (bool): Whether to include the standard deviation.
Returns:
fps (dict):
Expand All @@ -365,15 +386,15 @@ def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dic
Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][<list_of_examples>].
"""
# additives
all_fps: Dict[str, int] = {}
all_fns: Dict[str, int] = {}
all_tps: Dict[str, int] = {}
all_fps: Dict[str, PerCUIMetrics] = {}
all_fns: Dict[str, PerCUIMetrics] = {}
all_tps: Dict[str, PerCUIMetrics] = {}
# weighted-averages
all_cui_prec: Dict[str, Tuple[int, float]] = {}
all_cui_rec: Dict[str, Tuple[int, float]] = {}
all_cui_f1: Dict[str, Tuple[int, float]] = {}
all_cui_prec: Dict[str, PerCUIMetrics] = {}
all_cui_rec: Dict[str, PerCUIMetrics] = {}
all_cui_f1: Dict[str, PerCUIMetrics] = {}
# additive
all_cui_counts: Dict[str, int] = {}
all_cui_counts: Dict[str, PerCUIMetrics] = {}
# combined
all_additives = [
all_fps, all_fns, all_tps, all_cui_counts
Expand All @@ -386,29 +407,49 @@ def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dic
for current in metrics:
cur_wa: list = list(current[3:-2])
cur_counts = current[-2]
_update_all_weighted_average(all_weighted_averages, cur_wa, cur_counts)
# update ones that just need to be added up
cur_adds = list(current[:3]) + [cur_counts]
_update_all_add(all_additives, cur_adds)
# merge examples
_add_helper(all_additives, cur_adds)
_add_weighted_helper(all_weighted_averages, cur_wa, cur_counts)
cur_examples = current[-1]
_merge_examples(all_examples, cur_examples)
cui_prec: Dict[str, float] = {}
cui_rec: Dict[str, float] = {}
cui_f1: Dict[str, float] = {}
final_wa = [
cui_prec, cui_rec, cui_f1
# conversion from PerCUI metrics to int/float and (if needed) STD
cui_fps: IntValuedMetric = {}
cui_fns: IntValuedMetric = {}
cui_tps: IntValuedMetric = {}
cui_prec: FloatValuedMetric = {}
cui_rec: FloatValuedMetric = {}
cui_f1: FloatValuedMetric = {}
final_counts: IntValuedMetric = {}
to_change: List[Union[IntValuedMetric, FloatValuedMetric]] = [
cui_fps, cui_fns, cui_tps, final_counts,
cui_prec, cui_rec, cui_f1,
]
# just remove the weight / count
for df, d in zip(final_wa, all_weighted_averages):
# get the mean and/or std
for nr, (df, d) in enumerate(zip(to_change, all_additives + all_weighted_averages)):
for k, v in d.items():
df[k] = v[1] # only the value, ingore the weight
return (all_fps, all_fns, all_tps, final_wa[0], final_wa[1], final_wa[2],
all_cui_counts, all_examples)
if nr == 3 and not include_std:
# counts need to be added up
# NOTE: the type:ignore comment _shouldn't_ be necessary
# but mypy thinks we're setting a float or integer
# where a tuple is expected
df[k] = sum(v.vals) # type: ignore
# NOTE: The current implementation shows the sum for counts
# if not STD is required, but the mean along with the
# standard deviation if the latter is required.
elif not include_std:
df[k] = v.get_mean()
else:
# NOTE: the type:ignore comment _shouldn't_ be necessary
# but mypy thinks we're setting a tuple
# where a float or integer is expected
df[k] = (v.get_mean(), v.get_std()) # type: ignore
return (cui_fps, cui_fns, cui_tps, cui_prec, cui_rec, cui_f1,
final_counts, all_examples)


def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int = 3,
split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED, *args, **kwargs) -> Tuple:
split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED,
include_std: bool = False, *args, **kwargs) -> Tuple:
"""Get the k-fold stats for the model with the specified data.
First this will split the MCT export into `k` folds. You can do
Expand All @@ -424,13 +465,15 @@ def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int
mct_export_data (MedCATTrainerExport): The MCT export.
k (int): The number of folds. Defaults to 3.
split_type (SplitType): Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED.
include_std (bool): Whether to include stanrdard deviation. Defaults to False.
*args: Arguments passed to the `CAT.train_supervised_raw` method.
**kwargs: Keyword arguments passed to the `CAT.train_supervised_raw` method.
Returns:
Tuple: The averaged metrics.
Tuple: The averaged metrics. Potentially with their corresponding standard deviations.
"""
creator = get_fold_creator(mct_export_data, k, split_type=split_type)
folds = creator.create_folds()
per_fold_metrics = get_per_fold_metrics(cat, folds, *args, **kwargs)
return get_metrics_mean(per_fold_metrics)
means = get_metrics_mean(per_fold_metrics, include_std)
return means
137 changes: 135 additions & 2 deletions tests/stats/test_kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class KFoldWeightedDocsMetricsTests(KFoldMetricsTests):

class KFoldDuplicatedTests(KFoldCATTests):
COPIES = 3
INCLUDE_STD = False

@classmethod
def setUpClass(cls) -> None:
Expand All @@ -232,8 +233,8 @@ def setUpClass(cls) -> None:
project['documents'] = copies
cls.docs_in_copy = kfold.count_all_docs(cls.data_copied)
cls.anns_in_copy = kfold.count_all_annotations(cls.data_copied)
cls.stats_copied = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES)
cls.stats_copied_2 = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES)
cls.stats_copied = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES, include_std=cls.INCLUDE_STD)
cls.stats_copied_2 = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES, include_std=cls.INCLUDE_STD)

# some stats with real model/data will be e.g 0.99 vs 0.9747
# so in that case, lower it to 1 or so
Expand Down Expand Up @@ -296,3 +297,135 @@ def test_metrics_3_fold(self):
self.assertIn(cui, new.keys(), f"CUI '{cui}' ({cuiname}) not in new")
v1, v2 = old[cui], new[cui]
self.assertEqual(v1, v2, f"Values not equal for {cui} ({self.cat.cdb.cui2preferred_name.get(cui, cui)})")


class MetricsMeanSTDTests(unittest.TestCase):
METRICS = [
# m1
[
# fps
{"FPCUI": 3},
# fns
{"FNCUI": 4},
# tps
{"TPCUI": 5},
# prec
{"PREC_CUI": 0.3},
# recall
{"REC_CUI": 0.4},
# f1
{"F1_CUI": 0.5},
# counts
{"FPCUI": 3, "FNCUI": 4, "TPCUI": 5,
"PREC_CUI": 3, "REC_CUI": 4, "F1_CUI": 5},
# examples
{}
],
# m2
[
# fps
{"FPCUI": 13},
# fns
{"FNCUI": 14},
# tps
{"TPCUI": 15},
# prec
{"PREC_CUI": 0.9},
# recall
{"REC_CUI": 0.8},
# f1
{"F1_CUI": 0.7},
# counts
{"FPCUI": 13, "FNCUI": 14, "TPCUI": 15,
"PREC_CUI": 13, "REC_CUI": 14, "F1_CUI": 15},
# examples
{}
]
]
EXPECTED_METRICS = (
# these are simple averages and std
# fps
{"FPCUI": (8, 5.0)},
# fns
{"FNCUI": (9, 5.0)},
# tps
{"TPCUI": (10, 5.0)},
# these are WEIGHTED averages and std
# NOTE: This within numerical precision,
# but assertAlmostEqual should still work
# prec
{"PREC_CUI": (0.7875, 0.23418742493993994)},
# recall
{"REC_CUI": (0.7111111111111112, 0.16629588385661961)},
# f1
{"F1_CUI": (0.65, 0.08660254037844385)},
# counts
{"FPCUI": (8, 5.0), "FNCUI": (9, 5.0), "TPCUI": (10, 5.0),
"PREC_CUI": (8, 5.0), "REC_CUI": (9, 5.0), "F1_CUI": (10, 5.0)},
# examples
{}
)
_names = ['fps', 'fns', 'tps', 'prec', 'rec', 'f1', 'counts', 'examples']

def setUp(self) -> None:
self.metrics = kfold.get_metrics_mean(self.METRICS, include_std=True)

def test_gets_expected_means_and_std(self):
for name, part, exp_part in zip(self._names, self.metrics, self.EXPECTED_METRICS):
with self.subTest(f"{name}"):
self.assertEqual(part.keys(), exp_part.keys())
for cui in exp_part:
got_val, exp_val = part[cui], exp_part[cui]
with self.subTest(f"{name}-{cui}"):
# iterating since this way I can use the assertAlmostEqual method
for nr1, nr2 in zip(got_val, exp_val):
self.assertAlmostEqual(nr1, nr2)


class KFoldDuplicatedSTDTests(KFoldDuplicatedTests):
INCLUDE_STD = True

@classmethod
def setUpClass(cls) -> None:
super().setUpClass()
# NOTE: This would compare to the "regular" stats which
# do not contain standard deviations
# and such the results are not directly comparable
cls.test_metrics_3_fold = lambda _: None

def test_gets_std(self):
for name, stat in zip(self._names, self.stats_copied):
if name == 'examples':
continue
for cui, val in stat.items():
cuiname = self.cat.cdb.cui2preferred_name.get(cui, cui)
with self.subTest(f"{name}-{cui} [{cuiname}]"):
self.assertIsInstance(val, tuple)

def test_std_is_0(self):
# NOTE: 0 because the data is copied and as such there's no variance
for name, stat in zip(self._names, self.stats_copied):
if name == 'examples':
continue
for cui, val in stat.items():
cuiname = self.cat.cdb.cui2preferred_name.get(cui, cui)
with self.subTest(f"{name}-{cui} [{cuiname}]"):
self.assertEqual(val[1], 0, f'STD for CUI {cui} not 0')

def test_std_nonzero_diff_nr_of_folds(self):
# NOTE: this will expect some standard deviations to be non-zero
# but they will not all be non-zero (e.g) in case only 1 of the
# folds saw the example
stats = kfold.get_k_fold_stats(self.cat, self.mct_export,
k=self.COPIES - 1, include_std=True)
total_cnt = 0
std_0_cnt = 0
for name, stat in zip(self._names, stats):
if name == 'examples':
continue
for val in stat.values():
total_cnt += 1
if val[1] == 0:
std_0_cnt += 1
self.assertGreaterEqual(total_cnt - std_0_cnt, 4,
"Expected some standard deviations to be nonzeros")
Loading

0 comments on commit 5a409f2

Please sign in to comment.