Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanna Imshenetska authored and Hanna Imshenetska committed Oct 23, 2024
2 parents ee12aa1 + ac1fd9c commit c7aec2e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/syngen/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.46
0.9.47
5 changes: 4 additions & 1 deletion src/syngen/ml/metrics/accuracy_test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ def __init__(
paths: dict,
table_name: str,
infer_config: Dict,
columns_nan_labels: Dict,
):
super().__init__(original, synthetic, paths, table_name, infer_config)
self.reports_path = f"{self.paths['reports_path']}/accuracy"
self.columns_nan_labels = columns_nan_labels
self.univariate = UnivariateMetric(
self.original,
self.synthetic,
Expand All @@ -126,7 +128,8 @@ def __init__(
self.original,
self.synthetic,
self.plot_exists,
self.reports_path
self.reports_path,
self.columns_nan_labels
)
self.correlations = Correlations(
self.original,
Expand Down
34 changes: 21 additions & 13 deletions src/syngen/ml/metrics/metrics_classes/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,11 @@ def __init__(
synthetic: pd.DataFrame,
plot: bool,
reports_path: str,
columns_nan_labels: Dict,
):
super().__init__(original, synthetic, plot, reports_path)
self.cmap = LinearSegmentedColormap.from_list("rg", ["#0D5598", "#3E92E0", "#E8F4FF"])
self.columns_nan_labels = columns_nan_labels

@staticmethod
def _format_date_labels(heatmap_orig_data, heatmap_synthetic_data, axis):
Expand Down Expand Up @@ -442,9 +444,15 @@ def calculate_all(
heatmap_orig_data, heatmap_synthetic_data, "y"
)

self._plot_heatmap(heatmap_orig_data, 0, heatmap_min, heatmap_max, cbar=False)
self._plot_heatmap(heatmap_orig_data, plt_index=0,
vrange=(heatmap_min, heatmap_max),
features=(first_col, second_col),
cbar=False)

self._plot_heatmap(heatmap_synthetic_data, 1, heatmap_min, heatmap_max, cbar=True)
self._plot_heatmap(heatmap_synthetic_data, plt_index=1,
vrange=(heatmap_min, heatmap_max),
features=(first_col, second_col),
cbar=True)
# first_col is x axis, second_col is y axis
title = f"{first_col} vs. {second_col}"
path_to_image = (
Expand All @@ -463,31 +471,31 @@ def get_common_min_max(original, synthetic):
return vmin, vmax

@staticmethod
def __format_float_tick_labels(labels: List) -> List:
def __format_float_tick_labels(labels: List, nan_label: str = 'NaN') -> List:
labels = [nan_label if pd.isna(l) else l for l in labels]
if all([isinstance(i, float) for i in labels]) and (
max(labels) > 1e5 or min(labels) < 1e-03
):
labels = [f"{label:.4e}" for label in labels]
return labels
return [f"{label:.4e}" for label in labels]
if all([isinstance(i, float) for i in labels]):
labels = [f"{round(i, 4)}" for i in labels]
return labels
else:
return labels
return [f"{round(i, 4)}" for i in labels]
return labels

def _plot_heatmap(
self,
heatmap_data: List,
plt_index: int,
vmin: float,
vmax: float,
vrange: tuple[float],
features: tuple[str],
cbar=True,
):
vmin, vmax = vrange
xfeature, yfeature = features
ax = self._axes.flat[plt_index]
ax.tick_params(labelsize=14)
heatmap, x_tick_labels, y_tick_labels = heatmap_data
x_tick_labels = self.__format_float_tick_labels(x_tick_labels)
y_tick_labels = self.__format_float_tick_labels(y_tick_labels)
x_tick_labels = self.__format_float_tick_labels(x_tick_labels, self.columns_nan_labels.get(xfeature, 'NaN'))
y_tick_labels = self.__format_float_tick_labels(y_tick_labels, self.columns_nan_labels.get(yfeature, 'NaN'))
ax = sns.heatmap(
heatmap,
xticklabels=x_tick_labels,
Expand Down
1 change: 1 addition & 0 deletions src/syngen/ml/reporters/reporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def report(self):
self.paths,
self.table_name,
self.config,
self.columns_nan_labels,
)
accuracy_test.report(
cont_columns=list(float_columns | int_columns),
Expand Down

0 comments on commit c7aec2e

Please sign in to comment.