Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zero-one loss for classification🆕 #22

Merged
merged 6 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Check out the following notebooks to get started:
- [MNIST](https://github.com/iancovert/sage/blob/master/notebooks/mnist.ipynb): shows strategies to accelerate convergence for datasets with many features (feature grouping, different imputing setups)
- [Consistency](https://github.com/iancovert/sage/blob/master/notebooks/consistency.ipynb): verifies that our various Shapley value estimators return the same results (see the estimators listed below)
- [Calibration](https://github.com/iancovert/sage/blob/master/notebooks/calibration.ipynb): verifies that SAGE's confidence intervals are representative of the uncertainty across runs
- [Losses](https://github.com/iancovert/sage/blob/master/notebooks/losses.ipynb): shows how SAGE can be used in classification with alternative loss functions.

If you want to replicate the experiments described in our paper, see this separate [repository](https://github.com/iancovert/sage-experiments).

Expand Down
373 changes: 373 additions & 0 deletions notebooks/losses.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sage/iterated_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class IteratedEstimator:

Args:
imputer: model that accommodates held out features.
loss: loss function ('mse', 'cross entropy').
loss: loss function ('mse', 'cross entropy', 'zero one').
random_state: random seed, enables reproducibility.
'''

Expand Down
2 changes: 1 addition & 1 deletion sage/kernel_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class KernelEstimator:

Args:
imputer: model that accommodates held out features.
loss: loss function ('mse', 'cross entropy').
loss: loss function ('mse', 'cross entropy', 'zero one').
random_state: random seed, enables reproducibility.
'''

Expand Down
2 changes: 1 addition & 1 deletion sage/permutation_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class PermutationEstimator:

Args:
imputer: model that accommodates held out features.
loss: loss function ('mse', 'cross entropy').
loss: loss function ('mse', 'cross entropy', 'zero one').
n_jobs: number of jobs for parallel processing.
random_state: random seed, enables reproducibility.
'''
Expand Down
2 changes: 1 addition & 1 deletion sage/sign_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SignEstimator:

Args:
imputer: model that accommodates held out features.
loss: loss function ('mse', 'cross entropy').
loss: loss function ('mse', 'cross entropy', 'zero one').
random_state: random seed, enables reproducibility.
'''

Expand Down
35 changes: 32 additions & 3 deletions sage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np



def model_conversion(model):
'''Convert model to callable.'''
if safe_isinstance(model, 'sklearn.base.ClassifierMixin'):
Expand Down Expand Up @@ -69,13 +70,13 @@ def verify_model_data(imputer, X, Y, loss, batch_size):
Y = dataset_output(imputer, X, batch_size)

# Fix output shape for classification tasks.
if isinstance(loss, CrossEntropyLoss):
if isinstance(loss, (CrossEntropyLoss, ZeroOneLoss)):
if Y.shape == (len(X),):
Y = Y[:, np.newaxis]
if Y.shape[1] == 1:
Y = np.concatenate([1 - Y, Y], axis=1)

if isinstance(loss, CrossEntropyLoss):
if isinstance(loss, (CrossEntropyLoss, ZeroOneLoss)):
x = X[:batch_size]
probs = imputer(x, np.ones((len(x), imputer.num_groups), dtype=bool))

Expand All @@ -91,6 +92,7 @@ def verify_model_data(imputer, X, Y, loss, batch_size):
raise ValueError('labels shape should be (batch,) or (batch, 1)'
' for cross entropy loss')


if (probs.ndim == 1) or (probs.shape[1] == 1):
# Check label encoding.
if check_labels:
Expand Down Expand Up @@ -199,6 +201,33 @@ def __call__(self, pred, target):
else:
return loss

class ZeroOneLoss:
'''zero-one loss that expects probabilities.'''

def __init__(self, reduction='mean'):
assert reduction in ('none', 'mean')
self.reduction = reduction

def __call__(self, pred, target):

# Add a dimension to prediction probabilities if necessary.
if pred.ndim == 1:
pred = pred[:, np.newaxis]
if pred.shape[1] == 1:
pred = np.append(1 - pred, pred, axis=1)

if target.ndim == 1:
# Class labels.
loss = (np.argmax(pred, axis=1) != target).astype(float)
elif target.ndim == 2:
# Probabilistic labels.
loss = (np.argmax(pred, axis=1) != np.argmax(target, axis=1)).astype(float)
else:
raise ValueError('incorrect labels shape for zero-one loss')

if self.reduction == 'mean':
return np.mean(loss)
return loss

class CrossEntropyLoss:
'''Cross entropy loss that expects probabilities.'''
Expand Down Expand Up @@ -271,4 +300,4 @@ def safe_isinstance(obj, class_str):
class_type = getattr(module, class_name, None)
if class_type is None:
return False
return isinstance(obj, class_type)
return isinstance(obj, class_type)