diff --git a/src/pydvl/value/shapley/__init__.py b/src/pydvl/value/shapley/__init__.py index c841873b7..b70a16e3b 100644 --- a/src/pydvl/value/shapley/__init__.py +++ b/src/pydvl/value/shapley/__init__.py @@ -10,21 +10,19 @@ from pydvl.value.shapley.gt import group_testing_shapley from pydvl.value.shapley.knn import knn_shapley from pydvl.value.shapley.montecarlo import ( - NoTruncation, - OwenAlgorithm, combinatorial_montecarlo_shapley, - owen_sampling_shapley, permutation_montecarlo_shapley, - truncated_montecarlo_shapley, ) from pydvl.value.shapley.naive import ( combinatorial_exact_shapley, permutation_exact_shapley, ) +from pydvl.value.shapley.owen import OwenAlgorithm, owen_sampling_shapley +from pydvl.value.shapley.truncated import NoTruncation, truncated_montecarlo_shapley from pydvl.value.shapley.types import ShapleyMode from pydvl.value.stopping import MaxUpdates, StoppingCriterion -__all__ = ["compute_shapley_values"] +__all__ = ["compute_shapley_values", "ShapleyMode"] def compute_shapley_values( diff --git a/src/pydvl/value/shapley/actor.py b/src/pydvl/value/shapley/actor.py index a17ef21db..ea993f06a 100644 --- a/src/pydvl/value/shapley/actor.py +++ b/src/pydvl/value/shapley/actor.py @@ -17,7 +17,7 @@ from pydvl.utils.parallel.actor import Coordinator, RayActorWrapper, Worker from pydvl.utils.utility import Utility from pydvl.value.result import ValuationResult -from pydvl.value.shapley.montecarlo import TruncationPolicy +from pydvl.value.shapley.truncated import TruncationPolicy from pydvl.value.stopping import MaxChecks, StoppingCriterion __all__ = ["get_shapley_coordinator", "get_shapley_worker"] diff --git a/src/pydvl/value/shapley/montecarlo.py b/src/pydvl/value/shapley/montecarlo.py index 36a831339..d462cf599 100644 --- a/src/pydvl/value/shapley/montecarlo.py +++ b/src/pydvl/value/shapley/montecarlo.py @@ -30,272 +30,33 @@ the algorithms mentioned above, including Group Testing, can work to valuate groups of samples as units. """ -import abc import logging import math import operator -from enum import Enum from functools import reduce from itertools import cycle, takewhile -from time import sleep from typing import Sequence -from warnings import warn import numpy as np from numpy.typing import NDArray from tqdm import tqdm from pydvl.utils.config import ParallelConfig -from pydvl.utils.numeric import random_powerset, running_moments -from pydvl.utils.parallel import MapReduceJob, init_parallel_backend -from pydvl.utils.progress import maybe_progress +from pydvl.utils.numeric import random_powerset +from pydvl.utils.parallel import MapReduceJob from pydvl.utils.utility import Utility from pydvl.value.result import ValuationResult +from pydvl.value.shapley.truncated import NoTruncation, TruncationPolicy from pydvl.value.stopping import StoppingCriterion logger = logging.getLogger(__name__) __all__ = [ - "truncated_montecarlo_shapley", "permutation_montecarlo_shapley", "combinatorial_montecarlo_shapley", - "owen_sampling_shapley", ] -class TruncationPolicy(abc.ABC): - """A policy for deciding whether to stop computing marginals in a - permutation. - - Statistics are kept on the number of calls and truncations as :attr:`n_calls` - and :attr:`n_truncations` respectively. - - .. todo:: - Because the policy objects are copied to the workers, the statistics - are not accessible from the - :class:`~pydvl.value.shapley.actor.ShapleyCoordinator`. We need to add - methods for this. - """ - - def __init__(self): - self.n_calls: int = 0 - self.n_truncations: int = 0 - - @abc.abstractmethod - def _check(self, idx: int, score: float) -> bool: - """Implement the policy.""" - ... - - @abc.abstractmethod - def reset(self): - """Reset the policy to a state ready for a new permutation.""" - ... - - def __call__(self, idx: int, score: float) -> bool: - """Check whether the computation should be interrupted. - - :param idx: Position in the permutation currently being computed. - :param score: Last utility computed. - :return: ``True`` if the computation should be interrupted. - """ - ret = self._check(idx, score) - self.n_calls += 1 - self.n_truncations += 1 if ret else 0 - return ret - - -class NoTruncation(TruncationPolicy): - """A policy which never interrupts the computation.""" - - def _check(self, idx: int, score: float) -> bool: - return False - - def reset(self): - pass - - -class FixedTruncation(TruncationPolicy): - """Break a permutation after computing a fixed number of marginals. - - :param u: Utility object with model, data, and scoring function - :param fraction: Fraction of marginals in a permutation to compute before - stopping (e.g. 0.5 to compute half of the marginals). - """ - - def __init__(self, u: Utility, fraction: float): - super().__init__() - if fraction <= 0 or fraction > 1: - raise ValueError("fraction must be in (0, 1]") - self.max_marginals = len(u.data) * fraction - self.count = 0 - - def _check(self, idx: int, score: float) -> bool: - self.count += 1 - return self.count >= self.max_marginals - - def reset(self): - self.count = 0 - - -class RelativeTruncation(TruncationPolicy): - """Break a permutation if the marginal utility is too low. - - This is called "performance tolerance" in :footcite:t:`ghorbani_data_2019`. - - :param u: Utility object with model, data, and scoring function - :param rtol: Relative tolerance. The permutation is broken if the - last computed utility is less than ``total_utility * rtol``. - """ - - def __init__(self, u: Utility, rtol: float): - super().__init__() - self.rtol = rtol - logger.info("Computing total utility for permutation truncation.") - self.total_utility = u(u.data.indices) - - def _check(self, idx: int, score: float) -> bool: - return np.allclose(score, self.total_utility, rtol=self.rtol) - - def reset(self): - pass - - -class BootstrapTruncation(TruncationPolicy): - """Break a permutation if the last computed utility is close to the total - utility, measured as a multiple of the standard deviation of the utilities. - - :param u: Utility object with model, data, and scoring function - :param n_samples: Number of bootstrap samples to use to compute the variance - of the utilities. - :param sigmas: Number of standard deviations to use as a threshold. - """ - - def __init__(self, u: Utility, n_samples: int, sigmas: float = 1): - super().__init__() - self.n_samples = n_samples - logger.info("Computing total utility for permutation truncation.") - self.total_utility = u(u.data.indices) - self.count: int = 0 - self.variance: float = 0 - self.mean: float = 0 - self.sigmas: float = sigmas - - def _check(self, idx: int, score: float) -> bool: - self.mean, self.variance = running_moments( - self.mean, self.variance, self.count, score - ) - self.count += 1 - logger.info( - f"Bootstrap truncation: {self.count} samples, {self.variance:.2f} variance" - ) - if self.count < self.n_samples: - return False - return abs(score - self.total_utility) < float( - self.sigmas * np.sqrt(self.variance) - ) - - def reset(self): - self.count = 0 - self.variance = self.mean = 0 - - -def truncated_montecarlo_shapley( - u: Utility, - *, - done: StoppingCriterion, - truncation: TruncationPolicy, - n_jobs: int = 1, - config: ParallelConfig = ParallelConfig(), - coordinator_update_period: int = 10, - worker_update_period: int = 5, -) -> ValuationResult: - """Monte Carlo approximation to the Shapley value of data points. - - This implements the permutation-based method described in - :footcite:t:`ghorbani_data_2019`. It is a Monte Carlo estimate of the sum - over all possible permutations of the index set, with a double stopping - criterion. - - .. todo:: - Think of how to add Robin-Gelman or some other more principled stopping - criterion. - - Instead of naively implementing the expectation, we sequentially add points - to a dataset from a permutation and incrementally compute marginal utilities. - We stop computing marginals for a given permutation based on a - :class:`TruncationPolicy`. :footcite:t:`ghorbani_data_2019` mention two - policies: one that stops after a certain fraction of marginals are computed, - implemented in :class:`FixedTruncation`, and one that stops if the last - computed utility ("score") is close to the total utility using the standard - deviation of the utility as a measure of proximity, implemented in - :class:`BootstrapTruncation`. - - We keep sampling permutations and updating all shapley values - until the :class:`StoppingCriterion` returns ``True``. - - :param u: Utility object with model, data, and scoring function - :param done: Check on the results which decides when to stop - sampling permutations. - :param truncation: callable that decides whether to stop computing - marginals for a given permutation. - :param n_jobs: number of jobs processing permutations. If None, it will be - set to :func:`available_cpus`. - :param config: Object configuring parallel computation, with cluster - address, number of cpus, etc. - :param coordinator_update_period: in seconds. How often to check the - accumulated results from the workers for convergence. - :param worker_update_period: interval in seconds between different - updates to and from the coordinator - :return: Object with the data values. - - """ - # Avoid circular imports - from .actor import get_shapley_coordinator, get_shapley_worker - - if config.backend == "sequential": - raise NotImplementedError( - "Truncated MonteCarlo Shapley does not work with " - "the Sequential parallel backend." - ) - - parallel_backend = init_parallel_backend(config) - n_jobs = parallel_backend.effective_n_jobs(n_jobs) - u_id = parallel_backend.put(u) - - coordinator = get_shapley_coordinator(config=config, done=done) # type: ignore - - workers = [ - get_shapley_worker( # type: ignore - u=u_id, - coordinator=coordinator, - truncation=truncation, - worker_id=worker_id, - update_period=worker_update_period, - config=config, - ) - for worker_id in range(n_jobs) - ] - for worker in workers: - worker.run(block=False) - - while not coordinator.check_convergence(): - sleep(coordinator_update_period) - - return coordinator.accumulate() - - # Something like this would be nicer, but it doesn't seem to be possible - # to start the workers from the coordinator. - # coordinator.add_workers( - # n_workers=n_jobs, - # u=u_id, - # update_period=worker_update_period, - # config=config, - # truncation=truncation, - # ) - # - # return coordinator.run(delay=coordinator_update_period) - - def _permutation_montecarlo_shapley( u: Utility, *, @@ -480,146 +241,3 @@ def combinatorial_montecarlo_shapley( config=config, ) return map_reduce_job() - - -class OwenAlgorithm(Enum): - Standard = "standard" - Antithetic = "antithetic" - - -def _owen_sampling_shapley( - indices: Sequence[int], - u: Utility, - method: OwenAlgorithm, - n_iterations: int, - max_q: int, - *, - progress: bool = False, - job_id: int = 1, -) -> ValuationResult: - r"""This is the algorithm as detailed in the paper: to compute the outer - integral over q ∈ [0,1], use uniformly distributed points for evaluation - of the integrand. For the integrand (the expected marginal utility over the - power set), use Monte Carlo. - - .. todo:: - We might want to try better quadrature rules like Gauss or Rombert or - use Monte Carlo for the double integral. - - :param indices: Indices to compute the value for - :param u: Utility object with model, data, and scoring function - :param method: Either :attr:`~OwenAlgorithm.Full` for $q \in [0,1]$ or - :attr:`~OwenAlgorithm.Halved` for $q \in [0,0.5]$ and correlated samples - :param n_iterations: Number of subsets to sample to estimate the integrand - :param max_q: number of subdivisions for the integration over $q$ - :param progress: Whether to display progress bars for each job - :param job_id: For positioning of the progress bar - :return: Object with the data values, errors. - """ - values = np.zeros(len(u.data)) - - q_stop = {OwenAlgorithm.Standard: 1.0, OwenAlgorithm.Antithetic: 0.5} - q_steps = np.linspace(start=0, stop=q_stop[method], num=max_q) - - index_set = set(indices) - for i in maybe_progress(indices, progress, position=job_id): - e = np.zeros(max_q) - subset = np.array(list(index_set.difference({i}))) - for j, q in enumerate(q_steps): - for s in random_powerset(subset, n_samples=n_iterations, q=q): - marginal = u({i}.union(s)) - u(s) - if method == OwenAlgorithm.Antithetic and q != 0.5: - s_complement = index_set.difference(s) - marginal += u({i}.union(s_complement)) - u(s_complement) - e[j] += marginal - e /= n_iterations - # values[i] = e.mean() - # Trapezoidal rule - values[i] = (e[:-1] + e[1:]).sum() / (2 * max_q) - - return ValuationResult( - algorithm="owen_sampling_shapley_" + str(method), values=values - ) - - -def owen_sampling_shapley( - u: Utility, - n_iterations: int, - max_q: int, - *, - method: OwenAlgorithm = OwenAlgorithm.Standard, - n_jobs: int = 1, - config: ParallelConfig = ParallelConfig(), - progress: bool = False, -) -> ValuationResult: - r"""Owen sampling of Shapley values as described in - :footcite:t:`okhrati_multilinear_2021`. - - .. warning:: - Antithetic sampling is unstable and not properly tested - - This function computes a Monte Carlo approximation to - - $$v_u(i) = \int_0^1 \mathbb{E}_{S \sim P_q(D_{\backslash \{i\}})} - [u(S \cup \{i\}) - u(S)]$$ - - using one of two methods. The first one, selected with the argument ``mode = - OwenAlgorithm.Standard``, approximates the integral with: - - $$\hat{v}_u(i) = \frac{1}{Q M} \sum_{j=0}^Q \sum_{m=1}^M [u(S^{(q_j)}_m - \cup \{i\}) - u(S^{(q_j)}_m)],$$ - - where $q_j = \frac{j}{Q} \in [0,1]$ and the sets $S^{(q_j)}$ are such that a - sample $x \in S^{(q_j)}$ if a draw from a $Ber(q_j)$ distribution is 1. - - The second method, selected with the argument ``mode = - OwenAlgorithm.Anthithetic``, - uses correlated samples in the inner sum to reduce the variance: - - $$\hat{v}_u(i) = \frac{1}{Q M} \sum_{j=0}^Q \sum_{m=1}^M [u(S^{(q_j)}_m - \cup \{i\}) - u(S^{(q_j)}_m) + u((S^{(q_j)}_m)^c \cup \{i\}) - u((S^{( - q_j)}_m)^c)],$$ - - where now $q_j = \frac{j}{2Q} \in [0,\frac{1}{2}]$, and $S^c$ is the - complement of $S$. - - :param u: :class:`~pydvl.utils.utility.Utility` object holding data, model - and scoring function. - :param n_iterations: Numer of sets to sample for each value of q - :param max_q: Number of subdivisions for q ∈ [0,1] (the element sampling - probability) used to approximate the outer integral. - :param method: Selects the algorithm to use, see the description. Either - :attr:`~OwenAlgorithm.Full` for $q \in [0,1]$ or - :attr:`~OwenAlgorithm.Halved` for $q \in [0,0.5]$ and correlated samples - :param n_jobs: Number of parallel jobs to use. Each worker receives a chunk - of the total of `max_q` values for q. - :param config: Object configuring parallel computation, with cluster - address, number of cpus, etc. - :param progress: Whether to display progress bars for each job. - :return: Object with the data values. - - .. versionadded:: 0.3.0 - - """ - if n_jobs > 1: - raise NotImplementedError("Parallel Owen sampling not implemented yet") - - if OwenAlgorithm(method) == OwenAlgorithm.Antithetic: - warn("Owen antithetic sampling not tested and probably bogus") - - map_reduce_job: MapReduceJob[NDArray, ValuationResult] = MapReduceJob( - u.data.indices, - map_func=_owen_sampling_shapley, - reduce_func=lambda results: reduce(operator.add, results), - map_kwargs=dict( - u=u, - method=OwenAlgorithm(method), - n_iterations=n_iterations, - max_q=max_q, - progress=progress, - ), - n_jobs=n_jobs, - config=config, - ) - - return map_reduce_job() diff --git a/src/pydvl/value/shapley/owen.py b/src/pydvl/value/shapley/owen.py new file mode 100644 index 000000000..7450413d5 --- /dev/null +++ b/src/pydvl/value/shapley/owen.py @@ -0,0 +1,160 @@ +import operator +from enum import Enum +from functools import reduce +from typing import Sequence + +import numpy as np +from _warnings import warn +from numpy._typing import NDArray + +from pydvl.utils import ( + MapReduceJob, + ParallelConfig, + Utility, + maybe_progress, + random_powerset, +) +from pydvl.value import ValuationResult + + +class OwenAlgorithm(Enum): + Standard = "standard" + Antithetic = "antithetic" + + +def _owen_sampling_shapley( + indices: Sequence[int], + u: Utility, + method: OwenAlgorithm, + n_iterations: int, + max_q: int, + *, + progress: bool = False, + job_id: int = 1, +) -> ValuationResult: + r"""This is the algorithm as detailed in the paper: to compute the outer + integral over q ∈ [0,1], use uniformly distributed points for evaluation + of the integrand. For the integrand (the expected marginal utility over the + power set), use Monte Carlo. + + .. todo:: + We might want to try better quadrature rules like Gauss or Rombert or + use Monte Carlo for the double integral. + + :param indices: Indices to compute the value for + :param u: Utility object with model, data, and scoring function + :param method: Either :attr:`~OwenAlgorithm.Full` for $q \in [0,1]$ or + :attr:`~OwenAlgorithm.Halved` for $q \in [0,0.5]$ and correlated samples + :param n_iterations: Number of subsets to sample to estimate the integrand + :param max_q: number of subdivisions for the integration over $q$ + :param progress: Whether to display progress bars for each job + :param job_id: For positioning of the progress bar + :return: Object with the data values, errors. + """ + values = np.zeros(len(u.data)) + + q_stop = {OwenAlgorithm.Standard: 1.0, OwenAlgorithm.Antithetic: 0.5} + q_steps = np.linspace(start=0, stop=q_stop[method], num=max_q) + + index_set = set(indices) + for i in maybe_progress(indices, progress, position=job_id): + e = np.zeros(max_q) + subset = np.array(list(index_set.difference({i}))) + for j, q in enumerate(q_steps): + for s in random_powerset(subset, n_samples=n_iterations, q=q): + marginal = u({i}.union(s)) - u(s) + if method == OwenAlgorithm.Antithetic and q != 0.5: + s_complement = index_set.difference(s) + marginal += u({i}.union(s_complement)) - u(s_complement) + e[j] += marginal + e /= n_iterations + # values[i] = e.mean() + # Trapezoidal rule + values[i] = (e[:-1] + e[1:]).sum() / (2 * max_q) + + return ValuationResult( + algorithm="owen_sampling_shapley_" + str(method), values=values + ) + + +def owen_sampling_shapley( + u: Utility, + n_iterations: int, + max_q: int, + *, + method: OwenAlgorithm = OwenAlgorithm.Standard, + n_jobs: int = 1, + config: ParallelConfig = ParallelConfig(), + progress: bool = False, +) -> ValuationResult: + r"""Owen sampling of Shapley values as described in + :footcite:t:`okhrati_multilinear_2021`. + + .. warning:: + Antithetic sampling is unstable and not properly tested + + This function computes a Monte Carlo approximation to + + $$v_u(i) = \int_0^1 \mathbb{E}_{S \sim P_q(D_{\backslash \{i\}})} + [u(S \cup \{i\}) - u(S)]$$ + + using one of two methods. The first one, selected with the argument ``mode = + OwenAlgorithm.Standard``, approximates the integral with: + + $$\hat{v}_u(i) = \frac{1}{Q M} \sum_{j=0}^Q \sum_{m=1}^M [u(S^{(q_j)}_m + \cup \{i\}) - u(S^{(q_j)}_m)],$$ + + where $q_j = \frac{j}{Q} \in [0,1]$ and the sets $S^{(q_j)}$ are such that a + sample $x \in S^{(q_j)}$ if a draw from a $Ber(q_j)$ distribution is 1. + + The second method, selected with the argument ``mode = + OwenAlgorithm.Anthithetic``, + uses correlated samples in the inner sum to reduce the variance: + + $$\hat{v}_u(i) = \frac{1}{Q M} \sum_{j=0}^Q \sum_{m=1}^M [u(S^{(q_j)}_m + \cup \{i\}) - u(S^{(q_j)}_m) + u((S^{(q_j)}_m)^c \cup \{i\}) - u((S^{( + q_j)}_m)^c)],$$ + + where now $q_j = \frac{j}{2Q} \in [0,\frac{1}{2}]$, and $S^c$ is the + complement of $S$. + + :param u: :class:`~pydvl.utils.utility.Utility` object holding data, model + and scoring function. + :param n_iterations: Numer of sets to sample for each value of q + :param max_q: Number of subdivisions for q ∈ [0,1] (the element sampling + probability) used to approximate the outer integral. + :param method: Selects the algorithm to use, see the description. Either + :attr:`~OwenAlgorithm.Full` for $q \in [0,1]$ or + :attr:`~OwenAlgorithm.Halved` for $q \in [0,0.5]$ and correlated samples + :param n_jobs: Number of parallel jobs to use. Each worker receives a chunk + of the total of `max_q` values for q. + :param config: Object configuring parallel computation, with cluster + address, number of cpus, etc. + :param progress: Whether to display progress bars for each job. + :return: Object with the data values. + + .. versionadded:: 0.3.0 + + """ + if n_jobs > 1: + raise NotImplementedError("Parallel Owen sampling not implemented yet") + + if OwenAlgorithm(method) == OwenAlgorithm.Antithetic: + warn("Owen antithetic sampling not tested and probably bogus") + + map_reduce_job: MapReduceJob[NDArray, ValuationResult] = MapReduceJob( + u.data.indices, + map_func=_owen_sampling_shapley, + reduce_func=lambda results: reduce(operator.add, results), + map_kwargs=dict( + u=u, + method=OwenAlgorithm(method), + n_iterations=n_iterations, + max_q=max_q, + progress=progress, + ), + n_jobs=n_jobs, + config=config, + ) + + return map_reduce_job() diff --git a/src/pydvl/value/shapley/truncated.py b/src/pydvl/value/shapley/truncated.py new file mode 100644 index 000000000..28bf055b8 --- /dev/null +++ b/src/pydvl/value/shapley/truncated.py @@ -0,0 +1,254 @@ +import abc +import logging +from time import sleep + +import numpy as np + +from pydvl.utils import ParallelConfig, Utility, init_parallel_backend, running_moments +from pydvl.value import ValuationResult +from pydvl.value.stopping import StoppingCriterion + +__all__ = [ + "TruncationPolicy", + "NoTruncation", + "FixedTruncation", + "BootstrapTruncation", + "RelativeTruncation", + "truncated_montecarlo_shapley", +] + + +logger = logging.getLogger(__name__) + + +class TruncationPolicy(abc.ABC): + """A policy for deciding whether to stop computing marginals in a + permutation. + + Statistics are kept on the number of calls and truncations as :attr:`n_calls` + and :attr:`n_truncations` respectively. + + .. todo:: + Because the policy objects are copied to the workers, the statistics + are not accessible from the + :class:`~pydvl.value.shapley.actor.ShapleyCoordinator`. We need to add + methods for this. + """ + + def __init__(self): + self.n_calls: int = 0 + self.n_truncations: int = 0 + + @abc.abstractmethod + def _check(self, idx: int, score: float) -> bool: + """Implement the policy.""" + ... + + @abc.abstractmethod + def reset(self): + """Reset the policy to a state ready for a new permutation.""" + ... + + def __call__(self, idx: int, score: float) -> bool: + """Check whether the computation should be interrupted. + + :param idx: Position in the permutation currently being computed. + :param score: Last utility computed. + :return: ``True`` if the computation should be interrupted. + """ + ret = self._check(idx, score) + self.n_calls += 1 + self.n_truncations += 1 if ret else 0 + return ret + + +class NoTruncation(TruncationPolicy): + """A policy which never interrupts the computation.""" + + def _check(self, idx: int, score: float) -> bool: + return False + + def reset(self): + pass + + +class FixedTruncation(TruncationPolicy): + """Break a permutation after computing a fixed number of marginals. + + :param u: Utility object with model, data, and scoring function + :param fraction: Fraction of marginals in a permutation to compute before + stopping (e.g. 0.5 to compute half of the marginals). + """ + + def __init__(self, u: Utility, fraction: float): + super().__init__() + if fraction <= 0 or fraction > 1: + raise ValueError("fraction must be in (0, 1]") + self.max_marginals = len(u.data) * fraction + self.count = 0 + + def _check(self, idx: int, score: float) -> bool: + self.count += 1 + return self.count >= self.max_marginals + + def reset(self): + self.count = 0 + + +class RelativeTruncation(TruncationPolicy): + """Break a permutation if the marginal utility is too low. + + This is called "performance tolerance" in :footcite:t:`ghorbani_data_2019`. + + :param u: Utility object with model, data, and scoring function + :param rtol: Relative tolerance. The permutation is broken if the + last computed utility is less than ``total_utility * rtol``. + """ + + def __init__(self, u: Utility, rtol: float): + super().__init__() + self.rtol = rtol + logger.info("Computing total utility for permutation truncation.") + self.total_utility = u(u.data.indices) + + def _check(self, idx: int, score: float) -> bool: + return np.allclose(score, self.total_utility, rtol=self.rtol) + + def reset(self): + pass + + +class BootstrapTruncation(TruncationPolicy): + """Break a permutation if the last computed utility is close to the total + utility, measured as a multiple of the standard deviation of the utilities. + + :param u: Utility object with model, data, and scoring function + :param n_samples: Number of bootstrap samples to use to compute the variance + of the utilities. + :param sigmas: Number of standard deviations to use as a threshold. + """ + + def __init__(self, u: Utility, n_samples: int, sigmas: float = 1): + super().__init__() + self.n_samples = n_samples + logger.info("Computing total utility for permutation truncation.") + self.total_utility = u(u.data.indices) + self.count: int = 0 + self.variance: float = 0 + self.mean: float = 0 + self.sigmas: float = sigmas + + def _check(self, idx: int, score: float) -> bool: + self.mean, self.variance = running_moments( + self.mean, self.variance, self.count, score + ) + self.count += 1 + logger.info( + f"Bootstrap truncation: {self.count} samples, {self.variance:.2f} variance" + ) + if self.count < self.n_samples: + return False + return abs(score - self.total_utility) < float( + self.sigmas * np.sqrt(self.variance) + ) + + def reset(self): + self.count = 0 + self.variance = self.mean = 0 + + +def truncated_montecarlo_shapley( + u: Utility, + *, + done: StoppingCriterion, + truncation: TruncationPolicy, + n_jobs: int = 1, + config: ParallelConfig = ParallelConfig(), + coordinator_update_period: int = 10, + worker_update_period: int = 5, +) -> ValuationResult: + """Monte Carlo approximation to the Shapley value of data points. + + This implements the permutation-based method described in + :footcite:t:`ghorbani_data_2019`. It is a Monte Carlo estimate of the sum + over all possible permutations of the index set, with a double stopping + criterion. + + .. todo:: + Think of how to add Robin-Gelman or some other more principled stopping + criterion. + + Instead of naively implementing the expectation, we sequentially add points + to a dataset from a permutation and incrementally compute marginal utilities. + We stop computing marginals for a given permutation based on a + :class:`TruncationPolicy`. :footcite:t:`ghorbani_data_2019` mention two + policies: one that stops after a certain fraction of marginals are computed, + implemented in :class:`FixedTruncation`, and one that stops if the last + computed utility ("score") is close to the total utility using the standard + deviation of the utility as a measure of proximity, implemented in + :class:`BootstrapTruncation`. + + We keep sampling permutations and updating all shapley values + until the :class:`StoppingCriterion` returns ``True``. + + :param u: Utility object with model, data, and scoring function + :param done: Check on the results which decides when to stop + sampling permutations. + :param truncation: callable that decides whether to stop computing + marginals for a given permutation. + :param n_jobs: number of jobs processing permutations. If None, it will be + set to :func:`available_cpus`. + :param config: Object configuring parallel computation, with cluster + address, number of cpus, etc. + :param coordinator_update_period: in seconds. How often to check the + accumulated results from the workers for convergence. + :param worker_update_period: interval in seconds between different + updates to and from the coordinator + :return: Object with the data values. + + """ + # Avoid circular imports + from .actor import get_shapley_coordinator, get_shapley_worker + + if config.backend == "sequential": + raise NotImplementedError( + "Truncated MonteCarlo Shapley does not work with " + "the Sequential parallel backend." + ) + + parallel_backend = init_parallel_backend(config) + n_jobs = parallel_backend.effective_n_jobs(n_jobs) + u_id = parallel_backend.put(u) + + coordinator = get_shapley_coordinator(config=config, done=done) # type: ignore + + workers = [ + get_shapley_worker( # type: ignore + u=u_id, + coordinator=coordinator, + truncation=truncation, + worker_id=worker_id, + update_period=worker_update_period, + config=config, + ) + for worker_id in range(n_jobs) + ] + for worker in workers: + worker.run(block=False) + + while not coordinator.check_convergence(): + sleep(coordinator_update_period) + + return coordinator.accumulate() + + # Something like this would be nicer, but it doesn't seem to be possible + # to start the workers from the coordinator. + # coordinator.add_workers( + # n_workers=n_jobs, + # u=u_id, + # update_period=worker_update_period, + # config=config, + # truncation=truncation, + # ) + # + # return coordinator.run(delay=coordinator_update_period) diff --git a/tests/value/shapley/test_montecarlo.py b/tests/value/shapley/test_montecarlo.py index 37e99b67d..1afa65506 100644 --- a/tests/value/shapley/test_montecarlo.py +++ b/tests/value/shapley/test_montecarlo.py @@ -15,8 +15,8 @@ ) from pydvl.value import compute_shapley_values from pydvl.value.shapley import ShapleyMode -from pydvl.value.shapley.montecarlo import NoTruncation from pydvl.value.shapley.naive import combinatorial_exact_shapley +from pydvl.value.shapley.truncated import NoTruncation from pydvl.value.stopping import HistoryDeviation, MaxChecks, MaxUpdates from .. import check_rank_correlation, check_total_value, check_values