Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8694cd9t2: Allow merging config into model pack config before init #462

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def load_model_pack(cls,
zip_path: str,
meta_cat_config_dict: Optional[Dict] = None,
ner_config_dict: Optional[Dict] = None,
medcat_config_dict: Optional[Dict] = None,
load_meta_models: bool = True,
load_addl_ner: bool = True,
load_rel_models: bool = True) -> "CAT":
Expand All @@ -371,6 +372,10 @@ def load_model_pack(cls,
A config dict that will overwrite existing configs in transformers ner.
e.g. ner_config_dict = {'general': {'chunking_overlap_window': 6}.
Defaults to None.
medcat_config_dict (Optional[Dict]):
A config dict that will overwrite existing configs in the main medcat config
before pipe initialisation. This can be useful if wanting to change something
that only takes effect at init time (e.g spacy model). Defaults to None.
load_meta_models (bool):
Whether to load MetaCAT models if present (Default value True).
load_addl_ner (bool):
Expand All @@ -393,7 +398,7 @@ def load_model_pack(cls,

# load config
config_path = os.path.join(model_pack_path, "config.json")
cdb.load_config(config_path)
cdb.load_config(config_path, medcat_config_dict)

# TODO load addl_ner

Expand Down
14 changes: 13 additions & 1 deletion medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,17 @@ async def save_async(self, path: str) -> None:
}
await f.write(dill.dumps(to_save))

def load_config(self, config_path: str) -> None:
def load_config(self, config_path: str, config_dict: Optional[Dict] = None) -> None:
"""Load the config from disk.

Args:
config_path (str): The path to the config file.
config_dict (Optional[Dict]): A config to merge with.

Raises:
ValueError: If a config was not found in CDB nor as a separate json.
Or if a config was found both in CDB as well as a separate json.
"""
if not os.path.exists(config_path):
if not self._config_from_file:
# if there's no config defined anywhere
Expand Down Expand Up @@ -544,6 +554,8 @@ def load_config(self, config_path: str) -> None:
# new config, potentially new weighted_average_function to read
self._init_waf_from_config()
# mark config read from file
if config_dict:
self.config.merge_config(config_dict)
self._config_from_file = True

@classmethod
Expand Down
Loading