diff --git a/medcat/config.py b/medcat/config.py index ba8aaef91..cdb30d0fe 100644 --- a/medcat/config.py +++ b/medcat/config.py @@ -208,9 +208,10 @@ def load(cls, save_path: str) -> "MixingConfig": # Read the jsonpickle string with open(save_path) as f: config_dict = json.load(f, object_hook=default_hook) - if is_old_type_config_dict(config_dict): - logger.warning("Loading an old type of config (jsonpickle) from '%s'", - save_path) + if is_old_type_config_dict(config_dict): + logger.warning("Loading an old type of config (jsonpickle) from '%s'", + save_path) + with open(save_path) as f: config_dict = jsonpickle.decode(f.read()) config.merge_config(config_dict) diff --git a/medcat/utils/config_utils.py b/medcat/utils/config_utils.py index 92ea111ed..09989b258 100644 --- a/medcat/utils/config_utils.py +++ b/medcat/utils/config_utils.py @@ -15,9 +15,24 @@ def weighted_average_function(self) -> Callable[[float], int]: def is_old_type_config_dict(d: dict) -> bool: - if set(('py/object', 'py/state')) <= set(d.keys()): - return True - return False + """Checks if the dict provided is an old style (jsonpickle) config. + + This checks for json-pickle specific keys such as py/object and py/state. + If both of those are keys somewhere within the 2 initial layers of the + nested dict, it's considered old style. + + Args: + d (dict): Loaded config. + + Returns: + bool: Whether it's an old style (jsonpickle) config. + """ + # all 2nd level keys + all_keys = set(sub_key for key in d for sub_key in (d[key] if isinstance(d[key], dict) else [key])) + # add 1st level keys + all_keys.update(d.keys()) + # is old if py/object and py/state somewhere in keys + return set(('py/object', 'py/state')) <= all_keys def fix_waf_lambda(carrier: WAFCarrier) -> None: diff --git a/tests/resources/jsonpickle_config.json b/tests/resources/jsonpickle_config.json new file mode 100644 index 000000000..784f933ce --- /dev/null +++ b/tests/resources/jsonpickle_config.json @@ -0,0 +1,274 @@ +{ + "version": { + "py/object": "medcat.config.VersionInfo", + "py/state": { + "__dict__": { + "history": ["0c0de303b6dc0020"], + "meta_cats": {}, + "cdb_info": {}, + "performance": { + "ner": {}, + "meta": {} + }, + "description": "No description", + "id": null, + "last_modified": null, + "location": null, + "ontology": null, + "medcat_version": null + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "cdb_maker": { + "py/object": "medcat.config.CDBMaker", + "py/state": { + "__dict__": { + "name_versions": [ + "LOWER", + "CLEAN" + ], + "multi_separator": "|", + "remove_parenthesis": 5, + "min_letters_required": 2 + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "annotation_output": { + "py/object": "medcat.config.AnnotationOutput", + "py/state": { + "__dict__": { + "doc_extended_info": false, + "context_left": -1, + "context_right": -1, + "lowercase_context": true, + "include_text_in_output": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "general": { + "py/object": "medcat.config.General", + "py/state": { + "__dict__": { + "spacy_disabled_components": [ + "ner", + "parser", + "vectors", + "textcat", + "entity_linker", + "sentencizer", + "entity_ruler", + "merge_noun_chunks", + "merge_entities", + "merge_subtokens" + ], + "checkpoint": { + "py/object": "medcat.config.CheckPoint", + "py/state": { + "__dict__": { + "output_dir": "checkpoints", + "steps": null, + "max_to_keep": 1 + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "log_level": 20, + "log_format": "%(levelname)s:%(name)s: %(message)s", + "log_path": "./medcat.log", + "spacy_model": "en_core_web_lg", + "separator": "~", + "spell_check": true, + "diacritics": false, + "spell_check_deep": false, + "spell_check_len_limit": 7, + "show_nested_entities": false, + "full_unlink": false, + "workers": 7, + "make_pretty_labels": null, + "map_cui_to_group": false + }, + "__fields_set__": { + "py/set": [ + "spacy_model" + ] + }, + "__private_attribute_values__": {} + } + }, + "preprocessing": { + "py/object": "medcat.config.Preprocessing", + "py/state": { + "__dict__": { + "words_to_skip": { + "py/set": [ + "nos" + ] + }, + "keep_punct": { + "py/set": [ + ".", + ":" + ] + }, + "do_not_normalize": { + "py/set": [ + "VBD", + "VBP", + "VBN", + "JJR", + "JJS", + "VBG" + ] + }, + "skip_stopwords": false, + "min_len_normalize": 5, + "stopwords": { + "py/set": [ + "three", + "two", + "one" + ] + }, + "max_document_length": 1000000 + }, + "__fields_set__": { + "py/set": [ + "stopwords" + ] + }, + "__private_attribute_values__": {} + } + }, + "ner": { + "py/object": "medcat.config.Ner", + "py/state": { + "__dict__": { + "min_name_len": 3, + "max_skip_tokens": 2, + "check_upper_case_names": false, + "upper_case_limit_len": 4, + "try_reverse_word_order": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "linking": { + "py/object": "medcat.config.Linking", + "py/state": { + "__dict__": { + "optim": { + "type": "linear", + "base_lr": 1, + "min_lr": 0.00005 + }, + "context_vector_sizes": { + "xlong": 27, + "long": 18, + "medium": 9, + "short": 3 + }, + "context_vector_weights": { + "xlong": 0.1, + "long": 0.4, + "medium": 0.4, + "short": 0.1 + }, + "filters": { + "py/object": "medcat.config.LinkingFilters", + "py/state": { + "__dict__": { + "cuis": { + "py/set": [] + }, + "cuis_exclude": { + "py/set": [] + } + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "train": true, + "random_replacement_unsupervised": 0.8, + "disamb_length_limit": 3, + "filter_before_disamb": false, + "train_count_threshold": 1, + "always_calculate_similarity": false, + "weighted_average_function": { + "py/object": "medcat.config._DefPartial", + "fun": { + "py/reduce": [ + { + "py/type": "functools.partial" + }, + { + "py/tuple": [ + { + "py/function": "medcat.utils.config_utils.weighted_average" + } + ] + }, + { + "py/tuple": [ + { + "py/function": "medcat.utils.config_utils.weighted_average" + }, + { + "py/tuple": [] + }, + { + "factor": 0.0004 + }, + {} + ] + } + ] + } + }, + "calculate_dynamic_threshold": false, + "similarity_threshold_type": "static", + "similarity_threshold": 0.25, + "negative_probability": 0.5, + "negative_ignore_punct_and_num": true, + "prefer_primary_name": 0.35, + "prefer_frequent_concepts": 0.35, + "subsample_after": 30000, + "devalue_linked_concepts": false, + "context_ignore_center_tokens": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "word_skipper": { + "py/object": "re.Pattern", + "pattern": "^(nos)$" + }, + "punct_checker": { + "py/object": "re.Pattern", + "pattern": "[^a-z0-9]+" + }, + "hash": null + } \ No newline at end of file diff --git a/tests/resources/jsonpickle_meta_cat_config.json b/tests/resources/jsonpickle_meta_cat_config.json new file mode 100644 index 000000000..4da001c6c --- /dev/null +++ b/tests/resources/jsonpickle_meta_cat_config.json @@ -0,0 +1,89 @@ +{ + "general": { + "py/object": "medcat.config_meta_cat.General", + "py/state": { + "__dict__": { + "device": "cpu", + "disable_component_lock": false, + "seed": -100, + "description": "No description", + "category_name": null, + "category_value2id": {}, + "vocab_size": null, + "lowercase": true, + "cntx_left": 15, + "cntx_right": 10, + "replace_center": null, + "batch_size_eval": 5000, + "annotate_overlapping": false, + "tokenizer_name": "bbpe", + "save_and_reuse_tokens": false, + "pipe_batch_size_in_chars": 20000000, + "span_group": null + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "model": { + "py/object": "medcat.config_meta_cat.Model", + "py/state": { + "__dict__": { + "model_name": "lstm", + "model_variant": "bert-base-uncased", + "model_freeze_layers": true, + "num_layers": 2, + "input_size": 300, + "hidden_size": 300, + "dropout": 0.5, + "phase_number": 0, + "category_undersample": "", + "model_architecture_config": { + "fc2": true, + "fc3": false, + "lr_scheduler": true + }, + "num_directions": 2, + "nclasses": 2, + "padding_idx": -1, + "emb_grad": true, + "ignore_cpos": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "train": { + "py/object": "medcat.config_meta_cat.Train", + "py/state": { + "__dict__": { + "batch_size": 100, + "nepochs": 50, + "lr": 0.001, + "test_size": 0.1, + "shuffle_data": true, + "class_weights": null, + "compute_class_weights": false, + "score_average": "weighted", + "prerequisites": {}, + "cui_filter": null, + "auto_save_model": true, + "last_train_on": null, + "metric": { + "base": "weighted avg", + "score": "f1-score" + }, + "loss_funct": "cross_entropy", + "gamma": 2 + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + } + } \ No newline at end of file diff --git a/tests/resources/jsonpickle_rel_cat_config.json b/tests/resources/jsonpickle_rel_cat_config.json new file mode 100644 index 000000000..411caaa52 --- /dev/null +++ b/tests/resources/jsonpickle_rel_cat_config.json @@ -0,0 +1,91 @@ +{ + "general": { + "py/object": "medcat.config_rel_cat.General", + "py/state": { + "__dict__": { + "device": "cpu", + "relation_type_filter_pairs": [], + "vocab_size": null, + "lowercase": true, + "cntx_left": 15, + "cntx_right": 15, + "window_size": 300, + "mct_export_max_non_rel_sample_size": 200, + "mct_export_create_addl_rels": false, + "tokenizer_name": "bert", + "model_name": "bert-base-uncased", + "log_level": 20, + "max_seq_length": 512, + "tokenizer_special_tokens": false, + "annotation_schema_tag_ids": [], + "labels2idx": {}, + "idx2labels": {}, + "pin_memory": true, + "seed": 13, + "task": "train" + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "model": { + "py/object": "medcat.config_rel_cat.Model", + "py/state": { + "__dict__": { + "input_size": 300, + "hidden_size": 768, + "hidden_layers": 3, + "model_size": 5120, + "dropout": 0.2, + "num_directions": 2, + "padding_idx": -1, + "emb_grad": true, + "ignore_cpos": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + }, + "train": { + "py/object": "medcat.config_rel_cat.Train", + "py/state": { + "__dict__": { + "nclasses": 2, + "batch_size": 25, + "nepochs": 1, + "lr": 100000, + "adam_epsilon": 0.0001, + "test_size": 0.2, + "gradient_acc_steps": 1, + "multistep_milestones": [ + 2, + 4, + 6, + 8, + 12, + 15, + 18, + 20, + 22, + 24, + 26, + 30 + ], + "multistep_lr_gamma": 0.8, + "max_grad_norm": 1, + "shuffle_data": true, + "class_weights": null, + "score_average": "weighted", + "auto_save_model": true + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + } + } \ No newline at end of file diff --git a/tests/resources/jsonpickle_tner_config.json b/tests/resources/jsonpickle_tner_config.json new file mode 100644 index 000000000..eb3639453 --- /dev/null +++ b/tests/resources/jsonpickle_tner_config.json @@ -0,0 +1,23 @@ +{ + "general": { + "py/object": "medcat.config_transformers_ner.General", + "py/state": { + "__dict__": { + "name": "deid", + "model_name": "roberta-base", + "seed": 13, + "description": "No description", + "pipe_batch_size_in_chars": -100, + "ner_aggregation_strategy": "simple", + "chunking_overlap_window": 5, + "test_size": 0.2, + "last_train_on": null, + "verbose_metrics": false + }, + "__fields_set__": { + "py/set": [] + }, + "__private_attribute_values__": {} + } + } + } \ No newline at end of file diff --git a/tests/utils/test_config_utils.py b/tests/utils/test_config_utils.py index 713d15bb0..d1a7262e7 100644 --- a/tests/utils/test_config_utils.py +++ b/tests/utils/test_config_utils.py @@ -1,7 +1,12 @@ from medcat.config import Config from medcat.utils.saving.coding import default_hook, CustomDelegatingEncoder from medcat.utils import config_utils +from medcat import config as main_config +from medcat import config_meta_cat +from medcat import config_transformers_ner +from medcat import config_rel_cat import json +import os import unittest @@ -48,3 +53,69 @@ def test_identifies_old_style_dict(self): def test_identifies_new_style_dict(self): self.assertFalse(config_utils.is_old_type_config_dict(NEW_STYLE_DICT)) + + +class OldFormatJsonTests(unittest.TestCase): + + def assert_knows_old_format(self, file_path: str): + with open(file_path) as f: + d = json.load(f) + self.assertTrue(config_utils.is_old_type_config_dict(d)) + + +class OldConfigLoadTests(OldFormatJsonTests): + JSON_PICKLE_FILE_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_config.json" + ) + EXPECTED_VERSION_HISTORY = ['0c0de303b6dc0020',] + + def test_knows_is_old_format(self): + self.assert_knows_old_format(self.JSON_PICKLE_FILE_PATH) + + def test_loads_old_style_correctly(self): + cnf: main_config.Config = main_config.Config.load(self.JSON_PICKLE_FILE_PATH) + self.assertEqual(cnf.version.history, self.EXPECTED_VERSION_HISTORY) + + +class MetaCATConfigTests(OldFormatJsonTests): + META_CAT_OLD_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_meta_cat_config.json" + ) + EXPECTED_TARGET = -100 + TARGET_CLASS = config_meta_cat.ConfigMetaCAT + + @classmethod + def get_target(cls, cnf): + return cnf.general.seed + + def test_knows_is_old_format(self): + self.assert_knows_old_format(self.META_CAT_OLD_PATH) + + def test_can_load_old_format_correctly(self): + cnf = self.TARGET_CLASS.load(self.META_CAT_OLD_PATH) + self.assertIsInstance(cnf, self.TARGET_CLASS) + self.assertEqual(self.get_target(cnf), self.EXPECTED_TARGET) + + +class TNERCATConfigTests(MetaCATConfigTests): + META_CAT_OLD_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_tner_config.json" + ) + EXPECTED_TARGET = -100 + TARGET_CLASS = config_transformers_ner.ConfigTransformersNER + + @classmethod + def get_target(cls, cnf): + return cnf.general.pipe_batch_size_in_chars + + +class RelCATConfigTests(MetaCATConfigTests): + META_CAT_OLD_PATH = os.path.join( + os.path.dirname(__file__), "..", "resources", "jsonpickle_rel_cat_config.json" + ) + EXPECTED_TARGET = 100_000 + TARGET_CLASS = config_rel_cat.ConfigRelCAT + + @classmethod + def get_target(cls, cnf): + return cnf.train.lr