diff --git a/medcat/utils/versioning.py b/medcat/utils/versioning.py index 4fac8571..09c53d4f 100644 --- a/medcat/utils/versioning.py +++ b/medcat/utils/versioning.py @@ -4,10 +4,14 @@ import shutil import argparse import logging +from functools import partial import dill +import json from medcat.cat import CAT +from medcat.utils.decorators import deprecated +from medcat.utils.config_utils import default_weighted_average logger = logging.getLogger(__name__) @@ -209,15 +213,18 @@ def upgrade(self, new_path: str, overwrite: bool = False) -> None: Raises: ValueError: If one of the target files exists and cannot be overwritten. - ValueError: If model pack does not need an upgrade + IncorrectModel: If model pack does not need an upgrade """ if not self.needs_upgrade(): - raise ValueError(f"Model pack does not need ugprade: {self.model_pack_path} " - f"since it's at version: {self.current_version}") + raise IncorrectModel(f"Model pack does not need ugprade: {self.model_pack_path} " + f"since it's at version: {self.current_version}") logger.info("Starting to upgrade %s at (version %s)", self.model_pack_path, self.current_version) files_to_copy = self._get_relevant_files() - self._check_existance(files_to_copy, new_path, overwrite) + try: + self._check_existance(files_to_copy, new_path, overwrite) + except ValueError as e: + raise e logger.debug("Copying files from %s", self.model_pack_path) self._copy_files(files_to_copy, new_path) logger.info("Going to try and fix CDB") @@ -259,7 +266,8 @@ def parse_args() -> argparse.Namespace: """ parser = argparse.ArgumentParser() parser.add_argument( - "action", help="The action. Currently, only 'fix-config' is available.", choices=['fix-config'], type=str.lower) + "action", help="The action. Currently, only 'fix-config' or 'allow-pre-1.12' are available.", + choices=['fix-config', 'allow-pre-1.12'], type=str.lower) parser.add_argument("modelpack", help="MedCAT modelpack zip path") parser.add_argument("newpath", help="The path for the new modelpack") parser.add_argument( @@ -283,6 +291,9 @@ def setup_logging(args: argparse.Namespace) -> None: logger.setLevel(logging.DEBUG) +@deprecated("This is no longer needed. Since medcat 1.10 (PR #352) " + "this dealt with automatically upon model load.", + depr_version=(1, 10, 0), removal_version=(1, 14, 0)) def fix_config(args: argparse.Namespace) -> None: """Perform the fix-config action based on the CLI arguments. @@ -295,6 +306,117 @@ def fix_config(args: argparse.Namespace) -> None: upgrader.upgrade(args.newpath, overwrite=args.overwrite) +def _do_pre_1_12_fix(model_pack_path: str) -> CAT: + cat = CAT.load_model_pack(model_pack_path) + waf = cat.cdb.weighted_average_function + is_def = waf is default_weighted_average + is_partial = (isinstance(waf, partial) + and waf.func is default_weighted_average) + if is_def: + factor = 0.0004 + logger.info("Was using default weighted average") + elif is_partial: + pargs = waf.args + pkwargs = waf.keywords + factor = pargs[0] if pargs else pkwargs['factor'] + logger.info("Was using a (near) default weighted average") + else: + raise IncorrectModel("Model does not have fixable weighted_average tied to its CDB, " + f"found: {waf}") + cat.cdb.weighted_average_function = lambda step: max(0.1, 1 - (step ** 2 * factor)) + return cat + + +def _set_change(val: dict): + return {"py/set": val["==SET=="]} + + +def _pattern_change(val: dict): + return { + "py/object": "re.Pattern", + "pattern": val["==PATTERN=="] + } + + +TO_CHANGE = { + "preprocessing.words_to_skip": _set_change, + "preprocessing.keep_punct": _set_change, + "preprocessing.do_not_normalize": _set_change, + "linking.filters.cuis": _set_change, + "linking.filters.cuis_exclude": _set_change, + "word_skipper": _pattern_change, + "punct_checker": _pattern_change, +} + + +def _fix_config_for_pre_1_12(folder: str): + config_path = os.path.join(folder, 'config.json') + with open(config_path) as f: + data = json.load(f) + for fix_path, fixer in TO_CHANGE.items(): + logger.info("[Pre 1.12 fix] Changing %s", fix_path) + cur_path = fix_path + last_dict = data + while "." in cur_path: + cur_key, cur_path = cur_path.split(".", 1) + last_dict = last_dict[cur_key] + last_key = cur_path + last_value = last_dict[last_key] + last_dict[last_key] = fixer(last_value) + logger.info("[Pre 1.12 fix] Saving config back to %s", config_path) + with open(config_path, 'w') as f: + json.dump(data, f) + logger.info("[Pre 1.12 fix] Recreating archive for %s", folder) + shutil.make_archive(folder, 'zip', root_dir=folder) + + +@deprecated("This is only really needed for 1.12+ models " + "to be converted to lower versions of medcat. " + "It should not be needed in the long run.", + depr_version=(1, 13, 0), removal_version=(1, 14, 0)) +def allow_loading_with_pre_1_12(args: argparse.Namespace): + """This method converts a model created after medcat 1.12 + such that it can be loaded in previous versions. + + The main two things it does: + - Simplifies the weighted average function attached to the CDB. + - Makes the config json-compatible + + Expected / used arguments in CLI: + - modelpack: The input model pack path + - newpath: The output model pack path + - overwrite: Whether to overwrite the new model + + Raises: + ValueError: If the file already exists + + Args: + args (argparse.Namespace): The CLI arguments. + """ + # this will fix the weighted_average function if possible + # since 1.12 this is within the CDB and generally refers + # to a method on medcat.utils.config_utils and the method + # and/or the module do not exist in previous version + cat = _do_pre_1_12_fix(args.modelpack) + if not args.overwrite and os.path.exists(args.newpath): + raise ValueError(f"File already exists: {args.newpath}. " + "Set --overwrite to overwrite") + mpn = cat.create_model_pack(args.newpath) + full_path = os.path.join(args.newpath, mpn) + logger.info("Saving model to: %s", full_path) + # now that the model has saved, we also need to do make + # some changes to the config to allow it to be properly + # loaded by jsonpickle (used before 1.12) rather than + # just json (used by 1.12+) + _fix_config_for_pre_1_12(full_path) + + +class IncorrectModel(ValueError): + + def __init__(self, *args: object) -> None: + super().__init__(*args) + + def main() -> None: """Run the CLI associated with this module. @@ -306,6 +428,8 @@ def main() -> None: logger.debug("Will attempt to perform action %s", args.action) if args.action == 'fix-config': fix_config(args) + elif args.action == 'allow-pre-1.12': + allow_loading_with_pre_1_12(args) else: raise ValueError(f"Unknown action: {args.action}")