Skip to content

Commit

Permalink
refactor the code in the class Dataset, minor changes in 'ml/validati…
Browse files Browse the repository at this point in the history
…on_schema'
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Dec 9, 2024
1 parent d64e39f commit 2c98512
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.53rc7
0.9.53rc9
12 changes: 6 additions & 6 deletions src/syngen/ml/vae/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,6 @@ def _preprocess_df(self, excluded_columns: Set[str]):
"""
self._cast_to_numeric(excluded_columns)
self.nan_labels_dict = get_nan_labels(self.df, excluded_columns)
if self.nan_labels_dict and self.format.get("na_values", []):
logger.info(
f"Despite the fact that data loading utilized the 'format' section "
f"for handling NA values, some values have been detected by the algorithm "
f"as NA labels in the columns - {self.nan_labels_dict}"
)
self.df = nan_labels_to_float(self.df, self.nan_labels_dict)

def _preparation_step(self):
Expand Down Expand Up @@ -415,6 +409,12 @@ def _common_detection(self):
self._set_uuid_columns()
self._set_long_text_columns()
self._set_email_columns()
if self.nan_labels_dict and self.format.get("na_values", []):
logger.info(
f"Despite the fact that data loading utilized the 'format' section "
f"for handling NA values, some values have been detected by the algorithm "
f"as NA labels in the columns - {self.nan_labels_dict}"
)

def _update_schema(self):
"""
Expand Down
8 changes: 4 additions & 4 deletions src/syngen/ml/validation_schema/validation_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ def validate_reports(x):
"The value 'all' or 'none' might not be passed in the list."
)
if not (
isinstance(x, list)
and all(
isinstance(elem, str)
and elem in ReportTypes().train_report_types for elem in x
isinstance(x, list)
and all(
isinstance(elem, str)
and elem in ReportTypes().train_report_types for elem in x
)
):
raise ValidationError("Invalid value.")
Expand Down

0 comments on commit 2c98512

Please sign in to comment.