Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MSR Banzhaf valuation #605

Merged
merged 14 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Added

- Refactoring MSR Banzhaf semivalues with the new sampler architecture.
[PR #605](https://github.com/aai-institute/pyDVL/pull/605)
- Refactoring group-testing shapley values with new sampler architecture
[PR #602](https://github.com/aai-institute/pyDVL/pull/602)
- Refactoring of least-core data valuation methods with more supported sampling methods
Expand Down Expand Up @@ -32,6 +34,9 @@
- Fix a bug in pydvl.utils.numeric.random_subset where 1 - q was used instead of q
as the probability of an element being sampled
[PR #597](https://github.com/aai-institute/pyDVL/pull/597)
- Fix a bug in the calculation of variance estimates for MSR Banzhaf
[PR #605](https://github.com/aai-institute/pyDVL/pull/605)



### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/valuation/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
from .knn_shapley import *
from .least_core import *
from .loo import *
from .msr_banzhaf import *
from .owen_shapley import *
from .semivalue import *
from .weighted_banzhaf import *
186 changes: 186 additions & 0 deletions src/pydvl/valuation/methods/msr_banzhaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
"""
janosg marked this conversation as resolved.
Show resolved Hide resolved
This module implements the MSR-Banzhaf valuation method, as described in
(Wang et. al.)<sup><a href="wang_data_2023">3</a></sup>.
janosg marked this conversation as resolved.
Show resolved Hide resolved

## References

[^1]: <a name="wang_data_2023"></a>Wang, J.T. and Jia, R., 2023.
[Data Banzhaf: A Robust Data Valuation Framework for Machine Learning](https://proceedings.mlr.press/v206/wang23e.html).
In: Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, pp. 6388-6421.

"""
from __future__ import annotations

import numpy as np
from joblib import Parallel, delayed
from typing_extensions import Self

from pydvl.utils.progress import Progress
from pydvl.valuation.dataset import Dataset
from pydvl.valuation.methods.semivalue import SemivalueValuation
from pydvl.valuation.result import ValuationResult
from pydvl.valuation.samplers import MSRSampler
from pydvl.valuation.stopping import StoppingCriterion
from pydvl.valuation.types import ValueUpdateKind
from pydvl.valuation.utility.base import UtilityBase
from pydvl.valuation.utils import (
ensure_backend_has_generator_return,
make_parallel_flag,
)

__all__ = ["MSRBanzhafValuation"]


class MSRBanzhafValuation(SemivalueValuation):
"""Class to compute Maximum Sample Re-use (MSR) Banzhaf values.

See [Data Valuation][data-valuation] for an overview.

The MSR Banzhaf valuation approximates the Banzhaf valuation and is much more
efficient than traditional Montecarlo approaches.

Args:
utility: Utility object with model, data and scoring function.
sampler: Sampling scheme to use. Currently, only one MSRSampler is implemented.
In the future, weighted MSRSamplers will be supported.
is_done: Stopping criterion to use.
progress: Whether to show a progress bar.

"""

algorithm_name = "MSR-Banzhaf"

def __init__(
self,
utility: UtilityBase,
sampler: MSRSampler,
is_done: StoppingCriterion,
janosg marked this conversation as resolved.
Show resolved Hide resolved
progress: bool = True,
):
sampler = MSRSampler()
janosg marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(
utility=utility,
sampler=sampler,
is_done=is_done,
progress=progress,
)

@staticmethod
def coefficient(n: int, k: int) -> float:
janosg marked this conversation as resolved.
Show resolved Hide resolved
return 1.0

def fit(self, data: Dataset) -> Self:
"""Calculate the MSR Banzhaf valuation on a dataset.

This method has to be called before calling `values()`.

Calculating the Banzhaf valuation is a computationally expensive task that
can be parallelized. To do so, call the `fit()` method inside a
`joblib.parallel_config` context manager as follows:

```python
from joblib import parallel_config

with parallel_config(n_jobs=4):
valuation.fit(data)
```

"""
self._pos_result = ValuationResult.zeros(
janosg marked this conversation as resolved.
Show resolved Hide resolved
indices=data.indices,
data_names=data.data_names,
algorithm=self.algorithm_name,
)

self._neg_result = ValuationResult.zeros(
janosg marked this conversation as resolved.
Show resolved Hide resolved
indices=data.indices,
data_names=data.data_names,
algorithm=self.algorithm_name,
)

self.result = ValuationResult.zeros(
indices=data.indices,
data_names=data.data_names,
algorithm=self.algorithm_name,
)

ensure_backend_has_generator_return()

self.utility.training_data = data

strategy = self.sampler.make_strategy(self.utility, self.coefficient)
processor = delayed(strategy.process)

with Parallel(return_as="generator_unordered") as parallel:
with make_parallel_flag() as flag:
delayed_evals = parallel(
processor(batch=list(batch), is_interrupted=flag)
for batch in self.sampler.generate_batches(data.indices)
)
for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
for evaluation in batch:
if evaluation.kind == ValueUpdateKind.POSITVE:
self._pos_result.update(evaluation.idx, evaluation.update)
elif evaluation.kind == ValueUpdateKind.NEGATIVE:
self._neg_result.update(evaluation.idx, evaluation.update)
else:
raise ValueError(
"Invalid ValueUpdateKind: {evaluation.kind}"
janosg marked this conversation as resolved.
Show resolved Hide resolved
)

self.result = _combine_results(
self._pos_result, self._neg_result, data=data
)

if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
break

if self.is_done(self.result):
break

return self


def _combine_results(
janosg marked this conversation as resolved.
Show resolved Hide resolved
pos_result: ValuationResult, neg_result: ValuationResult, data: Dataset
) -> ValuationResult:
"""Combine the positive and negative running means into a final result.

We cannot simply subtract the negative result from the positive result because
this would lead to wrong variance estimates, misleading update counts and even
wrong values if no further precaution is taken.

TODO: Verify that the two running means are statistically independent (which is
assumed in the aggregation of variances).
janosg marked this conversation as resolved.
Show resolved Hide resolved

Args:
pos_result: The result of the positive updates.
neg_result: The result of the negative updates.
data: The dataset used for the valuation. Used for indices and names.

Returns:
The combined valuation result.

"""
# set counts to the minimum of the two; This enables us to ensure via stopping
# criteria that both running means have a minimal number of updates
counts = np.minimum(pos_result.counts, neg_result.counts)

values = pos_result.values - neg_result.values
values[counts == 0] = np.nan

variances = pos_result.variances + neg_result.variances
janosg marked this conversation as resolved.
Show resolved Hide resolved
variances[counts == 0] = np.inf

result = ValuationResult(
values=values,
variances=variances,
counts=counts,
indices=data.indices,
data_names=data.data_names,
algorithm=pos_result.algorithm,
)

return result
43 changes: 14 additions & 29 deletions src/pydvl/valuation/methods/semivalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,39 +104,24 @@ def fit(self, data: Dataset):

self.utility.training_data = data

parallel = Parallel(return_as="generator_unordered")
strategy = self.sampler.make_strategy(self.utility, self.coefficient)
processor = delayed(strategy.process)

with make_parallel_flag() as flag:
delayed_evals = parallel(
processor(batch=list(batch), is_interrupted=flag)
for batch in self.sampler.generate_batches(data.indices)
)
for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
for evaluation in batch:
self.result.update(evaluation.idx, evaluation.update)
with Parallel(return_as="generator_unordered") as parallel:
with make_parallel_flag() as flag:
delayed_evals = parallel(
processor(batch=list(batch), is_interrupted=flag)
for batch in self.sampler.generate_batches(data.indices)
)
for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
for evaluation in batch:
self.result.update(evaluation.idx, evaluation.update)
if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
break

if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
break

if self.is_done(self.result):
break

#####################

# FIXME: remove NaN checking after fit()?
import logging

import numpy as np

logger = logging.getLogger(__name__)
nans = np.isnan(self.result.values).sum()
if nans > 0:
logger.warning(
f"{nans} NaN values in current result. "
"Consider setting a default value for the Scorer"
)

return self
28 changes: 0 additions & 28 deletions src/pydvl/valuation/methods/weighted_banzhaf.py

This file was deleted.

15 changes: 10 additions & 5 deletions src/pydvl/valuation/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ class ValueItem:
[Dataset][pydvl.utils.dataset.Dataset]
name: Name of the sample if it was provided. Otherwise, `str(index)`
value: The value
variance: Variance of the value if it was computed with an approximate
method
variance: Variance of the marginals from which the value was computed.
count: Number of updates for this value
"""

Expand Down Expand Up @@ -187,7 +186,8 @@ class ValuationResult(collections.abc.Sequence, Iterable[ValueItem]):
common to pass the indices of a [Dataset][pydvl.utils.dataset.Dataset]
here. Attention must be paid in a parallel context to copy them to
the local process. Just do `indices=np.copy(data.indices)`.
variances: An optional array of variances in the computation of each value.
variances: An optional array of variances of the marginals from which the values
are computed.
counts: An optional array with the number of updates for each value.
Defaults to an array of ones.
data_names: Names for the data points. Defaults to index numbers if not set.
Expand Down Expand Up @@ -311,12 +311,17 @@ def values(self) -> NDArray[np.float_]:

@property
def variances(self) -> NDArray[np.float_]:
"""The variances, possibly sorted."""
"""Variances of the marginals from which values were computed, possibly sorted.

Note that this is not the variance of the value estimate, but the sample
variance of the marginals used to compute it.

"""
return self._variances[self._sort_positions]

@property
def stderr(self) -> NDArray[np.float_]:
"""The raw standard errors, possibly sorted."""
"""Standard errors of the value estimates, possibly sorted."""
return cast(
NDArray[np.float_], np.sqrt(self.variances / np.maximum(1, self.counts))
)
Expand Down
1 change: 1 addition & 0 deletions src/pydvl/valuation/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def fit(self, data: Dataset):
"""
from typing import Union

from .msr import *
from .permutation import *
from .powerset import *

Expand Down
Loading