From 46b66cfe4be839c7b74badb84292bacb34ef3ae5 Mon Sep 17 00:00:00 2001 From: mart-r Date: Mon, 19 Aug 2024 12:25:49 +0100 Subject: [PATCH] CU-86956du3q: Add usage of regression suite name from the name of the file being read --- medcat/utils/regression/checking.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/medcat/utils/regression/checking.py b/medcat/utils/regression/checking.py index 46e9b307..754cf805 100644 --- a/medcat/utils/regression/checking.py +++ b/medcat/utils/regression/checking.py @@ -4,6 +4,7 @@ import logging import tqdm import datetime +import os from pydantic import BaseModel, Field @@ -285,9 +286,9 @@ class RegressionSuite: use_report (bool): Whether or not to use the report functionality (defaults to False) """ - def __init__(self, cases: List[RegressionCase], metadata: MetaData) -> None: + def __init__(self, cases: List[RegressionCase], metadata: MetaData, name: str) -> None: self.cases: List[RegressionCase] = cases - self.report = MultiDescriptor(name='ALL') # TODO - allow setting names + self.report = MultiDescriptor(name=name) self.metadata = metadata for case in self.cases: self.report.parts.append(case.report) @@ -385,7 +386,7 @@ def __eq__(self, other: object) -> bool: return self.cases == other.cases @classmethod - def from_dict(cls, in_dict: dict) -> 'RegressionSuite': + def from_dict(cls, in_dict: dict, name: str) -> 'RegressionSuite': """Construct a RegressionChecker from a dict. Most of the parsing is handled in RegressionChecker.from_dict. @@ -393,7 +394,8 @@ def from_dict(cls, in_dict: dict) -> 'RegressionSuite': and each value describes a RegressionCase. Args: - in_dict (dict): The input dict + in_dict (dict): The input dict. + name (str): The name of the regression suite. Returns: RegressionChecker: The built regression checker @@ -409,7 +411,7 @@ def from_dict(cls, in_dict: dict) -> 'RegressionSuite': metadata = MetaData.unknown() else: metadata = MetaData.parse_obj(in_dict['meta']) - return RegressionSuite(cases=cases, metadata=metadata) + return RegressionSuite(cases=cases, metadata=metadata, name=name) @classmethod def from_yaml(cls, file_name: str) -> 'RegressionSuite': @@ -425,14 +427,14 @@ def from_yaml(cls, file_name: str) -> 'RegressionSuite': """ with open(file_name) as f: data = yaml.safe_load(f) - return RegressionSuite.from_dict(data) + return RegressionSuite.from_dict(data, name=os.path.basename(file_name)) @classmethod def from_mct_export(cls, file_name: str) -> 'RegressionSuite': with open(file_name) as f: data = json.load(f) converted = MedCATTrainerExportConverter(data).convert() - return RegressionSuite.from_dict(converted) + return RegressionSuite.from_dict(converted, name=os.path.basename(file_name)) class MalformedRegressionCaseException(ValueError):