Skip to content

Commit 5a409f2

Browse files
committed
Merge branch 'master' into CU-8694wh3d5-track-usage
2 parents 6dcaf70 + 018fd7a commit 5a409f2

File tree

5 files changed

+287
-63
lines changed

5 files changed

+287
-63
lines changed

medcat/cat.py

+4
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ def get_hash(self, force_recalc: bool = False) -> str:
160160
str: The resulting hash
161161
"""
162162
hasher = Hasher()
163+
if self.config.general.simple_hash:
164+
logger.info("Using simplified hashing that only takes into account the model card")
165+
hasher.update(self.get_model_card())
166+
return hasher.hexdigest()
163167
hasher.update(self.cdb.get_hash(force_recalc))
164168

165169
hasher.update(self.config.get_hash())

medcat/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,11 @@ class General(MixingConfig, BaseModel):
385385
if `long` it will be CUI | Name | Confidence"""
386386
map_cui_to_group: bool = False
387387
"""If the cdb.addl_info['cui2group'] is provided and this option enabled, each CUI will be maped to the group"""
388+
simple_hash: bool = False
389+
"""Whether to use a simple hash.
390+
391+
NOTE: While using a simple hash is faster at save time, it is less
392+
reliable due to not taking into account all the details of the changes."""
388393

389394
class Config:
390395
extra = Extra.allow

medcat/stats/kfold.py

+98-55
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterable, Callable, cast, Any
1+
from typing import Protocol, Tuple, List, Dict, Optional, Set, Iterable, Callable, cast, Any, Union
22

33
from abc import ABC, abstractmethod
44
from enum import Enum, auto
55
from copy import deepcopy
6+
from pydantic import BaseModel
67

78
import numpy as np
89

@@ -299,52 +300,72 @@ def get_per_fold_metrics(cat: CATLike, folds: List[MedCATTrainerExport],
299300
return metrics
300301

301302

302-
def _update_all_weighted_average(joined: List[Dict[str, Tuple[int, float]]],
303-
single: List[Dict[str, float]], cui2count: Dict[str, int]) -> None:
304-
if len(joined) != len(single):
305-
raise ValueError(f"Incompatible lists. Joined {len(joined)} and single {len(single)}")
306-
for j, s in zip(joined, single):
307-
_update_one_weighted_average(j, s, cui2count)
303+
def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None:
304+
for ex_type, ex_dict in cur_examples.items():
305+
if ex_type not in all_examples:
306+
all_examples[ex_type] = {}
307+
per_type_examples = all_examples[ex_type]
308+
for ex_cui, cui_examples_list in ex_dict.items():
309+
if ex_cui not in per_type_examples:
310+
per_type_examples[ex_cui] = []
311+
per_type_examples[ex_cui].extend(cui_examples_list)
312+
313+
314+
# helper types
315+
IntValuedMetric = Union[
316+
Dict[str, int],
317+
Dict[str, Tuple[int, float]]
318+
]
319+
FloatValuedMetric = Union[
320+
Dict[str, float],
321+
Dict[str, Tuple[float, float]]
322+
]
323+
308324

325+
class PerCUIMetrics(BaseModel):
326+
weights: List[Union[int, float]] = []
327+
vals: List[Union[int, float]] = []
309328

310-
def _update_one_weighted_average(joined: Dict[str, Tuple[int, float]],
311-
one: Dict[str, float],
312-
cui2count: Dict[str, int]) -> None:
313-
for k in one:
314-
if k not in joined:
315-
joined[k] = (0, 0)
316-
prev_w, prev_val = joined[k]
317-
new_w, new_val = cui2count[k], one[k]
318-
total_w = prev_w + new_w
319-
total_val = (prev_w * prev_val + new_w * new_val) / total_w
320-
joined[k] = (total_w, total_val)
329+
def add(self, val, weight: int = 1):
330+
self.weights.append(weight)
331+
self.vals.append(val)
321332

333+
def get_mean(self):
334+
return sum(w * v for w, v in zip(self.weights, self.vals)) / sum(self.weights)
322335

323-
def _update_all_add(joined: List[Dict[str, int]], single: List[Dict[str, int]]) -> None:
336+
def get_std(self):
337+
mean = self.get_mean()
338+
return (sum(w * (v - mean)**2 for w, v in zip(self.weights, self.vals)) / sum(self.weights))**.5
339+
340+
341+
def _add_helper(joined: List[Dict[str, PerCUIMetrics]],
342+
single: List[Dict[str, int]]) -> None:
324343
if len(joined) != len(single):
325344
raise ValueError(f"Incompatible number of stuff: {len(joined)} vs {len(single)}")
326345
for j, s in zip(joined, single):
327346
for k, v in s.items():
328-
j[k] = j.get(k, 0) + v
347+
if k not in j:
348+
j[k] = PerCUIMetrics()
349+
j[k].add(v)
329350

330351

331-
def _merge_examples(all_examples: Dict, cur_examples: Dict) -> None:
332-
for ex_type, ex_dict in cur_examples.items():
333-
if ex_type not in all_examples:
334-
all_examples[ex_type] = {}
335-
per_type_examples = all_examples[ex_type]
336-
for ex_cui, cui_examples_list in ex_dict.items():
337-
if ex_cui not in per_type_examples:
338-
per_type_examples[ex_cui] = []
339-
per_type_examples[ex_cui].extend(cui_examples_list)
352+
def _add_weighted_helper(joined: List[Dict[str, PerCUIMetrics]],
353+
single: List[Dict[str, float]],
354+
cui2count: Dict[str, int]) -> None:
355+
for j, s in zip(joined, single):
356+
for k, v in s.items():
357+
if k not in j:
358+
j[k] = PerCUIMetrics()
359+
j[k].add(v, cui2count[k])
340360

341361

342-
def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]]
343-
) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]:
362+
def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]],
363+
include_std: bool) -> Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]:
344364
"""The the mean of the provided metrics.
345365
346366
Args:
347367
metrics (List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]): The metrics.
368+
include_std (bool): Whether to include the standard deviation.
348369
349370
Returns:
350371
fps (dict):
@@ -365,15 +386,15 @@ def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dic
365386
Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][<list_of_examples>].
366387
"""
367388
# additives
368-
all_fps: Dict[str, int] = {}
369-
all_fns: Dict[str, int] = {}
370-
all_tps: Dict[str, int] = {}
389+
all_fps: Dict[str, PerCUIMetrics] = {}
390+
all_fns: Dict[str, PerCUIMetrics] = {}
391+
all_tps: Dict[str, PerCUIMetrics] = {}
371392
# weighted-averages
372-
all_cui_prec: Dict[str, Tuple[int, float]] = {}
373-
all_cui_rec: Dict[str, Tuple[int, float]] = {}
374-
all_cui_f1: Dict[str, Tuple[int, float]] = {}
393+
all_cui_prec: Dict[str, PerCUIMetrics] = {}
394+
all_cui_rec: Dict[str, PerCUIMetrics] = {}
395+
all_cui_f1: Dict[str, PerCUIMetrics] = {}
375396
# additive
376-
all_cui_counts: Dict[str, int] = {}
397+
all_cui_counts: Dict[str, PerCUIMetrics] = {}
377398
# combined
378399
all_additives = [
379400
all_fps, all_fns, all_tps, all_cui_counts
@@ -386,29 +407,49 @@ def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dic
386407
for current in metrics:
387408
cur_wa: list = list(current[3:-2])
388409
cur_counts = current[-2]
389-
_update_all_weighted_average(all_weighted_averages, cur_wa, cur_counts)
390-
# update ones that just need to be added up
391410
cur_adds = list(current[:3]) + [cur_counts]
392-
_update_all_add(all_additives, cur_adds)
393-
# merge examples
411+
_add_helper(all_additives, cur_adds)
412+
_add_weighted_helper(all_weighted_averages, cur_wa, cur_counts)
394413
cur_examples = current[-1]
395414
_merge_examples(all_examples, cur_examples)
396-
cui_prec: Dict[str, float] = {}
397-
cui_rec: Dict[str, float] = {}
398-
cui_f1: Dict[str, float] = {}
399-
final_wa = [
400-
cui_prec, cui_rec, cui_f1
415+
# conversion from PerCUI metrics to int/float and (if needed) STD
416+
cui_fps: IntValuedMetric = {}
417+
cui_fns: IntValuedMetric = {}
418+
cui_tps: IntValuedMetric = {}
419+
cui_prec: FloatValuedMetric = {}
420+
cui_rec: FloatValuedMetric = {}
421+
cui_f1: FloatValuedMetric = {}
422+
final_counts: IntValuedMetric = {}
423+
to_change: List[Union[IntValuedMetric, FloatValuedMetric]] = [
424+
cui_fps, cui_fns, cui_tps, final_counts,
425+
cui_prec, cui_rec, cui_f1,
401426
]
402-
# just remove the weight / count
403-
for df, d in zip(final_wa, all_weighted_averages):
427+
# get the mean and/or std
428+
for nr, (df, d) in enumerate(zip(to_change, all_additives + all_weighted_averages)):
404429
for k, v in d.items():
405-
df[k] = v[1] # only the value, ingore the weight
406-
return (all_fps, all_fns, all_tps, final_wa[0], final_wa[1], final_wa[2],
407-
all_cui_counts, all_examples)
430+
if nr == 3 and not include_std:
431+
# counts need to be added up
432+
# NOTE: the type:ignore comment _shouldn't_ be necessary
433+
# but mypy thinks we're setting a float or integer
434+
# where a tuple is expected
435+
df[k] = sum(v.vals) # type: ignore
436+
# NOTE: The current implementation shows the sum for counts
437+
# if not STD is required, but the mean along with the
438+
# standard deviation if the latter is required.
439+
elif not include_std:
440+
df[k] = v.get_mean()
441+
else:
442+
# NOTE: the type:ignore comment _shouldn't_ be necessary
443+
# but mypy thinks we're setting a tuple
444+
# where a float or integer is expected
445+
df[k] = (v.get_mean(), v.get_std()) # type: ignore
446+
return (cui_fps, cui_fns, cui_tps, cui_prec, cui_rec, cui_f1,
447+
final_counts, all_examples)
408448

409449

410450
def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int = 3,
411-
split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED, *args, **kwargs) -> Tuple:
451+
split_type: SplitType = SplitType.DOCUMENTS_WEIGHTED,
452+
include_std: bool = False, *args, **kwargs) -> Tuple:
412453
"""Get the k-fold stats for the model with the specified data.
413454
414455
First this will split the MCT export into `k` folds. You can do
@@ -424,13 +465,15 @@ def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int
424465
mct_export_data (MedCATTrainerExport): The MCT export.
425466
k (int): The number of folds. Defaults to 3.
426467
split_type (SplitType): Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED.
468+
include_std (bool): Whether to include stanrdard deviation. Defaults to False.
427469
*args: Arguments passed to the `CAT.train_supervised_raw` method.
428470
**kwargs: Keyword arguments passed to the `CAT.train_supervised_raw` method.
429471
430472
Returns:
431-
Tuple: The averaged metrics.
473+
Tuple: The averaged metrics. Potentially with their corresponding standard deviations.
432474
"""
433475
creator = get_fold_creator(mct_export_data, k, split_type=split_type)
434476
folds = creator.create_folds()
435477
per_fold_metrics = get_per_fold_metrics(cat, folds, *args, **kwargs)
436-
return get_metrics_mean(per_fold_metrics)
478+
means = get_metrics_mean(per_fold_metrics, include_std)
479+
return means

tests/stats/test_kfold.py

+135-2
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ class KFoldWeightedDocsMetricsTests(KFoldMetricsTests):
216216

217217
class KFoldDuplicatedTests(KFoldCATTests):
218218
COPIES = 3
219+
INCLUDE_STD = False
219220

220221
@classmethod
221222
def setUpClass(cls) -> None:
@@ -232,8 +233,8 @@ def setUpClass(cls) -> None:
232233
project['documents'] = copies
233234
cls.docs_in_copy = kfold.count_all_docs(cls.data_copied)
234235
cls.anns_in_copy = kfold.count_all_annotations(cls.data_copied)
235-
cls.stats_copied = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES)
236-
cls.stats_copied_2 = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES)
236+
cls.stats_copied = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES, include_std=cls.INCLUDE_STD)
237+
cls.stats_copied_2 = kfold.get_k_fold_stats(cls.cat, cls.data_copied, k=cls.COPIES, include_std=cls.INCLUDE_STD)
237238

238239
# some stats with real model/data will be e.g 0.99 vs 0.9747
239240
# so in that case, lower it to 1 or so
@@ -296,3 +297,135 @@ def test_metrics_3_fold(self):
296297
self.assertIn(cui, new.keys(), f"CUI '{cui}' ({cuiname}) not in new")
297298
v1, v2 = old[cui], new[cui]
298299
self.assertEqual(v1, v2, f"Values not equal for {cui} ({self.cat.cdb.cui2preferred_name.get(cui, cui)})")
300+
301+
302+
class MetricsMeanSTDTests(unittest.TestCase):
303+
METRICS = [
304+
# m1
305+
[
306+
# fps
307+
{"FPCUI": 3},
308+
# fns
309+
{"FNCUI": 4},
310+
# tps
311+
{"TPCUI": 5},
312+
# prec
313+
{"PREC_CUI": 0.3},
314+
# recall
315+
{"REC_CUI": 0.4},
316+
# f1
317+
{"F1_CUI": 0.5},
318+
# counts
319+
{"FPCUI": 3, "FNCUI": 4, "TPCUI": 5,
320+
"PREC_CUI": 3, "REC_CUI": 4, "F1_CUI": 5},
321+
# examples
322+
{}
323+
],
324+
# m2
325+
[
326+
# fps
327+
{"FPCUI": 13},
328+
# fns
329+
{"FNCUI": 14},
330+
# tps
331+
{"TPCUI": 15},
332+
# prec
333+
{"PREC_CUI": 0.9},
334+
# recall
335+
{"REC_CUI": 0.8},
336+
# f1
337+
{"F1_CUI": 0.7},
338+
# counts
339+
{"FPCUI": 13, "FNCUI": 14, "TPCUI": 15,
340+
"PREC_CUI": 13, "REC_CUI": 14, "F1_CUI": 15},
341+
# examples
342+
{}
343+
]
344+
]
345+
EXPECTED_METRICS = (
346+
# these are simple averages and std
347+
# fps
348+
{"FPCUI": (8, 5.0)},
349+
# fns
350+
{"FNCUI": (9, 5.0)},
351+
# tps
352+
{"TPCUI": (10, 5.0)},
353+
# these are WEIGHTED averages and std
354+
# NOTE: This within numerical precision,
355+
# but assertAlmostEqual should still work
356+
# prec
357+
{"PREC_CUI": (0.7875, 0.23418742493993994)},
358+
# recall
359+
{"REC_CUI": (0.7111111111111112, 0.16629588385661961)},
360+
# f1
361+
{"F1_CUI": (0.65, 0.08660254037844385)},
362+
# counts
363+
{"FPCUI": (8, 5.0), "FNCUI": (9, 5.0), "TPCUI": (10, 5.0),
364+
"PREC_CUI": (8, 5.0), "REC_CUI": (9, 5.0), "F1_CUI": (10, 5.0)},
365+
# examples
366+
{}
367+
)
368+
_names = ['fps', 'fns', 'tps', 'prec', 'rec', 'f1', 'counts', 'examples']
369+
370+
def setUp(self) -> None:
371+
self.metrics = kfold.get_metrics_mean(self.METRICS, include_std=True)
372+
373+
def test_gets_expected_means_and_std(self):
374+
for name, part, exp_part in zip(self._names, self.metrics, self.EXPECTED_METRICS):
375+
with self.subTest(f"{name}"):
376+
self.assertEqual(part.keys(), exp_part.keys())
377+
for cui in exp_part:
378+
got_val, exp_val = part[cui], exp_part[cui]
379+
with self.subTest(f"{name}-{cui}"):
380+
# iterating since this way I can use the assertAlmostEqual method
381+
for nr1, nr2 in zip(got_val, exp_val):
382+
self.assertAlmostEqual(nr1, nr2)
383+
384+
385+
class KFoldDuplicatedSTDTests(KFoldDuplicatedTests):
386+
INCLUDE_STD = True
387+
388+
@classmethod
389+
def setUpClass(cls) -> None:
390+
super().setUpClass()
391+
# NOTE: This would compare to the "regular" stats which
392+
# do not contain standard deviations
393+
# and such the results are not directly comparable
394+
cls.test_metrics_3_fold = lambda _: None
395+
396+
def test_gets_std(self):
397+
for name, stat in zip(self._names, self.stats_copied):
398+
if name == 'examples':
399+
continue
400+
for cui, val in stat.items():
401+
cuiname = self.cat.cdb.cui2preferred_name.get(cui, cui)
402+
with self.subTest(f"{name}-{cui} [{cuiname}]"):
403+
self.assertIsInstance(val, tuple)
404+
405+
def test_std_is_0(self):
406+
# NOTE: 0 because the data is copied and as such there's no variance
407+
for name, stat in zip(self._names, self.stats_copied):
408+
if name == 'examples':
409+
continue
410+
for cui, val in stat.items():
411+
cuiname = self.cat.cdb.cui2preferred_name.get(cui, cui)
412+
with self.subTest(f"{name}-{cui} [{cuiname}]"):
413+
self.assertEqual(val[1], 0, f'STD for CUI {cui} not 0')
414+
415+
def test_std_nonzero_diff_nr_of_folds(self):
416+
# NOTE: this will expect some standard deviations to be non-zero
417+
# but they will not all be non-zero (e.g) in case only 1 of the
418+
# folds saw the example
419+
stats = kfold.get_k_fold_stats(self.cat, self.mct_export,
420+
k=self.COPIES - 1, include_std=True)
421+
total_cnt = 0
422+
std_0_cnt = 0
423+
for name, stat in zip(self._names, stats):
424+
if name == 'examples':
425+
continue
426+
for val in stat.values():
427+
total_cnt += 1
428+
if val[1] == 0:
429+
std_0_cnt += 1
430+
self.assertGreaterEqual(total_cnt - std_0_cnt, 4,
431+
"Expected some standard deviations to be nonzeros")

0 commit comments

Comments
 (0)