Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CU-8693bpq82 fallback spacy model #384

Merged
merged 3 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion medcat/cdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def load_config(self, config_path: str) -> None:
if not os.path.exists(config_path):
if not self._config_from_file:
# if there's no config defined anywhere
raise ValueError("Could not find a config in the CDB nor ",
raise ValueError("Could not find a config in the CDB nor "
"in the config.json for this model "
f"({os.path.dirname(config_path)})",
)
Expand Down
23 changes: 22 additions & 1 deletion medcat/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
logger = logging.getLogger(__name__) # different logger from the package-level one


DEFAULT_SPACY_MODEL = 'en_core_web_md'


class Pipe(object):
"""A wrapper around the standard spacy pipeline.

Expand All @@ -38,7 +41,22 @@ class Pipe(object):
"""

def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
self._nlp = spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components)
try:
self._nlp = self._init_nlp(config)
except Exception as e:
if config.general.spacy_model == DEFAULT_SPACY_MODEL:
raise e
logger.warning("Could not load spacy model from '%s'. "
"Falling back to installed en_core_web_md. "
"For best compatibility, we'd recommend "
"packaging and using your model pack with "
"the spacy model it was designed for",
config.general.spacy_model, exc_info=e)
# we're changing the config value so that this propages
# to other places that try to load the model. E.g:
# medcat.utils.normalizers.TokenNormalizer.__init__
config.general.spacy_model = DEFAULT_SPACY_MODEL
self._nlp = self._init_nlp(config)
if config.preprocessing.stopwords is not None:
self._nlp.Defaults.stop_words = set(config.preprocessing.stopwords)
self._nlp.tokenizer = tokenizer(self._nlp, config)
Expand All @@ -48,6 +66,9 @@ def __init__(self, tokenizer: Tokenizer, config: Config) -> None:
# Set log level
logger.setLevel(self.config.general.log_level)

def _init_nlp(selef, config: Config) -> Language:
return spacy.load(config.general.spacy_model, disable=config.general.spacy_disabled_components)

def add_tagger(self, tagger: Callable, name: Optional[str] = None, additional_fields: List[str] = []) -> None:
"""Add any kind of a tagger for tokens.

Expand Down
29 changes: 29 additions & 0 deletions tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from medcat.vocab import Vocab
from medcat.cdb import CDB, logger as cdb_logger
from medcat.cat import CAT, logger as cat_logger
from medcat.pipe import logger as pipe_logger
from medcat.utils.checkpoint import Checkpoint
from medcat.meta_cat import MetaCAT
from medcat.config_meta_cat import ConfigMetaCAT
Expand Down Expand Up @@ -499,6 +500,34 @@ def test_loading_model_pack_with_cdb_config_and_config_json_raises_exception(sel
CAT.load_model_pack(self.model_path)


class ModelLoadsUnreadableSpacy(unittest.TestCase):

@classmethod
def setUpClass(cls) -> None:
cls.temp_dir = tempfile.TemporaryDirectory()
model_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples")
cdb = CDB.load(os.path.join(model_path, 'cdb.dat'))
cdb.config.general.spacy_model = os.path.join(cls.temp_dir.name, "en_core_web_md")
# save CDB in new location
cdb.save(os.path.join(cls.temp_dir.name, 'cdb.dat'))
# save config in new location
cdb.config.save(os.path.join(cls.temp_dir.name, 'config.json'))
# copy vocab into new location
vocab_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "vocab.dat")
cls.vocab_path = os.path.join(cls.temp_dir.name, 'vocab.dat')
shutil.copyfile(vocab_path, cls.vocab_path)

@classmethod
def tearDownClass(cls) -> None:
# REMOVE temp dir
cls.temp_dir.cleanup()

def test_loads_without_specified_spacy_model(self):
with self.assertLogs(logger=pipe_logger, level=logging.WARNING):
cat = CAT.load_model_pack(self.temp_dir.name)
self.assertTrue(isinstance(cat, CAT))


class ModelWithZeroConfigsLoadTests(unittest.TestCase):

@classmethod
Expand Down