Skip to content

Commit

Permalink
add the method 'unregister_reporters' for the class Report, update un…
Browse files Browse the repository at this point in the history
…it tests
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Oct 28, 2024
2 parents f549636 + 630232d commit 12e53d1
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.48
0.9.49rc0
10 changes: 10 additions & 0 deletions src/syngen/ml/handlers/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from syngen.ml.vae import * # noqa: F403
from syngen.ml.data_loaders import DataLoader, DataFrameFetcher
from syngen.ml.reporters import Report
from syngen.ml.vae.models.dataset import Dataset
from syngen.ml.utils import (
fetch_config,
Expand Down Expand Up @@ -492,6 +493,15 @@ def handle(self, **kwargs):
else pd.DataFrame()
)
prepared_data = self._restore_empty_columns(prepared_data)
# workaround for the case when all columns are dropped
# with technical column
if self.dataset.tech_columns:
tech_columns = list(self.dataset.tech_columns)
prepared_data = prepared_data.drop(tech_columns, axis=1)
Report().unregister_reporters(tech_columns)
logger.debug(f"Technical columns "
f"{self.dataset.tech_columns} are dropped"
f" from the generated table")

is_pk = self._is_pk()

Expand Down
19 changes: 17 additions & 2 deletions src/syngen/ml/reporters/reporters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
from abc import abstractmethod
from typing import Dict, Tuple, Optional, Callable
from typing import (
Dict,
Tuple,
Optional,
Callable,
Union,
List
)
import itertools
from collections import defaultdict

Expand Down Expand Up @@ -188,7 +195,7 @@ class Report:
Singleton metaclass for registration all needed reporters
"""

_reporters: Dict[str, Reporter] = {}
_reporters: Dict[str, Union[Reporter, List]] = {}

def __new__(cls):
if not hasattr(cls, "instance"):
Expand All @@ -204,6 +211,14 @@ def register_reporter(cls, table: str, reporter: Reporter):
list_of_reporters.append(reporter)
cls._reporters[table] = list_of_reporters

@classmethod
def unregister_reporters(cls, tables: List[str]):
"""
Unregister all reporters for tables
"""
for table in tables:
cls._reporters[table] = list()

@classmethod
def clear_report(cls):
"""
Expand Down
15 changes: 15 additions & 0 deletions src/syngen/ml/vae/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
self.uuid_columns: Set = set()
self.uuid_columns_types: Dict = dict()
self.dropped_columns: Set = set()
self.tech_columns: Set = set()
self.order_of_columns: List = list()
self.custom_categorical_columns: Set = set()
self.categorical_columns: Set = set()
Expand Down Expand Up @@ -1320,6 +1321,20 @@ def pipeline(self) -> pd.DataFrame:
elif column in self.uuid_columns:
logger.info(f"Column '{column}' defined as UUID column")
self._assign_uuid_null_feature(column)

# workaround for the case when all columns are dropped
# add a technical column to proceed with the training process
if not self.features:
tech_column = "syngen_tech_column"
logger.warning(
f"There are no columns left to train on for '{self.table_name}'. "
f"Adding a technical column '{tech_column}' to proceed "
f"with the training process."
)
self.df[tech_column] = 1
self._assign_float_feature(tech_column)
self.tech_columns.add(tech_column)

self.fit()

# The end of the run related to the preprocessing stage
Expand Down
1 change: 1 addition & 0 deletions src/tests/unit/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_save_dataset(rp_logger):
"float_columns",
"int_columns",
"date_columns",
"tech_columns",
"date_mapping",
"binary_columns",
"email_columns",
Expand Down

0 comments on commit 12e53d1

Please sign in to comment.