Skip to content
This repository has been archived by the owner on Aug 1, 2024. It is now read-only.

Commit

Permalink
WIP: start adding detection of associated features
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromedockes committed Jan 19, 2024
1 parent bc83f49 commit 4f200c8
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 4 deletions.
16 changes: 14 additions & 2 deletions doc/make_doc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#! /usr/bin/env python3

import time
import shutil
from pathlib import Path

Expand Down Expand Up @@ -44,8 +45,19 @@
def add_report(df, name):
print(f"making report for {name}")
pretty_name = name.replace("_", " ").capitalize()
report = Report(df, title=pretty_name)
(reports_dir / f"{name}.html").write_text(report.html, "utf-8")
start = time.time()
html = Report(df, title=pretty_name).html
stop = time.time()
addition = f"""
<div style="padding: 1rem; font-size: 0.9rem;">
<p>
Report generated in {stop - start:.2f} seconds by <a href="https://github.com/skrub-data/skrubview">skrubview</a>.
</p>
<p><a href="..">Back to homepage</a>
</div>
"""
html = html.replace("</body>", f"{addition}\n</body>")
(reports_dir / f"{name}.html").write_text(html, "utf-8")
return f"<li><a href='reports/{name}.html'>{pretty_name}</a></li>"


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"matplotlib",
"numpy",
"dataframe-api-compat",
"scikit-learn",
]

[project.optional-dependencies]
Expand Down
32 changes: 32 additions & 0 deletions src/skrubview/_data/templates/dataframe-interactions.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
<article class="skrubview-wrapper">
<h2>Column pairwise associations</h2>
<div style="font-size: 1.5rem; padding: var(--skrubview-small);"><strong>🚧 Work In Progress 🚨</strong></div>
{% if summary["top_associations"] %}
<table class="pure-table">
<thead>
<tr>
<th>Column 1</th>
<th>Column 2</th>
<th><a href="https://en.wikipedia.org/wiki/Cram%C3%A9r%27s_V">Cramér's V</a></th>
</tr>
</thead>
<tbody>
{% for association in summary["top_associations"] %}
<tr>
<td>{{ association["left_column"] }}</td>
<td>{{ association["right_column"] }}</td>
<td
{% if association["cramer_v"] is gt 0.9 %}
class="skrubview-critical"
{%- endif -%}
>
{{ association["cramer_v"] | format_number }}
</td>
</tr>
{% endfor %}
</tbody>
</table>
{% else %}
No strong associations between any pair of columns were identified by a quick screening of a subsample of the dataframe.
{% endif %}
</article>
3 changes: 2 additions & 1 deletion src/skrubview/_data/templates/report.html
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ <h1>{{ summary.title }}</h1>
{% include "powerbar.html" %}
{% include "dataframe-sample.html" %}
{% include "dataframe-columns.html" %}
{% include "dataframe-interactions.html" %}
<script>
updateSelectedColsSnippet("{{ report_id }}", false);
updateSelectedColsSnippet("{{ report_id }}", false);
</script>
</div>
96 changes: 96 additions & 0 deletions src/skrubview/_interactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import warnings

import numpy as np
from sklearn.preprocessing import OneHotEncoder, KBinsDiscretizer

_N_BINS = 10
_CATEGORICAL_THRESHOLD = 30


def stack_symmetric_associations(associations, column_names):
left_indices, right_indices = np.triu_indices_from(associations, 1)
associations = associations[(left_indices, right_indices)]
order = np.argsort(associations)[::-1]
left_indices, right_indices, associations = (
left_indices[order],
right_indices[order],
associations[order],
)
return [
(column_names[left], column_names[right], a)
for (left, right, a) in zip(left_indices, right_indices, associations)
]


def chramer_v(df):
df = df.__dataframe_consortium_standard__().persist()
encoded = _onehot_encode(df, _N_BINS)
table = _contingency_table(encoded)
stats = _compute_cramer(table, df.shape()[0])
return stats


def _onehot_encode(df, n_bins):
n_rows, n_cols = df.shape()
output = np.zeros((n_cols, n_bins, n_rows), dtype=bool)
for col_idx, col_name in enumerate(df.column_names):
values = np.asarray(df.col(col_name).to_array())
if values.dtype.kind in "bOSU" or len(set(values)) <= _CATEGORICAL_THRESHOLD:
_onehot_encode_categories(values, n_bins, output[col_idx])
else:
_onehot_encode_numbers(values, n_bins, output[col_idx])
return output


def _onehot_encode_categories(values, n_bins, output):
encoded = OneHotEncoder(max_categories=n_bins, sparse_output=False).fit_transform(
values[:, None]
)
effective_n_bins = encoded.shape[1]
output[:effective_n_bins] = encoded.T


def _onehot_encode_numbers(values, n_bins, output):
values = values.astype(float)
mask = ~np.isfinite(values)
filled_na = np.array(values)
# TODO pick a better value & non-uniform bins?
filled_na[mask] = 0.0
encoder = KBinsDiscretizer(
n_bins=n_bins - 1,
strategy="uniform",
subsample=None,
encode="onehot-dense",
)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
encoded = encoder.fit_transform(filled_na[:, None])
encoded[mask] = 0
effective_n_bins = encoded.shape[1]
output[:effective_n_bins] = encoded.T
output[effective_n_bins] = mask


def _contingency_table(encoded):
n_cols, n_quantiles, _ = encoded.shape
out = np.empty((n_cols, n_cols, n_quantiles, n_quantiles), dtype="int32")
return np.einsum("ack,bdk", encoded, encoded, out=out)


def _compute_cramer(table, n_samples):
marginal_0 = table.sum(axis=-2)
marginal_1 = table.sum(axis=-1)
expected = (
marginal_0[:, :, None, :]
* marginal_1[:, :, :, None]
/ marginal_0.sum(axis=-1)[:, :, None, None]
)
diff = table - expected
expected[expected == 0] = 1
chi_stat = ((diff**2) / expected).sum(axis=-1).sum(axis=-1)
min_dim = np.minimum(
(marginal_0 > 0).sum(axis=-1) - 1, (marginal_1 > 0).sum(axis=-1) - 1
)
stat = np.sqrt(chi_stat / (n_samples * np.maximum(min_dim, 1)))
stat[min_dim == 0] = 0.0
return stat
18 changes: 17 additions & 1 deletion src/skrubview/_summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import polars as pl

from . import _plotting, _utils
from . import _plotting, _utils, _interactions

_HIGH_CARDINALITY_THRESHOLD = 10
_SUBSAMPLE_SIZE = 3000
_ASSOCIATION_THRESHOLD = 0.2


def summarize_dataframe(
Expand Down Expand Up @@ -49,9 +51,23 @@ def summarize_dataframe(
summary["n_constant_columns"] = sum(
c["value_is_constant"] for c in summary["columns"]
)
_add_interactions(df, summary)
return summary


def _add_interactions(df, dataframe_summary):
df = _utils.sample(df.dataframe, n=_SUBSAMPLE_SIZE)
associations = _interactions.stack_symmetric_associations(
_interactions.chramer_v(df),
df.__dataframe_consortium_standard__().column_names,
)[:20]
dataframe_summary["top_associations"] = [
dict(zip(("left_column", "right_column", "cramer_v"), a))
for a in associations
if a[2] > _ASSOCIATION_THRESHOLD
]


def _summarize_column(
column, position, dataframe_summary, *, with_plots, order_by_column
):
Expand Down
5 changes: 5 additions & 0 deletions src/skrubview/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def get_dtype_name(column):
return column.column.dtype.__class__.__name__


def sample(df, n, seed=0):
# TODO pandas
return df.sample(min(n, df.shape[0]), seed=seed)


def _to_html_via_pandas(df):
return df.to_pandas().to_html(index=False)

Expand Down

0 comments on commit 4f200c8

Please sign in to comment.