diff --git a/src/syngen/VERSION b/src/syngen/VERSION index 692b49eb..380550b4 100644 --- a/src/syngen/VERSION +++ b/src/syngen/VERSION @@ -1 +1 @@ -0.9.46 \ No newline at end of file +0.9.47 diff --git a/src/syngen/ml/metrics/accuracy_test/accuracy_test.py b/src/syngen/ml/metrics/accuracy_test/accuracy_test.py index 0615337a..436ab159 100644 --- a/src/syngen/ml/metrics/accuracy_test/accuracy_test.py +++ b/src/syngen/ml/metrics/accuracy_test/accuracy_test.py @@ -113,9 +113,11 @@ def __init__( paths: dict, table_name: str, infer_config: Dict, + columns_nan_labels: Dict, ): super().__init__(original, synthetic, paths, table_name, infer_config) self.reports_path = f"{self.paths['reports_path']}/accuracy" + self.columns_nan_labels = columns_nan_labels self.univariate = UnivariateMetric( self.original, self.synthetic, @@ -126,7 +128,8 @@ def __init__( self.original, self.synthetic, self.plot_exists, - self.reports_path + self.reports_path, + self.columns_nan_labels ) self.correlations = Correlations( self.original, diff --git a/src/syngen/ml/metrics/metrics_classes/metrics.py b/src/syngen/ml/metrics/metrics_classes/metrics.py index db8c5511..3a36a34a 100644 --- a/src/syngen/ml/metrics/metrics_classes/metrics.py +++ b/src/syngen/ml/metrics/metrics_classes/metrics.py @@ -338,9 +338,11 @@ def __init__( synthetic: pd.DataFrame, plot: bool, reports_path: str, + columns_nan_labels: Dict, ): super().__init__(original, synthetic, plot, reports_path) self.cmap = LinearSegmentedColormap.from_list("rg", ["#0D5598", "#3E92E0", "#E8F4FF"]) + self.columns_nan_labels = columns_nan_labels @staticmethod def _format_date_labels(heatmap_orig_data, heatmap_synthetic_data, axis): @@ -442,9 +444,15 @@ def calculate_all( heatmap_orig_data, heatmap_synthetic_data, "y" ) - self._plot_heatmap(heatmap_orig_data, 0, heatmap_min, heatmap_max, cbar=False) + self._plot_heatmap(heatmap_orig_data, plt_index=0, + vrange=(heatmap_min, heatmap_max), + features=(first_col, second_col), + cbar=False) - self._plot_heatmap(heatmap_synthetic_data, 1, heatmap_min, heatmap_max, cbar=True) + self._plot_heatmap(heatmap_synthetic_data, plt_index=1, + vrange=(heatmap_min, heatmap_max), + features=(first_col, second_col), + cbar=True) # first_col is x axis, second_col is y axis title = f"{first_col} vs. {second_col}" path_to_image = ( @@ -463,31 +471,31 @@ def get_common_min_max(original, synthetic): return vmin, vmax @staticmethod - def __format_float_tick_labels(labels: List) -> List: + def __format_float_tick_labels(labels: List, nan_label: str = 'NaN') -> List: + labels = [nan_label if pd.isna(l) else l for l in labels] if all([isinstance(i, float) for i in labels]) and ( max(labels) > 1e5 or min(labels) < 1e-03 ): - labels = [f"{label:.4e}" for label in labels] - return labels + return [f"{label:.4e}" for label in labels] if all([isinstance(i, float) for i in labels]): - labels = [f"{round(i, 4)}" for i in labels] - return labels - else: - return labels + return [f"{round(i, 4)}" for i in labels] + return labels def _plot_heatmap( self, heatmap_data: List, plt_index: int, - vmin: float, - vmax: float, + vrange: tuple[float], + features: tuple[str], cbar=True, ): + vmin, vmax = vrange + xfeature, yfeature = features ax = self._axes.flat[plt_index] ax.tick_params(labelsize=14) heatmap, x_tick_labels, y_tick_labels = heatmap_data - x_tick_labels = self.__format_float_tick_labels(x_tick_labels) - y_tick_labels = self.__format_float_tick_labels(y_tick_labels) + x_tick_labels = self.__format_float_tick_labels(x_tick_labels, self.columns_nan_labels.get(xfeature, 'NaN')) + y_tick_labels = self.__format_float_tick_labels(y_tick_labels, self.columns_nan_labels.get(yfeature, 'NaN')) ax = sns.heatmap( heatmap, xticklabels=x_tick_labels, diff --git a/src/syngen/ml/reporters/reporters.py b/src/syngen/ml/reporters/reporters.py index abd17c4f..11acf571 100644 --- a/src/syngen/ml/reporters/reporters.py +++ b/src/syngen/ml/reporters/reporters.py @@ -295,6 +295,7 @@ def report(self): self.paths, self.table_name, self.config, + self.columns_nan_labels, ) accuracy_test.report( cont_columns=list(float_columns | int_columns),