diff --git a/src/syngen/ml/metrics/metrics_classes/metrics.py b/src/syngen/ml/metrics/metrics_classes/metrics.py index 3a36a34a..23532425 100644 --- a/src/syngen/ml/metrics/metrics_classes/metrics.py +++ b/src/syngen/ml/metrics/metrics_classes/metrics.py @@ -472,7 +472,7 @@ def get_common_min_max(original, synthetic): @staticmethod def __format_float_tick_labels(labels: List, nan_label: str = 'NaN') -> List: - labels = [nan_label if pd.isna(l) else l for l in labels] + labels = [nan_label if pd.isna(label) else label for label in labels] if all([isinstance(i, float) for i in labels]) and ( max(labels) > 1e5 or min(labels) < 1e-03 ): @@ -494,8 +494,14 @@ def _plot_heatmap( ax = self._axes.flat[plt_index] ax.tick_params(labelsize=14) heatmap, x_tick_labels, y_tick_labels = heatmap_data - x_tick_labels = self.__format_float_tick_labels(x_tick_labels, self.columns_nan_labels.get(xfeature, 'NaN')) - y_tick_labels = self.__format_float_tick_labels(y_tick_labels, self.columns_nan_labels.get(yfeature, 'NaN')) + x_tick_labels = self.__format_float_tick_labels( + x_tick_labels, + self.columns_nan_labels.get(xfeature, 'NaN') + ) + y_tick_labels = self.__format_float_tick_labels( + y_tick_labels, + self.columns_nan_labels.get(yfeature, 'NaN') + ) ax = sns.heatmap( heatmap, xticklabels=x_tick_labels,