From cfd9dd2448cce78b7c8ab037421ac7c77dd72441 Mon Sep 17 00:00:00 2001 From: Hanna Imshenetska Date: Wed, 20 Nov 2024 19:18:59 +0000 Subject: [PATCH] refactor the code --- src/syngen/VERSION | 2 +- src/syngen/ml/config/configurations.py | 22 ++++++++++++++-------- src/syngen/ml/reporters/reporters.py | 1 + src/syngen/ml/strategies/strategies.py | 20 +++----------------- src/syngen/ml/worker/worker.py | 12 ------------ 5 files changed, 19 insertions(+), 38 deletions(-) diff --git a/src/syngen/VERSION b/src/syngen/VERSION index 86612fcf..ccade86c 100644 --- a/src/syngen/VERSION +++ b/src/syngen/VERSION @@ -1 +1 @@ -0.9.52rc26 +0.9.52rc28 diff --git a/src/syngen/ml/config/configurations.py b/src/syngen/ml/config/configurations.py index bdb7c898..9afd479e 100644 --- a/src/syngen/ml/config/configurations.py +++ b/src/syngen/ml/config/configurations.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Dict, Tuple, Set, List, Callable +from typing import Optional, Dict, Tuple, Set, List, Callable, Literal import os import shutil from datetime import datetime @@ -84,14 +84,12 @@ def _check_sample_report(self): """ Check whether it is necessary to generate a certain report """ - reports = self.metadata[self.table_name].get("train_settings").get("reports", []) - if "sample" in reports and self.initial_data_shape[0] == self.row_subset: + if "sample" in self.reports and self.initial_data_shape[0] == self.row_subset: logger.warning( "The generation of sampling report is unnecessary and will not be produced " "as the source data and sampled data sizes are identical." ) - reports.remove("sample") - self.metadata[self.table_name]["train_settings"]["reports"] = reports + self.reports.remove("sample") def _check_reports(self): """ @@ -318,6 +316,7 @@ class InferConfig: both_keys: bool log_level: str loader: Optional[Callable[[str], pd.DataFrame]] + type_of_process: Literal["train", "infer"] slugify_table_name: str = field(init=False) def __post_init__(self): @@ -367,12 +366,13 @@ def to_dict(self) -> Dict: "reports": self.reports, } - def _check_reports(self): + def _check_required_artifacts(self): """ - Check whether it is possible to generate reports + Check whether required artifacts exists """ if ( - self.reports + self.type_of_process == "infer" + and self.reports and ( not DataLoader(self.paths["input_data_path"]).has_existed_path and not self.loader @@ -389,6 +389,12 @@ def _check_reports(self): logger.warning(log_message) self.reports = list() + def _check_reports(self): + """ + Check whether it is possible to generate reports + """ + self._check_required_artifacts() + def _set_up_size(self): """ Set up "size" of generated data diff --git a/src/syngen/ml/reporters/reporters.py b/src/syngen/ml/reporters/reporters.py index 094f6b53..d161efc4 100644 --- a/src/syngen/ml/reporters/reporters.py +++ b/src/syngen/ml/reporters/reporters.py @@ -274,6 +274,7 @@ def _launch_reporter(cls, reporter, delta: float): if ( reporter.__class__.report_type == "accuracy" + and "accuracy" not in reporter.config["reports"] and "metrics_only" in reporter.config["reports"] ): message = ( diff --git a/src/syngen/ml/strategies/strategies.py b/src/syngen/ml/strategies/strategies.py index 3ed0988a..aca9457d 100644 --- a/src/syngen/ml/strategies/strategies.py +++ b/src/syngen/ml/strategies/strategies.py @@ -200,24 +200,10 @@ def run(self, **kwargs): """ table_name = kwargs["table_name"] try: - self.set_config( - destination=kwargs["destination"], - size=kwargs["size"], - table_name=kwargs["table_name"], - metadata=kwargs["metadata"], - metadata_path=kwargs["metadata_path"], - run_parallel=kwargs["run_parallel"], - batch_size=kwargs["batch_size"], - random_seed=kwargs["random_seed"], - reports=kwargs["reports"], - log_level=kwargs["log_level"], - both_keys=kwargs["both_keys"], - loader=kwargs["loader"] - ) + self.set_config(**kwargs) MlflowTracker().log_params(self.config.to_dict()) - self.add_reporters().add_handler( - type_of_process=kwargs["type_of_process"] - ) + self.add_reporters() + self.add_handler(type_of_process=kwargs["type_of_process"]) self.handler.handle() except Exception: logger.error( diff --git a/src/syngen/ml/worker/worker.py b/src/syngen/ml/worker/worker.py index 54fabade..88556c3e 100644 --- a/src/syngen/ml/worker/worker.py +++ b/src/syngen/ml/worker/worker.py @@ -114,23 +114,11 @@ def _update_metadata_for_tables(self): self._update_table_settings(table_settings, global_settings) self._update_table_settings(table_settings, self.settings) - def _update_reports_in_metadata(self): - """ - Remove the report 'metrics_only' from the list of reports - in case both 'accuracy' and 'metrics_only' reports have been mentioned - """ - for table, config in self.metadata.items(): - reports = config.get(f"{self.type_of_process}_settings", {}).get("reports", []) - if "accuracy" in reports and "metrics_only" in reports: - reports.remove("metrics_only") - config[f"{self.type_of_process}_settings"]["reports"] = reports - def _update_metadata(self) -> None: if self.metadata_path: self._update_metadata_for_tables() if self.table_name: self._update_metadata_for_table() - self._update_reports_in_metadata() def __fetch_metadata(self) -> Dict: """