Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8694fk90t (almost) only primitive config #425

Merged
merged 16 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
import aiofiles
import numpy as np
from typing import Dict, Set, Optional, List, Union, cast
from functools import partial
import os

from medcat import __version__
from medcat.utils.hasher import Hasher
from medcat.utils.matutils import unitvec
from medcat.utils.ml_utils import get_lr_linking
from medcat.utils.decorators import deprecated
from medcat.config import Config, weighted_average, workers
from medcat.config import Config, workers
from medcat.utils.saving.serializer import CDBSerializer
from medcat.utils.config_utils import get_and_del_weighted_average_from_config
from medcat.utils.config_utils import default_weighted_average
from medcat.utils.config_utils import ensure_backward_compatibility
from medcat.utils.config_utils import fix_waf_lambda, attempt_fix_weighted_average_function


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -98,6 +101,7 @@ def __init__(self, config: Union[Config, None] = None) -> None:
self.vocab: Dict = {} # Vocabulary of all words ever in our cdb
self._optim_params = None
self.is_dirty = False
self._init_waf_from_config()
self._hash: Optional[str] = None
# the config hash is kept track of here so that
# the CDB hash can be re-calculated when the config changes
Expand All @@ -107,6 +111,18 @@ def __init__(self, config: Union[Config, None] = None) -> None:
self._config_hash: Optional[str] = None
self._memory_optimised_parts: Set[str] = set()

def _init_waf_from_config(self):
waf = get_and_del_weighted_average_from_config(self.config)
if waf is not None:
logger.info("Using (potentially) custom value of weighed "
"average function")
self.weighted_average_function = attempt_fix_weighted_average_function(waf)
elif hasattr(self, 'weighted_average_function'):
# keep existing
pass
else:
self.weighted_average_function = default_weighted_average

def get_name(self, cui: str) -> str:
"""Returns preferred name if it exists, otherwise it will return
the longest name assigned to the concept.
Expand Down Expand Up @@ -558,6 +574,8 @@ def load_config(self, config_path: str) -> None:
# this should be the behaviour for all newer models
self.config = cast(Config, Config.load(config_path))
logger.debug("Loaded config from CDB from %s", config_path)
# new config, potentially new weighted_average_function to read
self._init_waf_from_config()
# mark config read from file
self._config_from_file = True

Expand All @@ -582,7 +600,8 @@ def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[
ser = CDBSerializer(path, json_path)
cdb = ser.deserialize(CDB)
cls._check_medcat_version(cdb.config.asdict())
cls._ensure_backward_compatibility(cdb.config)
fix_waf_lambda(cdb)
ensure_backward_compatibility(cdb.config, workers)

# Overwrite the config with new data
if config_dict is not None:
Expand Down Expand Up @@ -855,19 +874,6 @@ def most_similar(self,

return res

@staticmethod
def _ensure_backward_compatibility(config: Config) -> None:
# Hacky way of supporting old CDBs
weighted_average_function = config.linking.weighted_average_function
if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "<lambda>":
# the following type ignoring is for mypy because it is unable to detect the signature
config.linking.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore
if config.general.workers is None:
config.general.workers = workers()
disabled_comps = config.general.spacy_disabled_components
if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps:
config.general.spacy_disabled_components.append('lemmatizer')

@classmethod
def _check_medcat_version(cls, config_data: Dict) -> None:
cdb_medcat_version = config_data.get('version', {}).get('medcat_version', None)
Expand Down
20 changes: 12 additions & 8 deletions medcat/config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from datetime import datetime
from pydantic import BaseModel, Extra, ValidationError
from pydantic.fields import ModelField
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type
from multiprocessing import cpu_count
import logging
import jsonpickle
import json
from functools import partial
import re

from medcat.utils.hasher import Hasher
from medcat.utils.matutils import intersect_nonempty_set
from medcat.utils.config_utils import attempt_fix_weighted_average_function
from medcat.utils.config_utils import weighted_average
from medcat.utils.config_utils import weighted_average, is_old_type_config_dict
from medcat.utils.saving.coding import CustomDelegatingEncoder, default_hook


logger = logging.getLogger(__name__)
Expand All @@ -31,6 +33,7 @@ def __getitem__(self, arg: str) -> Any:
raise KeyError from e

def __setattr__(self, arg: str, val) -> None:
# TODO: remove this in the future when we stop stupporting this in config
if isinstance(self, Linking) and arg == "weighted_average_function":
val = attempt_fix_weighted_average_function(val)
super().__setattr__(arg, val)
Expand Down Expand Up @@ -103,8 +106,8 @@ def save(self, save_path: str) -> None:
save_path(str): Where to save the created json file
"""
# We want to save the dict here, not the whole class
json_string = jsonpickle.encode(
{field: getattr(self, field) for field in self.fields()})
json_string = json.dumps(self.asdict(), cls=cast(Type[json.JSONEncoder],
CustomDelegatingEncoder.def_inst))

with open(save_path, 'w') as f:
f.write(json_string)
Expand Down Expand Up @@ -204,7 +207,11 @@ def load(cls, save_path: str) -> "MixingConfig":

# Read the jsonpickle string
with open(save_path) as f:
config_dict = jsonpickle.decode(f.read())
config_dict = json.load(f, object_hook=default_hook)
if is_old_type_config_dict(config_dict):
logger.warning("Loading an old type of config (jsonpickle) from '%s'",
save_path)
config_dict = jsonpickle.decode(f.read())

config.merge_config(config_dict)

Expand Down Expand Up @@ -511,9 +518,6 @@ class Linking(MixingConfig, BaseModel):
similarity calculation and will have a similarity of -1."""
always_calculate_similarity: bool = False
"""Do we want to calculate context similarity even for concepts that are not ambigous."""
weighted_average_function: Callable[..., Any] = _DEFAULT_PARTIAL
"""Weights for a weighted average
'weighted_average_function': partial(weighted_average, factor=0.02),"""
calculate_dynamic_threshold: bool = False
"""Concepts below this similarity will be ignored. Type can be static/dynamic - if dynamic each CUI has a different TH
and it is calcualted as the average confidence for that CUI * similarity_threshold. Take care that dynamic works only
Expand Down
4 changes: 2 additions & 2 deletions medcat/linking/vector_context_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict:

values = []
# Add left
values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_)
values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_)
for step, tkn in enumerate(tokens_left) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])

if not self.config.linking['context_ignore_center_tokens']:
Expand All @@ -83,7 +83,7 @@ def get_context_vectors(self, entity: Span, doc: Doc, cui=None) -> Dict:
values.extend([self.vocab.vec(tkn.lower_) for tkn in tokens_center if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])

# Add right
values.extend([self.config.linking['weighted_average_function'](step) * self.vocab.vec(tkn.lower_)
values.extend([self.cdb.weighted_average_function(step) * self.vocab.vec(tkn.lower_)
for step, tkn in enumerate(tokens_right) if tkn.lower_ in self.vocab and self.vocab.vec(tkn.lower_) is not None])

if len(values) > 0:
Expand Down
51 changes: 50 additions & 1 deletion medcat/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,64 @@
from functools import partial
from typing import Callable
from typing import Callable, Optional, Protocol
import logging
from pydantic import BaseModel


class WAFCarrier(Protocol):

@property
def weighted_average_function(self) -> Callable[[float], int]:
pass


logger = logging.getLogger(__name__)


def is_old_type_config_dict(d: dict) -> bool:
if set(('py/object', 'py/state')) <= set(d.keys()):
return True
return False


def fix_waf_lambda(carrier: WAFCarrier) -> None:
weighted_average_function = carrier.weighted_average_function # type: ignore
if callable(weighted_average_function) and getattr(weighted_average_function, "__name__", None) == "<lambda>":
# the following type ignoring is for mypy because it is unable to detect the signature
carrier.weighted_average_function = partial(weighted_average, factor=0.0004) # type: ignore


# NOTE: This method is a hacky workaround. The type ignores are because I cannot
# import config here since it would produce a circular import
def ensure_backward_compatibility(config: BaseModel, workers: Callable[[], int]) -> None:
# Hacky way of supporting old CDBs
if hasattr(config.linking, 'weighted_average_function'): # type: ignore
fix_waf_lambda(config.linking) # type: ignore
if config.general.workers is None: # type: ignore
config.general.workers = workers() # type: ignore
disabled_comps = config.general.spacy_disabled_components # type: ignore
if 'tagger' in disabled_comps and 'lemmatizer' not in disabled_comps:
config.general.spacy_disabled_components.append('lemmatizer') # type: ignore


def get_and_del_weighted_average_from_config(config: BaseModel) -> Optional[Callable[[int], float]]:
if not hasattr(config, 'linking'):
return None
linking = config.linking
if not hasattr(linking, 'weighted_average_function'):
return None
waf = linking.weighted_average_function
delattr(linking, 'weighted_average_function')
return waf


def weighted_average(step: int, factor: float) -> float:
return max(0.1, 1 - (step ** 2 * factor))


def default_weighted_average(step: int) -> float:
return weighted_average(step, factor=0.0004)


def attempt_fix_weighted_average_function(waf: Callable[[int], float]
) -> Callable[[int], float]:
"""Attempf fix weighted_average_function.
Expand Down
32 changes: 30 additions & 2 deletions medcat/utils/saving/coding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Protocol, runtime_checkable, List, Union, Type, Optional, Callable

import json
import re


@runtime_checkable
Expand Down Expand Up @@ -35,6 +36,7 @@ def try_encode(self, obj: object) -> Any:


SET_IDENTIFIER = '==SET=='
PATTERN_IDENTIFIER = "==PATTERN=="


class SetEncoder(PartEncoder):
Expand Down Expand Up @@ -79,10 +81,34 @@ def try_decode(self, dct: dict) -> Union[dict, set]:
return dct


class PatternEncoder(PartEncoder):

def try_encode(self, obj):
if isinstance(obj, re.Pattern):
return {PATTERN_IDENTIFIER: obj.pattern}
raise UnsuitableObject()


class PatternDecoder(PartDecoder):

def try_decode(self, dct: dict) -> Union[dict, re.Pattern]:
"""Decode re.Patttern from input dicts.

Args:
dct (dict): The input dict

Returns:
Union[dict, set]: The original dict if this was not a serialized pattern, the pattern otherwise
"""
if PATTERN_IDENTIFIER in dct:
return re.compile(dct[PATTERN_IDENTIFIER])
return dct


PostProcessor = Callable[[Any], None] # CDB -> None

DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, ]
DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, ]
DEFAULT_ENCODERS: List[Type[PartEncoder]] = [SetEncoder, PatternEncoder]
DEFAULT_DECODERS: List[Type[PartDecoder]] = [SetDecoder, PatternDecoder]
LOADING_POSTPROCESSORS: List[PostProcessor] = []


Expand Down Expand Up @@ -133,6 +159,8 @@ def object_hook(self, dct: dict) -> Any:
def def_inst(cls) -> 'CustomDelegatingDecoder':
if cls._def_inst is None:
cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS])
elif len(cls._def_inst._delegates) < len(DEFAULT_DECODERS):
cls._def_inst = cls([_cls() for _cls in DEFAULT_DECODERS])
return cls._def_inst


Expand Down
2 changes: 1 addition & 1 deletion tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ class TestLoadingOldWeights(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.cdb = CDB.load(cls.cdb_path)
cls.wf = cls.cdb.config.linking.weighted_average_function
cls.wf = cls.cdb.weighted_average_function

def test_can_call_weights(self):
res = self.wf(step=1)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,13 @@ def test_config_hash_recalc_same_changed(self):
h2 = config.get_hash()
self.assertEqual(h1, h2)

def test_can_save_load(self):
config = Config()
with tempfile.NamedTemporaryFile() as file:
config.save(file.name)
config2 = Config.load(file.name)
self.assertEqual(config, config2)


class ConfigLinkingFiltersTests(unittest.TestCase):

Expand Down
4 changes: 0 additions & 4 deletions tests/utils/saving/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,6 @@ def test_round_trip(self):
# The spacy model has full path in the loaded model, thus won't be equal
cat.config.general.spacy_model = os.path.basename(
cat.config.general.spacy_model)
# There can also be issues with loading the config.linking.weighted_average_function from file
# This should be fixed with newer models,
# but the example model is older, so has the older functionalitys
cat.config.linking.weighted_average_function = self.undertest.config.linking.weighted_average_function
self.assertEqual(cat.config.asdict(), self.undertest.config.asdict())
self.assertEqual(cat.cdb.config, self.undertest.cdb.config)
self.assertEqual(len(cat.vocab.vocab), len(self.undertest.vocab.vocab))
Expand Down
50 changes: 50 additions & 0 deletions tests/utils/test_config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from medcat.config import Config
from medcat.utils.saving.coding import default_hook, CustomDelegatingEncoder
from medcat.utils import config_utils
import json

import unittest

OLD_STYLE_DICT = {'py/object': 'medcat.config.VersionInfo',
'py/state': {
'__dict__': {
'history': ['0c0de303b6dc0020',],
'meta_cats': [],
'cdb_info': {
'Number of concepts': 785910,
'Number of names': 2480049,
'Number of concepts that received training': 378746,
'Number of seen training examples in total': 1863973060,
'Average training examples per concept': {
'py/reduce': [{'py/function': 'numpy.core.multiarray.scalar'},]
}
},
'performance': {'ner': {}, 'meta': {}},
'description': 'No description',
'id': 'ff4f4e00bc97de58',
'last_modified': '26 April 2024',
'location': None,
'ontology': ['ONTOLOGY1'],
'medcat_version': '1.10.2'
},
'__fields_set__': {
'py/set': ['id', 'ontology', 'description', 'history',
'location', 'medcat_version', 'last_modified',
'meta_cats', 'cdb_info', 'performance']
},
'__private_attribute_values__': {}
}
}


NEW_STYLE_DICT = json.loads(json.dumps(Config().asdict(), cls=CustomDelegatingEncoder.def_inst),
object_hook=default_hook)


class ConfigUtilsTests(unittest.TestCase):

def test_identifies_old_style_dict(self):
self.assertTrue(config_utils.is_old_type_config_dict(OLD_STYLE_DICT))

def test_identifies_new_style_dict(self):
self.assertFalse(config_utils.is_old_type_config_dict(NEW_STYLE_DICT))
Loading