Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement handling of JSON columns #485

Open
wants to merge 65 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
211551b
implement the flattening of json columns
Feb 16, 2024
d9847e2
refactor the porcess of flattening json columns before training proce…
Feb 19, 2024
c6dbd6e
update 'VERSION'
Feb 19, 2024
ee0606d
update 'setup.cfg'
Feb 19, 2024
b38ff9f
update 'setup.cfg'
Feb 19, 2024
671cee9
refactor the code
Feb 20, 2024
00b00f6
revert changs related to flattening, fix handling of UUID columns
Feb 20, 2024
c9ea15a
refactor the process of unflattening
Feb 22, 2024
f93f575
update the methods 'restore_empty_values', 'check_none_values' in 'sy…
Feb 22, 2024
46801b2
refactor the code
Feb 22, 2024
8edfab9
fix issues raised by 'flake8'
Feb 22, 2024
17218cf
resolve conflicts
Feb 22, 2024
e12a8e0
fix issues raised by 'flake8'
Feb 22, 2024
cd2a46b
add unit tests to check the flattening process
Feb 22, 2024
0c479ec
add unit tests to check unflattening process
Feb 23, 2024
24211cf
revert changes
Feb 23, 2024
110e220
refactor the code
Feb 23, 2024
8fed3f1
fix the method 'handle' of the class VaeInferHandler
Feb 23, 2024
6638216
move flattening process to preprocess step before the start of traini…
Feb 25, 2024
fe4191c
update unit tests, fix issues raised by 'flake8'
Feb 25, 2024
dd17234
update 'VERSION'
Feb 25, 2024
953d7d9
move the logic of preprocessing to the separate class 'PreprocessHand…
Feb 25, 2024
a3778c5
update unit tests, update 'VERSION'
Feb 25, 2024
fd3c174
refactor the class PreprocessHandler in 'syngen/ml/preprocess'
Feb 25, 2024
5471640
refactor the code in 'syngen/ml/preprocess', update unit tests
Feb 26, 2024
da3642f
refactor the code
Feb 26, 2024
7eb2f33
refactor the code
Feb 27, 2024
7ec2e90
refactor the code, update unit tests
Feb 27, 2024
9a953b6
update unit tests
Feb 27, 2024
50b4118
update 'VERSION'
Feb 27, 2024
9ad5f23
refactor the method '_restore_empty_columns' of the class Postprocess…
Feb 27, 2024
eb8bcc7
refactor 'syngen/ml/processors'
Feb 27, 2024
81f4634
refactor the code in 'syngen/ml/processors'
Feb 28, 2024
9b9bafa
update 'VERSION'
Feb 28, 2024
d737c4e
refactor processors
Feb 28, 2024
1899247
minor change in 'syngen/ml/config'
Feb 28, 2024
b43644a
update unit tests
Feb 29, 2024
82cb7ec
minor changes in 'syngen/ml/processors', '/syngen/ml/strategies'
Feb 29, 2024
00867f6
refactor the process of the validation of the metadata
Mar 1, 2024
0e222f3
refactor the validation process
Mar 1, 2024
bad389d
resolve conflicts
Mar 5, 2024
6ac393b
resolve conflicts, update 'VERSION'
Mar 14, 2024
2e287a1
resolve conflicts
Mar 19, 2024
c37f811
update the method '_post_process_generated_data' of the class Postpro…
Mar 19, 2024
886af22
fix the method '_post_process_generated_data' of the class Postproces…
Mar 19, 2024
5af043f
fix the method '_post_process_generated_data' of the class Postproces…
Mar 19, 2024
83c3155
refactor the code related to handling of JSON columns, resolve conflicts
Jan 7, 2025
7cc1e69
refactor the code
Jan 13, 2025
ed1cf30
minor changes in unit tests
Jan 13, 2025
88129a1
minor changes in the class Validator
Jan 13, 2025
705e83c
refactor the method 'check_empty_df' of the class Utility
Jan 13, 2025
3a63e59
fix issues detected by 'flake8'
Jan 13, 2025
dde2e48
refactor the code
Jan 13, 2025
45947f6
minor changes in 'ml/metrics/metrics.py'
Jan 14, 2025
ad39712
refactor the class Worker, unit tests
Jan 14, 2025
3e2d091
update 'VERSION'
Jan 14, 2025
955da34
refactor the class Worker
Jan 14, 2025
3f44fbf
refactor the class TrainConfig
Jan 15, 2025
3478150
refactor the code in 'ml/processors', 'ml/worker'
Jan 16, 2025
f89340b
refactor the code in 'ml/processors'
Jan 16, 2025
787003d
refactor the code in 'ml/processors'
Jan 17, 2025
92d2de2
minor changes in 'tests/unit/processors'
Jan 17, 2025
3056a43
refactor the code in 'ml/processors'
Jan 17, 2025
e8643a3
update 'VERSION'
Jan 17, 2025
d0e6668
refactor the code to reach the compatibility with the EE versionh
Jan 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ boto3
category_encoders==2.6.3
click
Jinja2
flatten_json
keras==2.15.*
lazy==1.4
loguru
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ install_requires =
category_encoders==2.6.3
click
Jinja2
flatten_json
keras==2.15.*
lazy==1.4
loguru
Expand Down
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.2
0.10.3rc15
2 changes: 0 additions & 2 deletions src/syngen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import os
import argparse

from syngen.train import preprocess_data # noqa: F401


base_dir = os.path.dirname(__file__)
version_file = os.path.join(base_dir, "VERSION")
Expand Down
11 changes: 8 additions & 3 deletions src/syngen/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from syngen.ml.utils import (
setup_logger,
set_log_path,
check_if_logs_available
check_if_logs_available,
validate_parameter_reports
)
from syngen.ml.utils import validate_parameter_reports
from syngen.ml.validation_schema import ReportTypes


Expand All @@ -22,7 +22,12 @@


@click.command()
@click.option("--metadata_path", type=str, default=None, help="Path to the metadata file")
@click.option(
"--metadata_path",
type=str,
default=None,
help="Path to the metadata file"
)
@click.option(
"--size",
default=100,
Expand Down
51 changes: 18 additions & 33 deletions src/syngen/ml/config/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from slugify import slugify

from syngen.ml.data_loaders import DataLoader, DataFrameFetcher
from syngen.ml.utils import slugify_attribute
from syngen.ml.utils import slugify_attribute, fetch_unique_root
from syngen.ml.convertor import CSVConvertor


@dataclass
Expand All @@ -25,6 +26,7 @@ class TrainConfig:
row_limit: Optional[int]
table_name: Optional[str]
metadata: Dict
metadata_path: Optional[str]
reports: List[str]
batch_size: int
loader: Optional[Callable[[str], pd.DataFrame]]
Expand All @@ -40,8 +42,6 @@ class TrainConfig:

def __post_init__(self):
self._set_paths()
self._remove_existed_artifacts()
self._prepare_dirs()

def __getstate__(self) -> Dict:
"""
Expand All @@ -57,7 +57,6 @@ def __getstate__(self) -> Dict:
def preprocess_data(self):
self._extract_data()
self._save_original_schema()
self.columns = list(self.data.columns)
self._remove_empty_columns()
self._mark_removed_columns()
self._prepare_data()
Expand Down Expand Up @@ -100,31 +99,6 @@ def _check_reports(self):
"""
self._check_sample_report()

def _remove_existed_artifacts(self):
"""
Remove existed artifacts from previous train process
"""
if os.path.exists(self.paths["resources_path"]):
shutil.rmtree(self.paths["resources_path"])
logger.info(
f"The artifacts located in the path - '{self.paths['resources_path']}' "
f"were removed"
)
if os.path.exists(self.paths["tmp_store_path"]):
shutil.rmtree(self.paths["tmp_store_path"])
logger.info(
f"The artifacts located in the path - '{self.paths['tmp_store_path']}' "
f"were removed"
)

def _prepare_dirs(self):
"""
Create main directories for saving original, synthetic data and model artifacts
"""
os.makedirs(self.paths["model_artifacts_path"], exist_ok=True)
os.makedirs(self.paths["state_path"], exist_ok=True)
os.makedirs(self.paths["tmp_store_path"], exist_ok=True)

def _fetch_dataframe(self) -> Tuple[pd.DataFrame, Dict]:
"""
Fetch the dataframe using the callback function
Expand All @@ -140,8 +114,11 @@ def _load_source(self) -> Tuple[pd.DataFrame, Dict]:
"""
Return dataframe and schema of original data
"""
if self.loader is not None:
return self._fetch_dataframe()
if os.path.exists(self.paths["path_to_flatten_metadata"]):
data, schema = DataLoader(self.paths["input_data_path"]).load_data()
self.original_schema = DataLoader(self.paths["input_data_path"]).original_schema
schema = CSVConvertor.schema
return data, schema
else:
data_loader = DataLoader(self.source)
self.original_schema = data_loader.original_schema
Expand All @@ -159,8 +136,9 @@ def _remove_empty_columns(self):
self.data = self.data.dropna(how="all", axis=1)

self.dropped_columns = data_columns - set(self.data.columns)
if len(self.dropped_columns) > 0:
logger.info(f"Empty columns - {', '.join(self.dropped_columns)} were removed")
list_of_dropped_columns = [f"'{column}'" for column in self.dropped_columns]
if len(list_of_dropped_columns) > 0:
logger.info(f"Empty columns - {', '.join(list_of_dropped_columns)} were removed")

def _mark_removed_columns(self):
"""
Expand Down Expand Up @@ -189,6 +167,7 @@ def _extract_data(self):
"""
self.data, self.schema = self._load_source()
self.initial_data_shape = self.data.shape
self.columns = list(self.data.columns)
self._check_if_data_is_empty()

def _preprocess_data(self):
Expand Down Expand Up @@ -297,6 +276,9 @@ def _set_paths(self):
f"merged_infer_{self.slugify_table_name}.csv",
"no_ml_state_path":
f"model_artifacts/resources/{self.slugify_table_name}/no_ml/checkpoints/",
"path_to_flatten_metadata":
f"model_artifacts/tmp_store/flatten_configs/"
f"flatten_metadata_{fetch_unique_root(self.table_name, self.metadata_path)}.json",
"losses_path": f"model_artifacts/tmp_store/losses/{slugify(losses_file_name)}.csv"
}

Expand Down Expand Up @@ -449,4 +431,7 @@ def _set_paths(self):
"fk_kde_path": f"model_artifacts/resources/{dynamic_name}/vae/checkpoints/stat_keys/",
"path_to_no_ml":
f"model_artifacts/resources/{dynamic_name}/no_ml/checkpoints/kde_params.pkl",
"path_to_flatten_metadata":
f"model_artifacts/tmp_store/flatten_configs/"
f"flatten_metadata_{fetch_unique_root(self.table_name, self.metadata_path)}.json"
}
28 changes: 23 additions & 5 deletions src/syngen/ml/config/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from loguru import logger
from syngen.ml.data_loaders import MetadataLoader, DataLoader
from syngen.ml.validation_schema import ValidationSchema, ReportTypes
from syngen.ml.utils import fetch_unique_root


@dataclass
Expand Down Expand Up @@ -286,9 +287,8 @@ def _fetch_existed_columns(self, table_name: str) -> List[str]:
"""
metadata_of_table = self.merged_metadata[table_name]
format_settings = metadata_of_table.get("format", {})
return DataLoader(
metadata_of_table["train_settings"]["source"]
).get_columns(**format_settings)
path_to_source = self._fetch_path_to_source(table_name)
return DataLoader(path_to_source).get_columns(**format_settings)

def _gather_existed_columns(self, table_name: str):
"""
Expand All @@ -298,16 +298,34 @@ def _gather_existed_columns(self, table_name: str):
existed_columns = self._fetch_existed_columns(table_name)
self.existed_columns_mapping[table_name] = existed_columns

def _run(self):
def preprocess_metadata(self):
"""
Run the validation process
Preprocess the metadata, set the metadata and the merged metadata
"""
self._launch_validation_of_schema()
self._define_mapping()
self._merge_metadata()
self.merged_metadata.pop("global", None)
self.metadata.pop("global", None)

def _fetch_path_to_source(self, table_name):
"""
Fetch the path to the source of the certain table
"""
if os.path.exists(
f"{os.getcwd()}/model_artifacts/tmp_store/flatten_configs/flatten_metadata_"
f"{fetch_unique_root(table_name, self.metadata_path)}.json"
):
return (f"{os.getcwd()}/model_artifacts/tmp_store/{slugify(table_name)}/"
f"input_data_{slugify(table_name)}.pkl")
return self.metadata[table_name]["train_settings"]["source"]

def _run(self):
"""
Run the validation process
"""
self.preprocess_metadata()

if self.type_of_process == "train" and self.validation_source:
for table_name in self.merged_metadata.keys():
self._gather_existed_columns(table_name)
Expand Down
2 changes: 1 addition & 1 deletion src/syngen/ml/convertor/convertor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class CSVConvertor(Convertor):
"""
Class for supporting custom schema for csv files
"""
schema = {"fields": {}, "format": "CSV"}
schema: Dict = {"fields": {}, "format": "CSV"}

def __init__(self, df):
schema = {"fields": {}, "format": "CSV"}
Expand Down
4 changes: 4 additions & 0 deletions src/syngen/ml/data_loaders/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,10 @@ def _load_data(self) -> pd.DataFrame:
with open(self.path, "rb") as f:
return pkl.load(f)

def get_columns(self) -> List[str]:
data, schema = self.load_data()
return data.columns.tolist()

def load_data(self) -> Tuple[pd.DataFrame, None]:
return self._load_data(), None

Expand Down
38 changes: 15 additions & 23 deletions src/syngen/ml/handlers/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,15 @@ def _restore_empty_columns(self, df: pd.DataFrame) -> pd.DataFrame:

return df

def _save_data(self, generated_data):
"""
Save generated data to the path
"""
DataLoader(self.paths["path_to_merged_infer"]).save_data(
generated_data,
format=get_context().get_config(),
)

def handle(self, **kwargs):
self._prepare_dir()
list_of_reports = [f'"{report}"' for report in self.reports]
Expand Down Expand Up @@ -504,9 +513,7 @@ def handle(self, **kwargs):
if tech_columns:
prepared_data = prepared_data.drop(tech_columns, axis=1)
logger.debug(
"Technical columns "
f"{tech_columns} were removed "
"from the generated table."
f"Technical columns {tech_columns} were removed from the generated table."
)
Report().unregister_reporters(self.table_name)
logger.info(
Expand All @@ -525,27 +532,12 @@ def handle(self, **kwargs):
generated_data = generated_data[self.dataset.order_of_columns]

if generated_data is None:
DataLoader(self.paths["path_to_merged_infer"]).save_data(
prepared_data,
schema=self.original_schema,
format=get_context().get_config(),
)
self._save_data(prepared_data)
else:
DataLoader(self.paths["path_to_merged_infer"]).save_data(
generated_data,
schema=self.original_schema,
format=get_context().get_config(),
)
self._save_data(generated_data)
else:
DataLoader(self.paths["path_to_merged_infer"]).save_data(
prepared_data,
schema=self.original_schema,
format=get_context().get_config(),
)
self._save_data(prepared_data)
if self.metadata_path is None:
prepared_data = prepared_data[self.dataset.order_of_columns]
DataLoader(self.paths["path_to_merged_infer"]).save_data(
prepared_data,
schema=self.original_schema,
format=get_context().get_config(),
)

self._save_data(prepared_data)
19 changes: 14 additions & 5 deletions src/syngen/ml/metrics/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,12 @@ def _generate_report(
utility_barplot=transform_to_base64(
f"{self.reports_path}/utility_barplot.svg"
),
utility_table=utility_result.to_html(),
is_data_available=False if utility_result.empty else True,
utility_table=utility_result.to_html() if utility_result is not None else None,
is_data_available=(
False
if utility_result is None or (utility_result is not None and utility_result.empty)
else True
),
table_name=self.table_name,
training_config=train_config,
inference_config=infer_config,
Expand All @@ -296,9 +300,14 @@ def report(self, *args, **kwargs):
) = metrics
MlflowTracker().log_metrics(
{
"Utility_avg": utility_result["Synth to orig ratio"].mean(),
"Clustering": clustering_result if clustering_result is not None
else np.NaN,
"Utility_avg": (
utility_result["Synth to orig ratio"].mean()
if utility_result is not None else None
),
"Clustering": (
clustering_result
if clustering_result is not None else np.NaN
),
"Accuracy": acc_median,
"Correlation": corr_result,
}
Expand Down
22 changes: 20 additions & 2 deletions src/syngen/ml/metrics/metrics_classes/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, List, Optional, Dict
from typing import Union, List, Optional, Dict, Literal
from abc import ABC
from itertools import combinations
from collections import Counter
Expand Down Expand Up @@ -1097,6 +1097,19 @@ def __init__(

self.sample_size = sample_size

@staticmethod
def check_empty_df(df: pd.DataFrame, df_type: Literal["original", "synthetic"]) -> bool:
"""
Check if the dataframe is empty after dropping rows with missing values
"""
if df.empty:
logger.warning(
f"Utility metric calculation is skipped: the {df_type} dataframe is empty "
"after dropping rows with missing values (dropna() function is applied)"
)
return True
return False

def calculate_all(self, categorical_columns: List[str], cont_columns: List[str]):
logger.info("Calculating utility metric")

Expand All @@ -1115,9 +1128,14 @@ def calculate_all(self, categorical_columns: List[str], cont_columns: List[str])
self.synthetic = self.synthetic[cont_columns + categorical_columns].apply(
pd.to_numeric, axis=0, errors="ignore"
)

self.original = self.original.select_dtypes(include="number").dropna()
self.synthetic = self.synthetic.select_dtypes(include="number").dropna()

if self.check_empty_df(self.original, "original"):
return
if self.check_empty_df(self.synthetic, "synthetic"):
return

self.synthetic = self.synthetic[self.original.columns]

excluded_cols = [
Expand Down
1 change: 1 addition & 0 deletions src/syngen/ml/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from syngen.ml.processors.processors import PreprocessHandler, PostprocessHandler # noqa: F401
Loading
Loading