From 2c98512853eb433eda3c0de46f9c3f7407714f4b Mon Sep 17 00:00:00 2001 From: Hanna Imshenetska Date: Mon, 9 Dec 2024 14:31:51 +0000 Subject: [PATCH] refactor the code in the class Dataset, minor changes in 'ml/validation_schema' --- src/syngen/VERSION | 2 +- src/syngen/ml/vae/models/dataset.py | 12 ++++++------ src/syngen/ml/validation_schema/validation_schema.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/syngen/VERSION b/src/syngen/VERSION index e80174e8..3e3a86e4 100644 --- a/src/syngen/VERSION +++ b/src/syngen/VERSION @@ -1 +1 @@ -0.9.53rc7 +0.9.53rc9 diff --git a/src/syngen/ml/vae/models/dataset.py b/src/syngen/ml/vae/models/dataset.py index c4a7374c..ebe9d2fd 100644 --- a/src/syngen/ml/vae/models/dataset.py +++ b/src/syngen/ml/vae/models/dataset.py @@ -132,12 +132,6 @@ def _preprocess_df(self, excluded_columns: Set[str]): """ self._cast_to_numeric(excluded_columns) self.nan_labels_dict = get_nan_labels(self.df, excluded_columns) - if self.nan_labels_dict and self.format.get("na_values", []): - logger.info( - f"Despite the fact that data loading utilized the 'format' section " - f"for handling NA values, some values have been detected by the algorithm " - f"as NA labels in the columns - {self.nan_labels_dict}" - ) self.df = nan_labels_to_float(self.df, self.nan_labels_dict) def _preparation_step(self): @@ -415,6 +409,12 @@ def _common_detection(self): self._set_uuid_columns() self._set_long_text_columns() self._set_email_columns() + if self.nan_labels_dict and self.format.get("na_values", []): + logger.info( + f"Despite the fact that data loading utilized the 'format' section " + f"for handling NA values, some values have been detected by the algorithm " + f"as NA labels in the columns - {self.nan_labels_dict}" + ) def _update_schema(self): """ diff --git a/src/syngen/ml/validation_schema/validation_schema.py b/src/syngen/ml/validation_schema/validation_schema.py index 658c1215..ac52c7fe 100644 --- a/src/syngen/ml/validation_schema/validation_schema.py +++ b/src/syngen/ml/validation_schema/validation_schema.py @@ -110,10 +110,10 @@ def validate_reports(x): "The value 'all' or 'none' might not be passed in the list." ) if not ( - isinstance(x, list) - and all( - isinstance(elem, str) - and elem in ReportTypes().train_report_types for elem in x + isinstance(x, list) + and all( + isinstance(elem, str) + and elem in ReportTypes().train_report_types for elem in x ) ): raise ValidationError("Invalid value.")