diff --git a/docs/user_guide/correlations.rst b/docs/user_guide/correlations.rst index 6404347f..e9d3c61f 100644 --- a/docs/user_guide/correlations.rst +++ b/docs/user_guide/correlations.rst @@ -32,8 +32,16 @@ Let's first look at how we would go about detecting correlations inside a datafr import fairlens as fl columns = ["gender", "random", "score"] - data = [["male", 10, 50], ["female", 20, 80], ["male", 20, 60], ["female", 10, 90]] - + data = [ + ["male", 10, 60], + ["female", 10, 80], + ["male", 10, 60], + ["female", 10, 80], + ["male", 9, 59], + ["female", 11, 80], + ["male", 12, 61], + ["female", 10, 83], + ] df = pd.DataFrame(data, columns=columns) Here the score seems to be correlated with gender, with females leaning towards somewhat higher scores. @@ -65,7 +73,7 @@ Correlation Heatmaps ^^^^^^^^^^^^^^^^^^^^ The :code:`plot` module allows users to generate a correlation heatmap of any dataset by simply -passing the dataframe to the :code:`two_column_heatmap()` function, which will plot a heatmap from the +passing the dataframe to the :code:`heatmap()` function, which will plot a heatmap from the matrix of the correlation coefficients computed by using the Pearson Coefficient, the Kruskal-Wallis Test and Cramer's V between each two of the columns (for numerical-numerical, categorical-numerical and categorical-categorical associations, respectively). @@ -92,19 +100,17 @@ This will automatically choose different methods for different types of data, ho are configurable. .. ipython:: python - :okwarning: @savefig corr_heatmap_1.png - fl.plot.two_column_heatmap(df) + fl.plot.heatmap(df) Let's try generating a heatmap of the same dataset, but using some non-linear metrics for numerical-numerical and numerical-categorical associations for added precision. .. ipython:: python - :okwarning: from fairlens.metrics import distance_nn_correlation, distance_cn_correlation, cramers_v @savefig corr_heatmap_2.png - fl.plot.two_column_heatmap(df, distance_nn_correlation, distance_cn_correlation, cramers_v) + fl.plot.heatmap(df, distance_nn_correlation, distance_cn_correlation, cramers_v) diff --git a/src/fairlens/metrics/__init__.py b/src/fairlens/metrics/__init__.py index 27edbd41..5c4b2a35 100644 --- a/src/fairlens/metrics/__init__.py +++ b/src/fairlens/metrics/__init__.py @@ -23,6 +23,7 @@ cramers_v, distance_cn_correlation, distance_nn_correlation, + pearson, r2_mcfadden, kruskal_wallis, kruskal_wallis_boolean, @@ -58,6 +59,7 @@ "cramers_v", "distance_cn_correlation", "distance_nn_correlation", + "pearson", "r2_mcfadden", "kruskal_wallis", "kruskal_wallis_boolean", diff --git a/src/fairlens/metrics/correlation.py b/src/fairlens/metrics/correlation.py index 4b1ca0ba..4e12756c 100644 --- a/src/fairlens/metrics/correlation.py +++ b/src/fairlens/metrics/correlation.py @@ -11,6 +11,9 @@ from sklearn import linear_model from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler +EPSILON = 1e-6 +MIN_MEAN_SAMPLE_SIZE = 20 + def cramers_v(sr_a: pd.Series, sr_b: pd.Series) -> float: """Metric that calculates the corrected Cramer's V statistic for categorical-categorical @@ -23,43 +26,50 @@ def cramers_v(sr_a: pd.Series, sr_b: pd.Series) -> float: Second categorical series to analyze. Returns: - float: Value of the statistic. + float: + Value of the statistic. """ - if len(sr_a.value_counts()) == 1: + if sr_a.equals(sr_b): + return 1 + + confusion_matrix = pd.crosstab(sr_a, sr_b) + r, k = confusion_matrix.shape + n = confusion_matrix.to_numpy().sum() + + if r < 2 or k < 2: return 0 - if len(sr_b.value_counts()) == 1: + + chi2 = ss.chi2_contingency(confusion_matrix, correction=(confusion_matrix.shape[0] != 2))[0] + phi2 = chi2 / n + + phi2corr = phi2 - ((k - 1) * (r - 1)) / (n - 1) + + if phi2corr <= EPSILON: return 0 - else: - confusion_matrix = pd.crosstab(sr_a, sr_b) - if confusion_matrix.shape[0] == 2: - correct = False - else: - correct = True + rcorr = r - ((r - 1) ** 2) / (n - 1) + kcorr = k - ((k - 1) ** 2) / (n - 1) - chi2 = ss.chi2_contingency(confusion_matrix, correction=correct)[0] - n = sum(confusion_matrix.sum()) - phi2 = chi2 / n - r, k = confusion_matrix.shape - phi2corr = max(0, phi2 - ((k - 1) * (r - 1)) / (n - 1)) - rcorr = r - ((r - 1) ** 2) / (n - 1) - kcorr = k - ((k - 1) ** 2) / (n - 1) - return np.sqrt(phi2corr / min((kcorr - 1), (rcorr - 1))) + return np.sqrt(phi2corr / min((kcorr - 1), (rcorr - 1))) def pearson(sr_a: pd.Series, sr_b: pd.Series) -> float: - """Metric that calculates Pearson's correlation coefficent for numerical-numerical + """Calculates the Pearson's correlation coefficent for numerical-numerical pairs of series, used in heatmap generation. Args: - sr_a (pd.Series): First numerical series to analyze. - sr_b (pd.Series): Second numerical series to analyze. + sr_a (pd.Series): + First numerical series to analyze. + sr_b (pd.Series): + Second numerical series to analyze. Returns: - float: Value of the coefficient. + float: + Value of the coefficient. """ - return abs(sr_a.corr(sr_b)) + + return sr_a.corr(sr_b, method="pearson") def r2_mcfadden(sr_a: pd.Series, sr_b: pd.Series) -> float: @@ -78,6 +88,7 @@ def r2_mcfadden(sr_a: pd.Series, sr_b: pd.Series) -> float: Returns: float: Value of the pseudo-R2 McFadden score. """ + x = sr_b.to_numpy().reshape(-1, 1) x = StandardScaler().fit_transform(x) y = sr_a.to_numpy() @@ -120,16 +131,17 @@ def kruskal_wallis(sr_a: pd.Series, sr_b: pd.Series) -> float: p-value is the probability that the two columns are not correlated. """ - sr_a = sr_a.astype("category").cat.codes groups = sr_b.groupby(sr_a) - arrays = [groups.get_group(category) for category in sr_a.unique()] + if len(groups) < 2: + return 0 + + args = [groups.get_group(category).array for category in sr_a.unique()] - args = [group.array for group in arrays] - try: - _, p_val = ss.kruskal(*args, nan_policy="omit") - except ValueError: + if np.mean([len(values) for values in args]) <= MIN_MEAN_SAMPLE_SIZE: return 0 + _, p_val = ss.kruskal(*args, nan_policy="omit") + return p_val @@ -147,7 +159,8 @@ def kruskal_wallis_boolean(sr_a: pd.Series, sr_b: pd.Series, p_cutoff: float = 0 The maximum admitted p-value for the distributions to be considered independent. Returns: - bool: Bool value representing whether or not the two series are correlated. + bool: + Bool value representing whether or not the two series are correlated. """ sr_a = sr_a.astype("category").cat.codes @@ -181,8 +194,6 @@ def distance_nn_correlation(sr_a: pd.Series, sr_b: pd.Series) -> float: The correlation coefficient. """ - warnings.filterwarnings(action="ignore", category=UserWarning) - if sr_a.size < sr_b.size: sr_a = sr_a.append(pd.Series(sr_a.mean()).repeat(sr_b.size - sr_a.size), ignore_index=True) elif sr_a.size > sr_b.size: diff --git a/src/fairlens/metrics/unified.py b/src/fairlens/metrics/unified.py index ce49c459..df059104 100644 --- a/src/fairlens/metrics/unified.py +++ b/src/fairlens/metrics/unified.py @@ -2,9 +2,9 @@ Collection of helper methods which can be used as to interface metrics. """ -import multiprocessing as mp -from typing import Any, Callable, List, Mapping, Optional, Tuple, Type, Union +from typing import Any, Callable, List, Mapping, Tuple, Type, Union +import numpy as np import pandas as pd from .. import utils @@ -118,8 +118,6 @@ def correlation_matrix( num_num_metric: Callable[[pd.Series, pd.Series], float] = pearson, cat_num_metric: Callable[[pd.Series, pd.Series], float] = kruskal_wallis, cat_cat_metric: Callable[[pd.Series, pd.Series], float] = cramers_v, - columns_x: Optional[List[str]] = None, - columns_y: Optional[List[str]] = None, ) -> pd.DataFrame: """This function creates a correlation matrix out of a dataframe, using a correlation metric for each possible type of pair of series (i.e. numerical-numerical, categorical-numerical, categorical-categorical). @@ -135,60 +133,62 @@ def correlation_matrix( cat_cat_metric (Callable[[pd.Series, pd.Series], float], optional): The correlation metric used for categorical-categorical series pairs. Defaults to corrected Cramer's V statistic. - columns_x (Optional[List[str]]): - The column names that determine the rows of the matrix. - columns_y (Optional[List[str]]): - The column names that determine the columns of the matrix. Returns: pd.DataFrame: The correlation matrix to be used in heatmap generation. """ - if columns_x is None: - columns_x = df.columns + df = df.copy() - if columns_y is None: - columns_y = df.columns + distr_types = [utils.infer_distr_type(df[col]) for col in df.columns] - pool = mp.Pool(mp.cpu_count()) + for col in df.columns: + df[col] = utils.infer_dtype(df[col]) - series_list = [ - pd.Series( - pool.starmap( - _correlation_matrix_helper, - [(df[col_x], df[col_y], num_num_metric, cat_num_metric, cat_cat_metric) for col_x in columns_x], - ), - index=columns_x, - name=col_y, - ) - for col_y in columns_y - ] + if df[col].dtype.kind == "O": + df[col] = pd.Series(pd.factorize(df[col], na_sentinel=-1)[0]).replace(-1, np.nan) + + df = df.append(pd.DataFrame({col: [i] for i, col in enumerate(df.columns)})) - pool.close() + def corr(a: np.ndarray, b: np.ndarray): + return _correlation_matrix_helper( + a, + b, + distr_types=distr_types, + num_num_metric=num_num_metric, + cat_num_metric=cat_num_metric, + cat_cat_metric=cat_cat_metric, + ) - return pd.concat(series_list, axis=1, keys=[series.name for series in series_list]) + return df.corr(method=corr) def _correlation_matrix_helper( - sr_a: pd.Series, - sr_b: pd.Series, + a: np.ndarray, + b: np.ndarray, + distr_types: List[utils.DistrType], num_num_metric: Callable[[pd.Series, pd.Series], float] = pearson, cat_num_metric: Callable[[pd.Series, pd.Series], float] = kruskal_wallis, cat_cat_metric: Callable[[pd.Series, pd.Series], float] = cramers_v, ) -> float: - a_type = utils.infer_distr_type(sr_a) - b_type = utils.infer_distr_type(sr_b) + a_type = distr_types[int(a[-1])] + b_type = distr_types[int(b[-1])] + + sr_a = pd.Series(a[:-1]) + sr_b = pd.Series(b[:-1]) + + df = pd.DataFrame({"a": sr_a, "b": sr_b}).dropna().reset_index() if a_type.is_continuous() and b_type.is_continuous(): - return num_num_metric(sr_a, sr_b) + return num_num_metric(df["a"], df["b"]) elif b_type.is_continuous(): - return cat_num_metric(sr_a, sr_b) + return cat_num_metric(df["a"], df["b"]) elif a_type.is_continuous(): - return cat_num_metric(sr_b, sr_a) + return cat_num_metric(df["b"], df["a"]) else: - return cat_cat_metric(sr_a, sr_b) + return cat_cat_metric(df["a"], df["b"]) diff --git a/src/fairlens/plot/__init__.py b/src/fairlens/plot/__init__.py index b7c81aed..bdf3b29a 100644 --- a/src/fairlens/plot/__init__.py +++ b/src/fairlens/plot/__init__.py @@ -3,8 +3,8 @@ """ +from .correlation import heatmap from .distr import attr_distr_plot, distr_plot, mult_distr_plot -from .heatmap import two_column_heatmap from .style import reset_style, use_style -__all__ = ["use_style", "reset_style", "distr_plot", "attr_distr_plot", "mult_distr_plot", "two_column_heatmap"] +__all__ = ["use_style", "reset_style", "distr_plot", "attr_distr_plot", "mult_distr_plot", "heatmap"] diff --git a/src/fairlens/plot/correlation.py b/src/fairlens/plot/correlation.py new file mode 100644 index 00000000..e897e748 --- /dev/null +++ b/src/fairlens/plot/correlation.py @@ -0,0 +1,64 @@ +""" +Plot correlation heatmaps for datasets. +""" + +from typing import Callable, Optional, Sequence, Tuple + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from matplotlib.axes import Axes + +from ..metrics import correlation, unified + + +def heatmap( + df: pd.DataFrame, + num_num_metric: Callable[[pd.Series, pd.Series], float] = correlation.pearson, + cat_num_metric: Callable[[pd.Series, pd.Series], float] = correlation.kruskal_wallis, + cat_cat_metric: Callable[[pd.Series, pd.Series], float] = correlation.cramers_v, + cmap: Optional[Sequence[Tuple[float, float, float]]] = None, + annotate: bool = False, +) -> Axes: + """This function creates a correlation heatmap out of a dataframe, using user provided or default correlation + metrics for all possible types of pairs of series (i.e. numerical-numerical, categorical-numerical, + categorical-categorical). + + Args: + df (pd.DataFrame): + The dataframe used for computing correlations and producing a heatmap. + num_num_metric (Callable[[pd.Series, pd.Series], float], optional): + The correlation metric used for numerical-numerical series pairs. Defaults to Pearson's correlation + coefficient. + cat_num_metric (Callable[[pd.Series, pd.Series], float], optional): + The correlation metric used for categorical-numerical series pairs. Defaults to Kruskal-Wallis' H Test. + cat_cat_metric (Callable[[pd.Series, pd.Series], float], optional): + The correlation metric used for categorical-categorical series pairs. Defaults to corrected Cramer's V + statistic. + cmap (Optional[Sequence[Tuple[float, float, float]]], optional): + A sequence of RGB tuples used to colour the histograms. If None seaborn's default pallete + will be used. Defaults to None. + annotate (bool, optional): + Annotate the heatmap. + + Returns: + matplotlib.axes.Axes: + The matplotlib axis containing the plot. + + Examples: + >>> df = pd.read_csv("datasets/german_credit_data.csv") + >>> heatmap(df) + >>> plt.show() + + .. image:: ../../savefig/corr_heatmap_1.png + """ + + corr_matrix = unified.correlation_matrix(df, num_num_metric, cat_num_metric, cat_cat_metric) + + cmap = cmap or sns.cubehelix_palette(start=0.2, rot=-0.2, dark=0.3, as_cmap=True) + annot = annotate or None + + ax = sns.heatmap(corr_matrix, vmin=0, vmax=1, square=True, cmap=cmap, linewidth=0.5, annot=annot, fmt=".1f") + plt.tight_layout() + + return ax diff --git a/src/fairlens/plot/heatmap.py b/src/fairlens/plot/heatmap.py deleted file mode 100644 index 1223b0a6..00000000 --- a/src/fairlens/plot/heatmap.py +++ /dev/null @@ -1,82 +0,0 @@ -""" -Plot correlation heatmaps for datasets. -""" - -from typing import Callable, List, Optional - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import seaborn as sns - -from ..metrics import correlation, unified - - -def two_column_heatmap( - df: pd.DataFrame, - num_num_metric: Callable[[pd.Series, pd.Series], float] = correlation.pearson, - cat_num_metric: Callable[[pd.Series, pd.Series], float] = correlation.kruskal_wallis, - cat_cat_metric: Callable[[pd.Series, pd.Series], float] = correlation.cramers_v, - columns_x: Optional[List[str]] = None, - columns_y: Optional[List[str]] = None, -): - """This function creates a correlation heatmap out of a dataframe, using user provided or default correlation - metrics for all possible types of pairs of series (i.e. numerical-numerical, categorical-numerical, - categorical-categorical). - - Args: - df (pd.DataFrame): - The dataframe used for computing correlations and producing a heatmap. - num_num_metric (Callable[[pd.Series, pd.Series], float], optional): - The correlation metric used for numerical-numerical series pairs. Defaults to Pearson's correlation - coefficient. - cat_num_metric (Callable[[pd.Series, pd.Series], float], optional): - The correlation metric used for categorical-numerical series pairs. Defaults to Kruskal-Wallis' H Test. - cat_cat_metric (Callable[[pd.Series, pd.Series], float], optional): - The correlation metric used for categorical-categorical series pairs. Defaults to corrected Cramer's V - statistic. - columns_x (Optional[List[str]]): - The sensitive dataframe column names that will be used in generating the correlation heatmap. - columns_y (Optional[List[str]]): - The non-sensitive dataframe column names that will be used in generating the correlation heatmap. - """ - - if columns_x is None: - columns_x = df.columns - - if columns_y is None: - columns_y = df.columns - - corr_matrix = unified.correlation_matrix( - df, num_num_metric, cat_num_metric, cat_cat_metric, columns_x, columns_y - ).round(2) - - fig_width = 20.0 - margin_top = 0.8 - margin_bot = 0.8 - margin_left = 0.8 - margin_right = 0.8 - - cell_size = (fig_width - margin_left - margin_right) / float(len(columns_y)) - fig_height = cell_size * len(columns_x) + margin_bot + margin_top - - plt.figure(figsize=(fig_width, fig_height), tight_layout=True) - plt.subplots_adjust( - bottom=margin_bot / fig_height, - top=1.0 - margin_top / fig_height, - left=margin_left / fig_width, - right=1.0 - margin_right / fig_width, - ) - - g = sns.heatmap( - corr_matrix, - vmin=0, - vmax=1, - annot=True, - annot_kws={"size": 35 / np.sqrt(len(corr_matrix))}, - square=True, - cbar=True, - ) - - g.set_xticklabels(g.get_xticklabels(), rotation=90, horizontalalignment="right", fontdict={"fontsize": 14}) - g.set_yticklabels(g.get_yticklabels(), rotation=0, horizontalalignment="right", fontdict={"fontsize": 14}) diff --git a/src/fairlens/sensitive/correlation.py b/src/fairlens/sensitive/correlation.py index e72f2416..7d330a84 100644 --- a/src/fairlens/sensitive/correlation.py +++ b/src/fairlens/sensitive/correlation.py @@ -7,6 +7,7 @@ import pandas as pd +from .. import utils from ..metrics import correlation as cm from ..sensitive import detection as dt @@ -148,18 +149,17 @@ def find_column_correlation( def _compute_series_correlation( sr_a: pd.Series, sr_b: pd.Series, corr_cutoff: float = 0.75, p_cutoff: float = 0.1 ) -> bool: - a_categorical = sr_a.map(type).eq(str).all() - b_categorical = sr_b.map(type).eq(str).all() - - if a_categorical and b_categorical: - # If both columns are categorical, we use Cramer's V. - if cm.cramers_v(sr_a, sr_b) > corr_cutoff: - return True - elif not a_categorical and b_categorical: - # If just one column is categorical, we can group by it and use Kruskal-Wallis H Test. - return cm.kruskal_wallis_boolean(sr_b, sr_a, p_cutoff=p_cutoff) - elif a_categorical and not b_categorical: + a_type = utils.infer_distr_type(sr_a) + b_type = utils.infer_distr_type(sr_b) + + if a_type.is_continuous() and b_type.is_continuous(): + return cm.pearson(sr_a, sr_b) > corr_cutoff + + elif b_type.is_continuous(): return cm.kruskal_wallis_boolean(sr_a, sr_b, p_cutoff=p_cutoff) - # If both columns are numeric, we use standard Pearson correlation and the correlation cutoff. - return cm.pearson(sr_a, sr_b) > corr_cutoff + elif a_type.is_continuous(): + return cm.kruskal_wallis_boolean(sr_b, sr_a, p_cutoff=p_cutoff) + + else: + return cm.cramers_v(sr_a, sr_b) > corr_cutoff diff --git a/tests/test_correlation.py b/tests/test_correlation.py index e6213ddd..f055d2b4 100644 --- a/tests/test_correlation.py +++ b/tests/test_correlation.py @@ -1,6 +1,15 @@ import pandas as pd - -from fairlens.metrics.correlation import distance_cn_correlation, distance_nn_correlation +import pytest + +from fairlens import utils +from fairlens.metrics.correlation import ( + cramers_v, + distance_cn_correlation, + distance_nn_correlation, + kruskal_wallis, + pearson, +) +from fairlens.metrics.unified import correlation_matrix from fairlens.sensitive.correlation import find_column_correlation, find_sensitive_correlations pair_race = "race", "Ethnicity" @@ -9,6 +18,8 @@ pair_gender = "gender", "Gender" pair_nationality = "nationality", "Nationality" +epsilon = 1e-6 + def test_correlation(): col_names = ["gender", "random", "score"] @@ -21,6 +32,14 @@ def test_correlation(): ["female", 11, 80], ["male", 12, 61], ["female", 10, 83], + ["male", 10, 60], + ["female", 10, 80], + ["male", 10, 60], + ["female", 10, 80], + ["male", 9, 59], + ["female", 11, 80], + ["male", 12, 61], + ["female", 10, 83], ] df = pd.DataFrame(data, columns=col_names) res = {"score": [pair_gender]} @@ -64,11 +83,18 @@ def test_common_correlation(): ["carribean", 40, 10, 2000, "single", 10, 90, 220], ["indo-european", 42, 10, 2500, "widowed", 10, 120, 200], ["arabian", 19, 10, 2200, "married", 10, 60, 115], + ["arabian", 21, 10, 2000, "married", 10, 60, 120], + ["carribean", 20, 10, 3000, "single", 10, 90, 130], + ["indo-european", 41, 10, 1900, "widowed", 10, 120, 210], + ["carribean", 40, 10, 2000, "single", 10, 90, 220], + ["indo-european", 42, 10, 2500, "widowed", 10, 120, 200], + ["arabian", 19, 10, 2200, "married", 10, 60, 115], ] df = pd.DataFrame(data, columns=col_names) res = { "corr1": [pair_race, pair_age, pair_marital], - "corr2": [pair_age], + "corr2": [pair_race, pair_age, pair_marital], + "entries": [pair_age], } assert find_sensitive_correlations(df) == res @@ -97,14 +123,20 @@ def test_series_correlation(): ["carribean", 40, 10, 2000, "single", 10], ["indo-european", 42, 10, 2500, "widowed", 10], ["arabian", 19, 10, 2200, "married", 10], + ["arabian", 21, 10, 2000, "married", 10], + ["carribean", 20, 10, 3000, "single", 10], + ["indo-european", 41, 10, 1900, "widowed", 10], + ["carribean", 40, 10, 2000, "single", 10], + ["indo-european", 42, 10, 2500, "widowed", 10], + ["arabian", 19, 10, 2200, "married", 10], ] df = pd.DataFrame(data, columns=col_names) - s1 = pd.Series([60, 90, 120, 90, 120, 60]) - s2 = pd.Series([120, 130, 210, 220, 200, 115]) - res1 = [pair_race, pair_marital] - res2 = [pair_age] - assert set(find_column_correlation(s1, df, corr_cutoff=0.9)) == set(res1) - assert set(find_column_correlation(s2, df, corr_cutoff=0.9)) == set(res2) + s1 = pd.Series([60, 90, 120, 90, 120, 60, 60, 90, 120, 90, 120, 60]) + s2 = pd.Series([120, 130, 210, 220, 200, 115, 120, 130, 210, 220, 200, 115]) + res1 = [pair_age, pair_race, pair_marital] + res2 = [pair_age, pair_race, pair_marital] + assert set(find_column_correlation(s1, df)) == set(res1) + assert set(find_column_correlation(s2, df)) == set(res2) def test_basic_nn_distance_corr(): @@ -133,3 +165,38 @@ def test_cn_unequal_series_corr(): sr_b = pd.Series([100, 200, 99, 101, 201, 199, 299, 300, 301, 500, 501, 505, 10, 12, 1001, 1050]) assert distance_cn_correlation(sr_a, sr_b) > 0.7 + + +@pytest.mark.parametrize("dataset", ["titanic", "german_credit_data"]) +def test_correlation_matrix(dataset): + df = pd.read_csv(f"datasets/{dataset}.csv") + num_num_metric = pearson + cat_num_metric = kruskal_wallis + cat_cat_metric = cramers_v + + matrix = correlation_matrix( + df, num_num_metric=num_num_metric, cat_num_metric=cat_num_metric, cat_cat_metric=cat_cat_metric + ).to_numpy() + + for i, r in enumerate(df.columns): + for j, c in enumerate(df.columns): + sr_a = utils.infer_dtype(df[r]) + sr_b = utils.infer_dtype(df[c]) + a_type = utils.infer_distr_type(sr_a) + b_type = utils.infer_distr_type(sr_b) + + d = pd.DataFrame({"a": sr_a, "b": sr_b}).dropna().reset_index() + + if a_type.is_continuous() and b_type.is_continuous(): + corr = num_num_metric(d["a"], d["b"]) + + elif b_type.is_continuous(): + corr = cat_num_metric(d["a"], d["b"]) + + elif a_type.is_continuous(): + corr = cat_num_metric(d["b"], d["a"]) + + else: + corr = cat_cat_metric(d["a"], d["b"]) + + assert matrix[i][j] - corr < epsilon diff --git a/tests/test_plot.py b/tests/test_plot.py index 47c8e007..112ccb0b 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -1,6 +1,7 @@ import pandas as pd import seaborn as sns +from fairlens.plot.correlation import heatmap from fairlens.plot.distr import attr_distr_plot, distr_plot, mult_distr_plot dfa = pd.read_csv("datasets/adult.csv") @@ -33,3 +34,15 @@ def test_mult_distr_plot_german(): def test_mult_distr_plot_titanic(): mult_distr_plot(dft, "Survived", ["Sex", "Age"]) + + +def test_heatmap_adult(): + heatmap(dfa) + + +def test_heatmap_german(): + heatmap(dfg) + + +def test_heatmap_titanic(): + heatmap(dft)