Skip to content

Commit

Permalink
add the class ReportTypes in 'ml/validation_schema'
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Nov 14, 2024
1 parent 490cec8 commit 7d591c7
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 24 deletions.
6 changes: 3 additions & 3 deletions src/syngen/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
check_if_logs_available
)
from syngen.ml.utils import validate_parameter_reports
from syngen.ml.validation_schema import INFER_REPORT_TYPES
from syngen.ml.validation_schema import ReportTypes


validate_reports = validate_parameter_reports(
report_types=INFER_REPORT_TYPES,
full_list=["accuracy"]
report_types=ReportTypes().infer_report_types,
full_list=ReportTypes().full_list_of_infer_report_types
)


Expand Down
4 changes: 2 additions & 2 deletions src/syngen/ml/config/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from slugify import slugify
from loguru import logger
from syngen.ml.data_loaders import MetadataLoader, DataLoader
from syngen.ml.validation_schema import ValidationSchema, INFER_REPORT_TYPES
from syngen.ml.validation_schema import ValidationSchema, ReportTypes


@dataclass
Expand Down Expand Up @@ -69,7 +69,7 @@ def _check_conditions(self, metadata: Dict) -> bool:
self.type_of_process == "infer"
or (
self.type_of_process == "train" and
any([item in INFER_REPORT_TYPES for item in reports])
any([item in ReportTypes().infer_report_types for item in reports])
)
)

Expand Down
5 changes: 3 additions & 2 deletions src/syngen/ml/data_loaders/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from syngen.ml.validation_schema import (
ExcelFormatSettingsSchema,
CSVFormatSettingsSchema,
ReportTypes
)

DELIMITERS = {"\\t": "\t"}
Expand Down Expand Up @@ -409,8 +410,8 @@ class YAMLLoader(BaseDataLoader):
Class for loading and saving data in YAML format
"""
metadata_sections = ["train_settings", "infer_settings", "format", "keys"]
infer_reports = ["accuracy"]
train_reports = infer_reports + ["sample"]
infer_reports = ReportTypes().full_list_of_infer_report_types
train_reports = ReportTypes().full_list_of_train_report_types

def __init__(self, path: str):
super().__init__(path)
Expand Down
3 changes: 1 addition & 2 deletions src/syngen/ml/validation_schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@
KeysSchema,
ValidationSchema,
SUPPORTED_EXCEL_EXTENSIONS,
TRAIN_REPORT_TYPES,
INFER_REPORT_TYPES
ReportTypes
)
30 changes: 26 additions & 4 deletions src/syngen/ml/validation_schema/validation_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,30 @@
from loguru import logger

SUPPORTED_EXCEL_EXTENSIONS = [".xls", ".xlsx"]
INFER_REPORT_TYPES = ["accuracy", "metrics_only"]
TRAIN_REPORT_TYPES = INFER_REPORT_TYPES + ["sample"]


class ReportTypes:
def __init__(self):
self.infer_report_types = ["accuracy", "metrics_only"]
self.train_report_types = self.infer_report_types + ["sample"]
self.excluded_reports = ["metrics_only"]
self.full_list_of_train_report_types = self.get_list_of_report_types("train")
self.full_list_of_infer_report_types = self.get_list_of_report_types("infer")

def get_list_of_report_types(self, type_of_process: Literal["train", "infer"]):
"""
Get the full list of reports that should be generated
if the parameter 'reports' sets to 'all'
"""
report_types = (
self.train_report_types
if type_of_process == "train"
else self.infer_report_types
)
full_list = report_types.copy()
for report in self.excluded_reports:
full_list.remove(report)
return full_list


class ReferenceSchema(Schema):
Expand Down Expand Up @@ -82,7 +104,7 @@ class TrainingSettingsSchema(Schema):
required=False,
validate=(
lambda x: isinstance(x, list) and
all(isinstance(elem, str) and elem in TRAIN_REPORT_TYPES for elem in x)
all(isinstance(elem, str) and elem in ReportTypes().train_report_types for elem in x)
)
)

Expand All @@ -109,7 +131,7 @@ class InferSettingsSchema(Schema):
required=False,
validate=(
lambda x: isinstance(x, list) and
all(isinstance(elem, str) and elem in INFER_REPORT_TYPES for elem in x)
all(isinstance(elem, str) and elem in ReportTypes().infer_report_types for elem in x)
)
)

Expand Down
14 changes: 8 additions & 6 deletions src/syngen/ml/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from syngen.ml.context.context import global_context
from syngen.ml.utils import ProgressBarHandler
from syngen.ml.mlflow_tracker import MlflowTracker
from syngen.ml.validation_schema import INFER_REPORT_TYPES
from syngen.ml.validation_schema import ReportTypes


@define
Expand Down Expand Up @@ -245,8 +245,7 @@ def _split_pk_fk_metadata(self, config, tables):
@staticmethod
def _should_generate_data(
config_of_tables: Dict,
type_of_process: str,
list_of_reports: List[str]
type_of_process: str
):
"""
Determine whether the synthetic data should be generated
Expand All @@ -255,7 +254,8 @@ def _should_generate_data(
return any(
[
report in config.get(f"{type_of_process}_settings", {}).get("reports", [])
for report in list_of_reports for config in config_of_tables.values()
for report in ReportTypes().infer_report_types
for config in config_of_tables.values()
]
)

Expand Down Expand Up @@ -486,7 +486,8 @@ def launch_train(self):
) = metadata_for_inference

generation_of_reports = self._should_generate_data(
metadata_for_training, "train", list_of_reports=INFER_REPORT_TYPES
metadata_for_training,
"train"
)

self.__train_tables(
Expand Down Expand Up @@ -521,7 +522,8 @@ def launch_infer(self):
tables, config_of_tables = self._prepare_metadata_for_process(type_of_process="infer")

generation_of_reports = self._should_generate_data(
config_of_tables, "infer", list_of_reports=INFER_REPORT_TYPES
config_of_tables,
"infer"
)
delta = 0.25 / len(tables) if generation_of_reports else 0.5 / len(tables)

Expand Down
6 changes: 3 additions & 3 deletions src/syngen/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
check_if_logs_available
)
from syngen.ml.utils import validate_parameter_reports
from syngen.ml.validation_schema import TRAIN_REPORT_TYPES
from syngen.ml.validation_schema import ReportTypes


validate_reports = validate_parameter_reports(
report_types=TRAIN_REPORT_TYPES,
full_list=["accuracy", "sample"]
report_types=ReportTypes().train_report_types,
full_list=ReportTypes().full_list_of_train_report_types
)


Expand Down
3 changes: 2 additions & 1 deletion src/tests/unit/launchers/test_launch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from syngen.infer import launch_infer
from syngen.ml.worker import Worker
from syngen.ml.validation_schema import INFER_REPORT_TYPES
from syngen.ml.validation_schema import ReportTypes
from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME


TABLE_NAME = "test_table"
PATH_TO_METADATA = f"{DIR_NAME}/unit/launchers/fixtures/metadata.yaml"
INFER_REPORT_TYPES = ReportTypes().infer_report_types


@patch.object(Worker, "launch_infer")
Expand Down
3 changes: 2 additions & 1 deletion src/tests/unit/launchers/test_launch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from syngen.train import launch_train
from syngen.ml.worker import Worker
from syngen.ml.validation_schema import TRAIN_REPORT_TYPES
from syngen.ml.validation_schema import ReportTypes
from tests.conftest import SUCCESSFUL_MESSAGE, DIR_NAME

TABLE_NAME = "test_table"
PATH_TO_TABLE = f"{DIR_NAME}/unit/launchers/fixtures/table_with_data.csv"
PATH_TO_METADATA = f"{DIR_NAME}/unit/launchers/fixtures/metadata.yaml"
TRAIN_REPORT_TYPES = ReportTypes().train_report_types


@patch.object(Worker, "launch_train")
Expand Down

0 comments on commit 7d591c7

Please sign in to comment.