Skip to content

Commit

Permalink
fix the represantion of nan values in bivariate plots
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Oct 23, 2024
1 parent ac1fd9c commit 0431c81
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.47
0.9.48rc2
7 changes: 3 additions & 4 deletions src/syngen/ml/metrics/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,11 @@ def __init__(
synthetic: pd.DataFrame,
paths: dict,
table_name: str,
infer_config: Dict,
columns_nan_labels: Dict,
infer_config: Dict
):
super().__init__(original, synthetic, paths, table_name, infer_config)
self.reports_path = f"{self.paths['reports_path']}/accuracy"
self.columns_nan_labels = columns_nan_labels
self.dataset_pickle_path = self.paths["dataset_pickle_path"]
self.univariate = UnivariateMetric(
self.original,
self.synthetic,
Expand All @@ -129,7 +128,7 @@ def __init__(
self.synthetic,
self.plot_exists,
self.reports_path,
self.columns_nan_labels
self.dataset_pickle_path
)
self.correlations = Correlations(
self.original,
Expand Down
35 changes: 28 additions & 7 deletions src/syngen/ml/metrics/metrics_classes/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
from slugify import slugify
from loguru import logger

from syngen.ml.utils import timestamp_to_datetime, timing
from syngen.ml.utils import (
timestamp_to_datetime,
timing,
fetch_config
)
matplotlib.use("Agg")


Expand Down Expand Up @@ -338,11 +342,20 @@ def __init__(
synthetic: pd.DataFrame,
plot: bool,
reports_path: str,
columns_nan_labels: Dict,
dataset_pickle_path: str
):
super().__init__(original, synthetic, plot, reports_path)
self.cmap = LinearSegmentedColormap.from_list("rg", ["#0D5598", "#3E92E0", "#E8F4FF"])
self.columns_nan_labels = columns_nan_labels
self.cmap = LinearSegmentedColormap.from_list(
"rg", ["#0D5598", "#3E92E0", "#E8F4FF"]
)
self.dataset = fetch_config(dataset_pickle_path)
nan_labels_dict = self.dataset.nan_labels_dict
na_values = self.dataset.format.get("na_values", [])
self.missing_values: Dict[str, str] = (
{col: na_values[0] for col in self.dataset.order_of_columns}
if na_values
else nan_labels_dict
)

@staticmethod
def _format_date_labels(heatmap_orig_data, heatmap_synthetic_data, axis):
Expand Down Expand Up @@ -471,7 +484,7 @@ def get_common_min_max(original, synthetic):
return vmin, vmax

@staticmethod
def __format_float_tick_labels(labels: List, nan_label: str = 'NaN') -> List:
def __format_float_tick_labels(labels: List, nan_label: str = "nan") -> List:
labels = [nan_label if pd.isna(l) else l for l in labels]
if all([isinstance(i, float) for i in labels]) and (
max(labels) > 1e5 or min(labels) < 1e-03
Expand All @@ -494,8 +507,16 @@ def _plot_heatmap(
ax = self._axes.flat[plt_index]
ax.tick_params(labelsize=14)
heatmap, x_tick_labels, y_tick_labels = heatmap_data
x_tick_labels = self.__format_float_tick_labels(x_tick_labels, self.columns_nan_labels.get(xfeature, 'NaN'))
y_tick_labels = self.__format_float_tick_labels(y_tick_labels, self.columns_nan_labels.get(yfeature, 'NaN'))
print(f"!!!!!!!!!!!!!!!!!!!!!")
print(f"{self.missing_values}")
x_tick_labels = self.__format_float_tick_labels(
x_tick_labels,
self.missing_values.get(xfeature, "nan")
)
y_tick_labels = self.__format_float_tick_labels(
y_tick_labels,
self.missing_values.get(yfeature, "nan")
)
ax = sns.heatmap(
heatmap,
xticklabels=x_tick_labels,
Expand Down
5 changes: 2 additions & 3 deletions src/syngen/ml/reporters/reporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
self.loader = loader
self.dataset = None
self.columns_nan_labels = dict()
self.na_values = dict()

def _fetch_dataframe(self) -> pd.DataFrame:
"""
Expand All @@ -63,7 +64,6 @@ def _extract_report_data(self) -> Tuple[pd.DataFrame, pd.DataFrame]:

def fetch_data_types(self):
self.dataset = fetch_config(self.paths["dataset_pickle_path"])
self.columns_nan_labels = self.dataset.nan_labels_dict
types = (
self.dataset.str_columns,
self.dataset.date_columns,
Expand Down Expand Up @@ -294,8 +294,7 @@ def report(self):
synthetic,
self.paths,
self.table_name,
self.config,
self.columns_nan_labels,
self.config
)
accuracy_test.report(
cont_columns=list(float_columns | int_columns),
Expand Down

0 comments on commit 0431c81

Please sign in to comment.