diff --git a/medcat/cat.py b/medcat/cat.py index 2323cd737..0fb6b1167 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -271,6 +271,10 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M cdb_path = os.path.join(save_dir_path, "cdb.dat") self.cdb.save(cdb_path, json_path) + # Save the config + config_path = os.path.join(save_dir_path, "config.json") + self.cdb.config.save(config_path) + # Save the Vocab vocab_path = os.path.join(save_dir_path, "vocab.dat") if self.vocab is not None: @@ -362,6 +366,10 @@ def load_model_pack(cls, logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format') cdb = CDB.load(cdb_path, json_path) + # load config + config_path = os.path.join(model_pack_path, "config.json") + cdb.load_config(config_path) + # TODO load addl_ner # Modify the config to contain full path to spacy model diff --git a/medcat/cdb.py b/medcat/cdb.py index 44d4fd9dd..5a648f4af 100644 --- a/medcat/cdb.py +++ b/medcat/cdb.py @@ -5,8 +5,9 @@ import logging import aiofiles import numpy as np -from typing import Dict, Set, Optional, List, Union +from typing import Dict, Set, Optional, List, Union, cast from functools import partial +import os from medcat import __version__ from medcat.utils.hasher import Hasher @@ -61,8 +62,10 @@ class CDB(object): def __init__(self, config: Union[Config, None] = None) -> None: if config is None: self.config = Config() + self._config_from_file = False else: self.config = config + self._config_from_file = True self.name2cuis: Dict = {} self.name2cuis2status: Dict = {} @@ -95,6 +98,12 @@ def __init__(self, config: Union[Config, None] = None) -> None: self._optim_params = None self.is_dirty = False self._hash: Optional[str] = None + # the config hash is kept track of here so that + # the CDB hash can be re-calculated when the config changes + # it can also be used to make sure the config loaded with + # a CDB matches the config it was saved with + # since the config is now saved separately + self._config_hash: Optional[str] = None self._memory_optimised_parts: Set[str] = set() def get_name(self, cui: str) -> str: @@ -458,6 +467,35 @@ async def save_async(self, path: str) -> None: } await f.write(dill.dumps(to_save)) + def load_config(self, config_path: str) -> None: + if not os.path.exists(config_path): + if not self._config_from_file: + # if there's no config defined anywhere + raise ValueError("Could not find a config in the CDB nor ", + "in the config.json for this model " + f"({os.path.dirname(config_path)})", + ) + # if there is a config, but it's defined in the cdb.dat file + logger.warning("Could not find config.json in model pack folder " + f"({os.path.dirname(config_path)}). " + "This is probably an older model. Please save the model " + "again in the new format to avoid potential issues.") + else: + if self._config_from_file: + # if there's a config.json and one defined in the cbd.dat file + raise ValueError("Found a config in the CDB and in the config.json " + f"for model ({os.path.dirname(config_path)}) - " + "this is ambiguous. Please either remove the " + "config.json or load the CDB without the config.json " + "in the folder and re-save in the newer format " + "(the default save in this version)") + # if the only config is in the separate config.json file + # this should be the behaviour for all newer models + self.config = cast(Config, Config.load(config_path)) + logger.debug("Loaded config from CDB from %s", config_path) + # mark config read from file + self._config_from_file = True + @classmethod def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[Dict] = None) -> "CDB": """Load and return a CDB. This allows partial loads in probably not the right way at all. @@ -777,8 +815,34 @@ def _check_medcat_version(cls, config_data: Dict) -> None: or download the compatible model.""" ) + def _should_recalc_hash(self, force_recalc: bool) -> bool: + if force_recalc: + return True + if self.config.hash is None: + # TODO - perhaps this is not the best? + # as this is a side effect + # get and save result in config + self.config.get_hash() + if not self._hash or self.is_dirty: + # if no hash saved or is dirty + # need to calculate + logger.debug("Recalculating hash due to %s", + "no hash saved" if not self._hash else "CDB is dirty") + return True + # recalc config hash in case it changed + self.config.get_hash() + if self._config_hash is None or self._config_hash != self.config.hash: + # if no config hash saved + # or if the config hash is different from one saved in here + logger.debug("Recalculating hash due to %s", + "no config hash saved" if not self._config_hash + else "config hash has changed") + return True + return False + def get_hash(self, force_recalc: bool = False): - if not force_recalc and self._hash and not self.is_dirty: + should_recalc = self._should_recalc_hash(force_recalc) + if not should_recalc: logger.info("Reusing old hash of CDB since the CDB has not changed: %s", self._hash) return self._hash self.is_dirty = False @@ -791,7 +855,7 @@ def calculate_hash(self): for k,v in self.__dict__.items(): if k in ['cui2countext_vectors', 'name2cuis']: hasher.update(v, length=False) - elif k in ['_hash', 'is_dirty']: + elif k in ['_hash', 'is_dirty', '_config_hash']: # ignore _hash since if it previously didn't exist, the # new hash would be different when the value does exist # and ignore is_dirty so that we get the same hash as previously @@ -799,6 +863,9 @@ def calculate_hash(self): elif k != 'config': hasher.update(v, length=True) + # set cached config hash + self._config_hash = self.config.hash + self._hash = hasher.hexdigest() logger.info("Found new CDB hash: %s", self._hash) return self._hash diff --git a/medcat/config.py b/medcat/config.py index 87c6d34f5..e60c2eafc 100644 --- a/medcat/config.py +++ b/medcat/config.py @@ -548,6 +548,7 @@ class Config(MixingConfig, BaseModel): linking: Linking = Linking() word_skipper: re.Pattern = re.compile('') # empty pattern gets replaced upon init punct_checker: re.Pattern = re.compile('') # empty pattern gets replaced upon init + hash: Optional[str] = None class Config: # this if for word_skipper and punct_checker which would otherwise @@ -572,6 +573,9 @@ def rebuild_re(self) -> None: def get_hash(self): hasher = Hasher() for k, v in self.dict().items(): + if k in ['hash', ]: + # ignore hash + continue if k not in ['version', 'general', 'linking']: hasher.update(v, length=True) elif k == 'general': @@ -587,5 +591,5 @@ def get_hash(self): hasher.update(v2, length=False) else: hasher.update(v2, length=True) - - return hasher.hexdigest() + self.hash = hasher.hexdigest() + return self.hash diff --git a/medcat/utils/saving/serializer.py b/medcat/utils/saving/serializer.py index d82df751c..25529c778 100644 --- a/medcat/utils/saving/serializer.py +++ b/medcat/utils/saving/serializer.py @@ -135,13 +135,12 @@ def serialize(self, cdb, overwrite: bool = False) -> None: raise ValueError(f'Unable to overwrite shelf path "{self.json_path}"' ' - specify overrwrite=True if you wish to overwrite') to_save = {} - to_save['config'] = cdb.config.asdict() # This uses different names so as to not be ambiguous # when looking at files whether the json parts should # exist separately or not to_save['cdb_main' if self.jsons is not None else 'cdb'] = dict( ((key, val) for key, val in cdb.__dict__.items() if - key != 'config' and + key not in ('config', '_config_from_file') and (self.jsons is None or key not in SPECIALITY_NAMES))) logger.info('Dumping CDB to %s', self.main_path) with open(self.main_path, 'wb') as f: @@ -165,7 +164,17 @@ def deserialize(self, cdb_cls): logger.info('Reading CDB data from %s', self.main_path) with open(self.main_path, 'rb') as f: data = dill.load(f) - config = cast(Config, Config.from_dict(data['config'])) + if 'config' in data: + logger.warning("Found config in CDB for model (%s). " + "This is an old format. Please re-save the " + "model in the new format to avoid potential issues", + os.path.dirname(self.main_path)) + config = cast(Config, Config.from_dict(data['config'])) + else: + # by passing None as config to constructor + # the CDB should identify that there has been + # no config loaded + config = None cdb = cdb_cls(config=config) if self.jsons is None: cdb_main = data['cdb'] diff --git a/tests/test_cat.py b/tests/test_cat.py index 0baa0d35d..62db4d44d 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -367,7 +367,7 @@ def test_load_model_pack(self): meta_cat = _get_meta_cat(self.meta_cat_dir) cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat]) full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name") - cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) + cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) self.assertTrue(isinstance(cat, CAT)) self.assertIsNotNone(cat.config.version.medcat_version) self.assertEqual(repr(cat._meta_cats), repr([meta_cat])) @@ -377,7 +377,7 @@ def test_load_model_pack_without_meta_cat(self): meta_cat = _get_meta_cat(self.meta_cat_dir) cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat]) full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name") - cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False) + cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False) self.assertTrue(isinstance(cat, CAT)) self.assertIsNotNone(cat.config.version.medcat_version) self.assertEqual(cat._meta_cats, []) @@ -385,10 +385,56 @@ def test_load_model_pack_without_meta_cat(self): def test_hashing(self): save_dir_path = tempfile.TemporaryDirectory() full_model_pack_name = self.undertest.create_model_pack(save_dir_path.name, model_pack_name="mp_name") - cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) + cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip")) self.assertEqual(cat.get_hash(), cat.config.version.id) +class ModelWithTwoConfigsLoadTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples") + cdb = CDB.load(os.path.join(cls.model_path, "cdb.dat")) + # save config next to the CDB + cls.config_path = os.path.join(cls.model_path, 'config.json') + cdb.config.save(cls.config_path) + + + @classmethod + def tearDownClass(cls) -> None: + # REMOVE config next to the CDB + os.remove(cls.config_path) + + def test_loading_model_pack_with_cdb_config_and_config_json_raises_exception(self): + with self.assertRaises(ValueError): + CAT.load_model_pack(self.model_path) + + +class ModelWithZeroConfigsLoadTests(unittest.TestCase): + + @classmethod + def setUpClass(cls) -> None: + cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat") + cdb = CDB.load(cdb_path) + vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat") + # copy the CDB and vocab to a temp dir + cls.temp_dir = tempfile.TemporaryDirectory() + cls.cdb_path = os.path.join(cls.temp_dir.name, 'cdb.dat') + cdb.save(cls.cdb_path) # save without internal config + cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat') + shutil.copyfile(vocab_path, cls.vocab_path) + + + @classmethod + def tearDownClass(cls) -> None: + # REMOVE temp dir + cls.temp_dir.cleanup() + + def test_loading_model_pack_without_any_config_raises_exception(self): + with self.assertRaises(ValueError): + CAT.load_model_pack(self.temp_dir.name) + + def _get_meta_cat(meta_cat_dir): config = ConfigMetaCAT() config.general["category_name"] = "Status" diff --git a/tests/test_cdb.py b/tests/test_cdb.py index 96425bc8c..f7be24d64 100644 --- a/tests/test_cdb.py +++ b/tests/test_cdb.py @@ -6,6 +6,7 @@ import numpy as np from medcat.config import Config from medcat.cdb_maker import CDBMaker +from medcat.cdb import CDB class CDBTests(unittest.TestCase): @@ -53,6 +54,13 @@ def test_save_and_load(self): self.undertest.save(f.name) self.undertest.load(f.name) + def test_load_has_no_config(self): + with tempfile.NamedTemporaryFile() as f: + self.undertest.save(f.name) + cdb = CDB.load(f.name) + self.assertFalse(cdb._config_from_file) + + def test_save_async_and_load(self): with tempfile.NamedTemporaryFile() as f: asyncio.run(self.undertest.save_async(f.name)) diff --git a/tests/test_config.py b/tests/test_config.py index aacd0a760..ce6ed76eb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -179,6 +179,35 @@ def test_from_dict(self): config = Config.from_dict({"key": "value"}) self.assertEqual("value", config.key) + def test_config_no_hash_before_get(self): + config = Config() + self.assertIsNone(config.hash) + + def test_config_has_hash_after_get(self): + config = Config() + config.get_hash() + self.assertIsNotNone(config.hash) + + def test_config_hash_recalc_same_def(self): + config = Config() + h1 = config.get_hash() + h2 = config.get_hash() + self.assertEqual(h1, h2) + + def test_config_hash_changes_after_change(self): + config = Config() + h1 = config.get_hash() + config.linking.filters.cuis = {"a", "b"} + h2 = config.get_hash() + self.assertNotEqual(h1, h2) + + def test_config_hash_recalc_same_changed(self): + config = Config() + config.linking.filters.cuis = {"a", "b"} + h1 = config.get_hash() + h2 = config.get_hash() + self.assertEqual(h1, h2) + class ConfigLinkingFiltersTests(unittest.TestCase): diff --git a/tests/utils/saving/test_serialization.py b/tests/utils/saving/test_serialization.py index f0cc75de1..c2c44da16 100644 --- a/tests/utils/saving/test_serialization.py +++ b/tests/utils/saving/test_serialization.py @@ -87,7 +87,7 @@ def test_dill_to_json(self): model_pack_folder = os.path.join( self.json_model_pack.name, model_pack_path) json_path = os.path.join(model_pack_folder, "*.json") - jsons = glob.glob(json_path) + jsons = [fn for fn in glob.glob(json_path) if not fn.endswith("config.json")] # there is also a model_card.json # but nothing for cui2many or name2many # so can remove the length of ONE2MANY diff --git a/tests/utils/test_hashing.py b/tests/utils/test_hashing.py index 99c10b153..b6681461f 100644 --- a/tests/utils/test_hashing.py +++ b/tests/utils/test_hashing.py @@ -1,4 +1,5 @@ import os +from typing import Optional import tempfile import unittest import unittest.mock @@ -6,6 +7,7 @@ from medcat.cat import CAT from medcat.cdb import CDB from medcat.vocab import Vocab +from medcat.config import Config class CDBHashingTests(unittest.TestCase): @@ -30,6 +32,43 @@ def test_CDB_hash_saves_on_disk(self): self.assertEqual(h, cdb._hash) +class CDBHashingWithConfigTests(unittest.TestCase): + temp_dir = tempfile.TemporaryDirectory() + + @classmethod + def setUpClass(cls) -> None: + cls.cdb = CDB.load(os.path.join(os.path.dirname( + os.path.realpath(__file__)), "..", "..", "examples", "cdb.dat")) + # ensure config has hash + h = cls.cdb.get_hash() + cls.config = cls.config_copy(cls.cdb.config) + cls._config_hash = cls.cdb.config.hash + + @classmethod + def config_copy(cls, config: Optional[Config] = None) -> Config: + if config is None: + config = cls.config + return Config(**config.asdict()) + + def setUp(self) -> None: + # reset config + self.cdb.config = self.config_copy() + # reset config hash + self.cdb._config_hash = self._config_hash + self.cdb.config.hash = self._config_hash + + def test_CDB_same_hash_no_need_recalc(self): + self.assertFalse(self.cdb._should_recalc_hash(force_recalc=False)) + + def test_CDB_hash_recalc_if_no_config_hash(self): + self.cdb._config_hash = None + self.assertTrue(self.cdb._should_recalc_hash(force_recalc=False)) + + def test_CDB_hash_recalc_after_config_change(self): + self.cdb.config.linking.filters.cuis = {"a", "b", "c"} + self.assertTrue(self.cdb._should_recalc_hash(force_recalc=False)) + + class BaseCATHashingTests(unittest.TestCase): @classmethod @@ -75,8 +114,14 @@ def test_no_changes_recalc_same(self): class CATHashingTestsWithoutChange(CATHashingTestsWithFakeHash): - def test_no_changes_no_calc(self): + def setUp(self) -> None: + self._calculate_hash = self.undertest.cdb.calculate_hash + # make sure the hash exists + self.undertest.cdb._config_hash = self.undertest.cdb.config.get_hash() + self.undertest.cdb.get_hash() self.undertest.cdb.calculate_hash = unittest.mock.Mock() + + def test_no_changes_no_calc(self): hash = self.undertest.get_hash() self.assertIsInstance(hash, str) self.undertest.cdb.calculate_hash.assert_not_called()