1
- from typing import Protocol , Tuple , List , Dict , Optional , Set , Iterable , Callable , cast , Any
1
+ from typing import Protocol , Tuple , List , Dict , Optional , Set , Iterable , Callable , cast , Any , Union
2
2
3
3
from abc import ABC , abstractmethod
4
4
from enum import Enum , auto
5
5
from copy import deepcopy
6
+ from pydantic import BaseModel
6
7
7
8
import numpy as np
8
9
@@ -299,52 +300,72 @@ def get_per_fold_metrics(cat: CATLike, folds: List[MedCATTrainerExport],
299
300
return metrics
300
301
301
302
302
- def _update_all_weighted_average (joined : List [Dict [str , Tuple [int , float ]]],
303
- single : List [Dict [str , float ]], cui2count : Dict [str , int ]) -> None :
304
- if len (joined ) != len (single ):
305
- raise ValueError (f"Incompatible lists. Joined { len (joined )} and single { len (single )} " )
306
- for j , s in zip (joined , single ):
307
- _update_one_weighted_average (j , s , cui2count )
303
+ def _merge_examples (all_examples : Dict , cur_examples : Dict ) -> None :
304
+ for ex_type , ex_dict in cur_examples .items ():
305
+ if ex_type not in all_examples :
306
+ all_examples [ex_type ] = {}
307
+ per_type_examples = all_examples [ex_type ]
308
+ for ex_cui , cui_examples_list in ex_dict .items ():
309
+ if ex_cui not in per_type_examples :
310
+ per_type_examples [ex_cui ] = []
311
+ per_type_examples [ex_cui ].extend (cui_examples_list )
312
+
313
+
314
+ # helper types
315
+ IntValuedMetric = Union [
316
+ Dict [str , int ],
317
+ Dict [str , Tuple [int , float ]]
318
+ ]
319
+ FloatValuedMetric = Union [
320
+ Dict [str , float ],
321
+ Dict [str , Tuple [float , float ]]
322
+ ]
323
+
308
324
325
+ class PerCUIMetrics (BaseModel ):
326
+ weights : List [Union [int , float ]] = []
327
+ vals : List [Union [int , float ]] = []
309
328
310
- def _update_one_weighted_average (joined : Dict [str , Tuple [int , float ]],
311
- one : Dict [str , float ],
312
- cui2count : Dict [str , int ]) -> None :
313
- for k in one :
314
- if k not in joined :
315
- joined [k ] = (0 , 0 )
316
- prev_w , prev_val = joined [k ]
317
- new_w , new_val = cui2count [k ], one [k ]
318
- total_w = prev_w + new_w
319
- total_val = (prev_w * prev_val + new_w * new_val ) / total_w
320
- joined [k ] = (total_w , total_val )
329
+ def add (self , val , weight : int = 1 ):
330
+ self .weights .append (weight )
331
+ self .vals .append (val )
321
332
333
+ def get_mean (self ):
334
+ return sum (w * v for w , v in zip (self .weights , self .vals )) / sum (self .weights )
322
335
323
- def _update_all_add (joined : List [Dict [str , int ]], single : List [Dict [str , int ]]) -> None :
336
+ def get_std (self ):
337
+ mean = self .get_mean ()
338
+ return (sum (w * (v - mean )** 2 for w , v in zip (self .weights , self .vals )) / sum (self .weights ))** .5
339
+
340
+
341
+ def _add_helper (joined : List [Dict [str , PerCUIMetrics ]],
342
+ single : List [Dict [str , int ]]) -> None :
324
343
if len (joined ) != len (single ):
325
344
raise ValueError (f"Incompatible number of stuff: { len (joined )} vs { len (single )} " )
326
345
for j , s in zip (joined , single ):
327
346
for k , v in s .items ():
328
- j [k ] = j .get (k , 0 ) + v
347
+ if k not in j :
348
+ j [k ] = PerCUIMetrics ()
349
+ j [k ].add (v )
329
350
330
351
331
- def _merge_examples (all_examples : Dict , cur_examples : Dict ) -> None :
332
- for ex_type , ex_dict in cur_examples .items ():
333
- if ex_type not in all_examples :
334
- all_examples [ex_type ] = {}
335
- per_type_examples = all_examples [ex_type ]
336
- for ex_cui , cui_examples_list in ex_dict .items ():
337
- if ex_cui not in per_type_examples :
338
- per_type_examples [ex_cui ] = []
339
- per_type_examples [ex_cui ].extend (cui_examples_list )
352
+ def _add_weighted_helper (joined : List [Dict [str , PerCUIMetrics ]],
353
+ single : List [Dict [str , float ]],
354
+ cui2count : Dict [str , int ]) -> None :
355
+ for j , s in zip (joined , single ):
356
+ for k , v in s .items ():
357
+ if k not in j :
358
+ j [k ] = PerCUIMetrics ()
359
+ j [k ].add (v , cui2count [k ])
340
360
341
361
342
- def get_metrics_mean (metrics : List [Tuple [Dict , Dict , Dict , Dict , Dict , Dict , Dict , Dict ]]
343
- ) -> Tuple [Dict , Dict , Dict , Dict , Dict , Dict , Dict , Dict ]:
362
+ def get_metrics_mean (metrics : List [Tuple [Dict , Dict , Dict , Dict , Dict , Dict , Dict , Dict ]],
363
+ include_std : bool ) -> Tuple [Dict , Dict , Dict , Dict , Dict , Dict , Dict , Dict ]:
344
364
"""The the mean of the provided metrics.
345
365
346
366
Args:
347
367
metrics (List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dict, Dict]): The metrics.
368
+ include_std (bool): Whether to include the standard deviation.
348
369
349
370
Returns:
350
371
fps (dict):
@@ -365,15 +386,15 @@ def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dic
365
386
Examples for each of the fp, fn, tp. Format will be examples['fp']['cui'][<list_of_examples>].
366
387
"""
367
388
# additives
368
- all_fps : Dict [str , int ] = {}
369
- all_fns : Dict [str , int ] = {}
370
- all_tps : Dict [str , int ] = {}
389
+ all_fps : Dict [str , PerCUIMetrics ] = {}
390
+ all_fns : Dict [str , PerCUIMetrics ] = {}
391
+ all_tps : Dict [str , PerCUIMetrics ] = {}
371
392
# weighted-averages
372
- all_cui_prec : Dict [str , Tuple [ int , float ] ] = {}
373
- all_cui_rec : Dict [str , Tuple [ int , float ] ] = {}
374
- all_cui_f1 : Dict [str , Tuple [ int , float ] ] = {}
393
+ all_cui_prec : Dict [str , PerCUIMetrics ] = {}
394
+ all_cui_rec : Dict [str , PerCUIMetrics ] = {}
395
+ all_cui_f1 : Dict [str , PerCUIMetrics ] = {}
375
396
# additive
376
- all_cui_counts : Dict [str , int ] = {}
397
+ all_cui_counts : Dict [str , PerCUIMetrics ] = {}
377
398
# combined
378
399
all_additives = [
379
400
all_fps , all_fns , all_tps , all_cui_counts
@@ -386,29 +407,49 @@ def get_metrics_mean(metrics: List[Tuple[Dict, Dict, Dict, Dict, Dict, Dict, Dic
386
407
for current in metrics :
387
408
cur_wa : list = list (current [3 :- 2 ])
388
409
cur_counts = current [- 2 ]
389
- _update_all_weighted_average (all_weighted_averages , cur_wa , cur_counts )
390
- # update ones that just need to be added up
391
410
cur_adds = list (current [:3 ]) + [cur_counts ]
392
- _update_all_add (all_additives , cur_adds )
393
- # merge examples
411
+ _add_helper (all_additives , cur_adds )
412
+ _add_weighted_helper ( all_weighted_averages , cur_wa , cur_counts )
394
413
cur_examples = current [- 1 ]
395
414
_merge_examples (all_examples , cur_examples )
396
- cui_prec : Dict [str , float ] = {}
397
- cui_rec : Dict [str , float ] = {}
398
- cui_f1 : Dict [str , float ] = {}
399
- final_wa = [
400
- cui_prec , cui_rec , cui_f1
415
+ # conversion from PerCUI metrics to int/float and (if needed) STD
416
+ cui_fps : IntValuedMetric = {}
417
+ cui_fns : IntValuedMetric = {}
418
+ cui_tps : IntValuedMetric = {}
419
+ cui_prec : FloatValuedMetric = {}
420
+ cui_rec : FloatValuedMetric = {}
421
+ cui_f1 : FloatValuedMetric = {}
422
+ final_counts : IntValuedMetric = {}
423
+ to_change : List [Union [IntValuedMetric , FloatValuedMetric ]] = [
424
+ cui_fps , cui_fns , cui_tps , final_counts ,
425
+ cui_prec , cui_rec , cui_f1 ,
401
426
]
402
- # just remove the weight / count
403
- for df , d in zip (final_wa , all_weighted_averages ):
427
+ # get the mean and/or std
428
+ for nr , ( df , d ) in enumerate ( zip (to_change , all_additives + all_weighted_averages ) ):
404
429
for k , v in d .items ():
405
- df [k ] = v [1 ] # only the value, ingore the weight
406
- return (all_fps , all_fns , all_tps , final_wa [0 ], final_wa [1 ], final_wa [2 ],
407
- all_cui_counts , all_examples )
430
+ if nr == 3 and not include_std :
431
+ # counts need to be added up
432
+ # NOTE: the type:ignore comment _shouldn't_ be necessary
433
+ # but mypy thinks we're setting a float or integer
434
+ # where a tuple is expected
435
+ df [k ] = sum (v .vals ) # type: ignore
436
+ # NOTE: The current implementation shows the sum for counts
437
+ # if not STD is required, but the mean along with the
438
+ # standard deviation if the latter is required.
439
+ elif not include_std :
440
+ df [k ] = v .get_mean ()
441
+ else :
442
+ # NOTE: the type:ignore comment _shouldn't_ be necessary
443
+ # but mypy thinks we're setting a tuple
444
+ # where a float or integer is expected
445
+ df [k ] = (v .get_mean (), v .get_std ()) # type: ignore
446
+ return (cui_fps , cui_fns , cui_tps , cui_prec , cui_rec , cui_f1 ,
447
+ final_counts , all_examples )
408
448
409
449
410
450
def get_k_fold_stats (cat : CATLike , mct_export_data : MedCATTrainerExport , k : int = 3 ,
411
- split_type : SplitType = SplitType .DOCUMENTS_WEIGHTED , * args , ** kwargs ) -> Tuple :
451
+ split_type : SplitType = SplitType .DOCUMENTS_WEIGHTED ,
452
+ include_std : bool = False , * args , ** kwargs ) -> Tuple :
412
453
"""Get the k-fold stats for the model with the specified data.
413
454
414
455
First this will split the MCT export into `k` folds. You can do
@@ -424,13 +465,15 @@ def get_k_fold_stats(cat: CATLike, mct_export_data: MedCATTrainerExport, k: int
424
465
mct_export_data (MedCATTrainerExport): The MCT export.
425
466
k (int): The number of folds. Defaults to 3.
426
467
split_type (SplitType): Whether to use annodations or docs. Defaults to DOCUMENTS_WEIGHTED.
468
+ include_std (bool): Whether to include stanrdard deviation. Defaults to False.
427
469
*args: Arguments passed to the `CAT.train_supervised_raw` method.
428
470
**kwargs: Keyword arguments passed to the `CAT.train_supervised_raw` method.
429
471
430
472
Returns:
431
- Tuple: The averaged metrics.
473
+ Tuple: The averaged metrics. Potentially with their corresponding standard deviations.
432
474
"""
433
475
creator = get_fold_creator (mct_export_data , k , split_type = split_type )
434
476
folds = creator .create_folds ()
435
477
per_fold_metrics = get_per_fold_metrics (cat , folds , * args , ** kwargs )
436
- return get_metrics_mean (per_fold_metrics )
478
+ means = get_metrics_mean (per_fold_metrics , include_std )
479
+ return means
0 commit comments