Skip to content

Commit

Permalink
Merge pull request #605 from aai-institute/feature/refactor-msr-banzhaf
Browse files Browse the repository at this point in the history
Refactor MSR Banzhaf valuation
  • Loading branch information
janosg committed Jul 9, 2024
2 parents 4c34312 + 735c908 commit 1cff43b
Show file tree
Hide file tree
Showing 15 changed files with 829 additions and 105 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Added

- Refactoring MSR Banzhaf semivalues with the new sampler architecture.
[PR #605](https://github.com/aai-institute/pyDVL/pull/605)
- Refactoring group-testing shapley values with new sampler architecture
[PR #602](https://github.com/aai-institute/pyDVL/pull/602)
- Refactoring of least-core data valuation methods with more supported sampling methods
Expand Down Expand Up @@ -32,6 +34,9 @@
- Fix a bug in pydvl.utils.numeric.random_subset where 1 - q was used instead of q
as the probability of an element being sampled
[PR #597](https://github.com/aai-institute/pyDVL/pull/597)
- Fix a bug in the calculation of variance estimates for MSR Banzhaf
[PR #605](https://github.com/aai-institute/pyDVL/pull/605)



### Changed
Expand Down
2 changes: 1 addition & 1 deletion src/pydvl/valuation/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
from .knn_shapley import *
from .least_core import *
from .loo import *
from .msr_banzhaf import *
from .owen_shapley import *
from .semivalue import *
from .weighted_banzhaf import *
196 changes: 196 additions & 0 deletions src/pydvl/valuation/methods/msr_banzhaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
r"""
This module implements the MSR-Banzhaf valuation method, as described in
(Wang et. al.)<sup><a href="#wang_data_2023">1</a></sup>.
## References
[^1]: <a name="wang_data_2023"></a>Wang, J.T. and Jia, R., 2023.
[Data Banzhaf: A Robust Data Valuation Framework for Machine Learning](
https://proceedings.mlr.press/v206/wang23e.html).
In: Proceedings of The 26th International Conference on Artificial Intelligence and
Statistics, pp. 6388-6421.
"""
from __future__ import annotations

import numpy as np
from joblib import Parallel, delayed
from typing_extensions import Self

from pydvl.utils.progress import Progress
from pydvl.valuation.dataset import Dataset
from pydvl.valuation.methods.semivalue import SemivalueValuation
from pydvl.valuation.result import ValuationResult
from pydvl.valuation.samplers import MSRSampler
from pydvl.valuation.stopping import StoppingCriterion
from pydvl.valuation.utility.base import UtilityBase
from pydvl.valuation.utils import (
ensure_backend_has_generator_return,
make_parallel_flag,
)

__all__ = ["MSRBanzhafValuation"]


class MSRBanzhafValuation(SemivalueValuation):
"""Class to compute Maximum Sample Re-use (MSR) Banzhaf values.
See [Data Valuation][data-valuation] for an overview.
The MSR Banzhaf valuation approximates the Banzhaf valuation and is much more
efficient than traditional Montecarlo approaches.
Args:
utility: Utility object with model, data and scoring function.
sampler: Sampling scheme to use. Currently, only one MSRSampler is implemented.
In the future, weighted MSRSamplers will be supported.
is_done: Stopping criterion to use.
progress: Whether to show a progress bar.
"""

algorithm_name = "MSR-Banzhaf"

def __init__(
self,
utility: UtilityBase,
sampler: MSRSampler,
is_done: StoppingCriterion,
progress: bool = True,
):
super().__init__(
utility=utility,
sampler=sampler,
is_done=is_done,
progress=progress,
)

def coefficient(self, n: int, k: int) -> float:
return 1.0

def fit(self, data: Dataset) -> Self:
"""Calculate the MSR Banzhaf valuation on a dataset.
This method has to be called before calling `values()`.
Calculating the Banzhaf valuation is a computationally expensive task that
can be parallelized. To do so, call the `fit()` method inside a
`joblib.parallel_config` context manager as follows:
```python
from joblib import parallel_config
with parallel_config(n_jobs=4):
valuation.fit(data)
```
"""
pos_result = ValuationResult.zeros(
indices=data.indices,
data_names=data.data_names,
algorithm=self.algorithm_name,
)

neg_result = ValuationResult.zeros(
indices=data.indices,
data_names=data.data_names,
algorithm=self.algorithm_name,
)

self.result = ValuationResult.zeros(
indices=data.indices,
data_names=data.data_names,
algorithm=self.algorithm_name,
)

ensure_backend_has_generator_return()

self.utility.training_data = data

strategy = self.sampler.make_strategy(self.utility, self.coefficient)
processor = delayed(strategy.process)

with Parallel(return_as="generator_unordered") as parallel:
with make_parallel_flag() as flag:
delayed_evals = parallel(
processor(batch=list(batch), is_interrupted=flag)
for batch in self.sampler.generate_batches(data.indices)
)
for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
for evaluation in batch:
if evaluation.is_positive:
pos_result.update(evaluation.idx, evaluation.update)
else:
neg_result.update(evaluation.idx, evaluation.update)

self.result = self._combine_results(
pos_result, neg_result, data=data
)

if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
break

if self.is_done(self.result):
break

return self

@staticmethod
def _combine_results(
pos_result: ValuationResult, neg_result: ValuationResult, data: Dataset
) -> ValuationResult:
"""Combine the positive and negative running means into a final result.
Since MSR-Banzhaf values are not a mean over marginals, both the variances of
the marginals and the update counts are ill-defined. We use the following
conventions:
1. The counts are defined as the minimum of the two counts. This definition
enables us to ensure a minimal number of updates for both running means via
stopping criteria and correctly detects that no actual update has taken place if
one of the counts is zero.
2. We reverse engineer the variances such that they yield correct standard
errors given our convention for the counts and the normal calculation of
standard errors in the valuation result.
Note that we cannot use the normal addition or subtraction defined by the
ValuationResult because it is weighted with counts. If we were to simply
subtract the negative result from the positive we would get wrong variance
estimates, misleading update counts and even wrong values if no further
precaution is taken.
TODO: Verify that the two running means are statistically independent (which is
assumed in the aggregation of variances).
Args:
pos_result: The result of the positive updates.
neg_result: The result of the negative updates.
data: The dataset used for the valuation. Used for indices and names.
Returns:
The combined valuation result.
"""
# define counts as minimum of the two counts (see docstring)
counts = np.minimum(pos_result.counts, neg_result.counts)

values = pos_result.values - neg_result.values
values[counts == 0] = np.nan

# define variances that yield correct standard errors (see docstring)
pos_var = pos_result.variances / np.clip(pos_result.counts, 1, np.inf)
neg_var = neg_result.variances / np.clip(neg_result.counts, 1, np.inf)
variances = np.where(counts != 0, (pos_var + neg_var) * counts, np.inf)

result = ValuationResult(
values=values,
variances=variances,
counts=counts,
indices=data.indices,
data_names=data.data_names,
algorithm=pos_result.algorithm,
)

return result
43 changes: 14 additions & 29 deletions src/pydvl/valuation/methods/semivalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,39 +104,24 @@ def fit(self, data: Dataset):

self.utility.training_data = data

parallel = Parallel(return_as="generator_unordered")
strategy = self.sampler.make_strategy(self.utility, self.coefficient)
processor = delayed(strategy.process)

with make_parallel_flag() as flag:
delayed_evals = parallel(
processor(batch=list(batch), is_interrupted=flag)
for batch in self.sampler.generate_batches(data.indices)
)
for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
for evaluation in batch:
self.result.update(evaluation.idx, evaluation.update)
with Parallel(return_as="generator_unordered") as parallel:
with make_parallel_flag() as flag:
delayed_evals = parallel(
processor(batch=list(batch), is_interrupted=flag)
for batch in self.sampler.generate_batches(data.indices)
)
for batch in Progress(delayed_evals, self.is_done, **self.tqdm_args):
for evaluation in batch:
self.result.update(evaluation.idx, evaluation.update)
if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
break

if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
break

if self.is_done(self.result):
break

#####################

# FIXME: remove NaN checking after fit()?
import logging

import numpy as np

logger = logging.getLogger(__name__)
nans = np.isnan(self.result.values).sum()
if nans > 0:
logger.warning(
f"{nans} NaN values in current result. "
"Consider setting a default value for the Scorer"
)

return self
28 changes: 0 additions & 28 deletions src/pydvl/valuation/methods/weighted_banzhaf.py

This file was deleted.

24 changes: 16 additions & 8 deletions src/pydvl/valuation/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
[ValuationResult.variances][pydvl.valuation.result.ValuationResult.variances],
[ValuationResult.counts][pydvl.valuation.result.ValuationResult.counts],
[ValuationResult.indices][pydvl.valuation.result.ValuationResult.indices],
[ValuationResult.names][pydvl.valuation.result.ValuationResult.names] are sorted in
the same way.
[ValuationResult.stderr][pydvl.valuation.result.ValuationResult.stderr],
[ValuationResult.names][pydvl.valuation.result.ValuationResult.names]
are sorted in the same way.
Indexing and slicing of results is supported and
[ValueItem][pydvl.valuation.result.ValueItem] objects are returned. These objects
Expand Down Expand Up @@ -93,8 +94,7 @@ class ValueItem:
[Dataset][pydvl.utils.dataset.Dataset]
name: Name of the sample if it was provided. Otherwise, `str(index)`
value: The value
variance: Variance of the value if it was computed with an approximate
method
variance: Variance of the marginals from which the value was computed.
count: Number of updates for this value
"""

Expand Down Expand Up @@ -187,7 +187,8 @@ class ValuationResult(collections.abc.Sequence, Iterable[ValueItem]):
common to pass the indices of a [Dataset][pydvl.utils.dataset.Dataset]
here. Attention must be paid in a parallel context to copy them to
the local process. Just do `indices=np.copy(data.indices)`.
variances: An optional array of variances in the computation of each value.
variances: An optional array of variances of the marginals from which the values
are computed.
counts: An optional array with the number of updates for each value.
Defaults to an array of ones.
data_names: Names for the data points. Defaults to index numbers if not set.
Expand Down Expand Up @@ -283,6 +284,7 @@ def sort(
properties
[ValuationResult.values][pydvl.valuation.result.ValuationResult.values],
[ValuationResult.variances][pydvl.valuation.result.ValuationResult.variances],
[ValuationResult.stderr][pydvl.valuation.result.ValuationResult.stderr],
[ValuationResult.counts][pydvl.valuation.result.ValuationResult.counts],
[ValuationResult.indices][pydvl.valuation.result.ValuationResult.indices]
and [ValuationResult.names][pydvl.valuation.result.ValuationResult.names]
Expand All @@ -298,6 +300,7 @@ def sort(
"value": "_values",
"variance": "_variances",
"name": "_names",
"stderr": "stderr",
}
self._sort_positions = np.argsort(getattr(self, keymap[key]))
if reverse:
Expand All @@ -311,14 +314,19 @@ def values(self) -> NDArray[np.float_]:

@property
def variances(self) -> NDArray[np.float_]:
"""The variances, possibly sorted."""
"""Variances of the marginals from which values were computed, possibly sorted.
Note that this is not the variance of the value estimate, but the sample
variance of the marginals used to compute it.
"""
return self._variances[self._sort_positions]

@property
def stderr(self) -> NDArray[np.float_]:
"""The raw standard errors, possibly sorted."""
"""Standard errors of the value estimates, possibly sorted."""
return cast(
NDArray[np.float_], np.sqrt(self.variances / np.maximum(1, self.counts))
NDArray[np.float_], np.sqrt(self._variances / np.maximum(1, self.counts))
)

@property
Expand Down
1 change: 1 addition & 0 deletions src/pydvl/valuation/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def fit(self, data: Dataset):
"""
from typing import Union

from .msr import *
from .permutation import *
from .powerset import *

Expand Down
Loading

0 comments on commit 1cff43b

Please sign in to comment.