diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b278dc32..fda9ec64 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,52 +1,51 @@ -exclude: '^docs/conf.py' +exclude: "^docs/conf.py" repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.0.1 - hooks: - - id: trailing-whitespace - - id: check-ast - - id: check-json - - id: check-merge-conflict - - id: check-xml - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer - - id: requirements-txt-fixer - - id: mixed-line-ending - args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows -- repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - language_version: python3 -- repo: https://github.com/pycqa/isort - rev: 5.12.0 - hooks: - - id: isort -- repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - args: [ - "--max-line-length=480", - "--extend-ignore=E203,W503" - ] -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.991 - hooks: - - id: mypy - args: [ - "--ignore-missing-imports", - "--scripts-are-modules", - "--disallow-incomplete-defs", - "--no-implicit-optional", - "--warn-unused-ignores", - "--warn-redundant-casts", - "--strict-equality", - "--warn-unreachable", - "--disallow-untyped-defs", - "--disallow-untyped-calls", - "--install-types", - "--non-interactive", - ] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: trailing-whitespace + - id: check-ast + - id: check-json + - id: check-merge-conflict + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: requirements-txt-fixer + - id: mixed-line-ending + args: ["--fix=auto"] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + language_version: python3 + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + - repo: https://github.com/PyCQA/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + args: ["--max-line-length=480", "--extend-ignore=E203,W503"] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.991 + hooks: + - id: mypy + args: [ + "--ignore-missing-imports", + "--scripts-are-modules", + "--disallow-incomplete-defs", + "--no-implicit-optional", + "--warn-unused-ignores", + "--warn-redundant-casts", + "--strict-equality", + "--warn-unreachable", + "--disallow-untyped-defs", + "--disallow-untyped-calls", + "--install-types", + "--non-interactive", + "--follow-imports=skip", # This is temporary until the mbi directory is not excluded + ] + exclude: ^src/synthcity/plugins/core/models/mbi/ # This is temporary until the mbi directory is fully typed diff --git a/README.md b/README.md index b4fcae94..23912fc9 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@

- A library for generating and evaluating synthetic tabular data.. + A library for generating and evaluating synthetic tabular data.

@@ -20,7 +20,7 @@ [![](https://pepy.tech/badge/synthcity)](https://pypi.org/project/synthcity/) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/vanderschaarlab/synthcity/blob/main/LICENSE) -[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/) +[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/) [![about](https://img.shields.io/badge/about-The%20van%20der%20Schaar%20Lab-blue)](https://www.vanderschaar-lab.com/) [![slack](https://img.shields.io/badge/chat-on%20slack-purple?logo=slack)](https://join.slack.com/t/vanderschaarlab/shared_invite/zt-1pzy8z7ti-zVsUPHAKTgCd1UoY8XtTEw) diff --git a/docs/examples.rst b/docs/examples.rst index 034281c2..e247b6fd 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -26,6 +26,9 @@ General-purpose generators CTGAN Normalizing Flows TVAE + GOGGLE + ARF + GReaT Time-series generators @@ -49,6 +52,7 @@ Privacy-related generators AdsGAN PATEGAN PrivBayes + AIM Domain adaptation generators ------------------------------ diff --git a/docs/generators.rst b/docs/generators.rst index d78a30fd..b18e35c4 100644 --- a/docs/generators.rst +++ b/docs/generators.rst @@ -12,6 +12,9 @@ General purpose Normalizing Flows RTVAE TVAE + GOGGLE + ARF + GReaT Privacy-focused ----------------- @@ -24,6 +27,7 @@ Privacy-focused PrivBayes DP-GAN DECAF + AIM Domain adaptation ------------------- diff --git a/src/synthcity/plugins/core/models/aim.py b/src/synthcity/plugins/core/models/aim.py new file mode 100644 index 00000000..e5165ad9 --- /dev/null +++ b/src/synthcity/plugins/core/models/aim.py @@ -0,0 +1,367 @@ +# stdlib +import itertools +import math +import platform +from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast + +if platform.python_version() < "3.9": + # stdlib + from typing import Iterable +else: + from collections.abc import Iterable + +# third party +import numpy as np +from scipy.special import softmax + +# synthcity absolute +import synthcity.logger as log + +# synthcity relative +from .mbi.dataset import Dataset +from .mbi.domain import Domain +from .mbi.graphical_model import GraphicalModel +from .mbi.identity import Identity +from .mbi.inference import FactoredInference + + +def powerset(iterable: Iterable) -> Iterable: + "powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)" + s = list(iterable) + return itertools.chain.from_iterable( + itertools.combinations(s, r) for r in range(1, len(s) + 1) + ) + + +def downward_closure(Ws: List[Tuple]) -> List: + ans: Set = set() + for proj in Ws: + ans.update(powerset(proj)) + return list(sorted(ans, key=len)) + + +def hypothetical_model_size(domain: Domain, cliques: List[Union[Tuple, List]]) -> float: + model = GraphicalModel(domain, cliques) + return model.size * 8 / 2**20 + + +def compile_workload(workload: List[Tuple]) -> Dict: + def score(cl: Tuple) -> int: + return sum(len(set(cl) & set(ax)) for ax in workload) + + return {cl: score(cl) for cl in downward_closure(workload)} + + +def filter_candidates( + candidates: Dict, model: GraphicalModel, size_limit: Union[float, int] +) -> Dict: + ans = {} + free_cliques = downward_closure(model.cliques) + for cl in candidates: + cond1 = ( + hypothetical_model_size(model.domain, model.cliques + [cl]) <= size_limit + ) + cond2 = cl in free_cliques + if cond1 or cond2: + ans[cl] = candidates[cl] + return ans + + +def cdp_delta(rho: Union[float, int], eps: Union[float, int]) -> Union[float, int]: + if rho < 0: + raise ValueError("rho must be positive") + if eps < 0: + raise ValueError("eps must be positive") + if rho == 0: + return 0 # degenerate case + + # search for best alpha + # Note that any alpha in (1,infty) yields a valid upper bound on delta + # Thus if this search is slightly "incorrect" it will only result in larger delta (still valid) + # This code has two "hacks". + # First the binary search is run for a pre-specificed length. + # 1000 iterations should be sufficient to converge to a good solution. + # Second we set a minimum value of alpha to avoid numerical stability issues. + # Note that the optimal alpha is at least (1+eps/rho)/2. Thus we only hit this constraint + # when eps<=rho or close to it. This is not an interesting parameter regime, as you will + # inherently get large delta in this regime. + amin = 1.01 # don't let alpha be too small, due to numerical stability + amax = (eps + 1) / (2 * rho) + 2 + for i in range(1000): # should be enough iterations + alpha = (amin + amax) / 2 + derivative = (2 * alpha - 1) * rho - eps + math.log1p(-1.0 / alpha) + if derivative < 0: + amin = alpha + else: + amax = alpha + # now calculate delta + delta = math.exp( + (alpha - 1) * (alpha * rho - eps) + alpha * math.log1p(-1 / alpha) + ) / (alpha - 1.0) + return min(delta, 1.0) # delta<=1 always + + +def cdp_rho(eps: float, delta: float) -> float: + if eps < 0: + raise ValueError("eps must be positive") + if delta <= 0: + raise ValueError("delta must be positive") + + if delta >= 1: + return 0.0 # if delta>=1 anything goes + rho_min = 0.0 # maintain cdp_delta(rho,eps)<=delta + rho_max = eps + 1 # maintain cdp_delta(rho_max,eps)>delta + for i in range(1000): + rho = (rho_min + rho_max) / 2 + if cdp_delta(rho, eps) <= delta: + rho_min = rho + else: + rho_max = rho + return rho_min + + +class Mechanism: + def __init__(self, epsilon: float, delta: float): + """ + Base class for a mechanism. + :param epsilon: privacy parameter + :param delta: privacy parameter + :param prng: pseudo random number generator + """ + self.epsilon = epsilon + self.delta = delta + self.rho = 0 if delta == 0 else cdp_rho(epsilon, delta) + self.prng = np.random + + def run(self, dataset: Dataset, workload: List[Tuple]) -> Any: + pass + + # def generalized_exponential_mechanism( + # self, qualities, sensitivities, epsilon, t=None, base_measure=None + # ): + # def generalized_em_scores(q, ds, t): + # def pareto_efficient(costs: np.ndarray) -> int: + # eff = np.ones(costs.shape[0], dtype=bool) + # for i, c in enumerate(costs): + # if eff[i]: + # eff[eff] = np.any( + # costs[eff] <= c, axis=1 + # ) # Keep any point with a lower cost + # return np.nonzero(eff)[0] + + # q = -q + # idx = pareto_efficient(np.vstack([q, ds]).T) + # r = q + t * ds + # r = r[:, None] - r[idx][None, :] + # z = ds[:, None] + ds[idx][None, :] + # s = (r / z).max(axis=1) + # return -s + + # if t is None: + # t = 2 * np.log(len(qualities) / 0.5) / epsilon + # if isinstance(qualities, dict): + # keys = list(qualities.keys()) + # qualities = np.array([qualities[key] for key in keys]) + # sensitivities = np.array([sensitivities[key] for key in keys]) + # if base_measure is not None: + # base_measure = np.log([base_measure[key] for key in keys]) + # else: + # keys = np.arange(qualities.size) + # scores = generalized_em_scores(qualities, sensitivities, t) + # key = self.exponential_mechanism( + # scores, epsilon, 1.0, base_measure=base_measure + # ) + # return keys[key] + + # def permute_and_flip(self, qualities, epsilon, sensitivity=1.0): + # """Sample a candidate from the permute-and-flip mechanism""" + # q = qualities - qualities.max() + # p = np.exp(0.5 * epsilon / sensitivity * q) + # for i in np.random.permutation(p.size): + # if np.random.rand() <= p[i]: + # return i + + def exponential_mechanism( + self, + qualities: Union[Dict, np.ndarray, Any], + epsilon: float, + sensitivity: Union[float, int] = 1.0, + base_measure: Optional[Dict] = None, + ) -> np.ndarray: + if isinstance(qualities, dict): + keys = list(qualities.keys()) + qualities = cast(np.ndarray, np.array([qualities[key] for key in keys])) + if base_measure is not None: + base_measure = np.log([base_measure[key] for key in keys]) + else: + qualities = cast(np.ndarray, np.array(qualities)) + keys = np.arange(qualities.size) + + """ Sample a candidate from the permute-and-flip mechanism """ + q = qualities - qualities.max() + if base_measure is None: + p = softmax(0.5 * epsilon / sensitivity * q) + else: + p = softmax(0.5 * epsilon / sensitivity * q + base_measure) + + return keys[self.prng.choice(p.size, p=p)] + + # def gaussian_noise_scale(self, l2_sensitivity, epsilon, delta): + # """Return the Gaussian noise necessary to attain (epsilon, delta)-DP""" + # if self.bounded: + # l2_sensitivity *= 2.0 + # return ( + # l2_sensitivity + # * privacy_calibrator.ana_gaussian_mech(epsilon, delta)["sigma"] + # ) + + # def laplace_noise_scale(self, l1_sensitivity, epsilon): + # """Return the Laplace noise necessary to attain epsilon-DP""" + # if self.bounded: + # l1_sensitivity *= 2.0 + # return l1_sensitivity / epsilon + + def gaussian_noise(self, sigma: float, size: Union[int, Tuple]) -> np.ndarray: + """Generate iid Gaussian noise of a given scale and size""" + return self.prng.normal(0, sigma, size) + + # def laplace_noise(self, b, size): + # """Generate iid Laplace noise of a given scale and size""" + # return self.prng.laplace(0, b, size) + + # def best_noise_distribution(self, l1_sensitivity, l2_sensitivity, epsilon, delta): + # """Adaptively determine if Laplace or Gaussian noise will be better, and + # return a function that samples from the appropriate distribution""" + # b = self.laplace_noise_scale(l1_sensitivity, epsilon) + # sigma = self.gaussian_noise_scale(l2_sensitivity, epsilon, delta) + # if np.sqrt(2) * b < sigma: + # return partial(self.laplace_noise, b) + # return partial(self.gaussian_noise, sigma) + + +class AIM(Mechanism): + def __init__( + self, + epsilon: float, + delta: float, + rounds: Optional[Union[int, float]] = None, + max_model_size: int = 80, + structural_zeros: Dict = {}, + ): + super(AIM, self).__init__(epsilon, delta) + self.rounds = rounds + self.max_model_size = max_model_size + self.structural_zeros = structural_zeros + + def worst_approximated( + self, + candidates: Dict, + answers: Dict, + model: GraphicalModel, + eps: float, + sigma: float, + ) -> np.ndarray: + errors = {} + sensitivity = {} + for cl in candidates: + wgt = candidates[cl] + x = answers[cl] + bias = np.sqrt(2 / np.pi) * sigma * model.domain.size(cl) + xest = model.project(cl).datavector() + errors[cl] = wgt * (np.linalg.norm(x - xest, 1) - bias) + sensitivity[cl] = abs(wgt) + max_sensitivity = max( + sensitivity.values() + ) # if all weights are 0, could be a problem + return self.exponential_mechanism(errors, eps, max_sensitivity) + + def run(self, data: Dataset, W: List) -> Dataset: + rounds = self.rounds or 16 * len(data.domain) + workload = [cl for cl, _ in W] + candidates = compile_workload(workload) + answers = {cl: data.project(cl).datavector() for cl in candidates} + + oneway = [cl for cl in candidates if len(cl) == 1] + + sigma = np.sqrt(rounds / (2 * 0.9 * self.rho)) + epsilon = np.sqrt(8 * 0.1 * self.rho / rounds) + + measurements = [] + log.info("Initial Sigma", sigma) + rho_used = len(oneway) * 0.5 / sigma**2 + for cl in oneway: + x = data.project(cl).datavector() + y = x + self.gaussian_noise(sigma, x.size) + identity_I = Identity(y.size) + measurements.append((identity_I, y, sigma, cl)) + + # backend = "torch" if torch.cuda.is_available() else "cpu" # TODO: fix torch backend option + zeros = self.structural_zeros + engine = FactoredInference( + data.domain, + # backend=backend, + iters=1000, + warm_start=True, + structural_zeros=zeros, + ) + model = engine.estimate(measurements) + + t = 0 + terminate = False + while not terminate: + t += 1 + if self.rho - rho_used < 2 * (0.5 / sigma**2 + 1.0 / 8 * epsilon**2): + # Just use up whatever remaining budget there is for one last round + remaining = self.rho - rho_used + sigma = np.sqrt(1 / (2 * 0.9 * remaining)) + epsilon = np.sqrt(8 * 0.1 * remaining) + terminate = True + + rho_used += 1.0 / 8 * epsilon**2 + 0.5 / sigma**2 + size_limit = self.max_model_size * rho_used / self.rho + small_candidates = filter_candidates(candidates, model, size_limit) + cl = self.worst_approximated( + small_candidates, answers, model, epsilon, sigma + ) + + n = data.domain.size(cl) + Q = Identity(n) + x = data.project(cl).datavector() + y = x + self.gaussian_noise(sigma, n) + measurements.append((Q, y, sigma, cl)) + z = model.project(cl).datavector() + + model = engine.estimate(measurements) + w = model.project(cl).datavector() + log.info("Selected", cl, "Size", n, "Budget Used", rho_used / self.rho) + if np.linalg.norm(w - z, 1) <= sigma * np.sqrt(2 / np.pi) * n: + log.warning(f"Reducing sigma: {sigma/2}") + sigma /= 2 + epsilon *= 2 + + log.info("Generating Data...") + engine.iters = 2500 + model = engine.estimate(measurements) + synth = model.synthetic_data() + + return synth + + +def default_params() -> Dict[str, Any]: + """ + Return default parameters to run this program + + :returns: a dictionary of default parameter settings for each command line argument + """ + params: Dict[str, Any] = {} + params["dataset"] = "../data/adult.csv" # TODO: Generalize + params["domain"] = "../data/adult-domain.json" # TODO: Generalize + params["epsilon"] = 1.0 + params["delta"] = 1e-9 + params["noise"] = "laplace" + params["max_model_size"] = 80 + params["degree"] = 2 + params["num_marginals"] = None + params["max_cells"] = 10000 + + return params diff --git a/src/synthcity/plugins/core/models/mbi/__init__.py b/src/synthcity/plugins/core/models/mbi/__init__.py new file mode 100644 index 00000000..606066c6 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/__init__.py @@ -0,0 +1,12 @@ +""" +This module contains the core functionality of the MBI package, from +https://github.com/ryan112358/private-pgm/tree/master/src/mbi, which is +licensed under Apache-2.0. + +It is further modified to include mbi.identity, which contains some functionality +from Ektelo (https://github.com/ektelo/ektelo/blob/master/ektelo/matrix.py), which +is also licensed under Apache-2.0. + +The code is also edited in order to be compatible with the synthcity codebase and its +thorough type checking and code style. +""" diff --git a/src/synthcity/plugins/core/models/mbi/callbacks.py b/src/synthcity/plugins/core/models/mbi/callbacks.py new file mode 100644 index 00000000..dd340c67 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/callbacks.py @@ -0,0 +1,113 @@ +# stdlib +import time + +# third party +import numpy as np +import pandas as pd + +# synthcity absolute +import synthcity.logger as log + + +class CallBack: + """A CallBack is a function called after every iteration of an iterative optimization procedure + It is useful for tracking loss and other metrics over time. + """ + + def __init__(self, engine, frequency=50): + """Initialize the callback objet + + :param engine: the FactoredInference object that is performing the optimization + :param frequency: the number of iterations to perform before computing the callback function + """ + self.engine = engine + self.frequency = frequency + self.calls = 0 + + def run(self, marginals): + pass + + def __call__(self, marginals): + if self.calls == 0: + self.start = time.time() + if self.calls % self.frequency == 0: + self.run(marginals) + self.calls += 1 + + +class Logger(CallBack): + """Logger is the default callback function. It tracks the time, L1 loss, L2 loss, and + optionally the total variation distance to the true query answers (when available). + The last is for debugging purposes only - in practice the true answers can not be observed. + """ + + def __init__(self, engine, true_answers=None, frequency=50): + """Initialize the callback objet + + :param engine: the FactoredInference object that is performing the optimization + :param true_answers: a dictionary containing true answers to the measurement queries. + :param frequency: the number of iterations to perform before computing the callback function + """ + CallBack.__init__(self, engine, frequency) + self.true_answers = true_answers + self.idx = 0 + + def setup(self): + model = self.engine.model + total = sum(model.domain.size(cl) for cl in model.cliques) + log.debug("Total clique size:", total, flush=True) + # cl = max(model.cliques, key=lambda cl: model.domain.size(cl)) + cols = ["iteration", "time", "l1_loss", "l2_loss", "feasibility"] + if self.true_answers is not None: + cols.append("variation") + self.results = pd.DataFrame(columns=cols) + log.debug("\t\t".join(cols), flush=True) + + def variational_distances(self, marginals): + errors = [] + for Q, y, proj in self.true_answers: + for cl in marginals: + if set(proj) <= set(cl): + mu = marginals[cl].project(proj) + x = mu.values.flatten() + diff = Q.dot(x) - y + err = 0.5 * np.abs(diff).sum() / y.sum() + errors.append(err) + break + return errors + + def primal_feasibility(self, mu): + ans = 0 + count = 0 + for r in mu: + for s in mu: + if r == s: + break + d = tuple(set(r) & set(s)) + if len(d) > 0: + x = mu[r].project(d).datavector() + y = mu[s].project(d).datavector() + err = np.linalg.norm(x - y, 1) + ans += err + count += 1 + try: + return ans / count + except BaseException: + return 0 + + def run(self, marginals): + if self.idx == 0: + self.setup() + + t = time.time() - self.start + l1_loss = self.engine._marginal_loss(marginals, metric="L1")[0] + l2_loss = self.engine._marginal_loss(marginals, metric="L2")[0] + feasibility = self.primal_feasibility(marginals) + row = [self.calls, t, l1_loss, l2_loss, feasibility] + if self.true_answers is not None: + variational = np.mean(self.variational_distances(marginals)) + row.append(100 * variational) + self.results.loc[self.idx] = row + self.idx += 1 + + log.debug("\t\t".join(["%.2f" % v for v in row]), flush=True) diff --git a/src/synthcity/plugins/core/models/mbi/clique_vector.py b/src/synthcity/plugins/core/models/mbi/clique_vector.py new file mode 100644 index 00000000..8005a83f --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/clique_vector.py @@ -0,0 +1,107 @@ +# third party +import numpy as np + + +class CliqueVector(dict): + """This is a convenience class for simplifying arithmetic over the + concatenated vector of marginals and potentials. + + These vectors are represented as a dictionary mapping cliques (subsets of attributes) + to marginals/potentials (Factor objects) + """ + + def __init__(self, dictionary): + self.dictionary = dictionary + dict.__init__(self, dictionary) + + @staticmethod + def zeros(domain, cliques): + # synthcity relative + from .factor import Factor + + return CliqueVector({cl: Factor.zeros(domain.project(cl)) for cl in cliques}) + + @staticmethod + def ones(domain, cliques): + # synthcity relative + from .factor import Factor + + return CliqueVector({cl: Factor.ones(domain.project(cl)) for cl in cliques}) + + @staticmethod + def uniform(domain, cliques): + # synthcity relative + from .factor import Factor + + return CliqueVector({cl: Factor.uniform(domain.project(cl)) for cl in cliques}) + + @staticmethod + def random(domain, cliques, prng=np.random): + # synthcity relative + from .factor import Factor + + return CliqueVector( + {cl: Factor.random(domain.project(cl), prng) for cl in cliques} + ) + + @staticmethod + def normal(domain, cliques, prng=np.random): + # synthcity relative + from .factor import Factor + + return CliqueVector( + {cl: Factor.normal(domain.project(cl), prng) for cl in cliques} + ) + + @staticmethod + def from_data(data, cliques): + # synthcity relative + from .factor import Factor + + ans = {} + for cl in cliques: + mu = data.project(cl) + ans[cl] = Factor(mu.domain, mu.datavector()) + return CliqueVector(ans) + + def combine(self, other): + # combines this CliqueVector with other, even if they do not share the same set of factors + # used for warm-starting optimization + # Important note: if other contains factors not defined within this CliqueVector, they + # are ignored and *not* combined into this CliqueVector + for cl in other: + for cl2 in self: + if set(cl) <= set(cl2): + self[cl2] += other[cl] + break + + def __mul__(self, const): + ans = {cl: const * self[cl] for cl in self} + return CliqueVector(ans) + + def __rmul__(self, const): + return self.__mul__(const) + + def __add__(self, other): + if np.isscalar(other): + ans = {cl: self[cl] + other for cl in self} + else: + ans = {cl: self[cl] + other[cl] for cl in self} + return CliqueVector(ans) + + def __sub__(self, other): + return self + -1 * other + + def exp(self): + ans = {cl: self[cl].exp() for cl in self} + return CliqueVector(ans) + + def log(self): + ans = {cl: self[cl].log() for cl in self} + return CliqueVector(ans) + + def dot(self, other): + return sum((self[cl] * other[cl]).sum() for cl in self) + + def size(self): + return sum(self[cl].domain.size() for cl in self) diff --git a/src/synthcity/plugins/core/models/mbi/dataset.py b/src/synthcity/plugins/core/models/mbi/dataset.py new file mode 100644 index 00000000..ad1347a8 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/dataset.py @@ -0,0 +1,72 @@ +# stdlib +import json + +# third party +import numpy as np +import pandas as pd + +# synthcity relative +from .domain import Domain + + +class Dataset: + def __init__(self, df, domain, weights=None): + """create a Dataset object + + :param df: a pandas dataframe + :param domain: a domain object + :param weight: weight for each row + """ + if set(domain.attrs) > set(df.columns): + raise AssertionError("data must contain domain attributes") + if weights is not None and df.shape[0] != weights.size: + raise AssertionError("weights must be the same size as the data") + self.domain = domain + self.df = df.loc[:, domain.attrs] + self.weights = weights + + @staticmethod + def synthetic(domain, N): + """Generate synthetic data conforming to the given domain + + :param domain: The domain object + :param N: the number of individuals + """ + arr = [np.random.randint(low=0, high=n, size=N) for n in domain.shape] + values = np.array(arr).T + df = pd.DataFrame(values, columns=domain.attrs) + return Dataset(df, domain) + + @staticmethod + def load(path, domain): + """Load data into a dataset object + + :param path: path to csv file + :param domain: path to json file encoding the domain information + """ + df = pd.read_csv(path) + config = json.load(open(domain)) + domain = Domain(config.keys(), config.values()) + return Dataset(df, domain) + + def project(self, cols): + """project dataset onto a subset of columns""" + if type(cols) in [str, int]: + cols = [cols] + data = self.df.loc[:, cols] + domain = self.domain.project(cols) + return Dataset(data, domain, self.weights) + + def drop(self, cols): + proj = [c for c in self.domain if c not in cols] + return self.project(proj) + + @property + def records(self): + return self.df.shape[0] + + def datavector(self, flatten=True): + """return the database in vector-of-counts form""" + bins = [range(n + 1) for n in self.domain.shape] + ans = np.histogramdd(self.df.values, bins, weights=self.weights)[0] + return ans.flatten() if flatten else ans diff --git a/src/synthcity/plugins/core/models/mbi/domain.py b/src/synthcity/plugins/core/models/mbi/domain.py new file mode 100644 index 00000000..5bc700bd --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/domain.py @@ -0,0 +1,121 @@ +# stdlib +from functools import reduce + + +class Domain: + def __init__(self, attrs, shape): + """Construct a Domain object + + :param attrs: a list or tuple of attribute names + :param shape: a list or tuple of domain sizes for each attribute + """ + if len(attrs) != len(shape): + raise AssertionError("dimensions must be equal") + self.attrs = tuple(attrs) + self.shape = tuple(shape) + self.config = dict(zip(attrs, shape)) + + @staticmethod + def fromdict(config): + """Construct a Domain object from a dictionary of { attr : size } values""" + return Domain(config.keys(), config.values()) + + def project(self, attrs): + """project the domain onto a subset of attributes + + :param attrs: the attributes to project onto + :return: the projected Domain object + """ + # return the projected domain + if type(attrs) is str: + attrs = [attrs] + shape = tuple(self.config[a] for a in attrs) + return Domain(attrs, shape) + + def marginalize(self, attrs): + """marginalize out some attributes from the domain (opposite of project) + + :param attrs: the attributes to marginalize out + :return: the marginalized Domain object + """ + proj = [a for a in self.attrs if a not in attrs] + return self.project(proj) + + def axes(self, attrs): + """return the axes tuple for the given attributes + + :param attrs: the attributes + :return: a tuple with the corresponding axes + """ + return tuple(self.attrs.index(a) for a in attrs) + + def transpose(self, attrs): + """reorder the attributes in the domain object""" + return self.project(attrs) + + def invert(self, attrs): + """returns the attributes in the domain not in the list""" + return [a for a in self.attrs if a not in attrs] + + def merge(self, other): + """merge this domain object with another + + :param other: another Domain object + :return: a new domain object covering the full domain + + Example: + >>> D1 = Domain(['a','b'], [10,20]) + >>> D2 = Domain(['b','c'], [20,30]) + >>> D1.merge(D2) + Domain(['a','b','c'], [10,20,30]) + """ + extra = other.marginalize(self.attrs) + return Domain(self.attrs + extra.attrs, self.shape + extra.shape) + + def contains(self, other): + """determine if this domain contains another""" + return set(other.attrs) <= set(self.attrs) + + def size(self, attrs=None): + """return the total size of the domain""" + if attrs is None: + return reduce(lambda x, y: x * y, self.shape, 1) + return self.project(attrs).size() + + def sort(self, how="size"): + """return a new domain object, sorted by attribute size or attribute name""" + if how == "size": + attrs = sorted(self.attrs, key=self.size) + elif how == "name": + attrs = sorted(self.attrs) + return self.project(attrs) + + def canonical(self, attrs): + """return the canonical ordering of the attributes""" + return tuple(a for a in self.attrs if a in attrs) + + def __contains__(self, attr): + return attr in self.attrs + + def __getitem__(self, a): + """return the size of an individual attribute + :param a: the attribute + """ + return self.config[a] + + def __iter__(self): + """iterator for the attributes in the domain""" + return self.attrs.__iter__() + + def __len__(self): + return len(self.attrs) + + def __eq__(self, other): + return self.attrs == other.attrs and self.shape == other.shape + + def __repr__(self): + inner = ", ".join(["%s: %d" % x for x in zip(self.attrs, self.shape)]) + return "Domain(%s)" % inner + + def __str__(self): + return self.__repr__() diff --git a/src/synthcity/plugins/core/models/mbi/factor.py b/src/synthcity/plugins/core/models/mbi/factor.py new file mode 100644 index 00000000..5cd5d850 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/factor.py @@ -0,0 +1,208 @@ +# third party +import numpy as np +import torch +from scipy.special import logsumexp + + +class Factor: + def __init__(self, domain, values): + """Initialize a factor over the given domain + + :param domain: the domain of the factor + :param values: the ndarray of factor values (for each element of the domain) + + Note: values may be a flattened 1d array or a ndarray with same shape as domain + """ + if isinstance(domain, torch.Tensor): + domain = domain.detach().cpu().numpy() + if isinstance(values, torch.Tensor): + values = values.detach().cpu().numpy() + if domain.size() != values.size: + raise AssertionError( + f"domain size ({domain.size()}) does not match values size ({values.size})" + ) + if values.ndim != 1 and values.shape != domain.shape: + raise AssertionError("invalid shape for values array") + self.domain = domain + self.values = values.reshape(domain.shape) + + @staticmethod + def zeros(domain): + return Factor(domain, np.zeros(domain.shape)) + + @staticmethod + def ones(domain): + return Factor(domain, np.ones(domain.shape)) + + @staticmethod + def random(domain): + return Factor(domain, np.random.rand(*domain.shape)) + + @staticmethod + def uniform(domain): + return Factor.ones(domain) / domain.size() + + @staticmethod + def active(domain, structural_zeros): + """create a factor that is 0 everywhere except in positions present in + 'structural_zeros', where it is -infinity + + :param: domain: the domain of this factor + :param: structural_zeros: a list of values that are not possible + """ + idx = tuple(np.array(structural_zeros).T) + vals = np.zeros(domain.shape) + vals[idx] = -np.inf + return Factor(domain, vals) + + def expand(self, domain): + if not domain.contains(self.domain): + raise AssertionError("expanded domain must contain current domain") + dims = len(domain) - len(self.domain) + values = self.values.reshape(self.domain.shape + tuple([1] * dims)) + ax = domain.axes(self.domain.attrs) + values = np.moveaxis(values, range(len(ax)), ax) + values = np.broadcast_to(values, domain.shape) + return Factor(domain, values) + + def transpose(self, attrs): + if set(attrs) != set(self.domain.attrs): + raise AssertionError("attrs must be same as domain attributes") + newdom = self.domain.project(attrs) + ax = newdom.axes(self.domain.attrs) + values = np.moveaxis(self.values, range(len(ax)), ax) + return Factor(newdom, values) + + def project(self, attrs, agg="sum"): + """ + project the factor onto a list of attributes (in order) + using either sum or logsumexp to aggregate along other attributes + """ + if agg not in ["sum", "logsumexp"]: + raise AssertionError("agg must be sum or logsumexp") + marginalized = self.domain.marginalize(attrs) + if agg == "sum": + ans = self.sum(marginalized.attrs) + elif agg == "logsumexp": + ans = self.logsumexp(marginalized.attrs) + return ans.transpose(attrs) + + def sum(self, attrs=None): + if attrs is None: + return np.sum(self.values) + axes = self.domain.axes(attrs) + values = np.sum(self.values, axis=axes) + newdom = self.domain.marginalize(attrs) + return Factor(newdom, values) + + def logsumexp(self, attrs=None): + if attrs is None: + return logsumexp(self.values) + axes = self.domain.axes(attrs) + values = logsumexp(self.values, axis=axes) + newdom = self.domain.marginalize(attrs) + return Factor(newdom, values) + + def logaddexp(self, other): + newdom = self.domain.merge(other.domain) + factor1 = self.expand(newdom) + factor2 = self.expand(newdom) + return Factor(newdom, np.logaddexp(factor1.values, factor2.values)) + + def max(self, attrs=None): + if attrs is None: + return self.values.max() + axes = self.domain.axes(attrs) + values = np.max(self.values, axis=axes) + newdom = self.domain.marginalize(attrs) + return Factor(newdom, values) + + def condition(self, evidence): + """evidence is a dictionary where + keys are attributes, and + values are elements of the domain for that attribute""" + slices = [evidence[a] if a in evidence else slice(None) for a in self.domain] + newdom = self.domain.marginalize(evidence.keys()) + values = self.values[tuple(slices)] + return Factor(newdom, values) + + def copy(self, out=None): + if out is None: + return Factor(self.domain, self.values.copy()) + np.copyto(out.values, self.values) + return out + + def __mul__(self, other): + if np.isscalar(other): + new_values = np.nan_to_num(other * self.values) + return Factor(self.domain, new_values) + newdom = self.domain.merge(other.domain) + factor1 = self.expand(newdom) + factor2 = other.expand(newdom) + return Factor(newdom, factor1.values * factor2.values) + + def __add__(self, other): + if np.isscalar(other): + return Factor(self.domain, other + self.values) + newdom = self.domain.merge(other.domain) + factor1 = self.expand(newdom) + factor2 = other.expand(newdom) + return Factor(newdom, np.add(factor1.values, factor2.values)) + + def __iadd__(self, other): + if np.isscalar(other): + self.values += other + return self + factor2 = other.expand(self.domain) + self.values += factor2.values + return self + + def __imul__(self, other): + if np.isscalar(other): + self.values *= other + return self + factor2 = other.expand(self.domain) + self.values *= factor2.values + return self + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __sub__(self, other): + if np.isscalar(other): + return Factor(self.domain, self.values - other) + other = Factor( + other.domain, np.where(other.values == -np.inf, 0, -other.values) + ) + return self + other + + def __truediv__(self, other): + if np.isscalar(other): + new_values = self.values / other + new_values = np.nan_to_num(new_values) + return Factor(self.domain, new_values) + tmp = other.expand(self.domain) + vals = np.divide(self.values, tmp.values, where=tmp.values > 0) + vals[tmp.values <= 0] = 0.0 + return Factor(self.domain, vals) + + def exp(self, out=None): + if out is None: + return Factor(self.domain, np.exp(self.values)) + np.exp(self.values, out=out.values) + return out + + def log(self, out=None): + if out is None: + return Factor(self.domain, np.log(self.values + 1e-100)) + np.log(self.values, out=out.values) + return out + + def datavector(self, flatten=True): + """Materialize the data vector""" + if flatten: + return self.values.flatten() + return self.values diff --git a/src/synthcity/plugins/core/models/mbi/factor_graph.py b/src/synthcity/plugins/core/models/mbi/factor_graph.py new file mode 100644 index 00000000..cd18ab13 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/factor_graph.py @@ -0,0 +1,287 @@ +# stdlib +from collections import defaultdict + +# third party +import numpy as np + +# synthcity relative +from .clique_vector import CliqueVector +from .factor import Factor + + +class FactorGraph: + def __init__(self, domain, cliques, total=1.0, convex=False, iters=25): + self.domain = domain + self.cliques = cliques + self.total = total + self.convex = convex + self.iters = iters + + if convex: + self.counting_numbers = self.get_counting_numbers() + self.belief_propagation = self.convergent_belief_propagation + else: + counting_numbers = {} + for cl in cliques: + counting_numbers[cl] = 1.0 + for a in domain: + counting_numbers[a] = 1.0 - len([cl for cl in cliques if a in cl]) + self.counting_numbers = None, None, counting_numbers + self.belief_propagation = self.loopy_belief_propagation + + self.potentials = None + self.marginals = None + self.messages = self.init_messages() + self.beliefs = {i: Factor.zeros(domain.project(i)) for i in domain} + + def datavector(self, flatten=True): + """Materialize the explicit representation of the distribution as a data vector.""" + logp = sum(self.potentials[cl] for cl in self.cliques) + ans = np.exp(logp - logp.logsumexp()) + wgt = ans.domain.size() / self.domain.size() + return ans.expand(self.domain).datavector(flatten) * wgt * self.total + + def init_messages(self): + mu_n = defaultdict(dict) + mu_f = defaultdict(dict) + for cl in self.cliques: + for v in cl: + mu_f[cl][v] = Factor.zeros(self.domain.project(v)) + mu_n[v][cl] = Factor.zeros(self.domain.project(v)) + return mu_n, mu_f + + def primal_feasibility(self, mu): + ans = 0 + count = 0 + for r in mu: + for s in mu: + if r == s: + break + d = tuple(set(r) & set(s)) + if len(d) > 0: + x = mu[r].project(d).datavector() + y = mu[s].project(d).datavector() + err = np.linalg.norm(x - y, 1) + ans += err + count += 1 + try: + return ans / count + except BaseException: + return 0 + + def project(self, attrs): + if type(attrs) is list: + attrs = tuple(attrs) + + if self.marginals is not None: + # we will average all ways to obtain the given marginal, + # since there may be more than one + ans = Factor.zeros(self.domain.project(attrs)) + terminate = False + for cl in self.cliques: + if set(attrs) <= set(cl): + ans += self.marginals[cl].project(attrs) + terminate = True + if terminate: + return ans * (self.total / ans.sum()) + + belief = sum(self.beliefs[i] for i in attrs) + belief += np.log(self.total) - belief.logsumexp() + return belief.transpose(attrs).exp() + + def loopy_belief_propagation(self, potentials, callback=None): + mu_n, mu_f = self.messages + self.potentials = potentials + + for i in range(self.iters): + # factor to variable BP + for cl in self.cliques: + pre = sum(mu_n[c][cl] for c in cl) + for v in cl: + complement = [var for var in cl if var is not v] + mu_f[cl][v] = potentials[cl] + pre - mu_n[v][cl] + mu_f[cl][v] = mu_f[cl][v].logsumexp(complement) + mu_f[cl][v] -= mu_f[cl][v].logsumexp() + + # variable to factor BP + for v in self.domain: + fac = [cl for cl in self.cliques if v in cl] + pre = sum(mu_f[cl][v] for cl in fac) + for f in fac: + complement = [var for var in fac if var is not f] + # mu_n[v][f] = Factor.zeros(self.domain.project(v)) + mu_n[v][f] = pre - mu_f[f][v] # sum(mu_f[c][v] for c in complement) + # mu_n[v][f] += sum(mu_f[c][v] for c in complement) + # mu_n[v][f] -= mu_n[v][f].logsumexp() + + if callback is not None: + mg = self.clique_marginals(mu_n, mu_f, potentials) + callback(mg) + + self.beliefs = { + v: sum(mu_f[cl][v] for cl in self.cliques if v in cl) for v in self.domain + } + self.messages = mu_n, mu_f + self.marginals = self.clique_marginals(mu_n, mu_f, potentials) + return self.marginals + + def convergent_belief_propagation(self, potentials, callback=None): + # Algorithm 11.2 in Koller & Friedman (modified to work in log space) + + v, vhat, k = self.counting_numbers + sigma, delta = self.messages + # sigma, delta = self.init_messages() + + for it in range(self.iters): + # pre = {} + # for r in self.cliques: + # pre[r] = sum(sigma[j][r] for j in r) + + for i in self.domain: + nbrs = [r for r in self.cliques if i in r] + for r in nbrs: + comp = [j for j in r if i != j] + delta[r][i] = potentials[r] + sum(sigma[j][r] for j in comp) + # delta[r][i] = potentials[r] + pre[r] - sigma[i][r] + delta[r][i] /= vhat[i, r] + delta[r][i] = delta[r][i].logsumexp(comp) + belief = Factor.zeros(self.domain.project(i)) + belief += sum(delta[r][i] * vhat[i, r] for r in nbrs) / vhat[i] + belief -= belief.logsumexp() + self.beliefs[i] = belief + for r in nbrs: + comp = [j for j in r if i != j] + A = -v[i, r] / vhat[i, r] + B = v[r] + sigma[i][r] = A * (potentials[r] + sum(sigma[j][r] for j in comp)) + # sigma[i][r] = A*(potentials[r] + pre[r] - sigma[i][r]) + sigma[i][r] += B * (belief - delta[r][i]) + if callback is not None: + mg = self.clique_marginals(sigma, delta, potentials) + callback(mg) + + self.messages = sigma, delta + return self.clique_marginals(sigma, delta, potentials) + + def clique_marginals(self, mu_n, mu_f, potentials): + if self.convex: + v, _, _ = self.counting_numbers + marginals = {} + for cl in self.cliques: + belief = potentials[cl] + sum(mu_n[n][cl] for n in cl) + if self.convex: + belief *= 1.0 / v[cl] + belief += np.log(self.total) - belief.logsumexp() + marginals[cl] = belief.exp() + return CliqueVector(marginals) + + def mle(self, marginals): + return -self.bethe_entropy(marginals)[1] + + def bethe_entropy(self, marginals): + """ + Return the Bethe Entropy and the gradient with respect to the marginals + + """ + _, _, weights = self.counting_numbers + entropy = 0 + dmarginals = {} + attributes = set() + for cl in self.cliques: + mu = marginals[cl] / self.total + entropy += weights[cl] * (mu * mu.log()).sum() + dmarginals[cl] = weights[cl] * (1 + mu.log()) / self.total + for a in set(cl) - set(attributes): + p = mu.project(a) + entropy += weights[a] * (p * p.log()).sum() + dmarginals[cl] += weights[a] * (1 + p.log()) / self.total + attributes.update(a) + + return -entropy, -1 * CliqueVector(dmarginals) + + def get_counting_numbers(self): + # third party + from cvxopt import matrix, solvers + + solvers.options["show_progress"] = False + index = {} + idx = 0 + + for i in self.domain: + index[i] = idx + idx += 1 + for r in self.cliques: + index[r] = idx + idx += 1 + + for r in self.cliques: + for i in r: + index[r, i] = idx + idx += 1 + + vectors = {} + for r in self.cliques: + v = np.zeros(idx) + v[index[r]] = 1 + for i in r: + v[index[r, i]] = 1 + vectors[r] = v + + for i in self.domain: + v = np.zeros(idx) + v[index[i]] = 1 + for r in self.cliques: + if i in r: + v[index[r, i]] = -1 + vectors[i] = v + + constraints = [] + for i in self.domain: + con = vectors[i].copy() + for r in self.cliques: + if i in r: + con += vectors[r] + constraints.append(con) + A = np.array(constraints) + b = np.ones(len(self.domain)) + + X = np.vstack([vectors[r] for r in self.cliques]) + y = np.ones(len(self.cliques)) + P = X.T @ X + q = -X.T @ y + G = -np.eye(q.size) + h = np.zeros(q.size) + minBound = 1.0 / len(self.domain) + for r in self.cliques: + h[index[r]] = -minBound + + P = matrix(P) + q = matrix(q) + G = matrix(G) + h = matrix(h) + A = matrix(A) + b = matrix(b) + + ans = solvers.qp(P, q, G, h, A, b) + x = np.array(ans["x"]).flatten() + + counting_v = {} + for r in self.cliques: + counting_v[r] = x[index[r]] + for i in r: + counting_v[i, r] = x[index[r, i]] + for i in self.domain: + counting_v[i] = x[index[i]] + + counting_vhat = {} + counting_k = {} + for i in self.domain: + nbrs = [r for r in self.cliques if i in r] + counting_vhat[i] = counting_v[i] + sum(counting_v[r] for r in nbrs) + counting_k[i] = counting_v[i] - sum(counting_v[i, r] for r in nbrs) + for r in nbrs: + counting_vhat[i, r] = counting_v[r] + counting_v[i, r] + for r in self.cliques: + counting_k[r] = counting_v[r] + sum(counting_v[i, r] for i in r) + + return counting_v, counting_vhat, counting_k diff --git a/src/synthcity/plugins/core/models/mbi/graphical_model.py b/src/synthcity/plugins/core/models/mbi/graphical_model.py new file mode 100644 index 00000000..bbe913de --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/graphical_model.py @@ -0,0 +1,332 @@ +# stdlib +import itertools +from functools import reduce + +# third party +import cloudpickle +import networkx as nx +import numpy as np +import pandas as pd + +# synthcity relative +from .clique_vector import CliqueVector +from .dataset import Dataset +from .domain import Domain +from .junction_tree import JunctionTree + + +class GraphicalModel: + def __init__(self, domain, cliques, total=1.0, elimination_order=None): + """Constructor for a GraphicalModel + + :param domain: a Domain object + :param total: the normalization constant for the distribution + :param cliques: a list of cliques (not necessarilly maximal cliques) + - each clique is a subset of attributes, represented as a tuple or list + :param elim_order: an elimination order for the JunctionTree algorithm + - Elimination order will impact the efficiency by not correctness. + By default, a greedy elimination order is used + """ + self.domain = domain + self.total = total + tree = JunctionTree(domain, cliques, elimination_order) + self.junction_tree = tree + + self.cliques = tree.maximal_cliques() # maximal cliques + self.message_order = tree.mp_order() + self.sep_axes = tree.separator_axes() + self.neighbors = tree.neighbors() + self.elimination_order = tree.elimination_order + + self.size = sum(domain.size(cl) for cl in self.cliques) + if self.size * 8 > 4 * 10**9: + # stdlib + import warnings + + message = "Size of parameter vector is %.2f GB. " % ( + self.size * 8 / 10**9 + ) + message += "Consider removing some measurements or finding a better elimination order" + warnings.warn(message) + + @staticmethod + def save(model, path): + cloudpickle.dump(model, open(path, "wb")) + + @staticmethod + def load(path): + return cloudpickle.load(open(path, "rb")) + + def project(self, attrs): + """Project the distribution onto a subset of attributes. + I.e., compute the marginal of the distribution + + :param attrs: a subset of attributes in the domain, represented as a list or tuple + :return: a Factor object representing the marginal distribution + """ + # use precalculated marginals if possible + if type(attrs) is list: + attrs = tuple(attrs) + if hasattr(self, "marginals"): + for cl in self.cliques: + if set(attrs) <= set(cl): + return self.marginals[cl].project(attrs) + + elim = self.domain.invert(attrs) + elim_order = greedy_order(self.domain, self.cliques + [attrs], elim) + pots = list(self.potentials.values()) + ans = variable_elimination_logspace(pots, elim_order, self.total) + return ans.project(attrs) + + def krondot(self, matrices): + """Compute the answer to the set of queries Q1 x Q2 X ... x Qd, where + Qi is a query matrix on the ith attribute and "x" is the Kronecker product + This may be more efficient than computing a supporting marginal then multiplying that by Q. + In particular, if each Qi has only a few rows. + + :param matrices: a list of matrices for each attribute in the domain + :return: the vector of query answers + """ + if any(M.shape[1] != n for M, n in zip(matrices, self.domain.shape)): + raise ValueError("matrices must conform to the shape of the domain") + logZ = self.belief_propagation(self.potentials, logZ=True) + factors = [self.potentials[cl].exp() for cl in self.cliques] + Factor = type(factors[0]) # infer the type of the factors + elim = self.domain.attrs + for attr, Q in zip(elim, matrices): + d = Domain(["%s-answer" % attr, attr], Q.shape) + factors.append(Factor(d, Q)) + result = variable_elimination(factors, elim) + result = result.transpose(["%s-answer" % a for a in elim]) + return result.datavector(flatten=False) * self.total / np.exp(logZ) + + def calculate_many_marginals(self, projections): + """Calculates marginals for all the projections in the list using + Algorithm for answering many out-of-clique queries (section 10.3 in Koller and Friedman) + + This method may be faster than calling project many times + + :param projections: a list of projections, where + each projection is a subset of attributes (represented as a list or tuple) + :return: a list of marginals, where each marginal is represented as a Factor + """ + + self.marginals = self.belief_propagation(self.potentials) + sep = self.sep_axes + neighbors = self.neighbors + # first calculate P(Cj | Ci) for all neighbors Ci, Cj + conditional = {} + for Ci in neighbors: + for Cj in neighbors[Ci]: + Sij = sep[(Cj, Ci)] + Z = self.marginals[Cj] + conditional[(Cj, Ci)] = Z / Z.project(Sij) + + # now iterate through pairs of cliques in order of distance + pred, dist = nx.floyd_warshall_predecessor_and_distance( + self.junction_tree.tree, weight=False + ) + results = {} + for Ci, Cj in sorted( + itertools.combinations(self.cliques, 2), key=lambda X: dist[X[0]][X[1]] + ): + Cl = pred[Ci][Cj] + Y = conditional[(Cj, Cl)] + if Cl == Ci: + X = self.marginals[Ci] + results[(Ci, Cj)] = results[(Cj, Ci)] = X * Y + else: + X = results[(Ci, Cl)] + S = set(Cl) - set(Ci) - set(Cj) + results[(Ci, Cj)] = results[(Cj, Ci)] = (X * Y).sum(S) + + results = { + self.domain.canonical(key[0] + key[1]): results[key] for key in results + } + + answers = {} + for proj in projections: + for attr in results: + if set(proj) <= set(attr): + answers[proj] = results[attr].project(proj) + break + if proj not in answers: + # just use variable elimination + answers[proj] = self.project(proj) + + return answers + + def datavector(self, flatten=True): + """Materialize the explicit representation of the distribution as a data vector.""" + logp = sum(self.potentials[cl] for cl in self.cliques) + ans = np.exp(logp - logp.logsumexp()) + wgt = ans.domain.size() / self.domain.size() + return ans.expand(self.domain).datavector(flatten) * wgt * self.total + + def belief_propagation(self, potentials, logZ=False): + """Compute the marginals of the graphical model with given parameters + + Note this is an efficient, numerically stable implementation of belief propagation + + :param potentials: the (log-space) parameters of the graphical model + :param logZ: flag to return logZ instead of marginals + :return marginals: the marginals of the graphical model + """ + beliefs = {cl: potentials[cl].copy() for cl in potentials} + messages = {} + for i, j in self.message_order: + sep = beliefs[i].domain.invert(self.sep_axes[(i, j)]) + if (j, i) in messages: + tau = beliefs[i] - messages[(j, i)] + else: + tau = beliefs[i] + messages[(i, j)] = tau.logsumexp(sep) + beliefs[j] += messages[(i, j)] + + cl = self.cliques[0] + if logZ: + return beliefs[cl].logsumexp() + + logZ = beliefs[cl].logsumexp() + for cl in self.cliques: + beliefs[cl] += np.log(self.total) - logZ + beliefs[cl] = beliefs[cl].exp(out=beliefs[cl]) + + return CliqueVector(beliefs) + + def mle(self, marginals): + """Compute the model parameters from the given marginals + + :param marginals: target marginals of the distribution + :param: the potentials of the graphical model with the given marginals + """ + potentials = {} + variables = set() + for cl in self.cliques: + new = tuple(variables & set(cl)) + # factor = marginals[cl] / marginals[cl].project(new) + variables.update(cl) + potentials[cl] = marginals[cl].log() - marginals[cl].project(new).log() + return CliqueVector(potentials) + + def fit(self, data): + # synthcity relative + from .factor import Factor + + if not data.domain.contains(self.domain): + raise ValueError("data domain not compatible with model domain") + marginals = {} + for cl in self.cliques: + x = data.project(cl).datavector() + dom = self.domain.project(cl) + marginals[cl] = Factor(dom, x) + self.potentials = self.mle(marginals) + + def synthetic_data(self, rows=None, method="round"): + """Generate synthetic tabular data from the distribution. + Valid options for method are 'round' and 'sample'.""" + total = int(self.total) if rows is None else rows + cols = self.domain.attrs + data = np.zeros((total, len(cols)), dtype=int) + df = pd.DataFrame(data, columns=cols) + cliques = [set(cl) for cl in self.cliques] + + def synthetic_col(counts, total): + if method == "sample": + probas = counts / counts.sum() + return np.random.choice(counts.size, total, True, probas) + counts *= total / counts.sum() + frac, integ = np.modf(counts) + integ = integ.astype(int) + extra = total - integ.sum() + if extra > 0: + idx = np.random.choice(counts.size, extra, False, frac / frac.sum()) + integ[idx] += 1 + vals = np.repeat(np.arange(counts.size), integ) + np.random.shuffle(vals) + return vals + + order = self.elimination_order[::-1] + col = order[0] + marg = self.project([col]).datavector(flatten=False) + df.loc[:, col] = synthetic_col(marg, total) + used = {col} + + for col in order[1:]: + relevant = [cl for cl in cliques if col in cl] + relevant = used.intersection(set.union(*relevant)) + proj = tuple(relevant) + used.add(col) + marg = self.project(proj + (col,)).datavector(flatten=False) + + def foo(group): + idx = group.name + vals = synthetic_col(marg[idx], group.shape[0]) + group[col] = vals + return group + + if len(proj) >= 1: + df = df.groupby(list(proj), group_keys=False).apply(foo) + else: + df[col] = synthetic_col(marg, df.shape[0]) + + return Dataset(df, self.domain) + + +def variable_elimination_logspace(potentials, elim, total): + """run variable elimination on a list of **logspace** factors""" + k = len(potentials) + psi = dict(zip(range(k), potentials)) + for z in elim: + psi2 = [psi.pop(i) for i in list(psi.keys()) if z in psi[i].domain] + phi = reduce(lambda x, y: x + y, psi2, 0) + tau = phi.logsumexp([z]) + psi[k] = tau + k += 1 + ans = reduce(lambda x, y: x + y, psi.values(), 0) + return (ans - ans.logsumexp() + np.log(total)).exp() + + +def variable_elimination(factors, elim): + """run variable elimination on a list of (non-logspace) factors""" + k = len(factors) + psi = dict(zip(range(k), factors)) + for z in elim: + psi2 = [psi.pop(i) for i in list(psi.keys()) if z in psi[i].domain] + phi = reduce(lambda x, y: x * y, psi2, 1) + tau = phi.sum([z]) + psi[k] = tau + k += 1 + return reduce(lambda x, y: x * y, psi.values(), 1) + + +def greedy_order(domain, cliques, elim): + order = [] + unmarked = set(elim) + cliques = set(cliques) + total_cost = 0 + for k in range(len(elim)): + cost = {} + for a in unmarked: + # all cliques that have a + neighbors = list(filter(lambda cl: a in cl, cliques)) + # variables in this "super-clique" + variables = tuple(set.union(set(), *map(set, neighbors))) + # domain for the resulting factor + newdom = domain.project(variables) + # cost of removing a + cost[a] = newdom.size() + + # find the best variable to eliminate + a = min(cost, key=lambda a: cost[a]) + + # do some cleanup + order.append(a) + unmarked.remove(a) + neighbors = list(filter(lambda cl: a in cl, cliques)) + variables = tuple(set.union(set(), *map(set, neighbors)) - {a}) + cliques -= set(neighbors) + cliques.add(variables) + total_cost += cost[a] + + return order diff --git a/src/synthcity/plugins/core/models/mbi/identity.py b/src/synthcity/plugins/core/models/mbi/identity.py new file mode 100644 index 00000000..5d97306d --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/identity.py @@ -0,0 +1,346 @@ +# third party +import numpy as np +from scipy import sparse +from scipy.sparse.linalg import LinearOperator + + +# util function +def class_to_dict(inst, ignore_list=[], attr_prefix=""): + """Writes state of class instance as a dict + Includes both attributes and properties (i.e. those methods labeled with @property) + Note: because this capture properties, it should be viewed as a snapshot of instance state + :param inst: instance to represent as dict + :param ignore_list: list of attr + :return: dict + """ + output = vars(inst).copy() # captures regular variables + cls = type(inst) # class of instance + properties = [p for p in dir(cls) if isinstance(getattr(cls, p), property)] + for p in properties: + prop = getattr(cls, p) # get property object by name + output[p] = prop.fget(inst) # call its fget + for k in list(output.keys()): # filter out dict keys mentioned in ignore-list + if k in ignore_list: + del output[k] + else: # prepend attr_prefix + output[attr_prefix + k] = output.pop(k) + + output[attr_prefix + "class"] = cls.__name__ + return output + + +# Main Matrix classes +class EkteloMatrix(LinearOperator): + """ + An EkteloMatrix is a linear transformation that can compute matrix-vector products + """ + + # must implement: _matmat, _transpose, matrix + # can implement: gram, sensitivity, sum, dense_matrix, sparse_matrix, __abs__ + + def __init__(self, matrix): + """Instantiate an EkteloMatrix from an explicitly represented backing matrix + + :param matrix: a 2d numpy array or a scipy sparse matrix + """ + self.matrix = matrix + self.dtype = matrix.dtype + self.shape = matrix.shape + + def asDict(self): + d = class_to_dict(self, ignore_list=[]) + return d + + def _transpose(self): + return EkteloMatrix(self.matrix.T) + + def _matmat(self, V): + """ + Matrix multiplication of a m x n matrix Q + + :param V: a n x p numpy array + :return Q*V: a m x p numpy aray + """ + return self.matrix @ V + + def gram(self): + """ + Compute the Gram matrix of the given matrix. + For a matrix Q, the gram matrix is defined as Q^T Q + """ + return self.T @ self # works for subclasses too + + def sensitivity(self): + # note: this works because np.abs calls self.__abs__ + return np.max(np.abs(self).sum(axis=0)) + + def sum(self, axis=None): + # this implementation works for all subclasses too + # (as long as they define _matmat and _transpose) + if axis == 0: + return self.T.dot(np.ones(self.shape[0])) + ans = self.dot(np.ones(self.shape[1])) + return ans if axis == 1 else np.sum(ans) + + def inv(self): + return EkteloMatrix(np.linalg.inv(self.dense_matrix())) + + def pinv(self): + return EkteloMatrix(np.linalg.pinv(self.dense_matrix())) + + def trace(self): + return self.diag().sum() + + def diag(self): + return np.diag(self.dense_matrix()) + + def _adjoint(self): + return self._transpose() + + def __mul__(self, other): + if np.isscalar(other): + return Weighted(self, other) # :noqa F821 + if type(other) == np.ndarray: + return self.dot(other) + if isinstance(other, EkteloMatrix): + return Product(self, other) + # note: this expects both matrix types to be compatible (e.g., sparse and sparse) + # todo: make it work for different backing representations + else: + raise TypeError( + "incompatible type %s for multiplication with EkteloMatrix" + % type(other) + ) + + def __add__(self, other): + if np.isscalar(other): + other = Weighted(Ones(self.shape), other) # :noqa F821 + return Sum([self, other]) + + def __sub__(self, other): + return self + -1 * other + + def __rmul__(self, other): + if np.isscalar(other): + return Weighted(self, other) # :noqa F821 + return NotImplemented + + def __getitem__(self, key): + """ + return a given row from the matrix + + :param key: the index of the row to return + :return: a 1xN EkteloMatrix + """ + # row indexing, subclasses may provide more efficient implementation + m = self.shape[0] + v = np.zeros(m) + v[key] = 1.0 + return EkteloMatrix(self.T.dot(v).reshape(1, self.shape[1])) + + def dense_matrix(self): + """ + return the dense representation of this matrix, as a 2D numpy array + """ + if sparse.issparse(self.matrix): + return self.matrix.toarray() + return self.matrix + + def sparse_matrix(self): + """ + return the sparse representation of this matrix, as a scipy matrix + """ + if sparse.issparse(self.matrix): + return self.matrix + return sparse.csr_matrix(self.matrix) + + @property + def ndim(self): + # todo: deprecate if possible + return 2 + + def __abs__(self): + return EkteloMatrix(self.matrix.__abs__()) + + def __sqr__(self): + if sparse.issparse(self.matrix): + return EkteloMatrix(self.matrix.power(2)) + return EkteloMatrix(self.matrix**2) + + +class Ones(EkteloMatrix): + """A m x n matrix of all ones""" + + def __init__(self, m, n, dtype=np.float64): + self.m = m + self.n = n + self.shape = (m, n) + self.dtype = dtype + + def _matmat(self, V): + ans = V.sum(axis=0, keepdims=True) + return np.repeat(ans, self.m, axis=0) + + def _transpose(self): + return Ones(self.n, self.m, self.dtype) + + def gram(self): + return self.m * Ones(self.n, self.n, self.dtype) + + def pinv(self): + c = 1.0 / (self.m * self.n) + return c * Ones(self.n, self.m, self.dtype) + + def trace(self): + if self.n != self.m: + raise ValueError("matrix is not square") + return self.n + + @property + def matrix(self): + return np.ones(self.shape, dtype=self.dtype) + + def __abs__(self): + return self + + def __sqr__(self): + return self + + +class Sum(EkteloMatrix): + """Class for the Sum of matrices""" + + def __init__(self, matrices): + # all must have same shape + self.matrices = matrices + self.shape = matrices[0].shape + self.dtype = np.result_type(*[Q.dtype for Q in matrices]) + + def _matmat(self, V): + return sum(Q.dot(V) for Q in self.matrices) + + def _transpose(self): + return Sum([Q.T for Q in self.matrices]) + + def __mul__(self, other): + if isinstance(other, EkteloMatrix): + return Sum( + [Q @ other for Q in self.matrices] + ) # should use others rmul though + return EkteloMatrix.__mul__(self, other) + + def diag(self): + return sum(Q.diag() for Q in self.matrices) + + @property + def matrix(self): + def _any_sparse(matrices): + return any(sparse.issparse(Q.matrix) for Q in matrices) + + if _any_sparse(self.matrices): + return sum(Q.sparse_matrix() for Q in self.matrices) + return sum(Q.dense_matrix() for Q in self.matrices) + + +class Weighted(EkteloMatrix): + """Class for multiplication by a constant""" + + def __init__(self, base, weight): + if isinstance(base, Weighted): + weight *= base.weight + base = base.base + self.base = base + self.weight = weight + self.shape = base.shape + self.dtype = base.dtype + + def _matmat(self, V): + return self.weight * self.base.dot(V) + + def _transpose(self): + return Weighted(self.base.T, self.weight) + + def gram(self): + return Weighted(self.base.gram(), self.weight**2) + + def pinv(self): + return Weighted(self.base.pinv(), 1.0 / self.weight) + + def inv(self): + return Weighted(self.base.inv(), 1.0 / self.weight) + + def trace(self): + return self.weight * self.base.trace() + + def __abs__(self): + return Weighted(self.base.__abs__(), np.abs(self.weight)) + + def __sqr__(self): + return Weighted(self.base.__sqr__(), self.weight**2) + + @property + def matrix(self): + return self.weight * self.base.matrix + + +class Product(EkteloMatrix): + def __init__(self, A, B): + if A.shape[1] != B.shape[0]: + raise ValueError("inner dimensions do not match") + self._A = A + self._B = B + self.shape = (A.shape[0], B.shape[1]) + self.dtype = np.result_type(A.dtype, B.dtype) + + def _matmat(self, X): + return self._A.dot(self._B.dot(X)) + + def _transpose(self): + return Product(self._B.T, self._A.T) + + @property + def matrix(self): + return self._A.matrix @ self._B.matrix + + def gram(self): + return Product(self.T, self) + + def inv(self): + return Product(self._B.inv(), self._A.inv()) + + +class Identity(EkteloMatrix): + def __init__(self, n, dtype=np.float64): + self.n = n + self.shape = (n, n) + self.dtype = dtype + + def _matmat(self, V): + return V + + def _transpose(self): + return self + + @property + def matrix(self): + return sparse.eye(self.n, dtype=self.dtype) + + def __mul__(self, other): + if other.shape[0] != self.n: + raise ValueError("dimension mismatch") + return other + + def inv(self): + return self + + def pinv(self): + return self + + def trace(self): + return self.n + + def __abs__(self): + return self + + def __sqr__(self): + return self diff --git a/src/synthcity/plugins/core/models/mbi/inference.py b/src/synthcity/plugins/core/models/mbi/inference.py new file mode 100644 index 00000000..4fb3dd6d --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/inference.py @@ -0,0 +1,419 @@ +# stdlib +from collections import defaultdict + +# third party +import numpy as np +from scipy import sparse +from scipy.sparse.linalg import aslinearoperator, eigsh, lsmr + +# synthcity absolute +from synthcity import logger + +# synthcity relative +from . import callbacks +from .clique_vector import CliqueVector +from .factor import Factor +from .graphical_model import GraphicalModel + + +class FactoredInference: + def __init__( + self, + domain, + # backend="numpy", + structural_zeros={}, + metric="L2", + log=False, + iters=1000, + warm_start=False, + elim_order=None, + ): + """ + Class for learning a GraphicalModel from noisy measurements on a data distribution + + :param domain: The domain information (A Domain object) + :param backend: numpy or torch backend + :param structural_zeros: An encoding of the known (structural) zeros in the distribution. + Specified as a dictionary where + - each key is a subset of attributes of size r + - each value is a list of r-tuples corresponding to impossible attribute settings + :param metric: The optimization metric. May be L1, L2 or a custom callable function + - custom callable function must consume the marginals and produce the loss and gradient + - see FactoredInference._marginal_loss for more information + :param log: flag to log iterations of optimization + :param iters: number of iterations to optimize for + :param warm_start: initialize new model or reuse last model when calling infer multiple times + :param elim_order: an elimination order for the JunctionTree algorithm + - Elimination order will impact the efficiency by not correctness. + By default, a greedy elimination order is used + """ + self.domain = domain + self.backend = "numpy" + # self.backend = backend + self.metric = metric + self.log = log + self.iters = iters + self.warm_start = warm_start + self.history = [] + self.elim_order = elim_order + self.Factor = Factor + # if backend == "torch": + # # synthcity relative + # from .torch_factor import Factor + + # self.Factor = Factor + # else: + # # synthcity relative + # from .factor import Factor + + # self.Factor = Factor + + self.structural_zeros = CliqueVector({}) + for cl in structural_zeros: + dom = self.domain.project(cl) + fact = structural_zeros[cl] + self.structural_zeros[cl] = self.Factor.active(dom, fact) + + def estimate( + self, measurements, total=None, engine="MD", callback=None, options={} + ): + """ + Estimate a GraphicalModel from the given measurements + + :param measurements: a list of (Q, y, noise, proj) tuples, where + Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator) + y is the noisy answers to the measurement queries + noise is the standard deviation of the noise added to y + proj defines the marginal used for this measurement set (a subset of attributes) + :param total: The total number of records (if known) + :param engine: the optimization algorithm to use, options include: + MD - Mirror Descent with armijo line search + RDA - Regularized Dual Averaging + IG - Interior Gradient + :param callback: a function to be called after each iteration of optimization + :param options: solver specific options passed as a dictionary + { param_name : param_value } + + :return model: A GraphicalModel that best matches the measurements taken + """ + measurements = self.fix_measurements(measurements) + options["callback"] = callback + if callback is None and self.log: + options["callback"] = callbacks.Logger(self) + if engine == "MD": + self.mirror_descent(measurements, total, **options) + elif engine == "RDA": + self.dual_averaging(measurements, total, **options) + elif engine == "IG": + self.interior_gradient(measurements, total, **options) + return self.model + + def fix_measurements(self, measurements): + if not isinstance(measurements, list): + raise TypeError("measurements must be a list") + if any(len(m) != 4 for m in measurements): + raise ValueError("each measurement must be a 4-tuple (Q, y, noise,proj)") + ans = [] + for Q, y, noise, proj in measurements: + if Q is not None and Q.shape[0] != y.size: + raise ValueError("shapes of Q and y are not compatible") + if type(proj) is list: + proj = tuple(proj) + if type(proj) is not tuple: + proj = (proj,) + if Q is None: + Q = sparse.eye(self.domain.size(proj)) + if not np.isscalar(noise): + raise TypeError("noise must be a real value, given " + str(noise)) + if any(a not in self.domain for a in proj): + raise ValueError(str(proj) + " not contained in domain") + if Q.shape[1] != self.domain.size(proj): + raise ValueError("shapes of Q and proj are not compatible") + ans.append((Q, y, noise, proj)) + return ans + + def interior_gradient( + self, measurements, total, lipschitz=None, c=1, sigma=1, callback=None + ): + """Use the interior gradient algorithm to estimate the GraphicalModel + See https://epubs.siam.org/doi/pdf/10.1137/S1052623403427823 for more information + + :param measurements: a list of (Q, y, noise, proj) tuples, where + Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator) + y is the noisy answers to the measurement queries + noise is the standard deviation of the noise added to y + proj defines the marginal used for this measurement set (a subset of attributes) + :param total: The total number of records (if known) + :param lipschitz: the Lipchitz constant of grad L(mu) + - automatically calculated for metric=L2 + - doesn't exist for metric=L1 + - must be supplied for custom callable metrics + :param c, sigma: parameters of the algorithm + :param callback: a function to be called after each iteration of optimization + """ + if self.metric == "L1": + raise ValueError("dual_averaging cannot be used with metric=L1") + if callable(self.metric) and lipschitz is None: + raise ValueError("lipschitz constant must be supplied") + self._setup(measurements, total) + # what are c and sigma? For now using 1 + model = self.model + total = model.total + L = self._lipschitz(measurements) if lipschitz is None else lipschitz + if self.log: + logger.debug(f"Lipchitz constant: {L}") + + theta = model.potentials + x = y = z = model.belief_propagation(theta) + sigma_over_L = sigma / L + for k in range(1, self.iters + 1): + a = ( + np.sqrt((c * sigma_over_L) ** 2 + 4 * c * sigma_over_L) + - sigma_over_L * c + ) / 2 + y = (1 - a) * x + a * z + c *= 1 - a + _, g = self._marginal_loss(y) + theta = theta - a / c / total * g + z = model.belief_propagation(theta) + x = (1 - a) * x + a * z + if callback is not None: + callback(x) + + model.marginals = x + model.potentials = model.mle(x) + + def dual_averaging(self, measurements, total=None, lipschitz=None, callback=None): + """Use the regularized dual averaging algorithm to estimate the GraphicalModel + See https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/xiao10JMLR.pdf + + :param measurements: a list of (Q, y, noise, proj) tuples, where + Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator) + y is the noisy answers to the measurement queries + noise is the standard deviation of the noise added to y + proj defines the marginal used for this measurement set (a subset of attributes) + :param total: The total number of records (if known) + :param lipschitz: the Lipchitz constant of grad L(mu) + - automatically calculated for metric=L2 + - doesn't exist for metric=L1 + - must be supplied for custom callable metrics + :param callback: a function to be called after each iteration of optimization + """ + if self.metric == "L1": + raise ValueError("dual_averaging cannot be used with metric=L1") + if callable(self.metric) and lipschitz is None: + raise ValueError("lipschitz constant must be supplied") + self._setup(measurements, total) + model = self.model + domain, cliques, total = model.domain, model.cliques, model.total + L = self._lipschitz(measurements) if lipschitz is None else lipschitz + logger.debug(f"Lipchitz constant: {L}") + if L == 0: + return + + theta = model.potentials + gbar = CliqueVector( + {cl: self.Factor.zeros(domain.project(cl)) for cl in cliques} + ) + w = v = model.belief_propagation(theta) + beta = 0 + + for t in range(1, self.iters + 1): + c = 2.0 / (t + 1) + u = (1 - c) * w + c * v + _, g = self._marginal_loss(u) # not interested in loss of this query point + gbar = (1 - c) * gbar + c * g + theta = -t * (t + 1) / (4 * L + beta) / self.model.total * gbar + v = model.belief_propagation(theta) + w = (1 - c) * w + c * v + + if callback is not None: + callback(w) + + model.marginals = w + model.potentials = model.mle(w) + + def mirror_descent(self, measurements, total=None, stepsize=None, callback=None): + """Use the mirror descent algorithm to estimate the GraphicalModel + See https://web.iem.technion.ac.il/images/user-files/becka/papers/3.pdf + + :param measurements: a list of (Q, y, noise, proj) tuples, where + Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator) + y is the noisy answers to the measurement queries + noise is the standard deviation of the noise added to y + proj defines the marginal used for this measurement set (a subset of attributes) + :param stepsize: The step size function for the optimization (None or scalar or function) + if None, will perform line search at each iteration (requires smooth objective) + if scalar, will use constant step size + if function, will be called with the iteration number + :param total: The total number of records (if known) + :param callback: a function to be called after each iteration of optimization + """ + if self.metric == "L1" and stepsize is None: + raise ValueError("loss function not smooth, cannot use line search") + + self._setup(measurements, total) + model = self.model + theta = model.potentials + mu = model.belief_propagation(theta) + ans = self._marginal_loss(mu) + if ans[0] == 0: + return ans[0] + + nols = stepsize is not None + if np.isscalar(stepsize): + alpha = float(stepsize) + stepsize = alpha # changed from lambda func: stepsize = lambda t: alpha + + if stepsize is None: + alpha = 1.0 / self.model.total**2 + stepsize = ( + 2.0 * alpha + ) # changed from lambda func: stepsize = lambda t: 2.0 * alpha + + for t in range(1, self.iters + 1): + if callback is not None: + callback(mu) + omega, nu = theta, mu + curr_loss, dL = ans + alpha = stepsize + for i in range(25): + theta = omega - alpha * dL + mu = model.belief_propagation(theta) + ans = self._marginal_loss(mu) + if nols or curr_loss - ans[0] >= 0.5 * alpha * dL.dot(nu - mu): + break + alpha *= 0.5 + + model.potentials = theta + model.marginals = mu + + return ans[0] + + def _marginal_loss(self, marginals, metric=None): + """Compute the loss and gradient for a given dictionary of marginals + + :param marginals: A dictionary with keys as projections and values as Factors + :return loss: the loss value + :return grad: A dictionary with gradient for each marginal + """ + if metric is None: + metric = self.metric + + if callable(metric): + return metric(marginals) + + loss = 0.0 + gradient = {} + + for cl in marginals: + mu = marginals[cl] + gradient[cl] = self.Factor.zeros(mu.domain) + for Q, y, noise, proj in self.groups[cl]: + c = 1.0 / noise + mu2 = mu.project(proj) + x = mu2.datavector() + diff = c * (Q @ x - y) + if metric == "L1": + loss += abs(diff).sum() + sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff) + grad = c * (Q.T @ sign) + else: + loss += 0.5 * (diff @ diff) + grad = c * (Q.T @ diff) + gradient[cl] += self.Factor(mu2.domain, grad) + return float(loss), CliqueVector(gradient) + + def _setup(self, measurements, total): + """Perform necessary setup for running estimation algorithms + + 1. If total is None, find the minimum variance unbiased estimate for total and use that + 2. Construct the GraphicalModel + * If there are structural_zeros in the distribution, initialize factors appropriately + 3. Pre-process measurements into groups so that _marginal_loss may be evaluated efficiently + """ + if total is None: + # find the minimum variance estimate of the total given the measurements + variances = np.array([]) + estimates = np.array([]) + for Q, y, noise, proj in measurements: + o = np.ones(Q.shape[1]) + v = lsmr(Q.T, o, atol=0, btol=0)[0] + if np.allclose(Q.T.dot(v), o): + variances = np.append(variances, noise**2 * np.dot(v, v)) + estimates = np.append(estimates, np.dot(v, y)) + if estimates.size == 0: + total = 1 + else: + variance = 1.0 / np.sum(1.0 / variances) + estimate = variance * np.sum(estimates / variances) + total = max(1, estimate) + + # if not self.warm_start or not hasattr(self, 'model'): + # initialize the model and parameters + cliques = [m[3] for m in measurements] + if self.structural_zeros is not None: + cliques += list(self.structural_zeros.keys()) + + model = GraphicalModel( + self.domain, cliques, total, elimination_order=self.elim_order + ) + + model.potentials = CliqueVector.zeros(self.domain, model.cliques) + model.potentials.combine(self.structural_zeros) + if self.warm_start and hasattr(self, "model"): + model.potentials.combine(self.model.potentials) + self.model = model + + # group the measurements into model cliques + cliques = self.model.cliques + # self.groups = { cl : [] for cl in cliques } + self.groups = defaultdict(lambda: []) + for Q, y, noise, proj in measurements: + if self.backend == "torch": + # third party + import torch + + device = self.Factor.device + y = torch.tensor(y, dtype=torch.float32, device=device) + if isinstance(Q, np.ndarray): + Q = torch.tensor(Q, dtype=torch.float32, device=device) + elif sparse.issparse(Q): + Q = Q.tocoo() + idx = torch.LongTensor([Q.row, Q.col]) + vals = torch.FloatTensor(Q.data) + Q = torch.sparse.FloatTensor(idx, vals).to(device) + + # else Q is a Linear Operator, must be compatible with torch + m = (Q, y, noise, proj) + for cl in sorted(cliques, key=model.domain.size): + # (Q, y, noise, proj) tuple + if set(proj) <= set(cl): + self.groups[cl].append(m) + break + + def _lipschitz(self, measurements): + """compute lipschitz constant for L2 loss + + Note: must be called after _setup + """ + eigs = {cl: 0.0 for cl in self.model.cliques} + for Q, _, noise, proj in measurements: + for cl in self.model.cliques: + if set(proj) <= set(cl): + n = self.domain.size(cl) + p = self.domain.size(proj) + Q = aslinearoperator(Q) + Q.dtype = np.dtype(Q.dtype) + eig = eigsh(Q.H * Q, 1)[0][0] + eigs[cl] += eig * n / p / noise**2 + break + return max(eigs.values()) + + def infer(self, measurements, total=None, engine="MD", callback=None, options={}): + # stdlib + import warnings + + message = "Function infer is deprecated. Please use estimate instead." + warnings.warn(message, DeprecationWarning) + return self.estimate(measurements, total, engine, callback, options) diff --git a/src/synthcity/plugins/core/models/mbi/junction_tree.py b/src/synthcity/plugins/core/models/mbi/junction_tree.py new file mode 100644 index 00000000..4656c505 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/junction_tree.py @@ -0,0 +1,131 @@ +# stdlib +import itertools +from collections import OrderedDict + +# third party +import networkx as nx +import numpy as np + + +class JunctionTree: + """A JunctionTree is a transformation of a GraphicalModel into a tree structure. It is used + to find the maximal cliques in the graphical model, and for specifying the message passing + order for belief propagation. The JunctionTree is characterized by an elimination_order, + which is chosen greedily by default, but may be passed in if desired. + """ + + def __init__(self, domain, cliques, elimination_order=None): + self.cliques = [tuple(cl) for cl in cliques] + self.domain = domain + self.graph = self._make_graph() + self.tree, self.order = self._make_tree(elimination_order) + + def maximal_cliques(self): + """return the list of maximal cliques in the model""" + # return list(self.tree.nodes()) + return list(nx.dfs_preorder_nodes(self.tree)) + + def mp_order(self): + """return a valid message passing order""" + edges = set() + messages = [(a, b) for a, b in self.tree.edges()] + [ + (b, a) for a, b in self.tree.edges() + ] + for m1 in messages: + for m2 in messages: + if m1[1] == m2[0] and m1[0] != m2[1]: + edges.add((m1, m2)) + G = nx.DiGraph() + G.add_nodes_from(messages) + G.add_edges_from(edges) + return list(nx.topological_sort(G)) + + def separator_axes(self): + return {(i, j): tuple(set(i) & set(j)) for i, j in self.mp_order()} + + def neighbors(self): + return {i: set(self.tree.neighbors(i)) for i in self.maximal_cliques()} + + def _make_graph(self): + G = nx.Graph() + G.add_nodes_from(self.domain.attrs) + for cl in self.cliques: + G.add_edges_from(itertools.combinations(cl, 2)) + return G + + def _triangulated(self, order): + edges = set() + G = nx.Graph(self.graph) + for node in order: + tmp = set(itertools.combinations(G.neighbors(node), 2)) + edges |= tmp + G.add_edges_from(tmp) + G.remove_node(node) + tri = nx.Graph(self.graph) + tri.add_edges_from(edges) + cliques = [tuple(c) for c in nx.find_cliques(tri)] + cost = sum(self.domain.project(cl).size() for cl in cliques) + return tri, cost + + def _greedy_order(self, stochastic=True): + order = [] + domain, cliques = self.domain, self.cliques + unmarked = list(domain.attrs) + cliques = set(cliques) + total_cost = 0 + for k in range(len(domain)): + cost = OrderedDict() + for a in unmarked: + # all cliques that have a + neighbors = list(filter(lambda cl: a in cl, cliques)) + # variables in this "super-clique" + variables = tuple(set.union(set(), *map(set, neighbors))) + # domain for the resulting factor + newdom = domain.project(variables) + # cost of removing a + cost[a] = newdom.size() + + # find the best variable to eliminate + if stochastic: + choices = list(unmarked) + costs = np.array([cost[a] for a in choices], dtype=float) + probas = np.max(costs) - costs + 1 + probas /= probas.sum() + i = np.random.choice(probas.size, p=probas) + a = choices[i] + else: + a = min(cost, key=lambda a: cost[a]) + + # do some cleanup + order.append(a) + unmarked.remove(a) + neighbors = list(filter(lambda cl: a in cl, cliques)) + variables = tuple(set.union(set(), *map(set, neighbors)) - {a}) + cliques -= set(neighbors) + cliques.add(variables) + total_cost += cost[a] + + return order, total_cost + + def _make_tree(self, order=None): + if order is None: + # orders = [self._greedy_order(stochastic=True) for _ in range(1000)] + # orders.append(self._greedy_order(stochastic=False)) + # order = min(orders, key=lambda x: x[1])[0] + order = self._greedy_order(stochastic=False)[0] + elif type(order) is int: + orders = [self._greedy_order(stochastic=False)] + [ + self._greedy_order(stochastic=True) for _ in range(order) + ] + order = min(orders, key=lambda x: x[1])[0] + self.elimination_order = order + tri, cost = self._triangulated(order) + # cliques = [tuple(c) for c in nx.find_cliques(tri)] + cliques = sorted([self.domain.canonical(c) for c in nx.find_cliques(tri)]) + complete = nx.Graph() + complete.add_nodes_from(cliques) + for c1, c2 in itertools.combinations(cliques, 2): + wgt = len(set(c1) & set(c2)) + complete.add_edge(c1, c2, weight=-wgt) + spanning = nx.minimum_spanning_tree(complete) + return spanning, order diff --git a/src/synthcity/plugins/core/models/mbi/local_inference.py b/src/synthcity/plugins/core/models/mbi/local_inference.py new file mode 100644 index 00000000..5acc0ac7 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/local_inference.py @@ -0,0 +1,293 @@ +# stdlib +from collections import defaultdict +from copy import deepcopy + +# third party +import numpy as np +from scipy import sparse +from scipy.sparse.linalg import lsmr + +# synthcity absolute +from synthcity import logger + +# synthcity relative +from . import callbacks +from .clique_vector import CliqueVector +from .factor_graph import FactorGraph +from .region_graph import RegionGraph + +""" +This file implements Approx-Private-PGM from the following paper: + +Relaxed Marginal Consistency for Differentially Private Query Answering +https://arxiv.org/pdf/2109.06153.pdf +""" + + +class LocalInference: + def __init__( + self, + domain, + backend="numpy", + structural_zeros={}, + metric="L2", + log=False, + iters=1000, + warm_start=False, + marginal_oracle="convex", + inner_iters=1, + ): + """ + Class for learning a GraphicalModel from noisy measurements on a data distribution + + :param domain: The domain information (A Domain object) + :param backend: numpy or torch backend + :param structural_zeros: An encoding of the known (structural) zeros in the distribution. + Specified as a dictionary where + - each key is a subset of attributes of size r + - each value is a list of r-tuples corresponding to impossible attribute settings + :param metric: The optimization metric. May be L1, L2 or a custom callable function + - custom callable function must consume the marginals and produce the loss and gradient + - see FactoredInference._marginal_loss for more information + :param log: flag to log iterations of optimization + :param iters: number of iterations to optimize for + :param warm_start: initialize new model or reuse last model when calling infer multiple times + :param marginal_oracle: One of + - convex (Region graph, convex Kikuchi entropy) + - approx (Region graph, Kikuchi entropy) + - pairwise-convex (Factor graph, convex Bethe entropy) + - pairwise (Factor graph, Bethe entropy) + - Can also pass any and FactorGraph or RegionGraph object + """ + self.domain = domain + self.backend = backend + self.metric = metric + self.log = log + self.iters = iters + self.warm_start = warm_start + self.history = [] + self.marginal_oracle = marginal_oracle + self.inner_iters = inner_iters + if backend == "torch": + # third party + from mbi.torch_factor import Factor + + self.Factor = Factor + else: + # third party + from mbi import Factor + + self.Factor = Factor + + self.structural_zeros = CliqueVector({}) + for cl in structural_zeros: + dom = self.domain.project(cl) + fact = structural_zeros[cl] + self.structural_zeros[cl] = self.Factor.active(dom, fact) + + def estimate(self, measurements, total=None, callback=None, options={}): + """ + Estimate a GraphicalModel from the given measurements + + :param measurements: a list of (Q, y, noise, proj) tuples, where + Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator) + y is the noisy answers to the measurement queries + noise is the standard deviation of the noise added to y + proj defines the marginal used for this measurement set (a subset of attributes) + :param total: The total number of records (if known) + :param callback: a function to be called after each iteration of optimization + :param options: solver specific options passed as a dictionary + { param_name : param_value } + + :return model: A GraphicalModel that best matches the measurements taken + """ + options["callback"] = callback + if callback is None and self.log: + options["callback"] = callbacks.Logger(self) + self.mirror_descent(measurements, total, **options) + return self.model + + def mirror_descent_auto(self, alpha, iters, callback=None): + model = self.model + theta0 = model.potentials + messages0 = deepcopy(model.messages) + theta = theta0 + mu = model.belief_propagation(theta) + l0, _ = self._marginal_loss(mu) + + prev_l = np.inf + for t in range(iters): + if callback is not None: + callback(mu) + l, dL = self._marginal_loss(mu) + theta = theta - alpha * dL + mu = model.belief_propagation(theta) + if l > prev_l: + if t <= 50: + if self.log: + logger.debug( + f"Reducing learning rate and restarting. alpha/2: {alpha / 2}" + ) + model.potentials = theta0 + model.messages = messages0 + return self.mirror_descent_auto(alpha / 2, iters, callback) + else: + model.damping = (0.9 + model.damping) / 2.0 + if self.log: + logger.debug( + f"Increasing damping and continuing. Damping: {model.damping}" + ) + alpha *= 0.5 + prev_l = l + + # run some extra iterations with no gradient update to make sure things are primal feasible + for _ in range(1000): + if model.primal_feasibility(mu) < 1.0: + break + mu = model.belief_propagation(theta) + if callback is not None: + callback(mu) + return l, theta, mu + + def mirror_descent( + self, measurements, total=None, initial_alpha=10.0, callback=None + ): + """Use the mirror descent algorithm to estimate the GraphicalModel + See https://web.iem.technion.ac.il/images/user-files/becka/papers/3.pdf + + :param measurements: a list of (Q, y, noise, proj) tuples, where + Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator) + y is the noisy answers to the measurement queries + noise is the standard deviation of the noise added to y + proj defines the marginal used for this measurement set (a subset of attributes) + :param total: The total number of records (if known) + :param stepsize: the learning rate function + :param callback: a function to be called after each iteration of optimization + """ + self._setup(measurements, total) + l, theta, mu = self.mirror_descent_auto( + alpha=initial_alpha, iters=self.iters, callback=callback + ) + + self.model.potentials = theta + self.model.marginals = mu + + return l + + def _marginal_loss(self, marginals, metric=None): + """Compute the loss and gradient for a given dictionary of marginals + + :param marginals: A dictionary with keys as projections and values as Factors + :return loss: the loss value + :return grad: A dictionary with gradient for each marginal + """ + if metric is None: + metric = self.metric + + if callable(metric): + return metric(marginals) + + loss = 0.0 + gradient = {} + + for cl in marginals: + mu = marginals[cl] + gradient[cl] = self.Factor.zeros(mu.domain) + for Q, y, noise, proj in self.groups[cl]: + c = 1.0 / noise + mu2 = mu.project(proj) + x = mu2.datavector() + diff = c * (Q @ x - y) + if metric == "L1": + loss += abs(diff).sum() + sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff) + grad = c * (Q.T @ sign) + else: + loss += 0.5 * (diff @ diff) + grad = c * (Q.T @ diff) + gradient[cl] += self.Factor(mu2.domain, grad) + return float(loss), CliqueVector(gradient) + + def _setup(self, measurements, total): + """Perform necessary setup for running estimation algorithms + + 1. If total is None, find the minimum variance unbiased estimate for total and use that + 2. Construct the GraphicalModel + * If there are structural_zeros in the distribution, initialize factors appropriately + 3. Pre-process measurements into groups so that _marginal_loss may be evaluated efficiently + """ + if total is None: + # find the minimum variance estimate of the total given the measurements + variances = np.array([]) + estimates = np.array([]) + for Q, y, noise, proj in measurements: + o = np.ones(Q.shape[1]) + v = lsmr(Q.T, o, atol=0, btol=0)[0] + if np.allclose(Q.T.dot(v), o): + variances = np.append(variances, noise**2 * np.dot(v, v)) + estimates = np.append(estimates, np.dot(v, y)) + if estimates.size == 0: + total = 1 + else: + variance = 1.0 / np.sum(1.0 / variances) + estimate = variance * np.sum(estimates / variances) + total = max(1, estimate) + + # if not self.warm_start or not hasattr(self, 'model'): + # initialize the model and parameters + cliques = [m[3] for m in measurements] + if self.structural_zeros is not None: + cliques += list(self.structural_zeros.keys()) + if self.marginal_oracle == "approx": + model = RegionGraph( + self.domain, cliques, total, convex=False, iters=self.inner_iters + ) + elif self.marginal_oracle == "convex": + model = RegionGraph( + self.domain, cliques, total, convex=True, iters=self.inner_iters + ) + elif self.marginal_oracle == "pairwise": + model = FactorGraph( + self.domain, cliques, total, convex=False, iters=self.inner_iters + ) + elif self.marginal_oracle == "pairwise-convex": + model = FactorGraph( + self.domain, cliques, total, convex=True, iters=self.inner_iters + ) + else: + model = self.marginal_oracle + model.total = total + + if type(self.marginal_oracle) is str: + model.potentials = CliqueVector.zeros(self.domain, model.cliques) + model.potentials.combine(self.structural_zeros) + if self.warm_start and hasattr(self, "model"): + model.potentials.combine(self.model.potentials) + self.model = model + + # group the measurements into model cliques + cliques = self.model.cliques + # self.groups = { cl : [] for cl in cliques } + self.groups = defaultdict(lambda: []) + for Q, y, noise, proj in measurements: + if self.backend == "torch": + # third party + import torch + + device = self.Factor.device + y = torch.tensor(y, dtype=torch.float32, device=device) + if isinstance(Q, np.ndarray): + Q = torch.tensor(Q, dtype=torch.float32, device=device) + elif sparse.issparse(Q): + Q = Q.tocoo() + idx = torch.LongTensor([Q.row, Q.col]) + vals = torch.FloatTensor(Q.data) + Q = torch.sparse.FloatTensor(idx, vals).to(device) + + # else Q is a Linear Operator, must be compatible with torch + m = (Q, y, noise, proj) + for cl in sorted(cliques, key=model.domain.size): + # (Q, y, noise, proj) tuple + if set(proj) <= set(cl): + self.groups[cl].append(m) + break diff --git a/src/synthcity/plugins/core/models/mbi/mechanism.py b/src/synthcity/plugins/core/models/mbi/mechanism.py new file mode 100644 index 00000000..2bdd28ed --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/mechanism.py @@ -0,0 +1,96 @@ +# third party +import numpy as np +from scipy import sparse +from scipy.stats import laplace, norm + +# synthcity relative +from .callbacks import Logger +from .inference import FactoredInference +from .local_inference import LocalInference + + +def run( + dataset, + measurements, + eps=1.0, + delta=0.0, + bounded=True, + engine="MD", + options={}, + iters=10000, + seed=None, + metric="L2", + elim_order=None, + frequency=1, + workload=None, + oracle="exact", +): + """ + Run a mechanism that measures the given measurements and runs inference. + This is a convenience method for running end-to-end experiments. + """ + + domain = dataset.domain + total = None + + state = np.random.RandomState(seed) + + if len(measurements) >= 1 and type(measurements[0][0]) is str: + + def matrix(proj): + return sparse.eye(domain.project(proj).size()) + + measurements = [(proj, matrix(proj)) for proj in measurements] + + l1 = 0 + l2 = 0 + for _, Q in measurements: + l1 += np.abs(Q).sum(axis=0).max() + try: + l2 += Q.power(2).sum(axis=0).max() # for spares matrices + except BaseException: + l2 += np.square(Q).sum(axis=0).max() # for dense matrices + + if bounded: + total = dataset.df.shape[0] + l1 *= 2 + l2 *= 2 + + if delta > 0: + noise = norm(loc=0, scale=np.sqrt(l2 * 2 * np.log(2 / delta)) / eps) + else: + noise = laplace(loc=0, scale=l1 / eps) + + if workload is None: + workload = measurements + + truth = [] + for ( + proj, + W, + ) in workload: + x = dataset.project(proj).datavector() + y = W.dot(x) + truth.append((W, y, proj)) + + answers = [] + for proj, Q in measurements: + x = dataset.project(proj).datavector() + z = noise.rvs(size=Q.shape[0], random_state=state) + y = Q.dot(x) + answers.append((Q, y + z, 1.0, proj)) + + if oracle == "exact": + estimator = FactoredInference( + domain, metric=metric, iters=iters, warm_start=False, elim_order=elim_order + ) + else: + estimator = LocalInference( + domain, metric=metric, iters=iters, warm_start=False, marginal_oracle=oracle + ) + logger = Logger(estimator, true_answers=truth, frequency=frequency) + model = estimator.estimate( + answers, total, engine=engine, callback=logger, options=options + ) + + return model, logger, answers diff --git a/src/synthcity/plugins/core/models/mbi/mixture_inference.py b/src/synthcity/plugins/core/models/mbi/mixture_inference.py new file mode 100644 index 00000000..6b365d45 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/mixture_inference.py @@ -0,0 +1,206 @@ +# third party +import jax.numpy as jnp +import numpy as np +import pandas as pd +from jax import vjp +from jax.nn import softmax as jax_softmax +from mbi import Dataset +from scipy.sparse.linalg import lsmr + +""" This file is experimental. + +It is a close approximation to the method described in RAP (https://arxiv.org/abs/2103.06641) +and an even closer approximation to RAP^{softmax} (https://arxiv.org/abs/2106.07153) + +Notable differences: +- Code now shares the same interface as Private-PGM (see FactoredInference) +- Named model "MixtureOfProducts", as that is one interpretation for the relaxed tabular format +(at least when softmax is used). +- Added support for unbounded-DP, with automatic estimate of total. +""" + + +def estimate_total(measurements): + # find the minimum variance estimate of the total given the measurements + variances = np.array([]) + estimates = np.array([]) + for Q, y, noise, proj in measurements: + o = np.ones(Q.shape[1]) + v = lsmr(Q.T, o, atol=0, btol=0)[0] + if np.allclose(Q.T.dot(v), o): + variances = np.append(variances, noise**2 * np.dot(v, v)) + estimates = np.append(estimates, np.dot(v, y)) + if estimates.size == 0: + return 1 + else: + variance = 1.0 / np.sum(1.0 / variances) + estimate = variance * np.sum(estimates / variances) + return max(1, estimate) + + +def adam(loss_and_grad, x0, iters=250): + a = 1.0 + b1, b2 = 0.9, 0.999 + eps = 10e-8 + + x = x0 + m = np.zeros_like(x) + v = np.zeros_like(x) + for t in range(1, iters + 1): + l, g = loss_and_grad(x) + m = b1 * m + (1 - b1) * g + v = b2 * v + (1 - b2) * g**2 + mhat = m / (1 - b1**t) + vhat = v / (1 - b2**t) + x = x - a * mhat / (np.sqrt(vhat) + eps) + return x + + +def synthetic_col(counts, total): + counts *= total / counts.sum() + frac, integ = np.modf(counts) + integ = integ.astype(int) + extra = total - integ.sum() + if extra > 0: + idx = np.random.choice(counts.size, extra, False, frac / frac.sum()) + integ[idx] += 1 + vals = np.repeat(np.arange(counts.size), integ) + np.random.shuffle(vals) + return vals + + +class MixtureOfProducts: + def __init__(self, products, domain, total): + self.products = products + self.domain = domain + self.total = total + self.num_components = next(iter(products.values())).shape[0] + + def project(self, cols): + products = {col: self.products[col] for col in cols} + domain = self.domain.project(cols) + return MixtureOfProducts(products, domain, self.total) + + def datavector(self, flatten=True): + letters = "bcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"[ + : len(self.domain) + ] + formula = ( + ",".join(["a%s" % letter for letter in letters]) + "->" + "".join(letters) + ) + components = [self.products[col] for col in self.domain] + ans = np.einsum(formula, *components) * self.total / self.num_components + return ans.flatten() if flatten else ans + + def synthetic_data(self, rows=None): + total = rows or int(self.total) + subtotal = total // self.num_components + 1 + + dfs = [] + for i in range(self.num_components): + df = pd.DataFrame() + for col in self.products: + counts = self.products[col][i] + df[col] = synthetic_col(counts, subtotal) + dfs.append(df) + + df = pd.concat(dfs).sample(frac=1).reset_index(drop=True)[:total] + return Dataset(df, self.domain) + + +class MixtureInference: + def __init__( + self, domain, components=10, metric="L2", iters=2500, warm_start=False + ): + """ + :param domain: A Domain object + :param components: The number of mixture components + :metric: The metric to use for the loss function (can be callable) + """ + self.domain = domain + self.components = components + self.metric = metric + self.iters = iters + self.warm_start = warm_start + self.params = np.random.normal( + loc=0, scale=0.25, size=sum(domain.shape) * components + ) + + def estimate(self, measurements, total=None, alpha=0.1): + if total is None: + total = estimate_total(measurements) + self.measurements = measurements + cliques = [M[-1] for M in measurements] + letters = "bcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + def get_products(params): + products = {} + idx = 0 + for col in self.domain: + n = self.domain[col] + k = self.components + products[col] = jax_softmax( + params[idx : idx + k * n].reshape(k, n), axis=1 + ) + idx += k * n + return products + + def marginals_from_params(params): + products = get_products(params) + mu = {} + for cl in cliques: + let = letters[: len(cl)] + formula = ( + ",".join(["a%s" % letter for letter in let]) + "->" + "".join(let) + ) + components = [products[col] for col in cl] + ans = jnp.einsum(formula, *components) * total / self.components + mu[cl] = ans.flatten() + return mu + + def loss_and_grad(params): + # For computing dL / dmu we will use ordinary numpy so as to support scipy sparse and linear operator inputs + # For computing dL / dparams we will use jax to avoid manually deriving gradients + params = jnp.array(params) + mu, backprop = vjp(marginals_from_params, params) + mu = {cl: np.array(mu[cl]) for cl in cliques} + loss, dL = self._marginal_loss(mu) + dL = {cl: jnp.array(dL[cl]) for cl in cliques} + dparams = backprop(dL) + return loss, np.array(dparams[0]) + + if not self.warm_start: + self.params = np.random.normal( + loc=0, scale=0.25, size=sum(self.domain.shape) * self.components + ) + self.params = adam(loss_and_grad, self.params, iters=self.iters) + products = get_products(self.params) + return MixtureOfProducts(products, self.domain, total) + + def _marginal_loss(self, marginals, metric=None): + """Compute the loss and gradient for a given dictionary of marginals + + :param marginals: A dictionary with keys as projections and values as Factors + :return loss: the loss value + :return grad: A dictionary with gradient for each marginal + """ + if metric is None: + metric = self.metric + + loss = 0.0 + gradient = {cl: np.zeros_like(marginals[cl]) for cl in marginals} + + for Q, y, noise, cl in self.measurements: + x = marginals[cl] + c = 1.0 / noise + diff = c * (Q @ x - y) + if metric == "L1": + loss += abs(diff).sum() + sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff) + grad = c * (Q.T @ sign) + else: + loss += 0.5 * (diff @ diff) + grad = c * (Q.T @ diff) + gradient[cl] += grad + + return float(loss), gradient diff --git a/src/synthcity/plugins/core/models/mbi/public_inference.py b/src/synthcity/plugins/core/models/mbi/public_inference.py new file mode 100644 index 00000000..d94b73a2 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/public_inference.py @@ -0,0 +1,128 @@ +# third party +import numpy as np +from scipy.sparse.linalg import lsmr +from scipy.special import logsumexp + +# synthcity relative +from .clique_vector import CliqueVector +from .dataset import Dataset +from .factor import Factor + +""" This file is experimental. +It is an attempt to re-implement and generalize the technique used in PMW^{Pub}. +https://arxiv.org/pdf/2102.08598.pdf + +Notable differences: +- Shares the same interface as Private-PGM (see FactoredInference) +- Supports unbounded differential privacy, with automatic estimate of total +- Supports arbitrary measurements over the data marginals +- Solves an L2 minimization problem (by default), but can pass other loss functions if desired. +""" + + +def entropic_mirror_descent(loss_and_grad, x0, total, iters=250): + logP = np.log(x0 + np.nextafter(0, 1)) + np.log(total) - np.log(x0.sum()) + P = np.exp(logP) + P = x0 * total / x0.sum() + loss, dL = loss_and_grad(P) + alpha = 1.0 + begun = False + + for _ in range(iters): + logQ = logP - alpha * dL + logQ += np.log(total) - logsumexp(logQ) + Q = np.exp(logQ) + # Q = P * np.exp(-alpha*dL) + # Q *= total / Q.sum() + new_loss, new_dL = loss_and_grad(Q) + + if loss - new_loss >= 0.5 * alpha * dL.dot(P - Q): + logP = logQ + loss, dL = new_loss, new_dL + # increase step size if we haven't already decreased it at least once + if not begun: + alpha *= 2 + else: + alpha *= 0.5 + begun = True + + return np.exp(logP) + + +def estimate_total(measurements): + # find the minimum variance estimate of the total given the measurements + variances = np.array([]) + estimates = np.array([]) + for Q, y, noise, proj in measurements: + o = np.ones(Q.shape[1]) + v = lsmr(Q.T, o, atol=0, btol=0)[0] + if np.allclose(Q.T.dot(v), o): + variances = np.append(variances, noise**2 * np.dot(v, v)) + estimates = np.append(estimates, np.dot(v, y)) + if estimates.size == 0: + return 1 + else: + variance = 1.0 / np.sum(1.0 / variances) + estimate = variance * np.sum(estimates / variances) + return max(1, estimate) + + +class PublicInference: + def __init__(self, public_data, metric="L2"): + self.public_data = public_data + self.metric = metric + self.weights = np.ones(self.public_data.records) + + def estimate(self, measurements, total=None): + if total is None: + total = estimate_total(measurements) + self.measurements = measurements + cliques = [M[-1] for M in measurements] + + def loss_and_grad(weights): + est = Dataset(self.public_data.df, self.public_data.domain, weights) + mu = CliqueVector.from_data(est, cliques) + loss, dL = self._marginal_loss(mu) + dweights = np.zeros(weights.size) + for cl in dL: + idx = est.project(cl).df.values + dweights += dL[cl].values[tuple(idx.T)] + return loss, dweights + + # bounds = [(0,None) for _ in self.weights] + # res = minimize(loss_and_grad, x0=self.weights, method='L-BFGS-B', jac=True, bounds=bounds) + # self.weights = res.x + + self.weights = entropic_mirror_descent(loss_and_grad, self.weights, total) + return Dataset(self.public_data.df, self.public_data.domain, self.weights) + + def _marginal_loss(self, marginals, metric=None): + """Compute the loss and gradient for a given dictionary of marginals + + :param marginals: A dictionary with keys as projections and values as Factors + :return loss: the loss value + :return grad: A dictionary with gradient for each marginal + """ + if metric is None: + metric = self.metric + + if callable(metric): + return metric(marginals) + + loss = 0.0 + gradient = {cl: Factor.zeros(marginals[cl].domain) for cl in marginals} + + for Q, y, noise, cl in self.measurements: + mu = marginals[cl] + c = 1.0 / noise + x = mu.datavector() + diff = c * (Q @ x - y) + if metric == "L1": + loss += abs(diff).sum() + sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff) + grad = c * (Q.T @ sign) + else: + loss += 0.5 * (diff @ diff) + grad = c * (Q.T @ diff) + gradient[cl] += Factor(mu.domain, grad) + return float(loss), CliqueVector(gradient) diff --git a/src/synthcity/plugins/core/models/mbi/region_graph.py b/src/synthcity/plugins/core/models/mbi/region_graph.py new file mode 100644 index 00000000..a2cb7a3e --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/region_graph.py @@ -0,0 +1,544 @@ +# stdlib +import itertools +from collections import defaultdict + +# third party +import networkx as nx +import numpy as np +from disjoint_set import DisjointSet + +# synthcity relative +from .clique_vector import CliqueVector +from .factor import Factor + + +class RegionGraph: + def __init__( + self, + domain, + cliques, + total=1.0, + minimal=True, + convex=True, + iters=25, + convergence=1e-3, + damping=0.5, + ): + self.domain = domain + self.cliques = cliques + if not convex: + self.cliques = [] + for r in cliques: + if not any(set(r) < set(s) for s in cliques): + self.cliques.append(r) + self.total = total + self.minimal = minimal + self.convex = convex + self.iters = iters + self.convergence = convergence + self.damping = damping + if convex: + self.belief_propagation = self.hazan_peng_shashua + else: + self.belief_propagation = self.generalized_belief_propagation + self.build_graph() + self.cliques = sorted(self.regions, key=len) + self.potentials = CliqueVector.zeros(domain, self.cliques) + self.marginals = CliqueVector.uniform(domain, self.cliques) * total + + def show(self): + # third party + import matplotlib.pyplot as plt + + labels = {r: "".join(r) for r in self.regions} + + pos = {} + xloc = defaultdict(lambda: 0) + for r in sorted(self.regions): + y = len(r) + pos[r] = (xloc[y] + 0.5 * (y % 2), y) + xloc[y] += 1 + + nx.draw(self.G, pos=pos, node_color="orange", node_size=1000) + nx.draw( + self.G, pos=pos, nodelist=self.cliques, node_color="green", node_size=1000 + ) + nx.draw_networkx_labels(self.G, pos=pos, labels=labels) + plt.show() + + def project(self, attrs, maxiter=100, alpha=None): + if type(attrs) is list: + attrs = tuple(attrs) + + for cl in self.cliques: + if set(attrs) <= set(cl): + return self.marginals[cl].project(attrs) + + # Use multiplicative weights/entropic mirror descent to solve projection problem + intersections = [set(cl) & set(attrs) for cl in self.cliques] + target_cliques = [ + tuple(t) for t in intersections if not any(t < s for s in intersections) + ] + target_cliques = list(set(target_cliques)) + target_mu = CliqueVector.from_data(self, target_cliques) + + if len(target_cliques) == 0: + return Factor.uniform(self.domain.project(attrs)) * self.total + # P = Factor.uniform(self.domain.project(attrs))*self.total + # Use a smart initialization + P = estimate_kikuchi_marginal(self.domain.project(attrs), self.total, target_mu) + if alpha is None: + # start with a safe step size + alpha = 1.0 / (self.total * len(target_cliques)) + + curr_mu = CliqueVector.from_data(P, target_cliques) + diff = curr_mu - target_mu + curr_loss, dL = diff.dot(diff), sum(diff.values()).expand(P.domain) + begun = False + + for _ in range(maxiter): + if curr_loss <= 1e-8: + return P # stop early if marginals are almost exactly realized + Q = P * (-alpha * dL).exp() + Q *= self.total / Q.sum() + curr_mu = CliqueVector.from_data(Q, target_cliques) + diff = curr_mu - target_mu + loss = diff.dot(diff) + + if curr_loss - loss >= 0.5 * alpha * dL.dot(P - Q): + P = Q + curr_loss = loss + dL = sum(diff.values()).expand(P.domain) + # increase step size if we haven't already decreased it at least once + if not begun: + alpha *= 2 + else: + alpha *= 0.5 + begun = True + + return P + + def primal_feasibility(self, mu): + ans = 0 + count = 0 + for r in self.cliques: + for s in self.children[r]: + x = mu[r].project(s).datavector() + y = mu[s].datavector() + err = np.linalg.norm(x - y, 1) + ans += err + count += 1 + return 0 if count == 0 else ans / count + + def is_converged(self, mu): + return self.primal_feasibility(mu) <= self.convergence + + def build_graph(self): + # Alg 11.3 of Koller & Friedman + regions = set(self.cliques) + size = 0 + while len(regions) > size: + size = len(regions) + for r1, r2 in itertools.combinations(regions, 2): + z = tuple(sorted(set(r1) & set(r2))) + if len(z) > 0 and z not in regions: + regions.update({z}) + + G = nx.DiGraph() + G.add_nodes_from(regions) + for r1 in regions: + for r2 in regions: + if set(r2) < set(r1) and not any( + set(r2) < set(r3) and set(r3) < set(r1) for r3 in regions + ): + G.add_edge(r1, r2) + + H = G.reverse() + G1, H1 = nx.transitive_closure(G), nx.transitive_closure(H) + + self.children = {r: list(G.neighbors(r)) for r in regions} + self.parents = {r: list(H.neighbors(r)) for r in regions} + self.descendants = {r: list(G1.neighbors(r)) for r in regions} + self.ancestors = {r: list(H1.neighbors(r)) for r in regions} + self.forebears = {r: set([r] + self.ancestors[r]) for r in regions} + self.downp = {r: set([r] + self.descendants[r]) for r in regions} + + if self.minimal: + min_edges = [] + for r in regions: + ds = DisjointSet() + for u in self.parents[r]: + ds.find(u) + for u, v in itertools.combinations(self.parents[r], 2): + uv = set(self.ancestors[u]) & set(self.ancestors[v]) + if len(uv) > 0: + ds.union(u, v) + canonical = set() + for u in self.parents[r]: + canonical.update({ds.find(u)}) + # if len(canonical) > 1:# or r in self.cliques: + min_edges.extend([(u, r) for u in canonical]) + # G = nx.DiGraph(min_edges) + # regions = list(G.nodes) + G = nx.DiGraph() + G.add_nodes_from(regions) + G.add_edges_from(min_edges) + + H = G.reverse() + G1, H1 = nx.transitive_closure(G), nx.transitive_closure(H) + + self.children = {r: list(G.neighbors(r)) for r in regions} + self.parents = {r: list(H.neighbors(r)) for r in regions} + # self.descendants = { r : list(G1.neighbors(r)) for r in regions } + # self.ancestors = { r : list(H1.neighbors(r)) for r in regions } + # self.forebears = { r : set([r] + self.ancestors[r]) for r in regions } + # self.downp = { r : set([r] + self.descendants[r]) for r in regions } + + self.G = G + self.regions = regions + + if self.convex: + self.counting_numbers = {r: 1.0 for r in regions} + else: + moebius = {} + + def get_counting_number(r): + if r not in moebius: + moebius[r] = 1 - sum( + get_counting_number(s) for s in self.ancestors[r] + ) + return moebius[r] + + for r in regions: + get_counting_number(r) + self.counting_numbers = moebius + + if self.minimal: + # https://people.eecs.berkeley.edu/~ananth/2002+/Payam/submittedkikuchi.pdf + # Eq. 30 and 31 + N, D, B = {}, {}, {} + for r in regions: + B[r] = set() + for p in self.parents[r]: + B[r].add((p, r)) + for d in self.descendants[r]: + for p in set(self.parents[d]) - {r} - set(self.descendants[r]): + B[r].add((p, d)) + + for p in self.regions: + for r in self.children[p]: + N[p, r], D[p, r] = set(), set() + for s in self.parents[p]: + N[p, r].add((s, p)) + for d in self.descendants[p]: + for s in ( + set(self.parents[d]) - {p} - set(self.descendants[p]) + ): + N[p, r].add((s, d)) + for s in set(self.parents[r]) - {p}: + D[p, r].add((s, r)) + for d in self.descendants[r]: + for p1 in ( + set(self.parents[d]) - {r} - set(self.descendants[r]) + ): + D[p, r].add((p1, d)) + cancel = N[p, r] & D[p, r] + N[p, r] = N[p, r] - cancel + D[p, r] = D[p, r] - cancel + + self.N, self.D, self.B = N, D, B + + else: + # From Yedida et al. for fully saturated region graphs + # for sending messages ru --> rd and computing beliefs B_r + N, D, B = {}, {}, {} + for r in regions: + B[r] = [(ru, r) for ru in self.parents[r]] + for rd in self.descendants[r]: + for ru in set(self.parents[rd]) - self.downp[r]: + B[r].append((ru, rd)) + + for ru in regions: + for rd in self.children[ru]: + fu, fd = self.downp[ru], self.downp[rd] + + def cond(r): + return not r[0] in fu and r[1] in (fu - fd) + + N[ru, rd] = [e for e in G.edges if cond(e)] + + def cond(r): + return r[0] in (fu - fd) and r[1] in fd and r != (ru, rd) + + D[ru, rd] = [e for e in G.edges if cond(e)] + + self.N, self.D, self.B = N, D, B + + self.messages = {} + self.message_order = [] + for ru in sorted( + regions, key=len + ): # nx.topological_sort(H): # should be G or H? + for rd in self.children[ru]: + self.message_order.append((ru, rd)) + self.messages[ru, rd] = Factor.zeros(self.domain.project(rd)) + self.messages[rd, ru] = Factor.zeros( + self.domain.project(rd) + ) # only for hazan et al + + def generalized_belief_propagation(self, potentials, callback=None): + # https://users.cs.duke.edu/~brd/Teaching/Bio/asmb/current/4paper/4-2.pdf + pot = {} + for r in self.regions: + if r in self.cliques: + pot[r] = potentials[r] + else: + pot[r] = Factor.zeros(self.domain.project(r)) + + for _ in range(self.iters): + new = {} + for ru, rd in self.message_order: + # Yedida et al. strongly recommend using updated messages for LHS (denom in our case) + # num = sum(pot[c] for c in self.downp[ru] if c != rd) + num = pot[ru] + num = num + sum(self.messages[r1, r2] for r1, r2 in self.N[ru, rd]) + denom = sum(new[r1, r2] for r1, r2 in self.D[ru, rd]) + diff = tuple(set(ru) - set(rd)) + new[ru, rd] = num.logsumexp(diff) - denom + new[ru, rd] -= new[ru, rd].logsumexp() + + # self.messages = new + for ru, rd in self.message_order: + self.messages[ru, rd] = 0.5 * self.messages[ru, rd] + 0.5 * new[ru, rd] + # ru, rd = self.message_order[0] + + marginals = {} + for r in self.cliques: + # belief = sum(potentials[c] for c in self.downp[r]) + sum(self.messages[r1,r2] for r1,r2 in self.B[r]) + belief = potentials[r] + sum(self.messages[r1, r2] for r1, r2 in self.B[r]) + belief += np.log(self.total) - belief.logsumexp() + marginals[r] = belief.exp() + + return CliqueVector(marginals) + + def hazan_peng_shashua(self, potentials, callback=None): + # https://arxiv.org/pdf/1210.4881.pdf + c0 = self.counting_numbers + pot = {} + for r in self.regions: + if r in self.cliques: + pot[r] = potentials[r] + else: + pot[r] = Factor.zeros(self.domain.project(r)) + + messages = self.messages + # for p in sorted(self.regions, key=len): #nx.topological_sort(H): # should be G or H? + # for r in self.children[p]: + # messages[p,r] = Factor.zeros(self.domain.project(r)) + # messages[r,p] = Factor.zeros(self.domain.project(r)) + + cc = {} + for r in self.regions: + for p in self.parents[r]: + cc[p, r] = c0[p] / (c0[r] + sum(c0[p1] for p1 in self.parents[r])) + + for _ in range(self.iters): + new = {} + for r in self.regions: + for p in self.parents[r]: + new[p, r] = ( + pot[p] + + sum(messages[c, p] for c in self.children[p] if c != r) + - sum(messages[p, p1] for p1 in self.parents[p]) + ) / c0[p] + new[p, r] = c0[p] * new[p, r].logsumexp(tuple(set(p) - set(r))) + new[p, r] -= new[p, r].logsumexp() + + for r in self.regions: + for p in self.parents[r]: + new[r, p] = ( + cc[p, r] + * ( + pot[r] + + sum(messages[c, r] for c in self.children[r]) + + sum(messages[p1, r] for p1 in self.parents[r]) + ) + - messages[p, r] + ) + # new[r,p] = cc[p,r]*(pot[r] + sum(messages[c,r] for c in self.children[r]) + sum(new[p1,r] for p1 in self.parents[r])) - new[p,r] + new[r, p] -= new[r, p].logsumexp() + + # messages = new + # Damping is not described in paper, but is needed to get convergence for dense graphs + rho = self.damping + for p in self.regions: + for r in self.children[p]: + messages[p, r] = rho * messages[p, r] + (1.0 - rho) * new[p, r] + messages[r, p] = rho * messages[r, p] + (1.0 - rho) * new[r, p] + mu = {} + for r in self.regions: + belief = ( + pot[r] + + sum(messages[c, r] for c in self.children[r]) + - sum(messages[r, p] for p in self.parents[r]) + ) / c0[r] + belief += np.log(self.total) - belief.logsumexp() + mu[r] = belief.exp() + + if callback is not None: + callback(mu) + + if self.is_converged(mu): + self.messages = messages + return CliqueVector(mu) + + self.messages = messages + return CliqueVector(mu) + + def wiegerinck(self, potentials, callback=None): + c = self.counting_numbers + m = {} + for delta in self.regions: + m[delta] = 0 + for alpha in self.ancestors[delta]: + m[delta] += c[alpha] + + Q = {} + for r in self.regions: + if r in self.cliques: + Q[r] = potentials[r] / c[r] + else: + Q[r] = Factor.zeros(self.domain.project(r)) + + inner = [r for r in self.regions if len(self.parents[r]) > 0] + + def diff(r, s): + return tuple(set(r) - set(s)) + + for _ in range(self.iters): + for r in inner: + A = c[r] / (m[r] + c[r]) + B = m[r] / (m[r] + c[r]) + Qbar = ( + sum(c[s] * Q[s].logsumexp(diff(s, r)) for s in self.ancestors[r]) + / m[r] + ) + Q[r] = Q[r] * A + Qbar * B + Q[r] -= Q[r].logsumexp() + for s in self.ancestors[r]: + Q[s] = Q[s] + Q[r] - Q[s].logsumexp(diff(s, r)) + Q[s] -= Q[s].logsumexp() + + marginals = {} + for r in self.regions: + marginals[r] = (Q[r] + np.log(self.total) - Q[r].logsumexp()).exp() + if callback is not None: + callback(marginals) + + return CliqueVector(marginals) + + def loh_wibisono(self, potentials, callback=None): + # https://papers.nips.cc/paper/2014/file/39027dfad5138c9ca0c474d71db915c3-Paper.pdf + pot = {} + for r in self.regions: + if r in self.cliques: + pot[r] = potentials[r] + else: + pot[r] = Factor.zeros(self.domain.project(r)) + + rho = self.counting_numbers + + for _ in range(self.iters): + new = {} + for s, r in self.message_order: + diff = tuple(set(s) - set(r)) + num = pot[s] / rho[s] + for v in self.parents[s]: + num += self.messages[v, s] * rho[v] / rho[s] + for w in self.children[s]: + if w != r: + num -= self.messages[s, w] + num = num.logsumexp(diff) + denom = pot[r] / rho[r] + for u in self.parents[r]: + if u != s: + denom += self.messages[u, r] * rho[u] / rho[r] + for t in self.children[r]: + denom -= self.messages[r, t] + + new[s, r] = rho[r] / (rho[r] + rho[s]) * (num - denom) + new[s, r] -= new[s, r].logsumexp() + + for ru, rd in self.message_order: + self.messages[ru, rd] = 0.5 * self.messages[ru, rd] + 0.5 * new[ru, rd] + + # ru, rd = self.message_order[0] + + marginals = {} + for r in self.regions: + belief = pot[r] / rho[r] + for s in self.parents[r]: + belief += self.messages[s, r] * rho[s] / rho[r] + for t in self.children[r]: + belief -= self.messages[r, t] + belief += np.log(self.total) - belief.logsumexp() + marginals[r] = belief.exp() + if callback is not None: + callback(marginals) + + return CliqueVector(marginals) + + def kikuchi_entropy(self, marginals): + """ + Return the Bethe Entropy and the gradient with respect to the marginals + + """ + weights = self.counting_numbers + entropy = 0 + dmarginals = {} + for cl in self.regions: + mu = marginals[cl] / self.total + entropy += weights[cl] * (mu * mu.log()).sum() + dmarginals[cl] = weights[cl] * (1 + mu.log()) / self.total + return -entropy, -1 * CliqueVector(dmarginals) + + def mle(self, mu): + return -1 * self.kikuchi_entropy(mu)[1] + + +def estimate_kikuchi_marginal(domain, total, marginals): + marginals = dict(marginals) + regions = set(marginals.keys()) + size = 0 + while len(regions) > size: + size = len(regions) + for r1, r2 in itertools.combinations(regions, 2): + z = tuple(sorted(set(r1) & set(r2))) + if len(z) > 0 and z not in regions: + marginals[z] = marginals[r1].project(z) + regions.update({z}) + + G = nx.DiGraph() + G.add_nodes_from(regions) + for r1 in regions: + for r2 in regions: + if set(r2) < set(r1) and not any( + set(r2) < set(r3) and set(r3) < set(r1) for r3 in regions + ): + G.add_edge(r1, r2) + + H1 = nx.transitive_closure(G.reverse()) + ancestors = {r: list(H1.neighbors(r)) for r in regions} + moebius = {} + + def get_counting_number(r): + if r not in moebius: + moebius[r] = 1 - sum(get_counting_number(s) for s in ancestors[r]) + return moebius[r] + + logP = Factor.zeros(domain) + for r in regions: + kr = get_counting_number(r) + logP += kr * marginals[r].log() + logP += np.log(total) - logP.logsumexp() + return logP.exp() diff --git a/src/synthcity/plugins/core/models/mbi/torch_factor.py b/src/synthcity/plugins/core/models/mbi/torch_factor.py new file mode 100644 index 00000000..06c9c905 --- /dev/null +++ b/src/synthcity/plugins/core/models/mbi/torch_factor.py @@ -0,0 +1,206 @@ +# third party +import numpy as np +import torch + + +class Factor: + device = "cuda" if torch.cuda.is_available() else "cpu" + + def __init__(self, domain, values): + """Initialize a factor over the given domain + + :param domain: the domain of the factor + :param values: the ndarray or tensor of factor values (for each element of the domain) + + Note: values may be a flattened 1d array or a ndarray with same shape as domain + """ + if type(values) == np.ndarray: + values = torch.tensor(values, dtype=torch.float32, device=Factor.device) + if domain.size() != values.nelement(): + raise ValueError("domain size does not match values size") + if len(values.shape) != 1 and values.shape != domain.shape: + raise ValueError("invalid shape for values array") + self.domain = domain + self.values = values.reshape(domain.shape).to(Factor.device) + + @staticmethod + def zeros(domain): + return Factor(domain, torch.zeros(domain.shape, device=Factor.device)) + + @staticmethod + def ones(domain): + return Factor(domain, torch.ones(domain.shape, device=Factor.device)) + + @staticmethod + def random(domain): + return Factor(domain, torch.rand(domain.shape, device=Factor.device)) + + @staticmethod + def uniform(domain): + return Factor.ones(domain) / domain.size() + + @staticmethod + def active(domain, structural_zeros): + """create a factor that is 0 everywhere except in positions present in + 'structural_zeros', where it is -infinity + + :param: domain: the domain of this factor + :param: structural_zeros: a list of values that are not possible + """ + idx = tuple(np.array(structural_zeros).T) + vals = torch.zeros(domain.shape, device=Factor.device) + vals[idx] = -np.inf + return Factor(domain, vals) + + def expand(self, domain): + if not domain.contains(self.domain): + raise AssertionError("expanded domain must contain current domain") + dims = len(domain) - len(self.domain) + values = self.values.view(self.values.size() + tuple([1] * dims)) + ax = domain.axes(self.domain.attrs) + # need to find replacement for moveaxis + ax = ax + tuple(i for i in range(len(domain)) if i not in ax) + ax = tuple(np.argsort(ax)) + values = values.permute(ax) + values = values.expand(domain.shape) + return Factor(domain, values) + + def transpose(self, attrs): + if set(attrs) != set(self.domain.attrs): + raise AssertionError("attrs must be same as domain attributes") + newdom = self.domain.project(attrs) + ax = newdom.axes(self.domain.attrs) + ax = tuple(np.argsort(ax)) + values = self.values.permute(ax) + return Factor(newdom, values) + + def project(self, attrs, agg="sum"): + """ + project the factor onto a list of attributes (in order) + using either sum or logsumexp to aggregate along other attributes + """ + if agg not in ["sum", "logsumexp"]: + raise ValueError("agg must be sum or logsumexp") + marginalized = self.domain.marginalize(attrs) + if agg == "sum": + ans = self.sum(marginalized.attrs) + elif agg == "logsumexp": + ans = self.logsumexp(marginalized.attrs) + return ans.transpose(attrs) + + def sum(self, attrs=None): + if attrs is None: + return float(self.values.sum()) + elif attrs == tuple(): + return self + axes = self.domain.axes(attrs) + values = self.values.sum(dim=axes) + newdom = self.domain.marginalize(attrs) + return Factor(newdom, values) + + def logsumexp(self, attrs=None): + if attrs is None: + return float( + self.values.logsumexp(dim=tuple(range(len(self.values.shape)))) + ) + elif attrs == tuple(): + return self + axes = self.domain.axes(attrs) + values = self.values.logsumexp(dim=axes) + newdom = self.domain.marginalize(attrs) + return Factor(newdom, values) + + def logaddexp(self, other): + return NotImplementedError + + def max(self, attrs=None): + if attrs is None: + return float(self.values.max()) + return NotImplementedError # torch.max does not behave like numpy + + def condition(self, evidence): + """evidence is a dictionary where + keys are attributes, and + values are elements of the domain for that attribute""" + slices = [evidence[a] if a in evidence else slice(None) for a in self.domain] + newdom = self.domain.marginalize(evidence.keys()) + values = self.values[tuple(slices)] + return Factor(newdom, values) + + def copy(self, out=None): + if out is None: + return Factor(self.domain, self.values.clone()) + np.copyto(out.values, self.values) + return out + + def __mul__(self, other): + if np.isscalar(other): + return Factor(self.domain, other * self.values) + newdom = self.domain.merge(other.domain) + factor1 = self.expand(newdom) + factor2 = other.expand(newdom) + return Factor(newdom, factor1.values * factor2.values) + + def __add__(self, other): + if np.isscalar(other): + return Factor(self.domain, other + self.values) + newdom = self.domain.merge(other.domain) + factor1 = self.expand(newdom) + factor2 = other.expand(newdom) + return Factor(newdom, factor1.values + factor2.values) + + def __iadd__(self, other): + if np.isscalar(other): + self.values += other + return self + factor2 = other.expand(self.domain) + self.values += factor2.values + return self + + def __imul__(self, other): + if np.isscalar(other): + self.values *= other + return self + factor2 = other.expand(self.domain) + self.values *= factor2.values + return self + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __sub__(self, other): + if np.isscalar(other): + return Factor(self.domain, self.values - other) + zero = torch.tensor(0.0, device=Factor.device) + inf = torch.tensor(np.inf, device=Factor.device) + values = torch.where(other.values == -inf, zero, -other.values) + other = Factor(other.domain, values) + return self + other + + def __truediv__(self, other): + if np.isscalar(other): + return self * (1.0 / other) + tmp = other.expand(self.domain) + vals = torch.div(self.values, tmp.values) + vals[tmp.values <= 0] = 0.0 + return Factor(self.domain, vals) + + def exp(self, out=None): + if out is None: + return Factor(self.domain, self.values.exp()) + torch.exp(self.values, out=out.values) + return out + + def log(self, out=None): + if out is None: + return Factor(self.domain, torch.log(self.values + 1e-100)) + torch.log(self.values, out=out.values) + return out + + def datavector(self, flatten=True): + """Materialize the data vector as a numpy array""" + ans = self.values.to("cpu").numpy() + return ans.flatten() if flatten else ans diff --git a/src/synthcity/plugins/core/models/tabular_aim.py b/src/synthcity/plugins/core/models/tabular_aim.py new file mode 100644 index 00000000..1b2c68cd --- /dev/null +++ b/src/synthcity/plugins/core/models/tabular_aim.py @@ -0,0 +1,133 @@ +# stdlib +import itertools +from typing import Any, Optional, Union + +# third party +import numpy as np +import pandas as pd +import torch +from pydantic import validate_arguments + +# synthcity absolute +from synthcity.utils.constants import DEVICE + +# synthcity relative +from .aim import AIM +from .mbi.dataset import Dataset +from .mbi.domain import Domain + + +class TabularAIM: + """ + .. inheritance-diagram:: synthcity.plugins.core.models.tabular_aim.TabularAIM + :parts: 1 + + + Adaptive and Iterative Mechanism (AIM) implementation, based on: + - code: https://github.com/ryan112358/private-pgm/blob/master/mechanisms/aim.py + - paper: https://www.vldb.org/pvldb/vol15/p2599-mckenna.pdf. + + + Args: + X (pd.DataFrame): Reference dataset, used for training the tabular encoder + # AIM parameters + + # core plugin arguments + encoder_max_clusters (int = 20): The max number of clusters to create for continuous columns when encoding with TabularEncoder. Defaults to 20. + encoder_whitelist (list = []): Ignore columns from encoding with TabularEncoder. Defaults to []. + device: Union[str, torch.device] = DEVICE, # This is not used for this model, as it is built with sklearn, which is cpu only + random_state (int, optional): _description_. Defaults to 0. # This is not used for this model + **kwargs (Any): The keyword arguments are passed to a SKLearn RandomForestClassifier - https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html. + """ + + def __init__( + self, + X: pd.DataFrame, + # AIM parameters + epsilon: float = 1.0, + delta: float = 1e-9, + max_model_size: int = 80, + degree: int = 2, + num_marginals: Optional[int] = None, + max_cells: int = 1000, + # core plugin arguments + encoder_max_clusters: int = 20, + encoder_whitelist: list = [], + device: Union[str, torch.device] = DEVICE, + learning_rate: float = 5e-3, + weight_decay: float = 1e-3, + logging_epoch: int = 100, + random_state: int = 0, + **kwargs: Any, + ): + super(TabularAIM, self).__init__() + self.columns = X.columns + self.epsilon = epsilon + self.delta = delta + self.max_model_size = max_model_size + self.degree = degree + self.num_marginals = num_marginals + self.max_cells = max_cells + self.prng = np.random + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def fit( + self, + X: pd.DataFrame, + **kwargs: Any, + ) -> Any: + """ + + Args: + data: Pandas DataFrame that contains the tabular data + + Returns: + AIMTrainer used for the fine-tuning process + """ + domain_shapes = X.nunique().to_dict() + mbi_domain = Domain(self.columns, domain_shapes.values()) + self.dataset = Dataset(X, mbi_domain) + + workload = list(itertools.combinations(self.dataset.domain, self.degree)) + if len(workload) == 0: + raise ValueError("No workload found. Is the dataset empty?") + workload = [ + cl for cl in workload if self.dataset.domain.size(cl) <= self.max_cells + ] + if len(workload) == 0: + raise ValueError( + "Domain sizes for the cells are too large. Increase max_cells values or further discretize the data." + ) + if self.num_marginals is not None: + workload = [ + workload[i] + for i in self.prng.choice( + len(workload), self.num_marginals, replace=False + ) + ] + + self.workload = [(cl, 1.0) for cl in workload] + self.model = AIM(self.epsilon, self.delta, max_model_size=self.max_model_size) + return self + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def generate( + self, + count: int, + start_col: Optional[str] = "", + start_col_dist: Optional[Union[dict, list]] = None, + temperature: float = 0.7, + k: int = 100, + max_length: int = 100, + ) -> pd.DataFrame: + """ + Generates tabular data using the trained AIM model. + + Args: + count (int): The number of samples to generate + + Returns: + pd.DataFrame: n_samples rows of generated data + """ + synth_dataset = self.model.run(self.dataset, self.workload) + return synth_dataset.df diff --git a/src/synthcity/plugins/privacy/plugin_aim.py b/src/synthcity/plugins/privacy/plugin_aim.py new file mode 100644 index 00000000..c32d6bf9 --- /dev/null +++ b/src/synthcity/plugins/privacy/plugin_aim.py @@ -0,0 +1,166 @@ +""" +Reference: "Adversarial random forests for density estimation and generative modeling" Authors: David S. Watson, Kristin Blesch, Jan Kapar, and Marvin N. Wright +""" + +# stdlib +from pathlib import Path +from typing import Any, List, Optional, Union + +# third party +import pandas as pd +import torch + +# Necessary packages +from pydantic import validate_arguments + +# synthcity absolute +from synthcity.plugins.core.dataloader import DataLoader +from synthcity.plugins.core.distribution import ( + Distribution, + FloatDistribution, + IntegerDistribution, + LogDistribution, +) +from synthcity.plugins.core.models.tabular_aim import TabularAIM +from synthcity.plugins.core.plugin import Plugin +from synthcity.plugins.core.schema import Schema +from synthcity.utils.constants import DEVICE + + +class AIMPlugin(Plugin): + """ + .. inheritance-diagram:: synthcity.plugins.privacy.plugin_adsgan.AdsGANPlugin + :parts: 1 + + Args: + + + Example: + >>> from sklearn.datasets import load_iris + >>> from synthcity.plugins import Plugins + >>> + >>> X, y = load_iris(as_frame = True, return_X_y = True) + >>> X["target"] = y + >>> + >>> plugin = Plugins().get("aim") + >>> plugin.fit(X) + >>> + >>> plugin.generate(50) + + """ + + @validate_arguments(config=dict(arbitrary_types_allowed=True)) + def __init__( + self, + # AIM plugin arguments + epsilon: float = 1.0, + delta: float = 1e-9, + max_model_size: int = 80, + degree: int = 2, + num_marginals: Optional[int] = None, + max_cells: int = 1000, + # core plugin arguments + device: Union[str, torch.device] = DEVICE, + random_state: int = 0, + sampling_patience: int = 500, + workspace: Path = Path("workspace"), + compress_dataset: bool = False, + **kwargs: Any, + ) -> None: + """ + .. inheritance-diagram:: synthcity.plugins.generic.plugin_aim.AIMPlugin + :parts: 1 + + Adversarial Random Forest implementation. + + Args: + degree: int = 2 + Degree of marginals to use. Defaults to 2. + device: Union[str, torch.device] = synthcity.utils.constants.DEVICE + The device that the model is run on. Defaults to "cuda" if cuda is available else "cpu". + random_state: int = 0 + random_state used. Defaults to 0. + sampling_patience: int = 500 + Max inference iterations to wait for the generated data to match the training schema. Defaults to 500. + workspace: Path + Path for caching intermediary results. Defaults to Path("workspace"). + compress_dataset: bool. Default = False + Drop redundant features before training the generator. Defaults to False. + dataloader_sampler: Any = None + Optional sampler for the dataloader. Defaults to None. + """ + super().__init__( + device=device, + random_state=random_state, + sampling_patience=sampling_patience, + workspace=workspace, + compress_dataset=compress_dataset, + **kwargs, + ) + self.epsilon = epsilon + self.delta = delta + self.max_model_size = max_model_size + self.degree = degree + self.num_marginals = num_marginals + self.max_cells = max_cells + + @staticmethod + def name() -> str: + return "aim" + + @staticmethod + def type() -> str: + return "generic" + + @staticmethod + def hyperparameter_space(**kwargs: Any) -> List[Distribution]: + return [ + FloatDistribution(name="epsilon", low=0.5, high=3.0, step=0.5), + LogDistribution(name="delta", low=1e-10, high=1e-5), + IntegerDistribution(name="max_model_size", low=50, high=200, step=50), + IntegerDistribution(name="degree", low=2, high=5, step=1), + IntegerDistribution(name="num_marginals", low=0, high=5, step=1), + IntegerDistribution(name="max_cells", low=5000, high=25000, step=5000), + ] + + def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "AIMPlugin": + """ + Internal function to fit the model to the data. + + Args: + X (DataLoader): The data to fit the model to. + + Raises: + NotImplementedError: _description_ + + Returns: + AIMPlugin: _description_ + """ + + self.model = TabularAIM( + X.dataframe(), + self.epsilon, + self.delta, + self.max_model_size, + self.degree, + self.num_marginals, + self.max_cells, + **kwargs, + ) + if "cond" in kwargs and kwargs["cond"] is not None: + raise NotImplementedError( + "conditional generation is not currently available for the Adaptive and Iterative Mechanism for Differentially Private Synthetic Data(AIM) plugin." + ) + self.model.fit(X.dataframe(), **kwargs) + return self + + def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFrame: + if "cond" in kwargs and kwargs["cond"] is not None: + raise NotImplementedError( + "conditional generation is not currently available for the Adaptive and Iterative Mechanism for Differentially Private Synthetic Data (AIM) plugin." + ) + + return self._safe_generate(self.model.generate, count, syn_schema) + + +plugin = AIMPlugin diff --git a/src/synthcity/utils/datasets/categorical/__init__.py b/src/synthcity/utils/datasets/categorical/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/synthcity/utils/datasets/categorical/categorical_adult.py b/src/synthcity/utils/datasets/categorical/categorical_adult.py new file mode 100644 index 00000000..4a53f6ad --- /dev/null +++ b/src/synthcity/utils/datasets/categorical/categorical_adult.py @@ -0,0 +1,33 @@ +# stdlib +import io +from pathlib import Path +from typing import Union + +# third party +import numpy as np +import pandas as pd +import requests + +URL = "https://raw.githubusercontent.com/ryan112358/private-pgm/master/data/adult.csv" +df_path = Path(__file__).parent / "data/adult.csv" + + +class CategoricalAdultDataloader: + def __init__(self, as_numpy: bool = False) -> None: + self.as_numpy = as_numpy + + def load( + self, + ) -> Union[pd.DataFrame, np.ndarray]: + # Load Google Data + if not df_path.exists(): + s = requests.get(URL, timeout=5).content + df = pd.read_csv(io.StringIO(s.decode("utf-8"))) + + df.to_csv(df_path, index=None) + else: + df = pd.read_csv(df_path) + + if self.as_numpy: + return df.to_numpy() + return df diff --git a/src/synthcity/utils/datasets/categorical/data/__init__ .py b/src/synthcity/utils/datasets/categorical/data/__init__ .py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nb_eval.py b/tests/nb_eval.py index 6561e3b8..f2f4bb75 100644 --- a/tests/nb_eval.py +++ b/tests/nb_eval.py @@ -72,6 +72,7 @@ def run_notebook(notebook_path: Path) -> None: "plugin_ctgan(generic)", "plugin_fourier_flows", "plugin_timegan", + "plugin_aim", ] if not goggle_disabled: diff --git a/tests/plugins/privacy/test_aim.py b/tests/plugins/privacy/test_aim.py new file mode 100644 index 00000000..e1691319 --- /dev/null +++ b/tests/plugins/privacy/test_aim.py @@ -0,0 +1,177 @@ +# stdlib +import random +from datetime import datetime, timedelta + +# third party +import numpy as np +import pandas as pd +import pytest +from fhelpers import generate_fixtures +from sklearn.datasets import load_iris + +# synthcity absolute +from synthcity.metrics.eval import PerformanceEvaluatorXGB +from synthcity.plugins import Plugin +from synthcity.plugins.core.constraints import Constraints +from synthcity.plugins.core.dataloader import GenericDataLoader +from synthcity.plugins.privacy.plugin_aim import plugin +from synthcity.utils.datasets.categorical.categorical_adult import ( + CategoricalAdultDataloader, +) +from synthcity.utils.serialization import load, save + +plugin_name = "aim" +plugin_args = { + "epsilon": 1.0, + "delta": 1e-9, + "max_model_size": 80, + "degree": 2, + "num_marginals": None, + "max_cells": 1000, +} + + +@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +def test_plugin_sanity(test_plugin: Plugin) -> None: + assert test_plugin is not None + + +@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +def test_plugin_name(test_plugin: Plugin) -> None: + assert test_plugin.name() == plugin_name + + +@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +def test_plugin_type(test_plugin: Plugin) -> None: + assert test_plugin.type() == "generic" + + +@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin)) +def test_plugin_hyperparams(test_plugin: Plugin) -> None: + assert len(test_plugin.hyperparameter_space()) == 6 + + +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_fit(test_plugin: Plugin) -> None: + X = CategoricalAdultDataloader().load().head() + test_plugin.fit(GenericDataLoader(X)) + + +@pytest.mark.parametrize( + "test_plugin", + generate_fixtures(plugin_name, plugin, plugin_args), +) +@pytest.mark.parametrize("serialize", [True, False]) +def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: + X = CategoricalAdultDataloader().load().head() + test_plugin.fit(GenericDataLoader(X)) + + if serialize: + saved = save(test_plugin) + test_plugin = load(saved) + + X_gen = test_plugin.generate() + assert len(X_gen) == len(X) + assert X_gen.shape[1] == X.shape[1] + assert test_plugin.schema_includes(X_gen) + + X_gen = test_plugin.generate(50) + assert len(X_gen) == 50 + assert test_plugin.schema_includes(X_gen) + + # generate with random seed + X_gen1 = test_plugin.generate(50, random_state=0) + X_gen2 = test_plugin.generate(50, random_state=0) + X_gen3 = test_plugin.generate(50) + assert (X_gen1.numpy() == X_gen2.numpy()).all() + assert (X_gen1.numpy() != X_gen3.numpy()).any() + + +@pytest.mark.parametrize( + "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args) +) +def test_plugin_generate_constraints_aim(test_plugin: Plugin) -> None: + X = CategoricalAdultDataloader().load().head() + test_plugin.fit(GenericDataLoader(X, target_column="income>50K")) + + constraints = Constraints( + rules=[ + ("income>50K", "eq", 1), + ] + ) + + X_gen = test_plugin.generate(constraints=constraints).dataframe() + assert len(X_gen) == len(X) + assert test_plugin.schema_includes(X_gen) + assert constraints.filter(X_gen).sum() == len(X_gen) + assert (X_gen["income>50K"] == 1).all() + + X_gen = test_plugin.generate(count=50, constraints=constraints).dataframe() + assert len(X_gen) == 50 + assert test_plugin.schema_includes(X_gen) + assert constraints.filter(X_gen).sum() == len(X_gen) + assert list(X_gen.columns) == list(X.columns) + + +def test_sample_hyperparams() -> None: + assert plugin is not None + for i in range(100): + args = plugin.sample_hyperparameters() + assert plugin(**args) is not None + + +@pytest.mark.slow +@pytest.mark.parametrize("compress_dataset", [True, False]) +def test_eval_performance_aim(compress_dataset: bool) -> None: + assert plugin is not None + results = [] + + X_raw, y = load_iris(as_frame=True, return_X_y=True) + X_raw["target"] = y + # Descretize the data + num_bins = 3 + for col in X_raw.columns: + X_raw[col] = pd.cut(X_raw[col], bins=num_bins, labels=list(range(num_bins))) + + X = GenericDataLoader(X_raw, target_column="target") + + for retry in range(2): + test_plugin = plugin(**plugin_args) + evaluator = PerformanceEvaluatorXGB() + + test_plugin.fit(X) + X_syn = test_plugin.generate(count=1000) + + results.append(evaluator.evaluate(X, X_syn)["syn_id"]) + print(results) + assert np.mean(results) > 0.7 + + +def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> datetime: + # generate a datetime in format yyyy-mm-dd hh:mm:ss.000000 + start = datetime(min_year, 1, 1, 00, 00, 00) + years = max_year - min_year + 1 + end = start + timedelta(days=365 * years) + return start + (end - start) * random.random() + + +def test_plugin_encoding() -> None: + assert plugin is not None + data = [[gen_datetime(), i % 2 == 0, i] for i in range(10)] + + df = pd.DataFrame(data, columns=["date", "bool", "int"]) + X = GenericDataLoader(df) + test_plugin = plugin(**plugin_args) + test_plugin.fit(X) + + syn = test_plugin.generate(10) + + assert len(syn) == 10 + assert test_plugin.schema_includes(syn) + + syn_df = syn.dataframe() + + assert syn_df["date"].infer_objects().dtype.kind == "M" + assert syn_df["bool"].infer_objects().dtype.kind == "b" diff --git a/tutorials/plugins/generic/plugin_great.ipynb b/tutorials/plugins/generic/plugin_great.ipynb index 89212557..355821fa 100644 --- a/tutorials/plugins/generic/plugin_great.ipynb +++ b/tutorials/plugins/generic/plugin_great.ipynb @@ -5,7 +5,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Adversarial Random Forest" + "# GReaT" ] }, { diff --git a/tutorials/plugins/privacy/plugin_aim.ipynb b/tutorials/plugins/privacy/plugin_aim.ipynb new file mode 100644 index 00000000..dc8a821d --- /dev/null +++ b/tutorials/plugins/privacy/plugin_aim.ipynb @@ -0,0 +1,181 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Adaptive and Iterative Mechanism for Differentially Private Synthetic Data\n", + "\n", + "This method is designed for categorical data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# stdlib\n", + "import warnings\n", + "import sys\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "\n", + "# synthcity absolute\n", + "from synthcity.plugins import Plugins\n", + "from synthcity.utils.datasets.categorical.categorical_adult import CategoricalAdultDataloader\n", + "import synthcity.logger as log\n", + "log.add(sink=sys.stderr, level=\"INFO\")\n", + "\n", + "eval_plugin = \"aim\"" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# synthcity absolute\n", + "from synthcity.plugins.core.dataloader import GenericDataLoader\n", + "\n", + "X = CategoricalAdultDataloader().load()\n", + "loader = GenericDataLoader(X, target_column=\"income>50K\", sensitive_columns=[\"sex\", \"race\"])\n", + "\n", + "loader.dataframe()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train the generator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# synthcity absolute\n", + "from synthcity.plugins import Plugins\n", + "\n", + "syn_model = Plugins().get(eval_plugin)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "syn_model.fit(loader)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate new samples" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "syn_model.generate(count=10).dataframe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "import matplotlib.pyplot as plt\n", + "\n", + "syn_model.plot(plt, loader, count=100)\n", + "\n", + "plt.show()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Benchmarks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# synthcity absolute\n", + "from synthcity.benchmark import Benchmarks\n", + "\n", + "score = Benchmarks.evaluate(\n", + " [\n", + " (eval_plugin, eval_plugin, {\"epsilon\": 1.0, \"delta\": 1e-7, \"max_model_size\": 80, \"degree\": 2, \"num_marginals\": None, \"max_cells\": 1000}),\n", + " ], # (testname, plugin, plugin_args) The plugin_args are given are simply to illustrate some of the paramters that can be passed to the plugin\n", + " loader,\n", + " repeats=2,\n", + " metrics={\n", + " \"detection\": [\"detection_mlp\"],\n", + " \"privacy\": [\"distinct l-diversity\", \"k-anonymization\", \"k-map\"],\n", + " },\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Benchmarks.print(score)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "synthcity-all", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}