Skip to content

Commit

Permalink
CU-86956duhb: Add method to backport a model pack from 1.12 to previo…
Browse files Browse the repository at this point in the history
…us version (#465)

* CU-86956duhb: Add method to backport a model pack from 1.12 to previous version

* CU-86956duhb: Fix some doc string issues

* CU-86956duhb: Add deprecation decorator to old config-fix

* CU-86956duhb: Mark backporting method as deprecated and to be removed in 1.14
  • Loading branch information
mart-r authored Aug 12, 2024
1 parent b7658ee commit 005796a
Showing 1 changed file with 129 additions and 5 deletions.
134 changes: 129 additions & 5 deletions medcat/utils/versioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
import shutil
import argparse
import logging
from functools import partial

import dill
import json

from medcat.cat import CAT
from medcat.utils.decorators import deprecated
from medcat.utils.config_utils import default_weighted_average

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -209,15 +213,18 @@ def upgrade(self, new_path: str, overwrite: bool = False) -> None:
Raises:
ValueError: If one of the target files exists and cannot be overwritten.
ValueError: If model pack does not need an upgrade
IncorrectModel: If model pack does not need an upgrade
"""
if not self.needs_upgrade():
raise ValueError(f"Model pack does not need ugprade: {self.model_pack_path} "
f"since it's at version: {self.current_version}")
raise IncorrectModel(f"Model pack does not need ugprade: {self.model_pack_path} "
f"since it's at version: {self.current_version}")
logger.info("Starting to upgrade %s at (version %s)",
self.model_pack_path, self.current_version)
files_to_copy = self._get_relevant_files()
self._check_existance(files_to_copy, new_path, overwrite)
try:
self._check_existance(files_to_copy, new_path, overwrite)
except ValueError as e:
raise e
logger.debug("Copying files from %s", self.model_pack_path)
self._copy_files(files_to_copy, new_path)
logger.info("Going to try and fix CDB")
Expand Down Expand Up @@ -259,7 +266,8 @@ def parse_args() -> argparse.Namespace:
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"action", help="The action. Currently, only 'fix-config' is available.", choices=['fix-config'], type=str.lower)
"action", help="The action. Currently, only 'fix-config' or 'allow-pre-1.12' are available.",
choices=['fix-config', 'allow-pre-1.12'], type=str.lower)
parser.add_argument("modelpack", help="MedCAT modelpack zip path")
parser.add_argument("newpath", help="The path for the new modelpack")
parser.add_argument(
Expand All @@ -283,6 +291,9 @@ def setup_logging(args: argparse.Namespace) -> None:
logger.setLevel(logging.DEBUG)


@deprecated("This is no longer needed. Since medcat 1.10 (PR #352) "
"this dealt with automatically upon model load.",
depr_version=(1, 10, 0), removal_version=(1, 14, 0))
def fix_config(args: argparse.Namespace) -> None:
"""Perform the fix-config action based on the CLI arguments.
Expand All @@ -295,6 +306,117 @@ def fix_config(args: argparse.Namespace) -> None:
upgrader.upgrade(args.newpath, overwrite=args.overwrite)


def _do_pre_1_12_fix(model_pack_path: str) -> CAT:
cat = CAT.load_model_pack(model_pack_path)
waf = cat.cdb.weighted_average_function
is_def = waf is default_weighted_average
is_partial = (isinstance(waf, partial)
and waf.func is default_weighted_average)
if is_def:
factor = 0.0004
logger.info("Was using default weighted average")
elif is_partial:
pargs = waf.args
pkwargs = waf.keywords
factor = pargs[0] if pargs else pkwargs['factor']
logger.info("Was using a (near) default weighted average")
else:
raise IncorrectModel("Model does not have fixable weighted_average tied to its CDB, "
f"found: {waf}")
cat.cdb.weighted_average_function = lambda step: max(0.1, 1 - (step ** 2 * factor))
return cat


def _set_change(val: dict):
return {"py/set": val["==SET=="]}


def _pattern_change(val: dict):
return {
"py/object": "re.Pattern",
"pattern": val["==PATTERN=="]
}


TO_CHANGE = {
"preprocessing.words_to_skip": _set_change,
"preprocessing.keep_punct": _set_change,
"preprocessing.do_not_normalize": _set_change,
"linking.filters.cuis": _set_change,
"linking.filters.cuis_exclude": _set_change,
"word_skipper": _pattern_change,
"punct_checker": _pattern_change,
}


def _fix_config_for_pre_1_12(folder: str):
config_path = os.path.join(folder, 'config.json')
with open(config_path) as f:
data = json.load(f)
for fix_path, fixer in TO_CHANGE.items():
logger.info("[Pre 1.12 fix] Changing %s", fix_path)
cur_path = fix_path
last_dict = data
while "." in cur_path:
cur_key, cur_path = cur_path.split(".", 1)
last_dict = last_dict[cur_key]
last_key = cur_path
last_value = last_dict[last_key]
last_dict[last_key] = fixer(last_value)
logger.info("[Pre 1.12 fix] Saving config back to %s", config_path)
with open(config_path, 'w') as f:
json.dump(data, f)
logger.info("[Pre 1.12 fix] Recreating archive for %s", folder)
shutil.make_archive(folder, 'zip', root_dir=folder)


@deprecated("This is only really needed for 1.12+ models "
"to be converted to lower versions of medcat. "
"It should not be needed in the long run.",
depr_version=(1, 13, 0), removal_version=(1, 14, 0))
def allow_loading_with_pre_1_12(args: argparse.Namespace):
"""This method converts a model created after medcat 1.12
such that it can be loaded in previous versions.
The main two things it does:
- Simplifies the weighted average function attached to the CDB.
- Makes the config json-compatible
Expected / used arguments in CLI:
- modelpack: The input model pack path
- newpath: The output model pack path
- overwrite: Whether to overwrite the new model
Raises:
ValueError: If the file already exists
Args:
args (argparse.Namespace): The CLI arguments.
"""
# this will fix the weighted_average function if possible
# since 1.12 this is within the CDB and generally refers
# to a method on medcat.utils.config_utils and the method
# and/or the module do not exist in previous version
cat = _do_pre_1_12_fix(args.modelpack)
if not args.overwrite and os.path.exists(args.newpath):
raise ValueError(f"File already exists: {args.newpath}. "
"Set --overwrite to overwrite")
mpn = cat.create_model_pack(args.newpath)
full_path = os.path.join(args.newpath, mpn)
logger.info("Saving model to: %s", full_path)
# now that the model has saved, we also need to do make
# some changes to the config to allow it to be properly
# loaded by jsonpickle (used before 1.12) rather than
# just json (used by 1.12+)
_fix_config_for_pre_1_12(full_path)


class IncorrectModel(ValueError):

def __init__(self, *args: object) -> None:
super().__init__(*args)


def main() -> None:
"""Run the CLI associated with this module.
Expand All @@ -306,6 +428,8 @@ def main() -> None:
logger.debug("Will attempt to perform action %s", args.action)
if args.action == 'fix-config':
fix_config(args)
elif args.action == 'allow-pre-1.12':
allow_loading_with_pre_1_12(args)
else:
raise ValueError(f"Unknown action: {args.action}")

Expand Down

0 comments on commit 005796a

Please sign in to comment.