Skip to content

Commit

Permalink
fixing issue #156
Browse files Browse the repository at this point in the history
  • Loading branch information
shakedzy committed Jan 27, 2024
1 parent 0dd3d4f commit e81e4a1
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ If you wish to install from source:
pip install git+https://github.com/shakedzy/dython.git
```

**Dependencies:** `numpy`, `pandas`, `seaborn`, `scipy`, `matplotlib`, `sklearn`, `scikit-plot`
**Dependencies:** `numpy`, `pandas`, `seaborn`, `scipy`, `matplotlib`, `sklearn`

## Contributing:
Contributions are always welcomed - if you found something you can fix, or have an idea for a new feature, feel free to write it and open a pull request. Please make sure to go over the [contributions guidelines](https://github.com/shakedzy/dython/blob/master/CONTRIBUTING.md).
Expand Down
114 changes: 111 additions & 3 deletions dython/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from scikitplot.helpers import binary_ks_curve
from sklearn.preprocessing import LabelEncoder
from typing import List, Union, Optional, Tuple, Dict, Any, Iterable
from numpy.typing import NDArray
from .typing import Number, OneDimArray
Expand Down Expand Up @@ -536,8 +536,8 @@ def ks_abc(
)
)

thresholds, nr, pr, ks_statistic, max_distance_at, _ = binary_ks_curve(
y_t, y_p
thresholds, nr, pr, ks_statistic, max_distance_at, _ = _binary_ks_curve(
y_t, y_p # type: ignore
)
if ax is None:
plt.figure(figsize=figsize)
Expand Down Expand Up @@ -584,3 +584,111 @@ def ks_abc(
"eopt": max_distance_at,
"ax": axis,
}


def _binary_ks_curve(
y_true: OneDimArray, y_probas: OneDimArray
) -> Tuple[NDArray, NDArray, NDArray, Number, Number, NDArray]:
"""Copied from scikit-plot: https://github.com/reiinakano/scikit-plot/blob/master/scikitplot/helpers.py
This function generates the points necessary to calculate the KS
Statistic curve.
Args:
y_true (array-like, shape (n_samples)): True labels of the data.
y_probas (array-like, shape (n_samples)): Probability predictions of
the positive class.
Returns:
thresholds (numpy.ndarray): An array containing the X-axis values for
plotting the KS Statistic plot.
pct1 (numpy.ndarray): An array containing the Y-axis values for one
curve of the KS Statistic plot.
pct2 (numpy.ndarray): An array containing the Y-axis values for one
curve of the KS Statistic plot.
ks_statistic (float): The KS Statistic, or the maximum vertical
distance between the two curves.
max_distance_at (float): The X-axis value at which the maximum vertical
distance between the two curves is seen.
classes (np.ndarray, shape (2)): An array containing the labels of the
two classes making up `y_true`.
Raises:
ValueError: If `y_true` is not composed of 2 classes. The KS Statistic
is only relevant in binary classification.
"""
y_true, y_probas = np.asarray(y_true), np.asarray(y_probas)
lb = LabelEncoder()
encoded_labels = lb.fit_transform(y_true)
if len(lb.classes_) != 2:
raise ValueError(
"Cannot calculate KS statistic for data with "
"{} category/ies".format(len(lb.classes_))
)
idx = encoded_labels == 0
data1 = np.sort(y_probas[idx])
data2 = np.sort(y_probas[np.logical_not(idx)])

ctr1, ctr2 = 0, 0
thresholds, pct1, pct2 = [], [], []
while ctr1 < len(data1) or ctr2 < len(data2):
# Check if data1 has no more elements
if ctr1 >= len(data1):
current = data2[ctr2]
while ctr2 < len(data2) and current == data2[ctr2]:
ctr2 += 1

# Check if data2 has no more elements
elif ctr2 >= len(data2):
current = data1[ctr1]
while ctr1 < len(data1) and current == data1[ctr1]:
ctr1 += 1

else:
if data1[ctr1] > data2[ctr2]:
current = data2[ctr2]
while ctr2 < len(data2) and current == data2[ctr2]:
ctr2 += 1

elif data1[ctr1] < data2[ctr2]:
current = data1[ctr1]
while ctr1 < len(data1) and current == data1[ctr1]:
ctr1 += 1

else:
current = data2[ctr2]
while ctr2 < len(data2) and current == data2[ctr2]:
ctr2 += 1
while ctr1 < len(data1) and current == data1[ctr1]:
ctr1 += 1

thresholds.append(current)
pct1.append(ctr1)
pct2.append(ctr2)

thresholds = np.asarray(thresholds)
pct1 = np.asarray(pct1) / float(len(data1))
pct2 = np.asarray(pct2) / float(len(data2))

if thresholds[0] != 0:
thresholds = np.insert(thresholds, 0, [0.0]) # type: ignore
pct1 = np.insert(pct1, 0, [0.0]) # type: ignore
pct2 = np.insert(pct2, 0, [0.0]) # type: ignore
if thresholds[-1] != 1:
thresholds = np.append(thresholds, [1.0]) # type: ignore
pct1 = np.append(pct1, [1.0]) # type: ignore
pct2 = np.append(pct2, [1.0]) # type: ignore

differences = pct1 - pct2
ks_statistic, max_distance_at = (
np.max(differences),
thresholds[np.argmax(differences)],
)

return thresholds, pct1, pct2, ks_statistic, max_distance_at, lb.classes_ # type: ignore
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,4 @@ seaborn>=0.12.0
scipy>=1.7.1
matplotlib>=3.6.0
scikit-learn>=0.24.2
scikit-plot>=0.3.7
psutil>=5.9.1

0 comments on commit e81e4a1

Please sign in to comment.