Skip to content

Commit

Permalink
refactor the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Nov 20, 2024
1 parent d202b7f commit cfd9dd2
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.52rc26
0.9.52rc28
22 changes: 14 additions & 8 deletions src/syngen/ml/config/configurations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Dict, Tuple, Set, List, Callable
from typing import Optional, Dict, Tuple, Set, List, Callable, Literal
import os
import shutil
from datetime import datetime
Expand Down Expand Up @@ -84,14 +84,12 @@ def _check_sample_report(self):
"""
Check whether it is necessary to generate a certain report
"""
reports = self.metadata[self.table_name].get("train_settings").get("reports", [])
if "sample" in reports and self.initial_data_shape[0] == self.row_subset:
if "sample" in self.reports and self.initial_data_shape[0] == self.row_subset:
logger.warning(
"The generation of sampling report is unnecessary and will not be produced "
"as the source data and sampled data sizes are identical."
)
reports.remove("sample")
self.metadata[self.table_name]["train_settings"]["reports"] = reports
self.reports.remove("sample")

def _check_reports(self):
"""
Expand Down Expand Up @@ -318,6 +316,7 @@ class InferConfig:
both_keys: bool
log_level: str
loader: Optional[Callable[[str], pd.DataFrame]]
type_of_process: Literal["train", "infer"]
slugify_table_name: str = field(init=False)

def __post_init__(self):
Expand Down Expand Up @@ -367,12 +366,13 @@ def to_dict(self) -> Dict:
"reports": self.reports,
}

def _check_reports(self):
def _check_required_artifacts(self):
"""
Check whether it is possible to generate reports
Check whether required artifacts exists
"""
if (
self.reports
self.type_of_process == "infer"
and self.reports
and (
not DataLoader(self.paths["input_data_path"]).has_existed_path
and not self.loader
Expand All @@ -389,6 +389,12 @@ def _check_reports(self):
logger.warning(log_message)
self.reports = list()

def _check_reports(self):
"""
Check whether it is possible to generate reports
"""
self._check_required_artifacts()

def _set_up_size(self):
"""
Set up "size" of generated data
Expand Down
1 change: 1 addition & 0 deletions src/syngen/ml/reporters/reporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def _launch_reporter(cls, reporter, delta: float):

if (
reporter.__class__.report_type == "accuracy"
and "accuracy" not in reporter.config["reports"]
and "metrics_only" in reporter.config["reports"]
):
message = (
Expand Down
20 changes: 3 additions & 17 deletions src/syngen/ml/strategies/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,24 +200,10 @@ def run(self, **kwargs):
"""
table_name = kwargs["table_name"]
try:
self.set_config(
destination=kwargs["destination"],
size=kwargs["size"],
table_name=kwargs["table_name"],
metadata=kwargs["metadata"],
metadata_path=kwargs["metadata_path"],
run_parallel=kwargs["run_parallel"],
batch_size=kwargs["batch_size"],
random_seed=kwargs["random_seed"],
reports=kwargs["reports"],
log_level=kwargs["log_level"],
both_keys=kwargs["both_keys"],
loader=kwargs["loader"]
)
self.set_config(**kwargs)
MlflowTracker().log_params(self.config.to_dict())
self.add_reporters().add_handler(
type_of_process=kwargs["type_of_process"]
)
self.add_reporters()
self.add_handler(type_of_process=kwargs["type_of_process"])
self.handler.handle()
except Exception:
logger.error(
Expand Down
12 changes: 0 additions & 12 deletions src/syngen/ml/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,11 @@ def _update_metadata_for_tables(self):
self._update_table_settings(table_settings, global_settings)
self._update_table_settings(table_settings, self.settings)

def _update_reports_in_metadata(self):
"""
Remove the report 'metrics_only' from the list of reports
in case both 'accuracy' and 'metrics_only' reports have been mentioned
"""
for table, config in self.metadata.items():
reports = config.get(f"{self.type_of_process}_settings", {}).get("reports", [])
if "accuracy" in reports and "metrics_only" in reports:
reports.remove("metrics_only")
config[f"{self.type_of_process}_settings"]["reports"] = reports

def _update_metadata(self) -> None:
if self.metadata_path:
self._update_metadata_for_tables()
if self.table_name:
self._update_metadata_for_table()
self._update_reports_in_metadata()

def __fetch_metadata(self) -> Dict:
"""
Expand Down

0 comments on commit cfd9dd2

Please sign in to comment.