Skip to content

Commit

Permalink
CU-8692t3fdf separate config on save (#350)
Browse files Browse the repository at this point in the history
* CU-8692t3fdf Move saving config outside of the cdb.dat; Add test to make sure the config does not get saved with the CDB; patch a few existing tests

* CU-8692t3fdf Use class methods on class instead of instance in a few tests

* CU-8692t3fdf Fix typing issue

* CU-8692t3fdf Add additional tests for 2 configs and zero configs when loading model pack

* CU-8692t3fdf: Make sure CDB is linked to the correct config; Treat incorrect configs as dirty CDBs and force a recalc of the hash
  • Loading branch information
mart-r authored Oct 30, 2023
1 parent d377f0b commit ad67048
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 13 deletions.
8 changes: 8 additions & 0 deletions medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,10 @@ def create_model_pack(self, save_dir_path: str, model_pack_name: str = DEFAULT_M
cdb_path = os.path.join(save_dir_path, "cdb.dat")
self.cdb.save(cdb_path, json_path)

# Save the config
config_path = os.path.join(save_dir_path, "config.json")
self.cdb.config.save(config_path)

# Save the Vocab
vocab_path = os.path.join(save_dir_path, "vocab.dat")
if self.vocab is not None:
Expand Down Expand Up @@ -362,6 +366,10 @@ def load_model_pack(cls,
logger.info('Loading model pack with %s', 'JSON format' if json_path else 'dill format')
cdb = CDB.load(cdb_path, json_path)

# load config
config_path = os.path.join(model_pack_path, "config.json")
cdb.load_config(config_path)

# TODO load addl_ner

# Modify the config to contain full path to spacy model
Expand Down
73 changes: 70 additions & 3 deletions medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import logging
import aiofiles
import numpy as np
from typing import Dict, Set, Optional, List, Union
from typing import Dict, Set, Optional, List, Union, cast
from functools import partial
import os

from medcat import __version__
from medcat.utils.hasher import Hasher
Expand Down Expand Up @@ -61,8 +62,10 @@ class CDB(object):
def __init__(self, config: Union[Config, None] = None) -> None:
if config is None:
self.config = Config()
self._config_from_file = False
else:
self.config = config
self._config_from_file = True
self.name2cuis: Dict = {}
self.name2cuis2status: Dict = {}

Expand Down Expand Up @@ -95,6 +98,12 @@ def __init__(self, config: Union[Config, None] = None) -> None:
self._optim_params = None
self.is_dirty = False
self._hash: Optional[str] = None
# the config hash is kept track of here so that
# the CDB hash can be re-calculated when the config changes
# it can also be used to make sure the config loaded with
# a CDB matches the config it was saved with
# since the config is now saved separately
self._config_hash: Optional[str] = None
self._memory_optimised_parts: Set[str] = set()

def get_name(self, cui: str) -> str:
Expand Down Expand Up @@ -458,6 +467,35 @@ async def save_async(self, path: str) -> None:
}
await f.write(dill.dumps(to_save))

def load_config(self, config_path: str) -> None:
if not os.path.exists(config_path):
if not self._config_from_file:
# if there's no config defined anywhere
raise ValueError("Could not find a config in the CDB nor ",
"in the config.json for this model "
f"({os.path.dirname(config_path)})",
)
# if there is a config, but it's defined in the cdb.dat file
logger.warning("Could not find config.json in model pack folder "
f"({os.path.dirname(config_path)}). "
"This is probably an older model. Please save the model "
"again in the new format to avoid potential issues.")
else:
if self._config_from_file:
# if there's a config.json and one defined in the cbd.dat file
raise ValueError("Found a config in the CDB and in the config.json "
f"for model ({os.path.dirname(config_path)}) - "
"this is ambiguous. Please either remove the "
"config.json or load the CDB without the config.json "
"in the folder and re-save in the newer format "
"(the default save in this version)")
# if the only config is in the separate config.json file
# this should be the behaviour for all newer models
self.config = cast(Config, Config.load(config_path))
logger.debug("Loaded config from CDB from %s", config_path)
# mark config read from file
self._config_from_file = True

@classmethod
def load(cls, path: str, json_path: Optional[str] = None, config_dict: Optional[Dict] = None) -> "CDB":
"""Load and return a CDB. This allows partial loads in probably not the right way at all.
Expand Down Expand Up @@ -777,8 +815,34 @@ def _check_medcat_version(cls, config_data: Dict) -> None:
or download the compatible model."""
)

def _should_recalc_hash(self, force_recalc: bool) -> bool:
if force_recalc:
return True
if self.config.hash is None:
# TODO - perhaps this is not the best?
# as this is a side effect
# get and save result in config
self.config.get_hash()
if not self._hash or self.is_dirty:
# if no hash saved or is dirty
# need to calculate
logger.debug("Recalculating hash due to %s",
"no hash saved" if not self._hash else "CDB is dirty")
return True
# recalc config hash in case it changed
self.config.get_hash()
if self._config_hash is None or self._config_hash != self.config.hash:
# if no config hash saved
# or if the config hash is different from one saved in here
logger.debug("Recalculating hash due to %s",
"no config hash saved" if not self._config_hash
else "config hash has changed")
return True
return False

def get_hash(self, force_recalc: bool = False):
if not force_recalc and self._hash and not self.is_dirty:
should_recalc = self._should_recalc_hash(force_recalc)
if not should_recalc:
logger.info("Reusing old hash of CDB since the CDB has not changed: %s", self._hash)
return self._hash
self.is_dirty = False
Expand All @@ -791,14 +855,17 @@ def calculate_hash(self):
for k,v in self.__dict__.items():
if k in ['cui2countext_vectors', 'name2cuis']:
hasher.update(v, length=False)
elif k in ['_hash', 'is_dirty']:
elif k in ['_hash', 'is_dirty', '_config_hash']:
# ignore _hash since if it previously didn't exist, the
# new hash would be different when the value does exist
# and ignore is_dirty so that we get the same hash as previously
continue
elif k != 'config':
hasher.update(v, length=True)

# set cached config hash
self._config_hash = self.config.hash

self._hash = hasher.hexdigest()
logger.info("Found new CDB hash: %s", self._hash)
return self._hash
8 changes: 6 additions & 2 deletions medcat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ class Config(MixingConfig, BaseModel):
linking: Linking = Linking()
word_skipper: re.Pattern = re.compile('') # empty pattern gets replaced upon init
punct_checker: re.Pattern = re.compile('') # empty pattern gets replaced upon init
hash: Optional[str] = None

class Config:
# this if for word_skipper and punct_checker which would otherwise
Expand All @@ -572,6 +573,9 @@ def rebuild_re(self) -> None:
def get_hash(self):
hasher = Hasher()
for k, v in self.dict().items():
if k in ['hash', ]:
# ignore hash
continue
if k not in ['version', 'general', 'linking']:
hasher.update(v, length=True)
elif k == 'general':
Expand All @@ -587,5 +591,5 @@ def get_hash(self):
hasher.update(v2, length=False)
else:
hasher.update(v2, length=True)

return hasher.hexdigest()
self.hash = hasher.hexdigest()
return self.hash
15 changes: 12 additions & 3 deletions medcat/utils/saving/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,12 @@ def serialize(self, cdb, overwrite: bool = False) -> None:
raise ValueError(f'Unable to overwrite shelf path "{self.json_path}"'
' - specify overrwrite=True if you wish to overwrite')
to_save = {}
to_save['config'] = cdb.config.asdict()
# This uses different names so as to not be ambiguous
# when looking at files whether the json parts should
# exist separately or not
to_save['cdb_main' if self.jsons is not None else 'cdb'] = dict(
((key, val) for key, val in cdb.__dict__.items() if
key != 'config' and
key not in ('config', '_config_from_file') and
(self.jsons is None or key not in SPECIALITY_NAMES)))
logger.info('Dumping CDB to %s', self.main_path)
with open(self.main_path, 'wb') as f:
Expand All @@ -165,7 +164,17 @@ def deserialize(self, cdb_cls):
logger.info('Reading CDB data from %s', self.main_path)
with open(self.main_path, 'rb') as f:
data = dill.load(f)
config = cast(Config, Config.from_dict(data['config']))
if 'config' in data:
logger.warning("Found config in CDB for model (%s). "
"This is an old format. Please re-save the "
"model in the new format to avoid potential issues",
os.path.dirname(self.main_path))
config = cast(Config, Config.from_dict(data['config']))
else:
# by passing None as config to constructor
# the CDB should identify that there has been
# no config loaded
config = None
cdb = cdb_cls(config=config)
if self.jsons is None:
cdb_main = data['cdb']
Expand Down
52 changes: 49 additions & 3 deletions tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def test_load_model_pack(self):
meta_cat = _get_meta_cat(self.meta_cat_dir)
cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name")
cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
self.assertTrue(isinstance(cat, CAT))
self.assertIsNotNone(cat.config.version.medcat_version)
self.assertEqual(repr(cat._meta_cats), repr([meta_cat]))
Expand All @@ -377,18 +377,64 @@ def test_load_model_pack_without_meta_cat(self):
meta_cat = _get_meta_cat(self.meta_cat_dir)
cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat])
full_model_pack_name = cat.create_model_pack(save_dir_path.name, model_pack_name="mp_name")
cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False)
cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"), load_meta_models=False)
self.assertTrue(isinstance(cat, CAT))
self.assertIsNotNone(cat.config.version.medcat_version)
self.assertEqual(cat._meta_cats, [])

def test_hashing(self):
save_dir_path = tempfile.TemporaryDirectory()
full_model_pack_name = self.undertest.create_model_pack(save_dir_path.name, model_pack_name="mp_name")
cat = self.undertest.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
cat = CAT.load_model_pack(os.path.join(save_dir_path.name, f"{full_model_pack_name}.zip"))
self.assertEqual(cat.get_hash(), cat.config.version.id)


class ModelWithTwoConfigsLoadTests(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
cls.model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples")
cdb = CDB.load(os.path.join(cls.model_path, "cdb.dat"))
# save config next to the CDB
cls.config_path = os.path.join(cls.model_path, 'config.json')
cdb.config.save(cls.config_path)


@classmethod
def tearDownClass(cls) -> None:
# REMOVE config next to the CDB
os.remove(cls.config_path)

def test_loading_model_pack_with_cdb_config_and_config_json_raises_exception(self):
with self.assertRaises(ValueError):
CAT.load_model_pack(self.model_path)


class ModelWithZeroConfigsLoadTests(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb.dat")
cdb = CDB.load(cdb_path)
vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat")
# copy the CDB and vocab to a temp dir
cls.temp_dir = tempfile.TemporaryDirectory()
cls.cdb_path = os.path.join(cls.temp_dir.name, 'cdb.dat')
cdb.save(cls.cdb_path) # save without internal config
cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat')
shutil.copyfile(vocab_path, cls.vocab_path)


@classmethod
def tearDownClass(cls) -> None:
# REMOVE temp dir
cls.temp_dir.cleanup()

def test_loading_model_pack_without_any_config_raises_exception(self):
with self.assertRaises(ValueError):
CAT.load_model_pack(self.temp_dir.name)


def _get_meta_cat(meta_cat_dir):
config = ConfigMetaCAT()
config.general["category_name"] = "Status"
Expand Down
8 changes: 8 additions & 0 deletions tests/test_cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from medcat.config import Config
from medcat.cdb_maker import CDBMaker
from medcat.cdb import CDB


class CDBTests(unittest.TestCase):
Expand Down Expand Up @@ -53,6 +54,13 @@ def test_save_and_load(self):
self.undertest.save(f.name)
self.undertest.load(f.name)

def test_load_has_no_config(self):
with tempfile.NamedTemporaryFile() as f:
self.undertest.save(f.name)
cdb = CDB.load(f.name)
self.assertFalse(cdb._config_from_file)


def test_save_async_and_load(self):
with tempfile.NamedTemporaryFile() as f:
asyncio.run(self.undertest.save_async(f.name))
Expand Down
29 changes: 29 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,35 @@ def test_from_dict(self):
config = Config.from_dict({"key": "value"})
self.assertEqual("value", config.key)

def test_config_no_hash_before_get(self):
config = Config()
self.assertIsNone(config.hash)

def test_config_has_hash_after_get(self):
config = Config()
config.get_hash()
self.assertIsNotNone(config.hash)

def test_config_hash_recalc_same_def(self):
config = Config()
h1 = config.get_hash()
h2 = config.get_hash()
self.assertEqual(h1, h2)

def test_config_hash_changes_after_change(self):
config = Config()
h1 = config.get_hash()
config.linking.filters.cuis = {"a", "b"}
h2 = config.get_hash()
self.assertNotEqual(h1, h2)

def test_config_hash_recalc_same_changed(self):
config = Config()
config.linking.filters.cuis = {"a", "b"}
h1 = config.get_hash()
h2 = config.get_hash()
self.assertEqual(h1, h2)


class ConfigLinkingFiltersTests(unittest.TestCase):

Expand Down
2 changes: 1 addition & 1 deletion tests/utils/saving/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_dill_to_json(self):
model_pack_folder = os.path.join(
self.json_model_pack.name, model_pack_path)
json_path = os.path.join(model_pack_folder, "*.json")
jsons = glob.glob(json_path)
jsons = [fn for fn in glob.glob(json_path) if not fn.endswith("config.json")]
# there is also a model_card.json
# but nothing for cui2many or name2many
# so can remove the length of ONE2MANY
Expand Down
Loading

0 comments on commit ad67048

Please sign in to comment.