diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index b278dc32..fda9ec64 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,52 +1,51 @@
-exclude: '^docs/conf.py'
+exclude: "^docs/conf.py"
repos:
-- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.0.1
- hooks:
- - id: trailing-whitespace
- - id: check-ast
- - id: check-json
- - id: check-merge-conflict
- - id: check-xml
- - id: check-yaml
- - id: debug-statements
- - id: end-of-file-fixer
- - id: requirements-txt-fixer
- - id: mixed-line-ending
- args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows
-- repo: https://github.com/psf/black
- rev: 22.3.0
- hooks:
- - id: black
- language_version: python3
-- repo: https://github.com/pycqa/isort
- rev: 5.12.0
- hooks:
- - id: isort
-- repo: https://github.com/PyCQA/flake8
- rev: 6.0.0
- hooks:
- - id: flake8
- args: [
- "--max-line-length=480",
- "--extend-ignore=E203,W503"
- ]
-- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v0.991
- hooks:
- - id: mypy
- args: [
- "--ignore-missing-imports",
- "--scripts-are-modules",
- "--disallow-incomplete-defs",
- "--no-implicit-optional",
- "--warn-unused-ignores",
- "--warn-redundant-casts",
- "--strict-equality",
- "--warn-unreachable",
- "--disallow-untyped-defs",
- "--disallow-untyped-calls",
- "--install-types",
- "--non-interactive",
- ]
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.0.1
+ hooks:
+ - id: trailing-whitespace
+ - id: check-ast
+ - id: check-json
+ - id: check-merge-conflict
+ - id: check-xml
+ - id: check-yaml
+ - id: debug-statements
+ - id: end-of-file-fixer
+ - id: requirements-txt-fixer
+ - id: mixed-line-ending
+ args: ["--fix=auto"] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows
+ - repo: https://github.com/psf/black
+ rev: 22.3.0
+ hooks:
+ - id: black
+ language_version: python3
+ - repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ - repo: https://github.com/PyCQA/flake8
+ rev: 6.0.0
+ hooks:
+ - id: flake8
+ args: ["--max-line-length=480", "--extend-ignore=E203,W503"]
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v0.991
+ hooks:
+ - id: mypy
+ args: [
+ "--ignore-missing-imports",
+ "--scripts-are-modules",
+ "--disallow-incomplete-defs",
+ "--no-implicit-optional",
+ "--warn-unused-ignores",
+ "--warn-redundant-casts",
+ "--strict-equality",
+ "--warn-unreachable",
+ "--disallow-untyped-defs",
+ "--disallow-untyped-calls",
+ "--install-types",
+ "--non-interactive",
+ "--follow-imports=skip", # This is temporary until the mbi directory is not excluded
+ ]
+ exclude: ^src/synthcity/plugins/core/models/mbi/ # This is temporary until the mbi directory is fully typed
diff --git a/README.md b/README.md
index b4fcae94..23912fc9 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
- A library for generating and evaluating synthetic tabular data..
+ A library for generating and evaluating synthetic tabular data.
@@ -20,7 +20,7 @@
[](https://pypi.org/project/synthcity/)
[](https://github.com/vanderschaarlab/synthcity/blob/main/LICENSE)
-[](https://www.python.org/downloads/release/python-370/)
+[](https://www.python.org/downloads/release/python-380/)
[](https://www.vanderschaar-lab.com/)
[](https://join.slack.com/t/vanderschaarlab/shared_invite/zt-1pzy8z7ti-zVsUPHAKTgCd1UoY8XtTEw)
diff --git a/docs/examples.rst b/docs/examples.rst
index 034281c2..e247b6fd 100644
--- a/docs/examples.rst
+++ b/docs/examples.rst
@@ -26,6 +26,9 @@ General-purpose generators
CTGAN
Normalizing Flows
TVAE
+ GOGGLE
+ ARF
+ GReaT
Time-series generators
@@ -49,6 +52,7 @@ Privacy-related generators
AdsGAN
PATEGAN
PrivBayes
+ AIM
Domain adaptation generators
------------------------------
diff --git a/docs/generators.rst b/docs/generators.rst
index d78a30fd..b18e35c4 100644
--- a/docs/generators.rst
+++ b/docs/generators.rst
@@ -12,6 +12,9 @@ General purpose
Normalizing Flows
RTVAE
TVAE
+ GOGGLE
+ ARF
+ GReaT
Privacy-focused
-----------------
@@ -24,6 +27,7 @@ Privacy-focused
PrivBayes
DP-GAN
DECAF
+ AIM
Domain adaptation
-------------------
diff --git a/src/synthcity/plugins/core/models/aim.py b/src/synthcity/plugins/core/models/aim.py
new file mode 100644
index 00000000..e5165ad9
--- /dev/null
+++ b/src/synthcity/plugins/core/models/aim.py
@@ -0,0 +1,367 @@
+# stdlib
+import itertools
+import math
+import platform
+from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast
+
+if platform.python_version() < "3.9":
+ # stdlib
+ from typing import Iterable
+else:
+ from collections.abc import Iterable
+
+# third party
+import numpy as np
+from scipy.special import softmax
+
+# synthcity absolute
+import synthcity.logger as log
+
+# synthcity relative
+from .mbi.dataset import Dataset
+from .mbi.domain import Domain
+from .mbi.graphical_model import GraphicalModel
+from .mbi.identity import Identity
+from .mbi.inference import FactoredInference
+
+
+def powerset(iterable: Iterable) -> Iterable:
+ "powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
+ s = list(iterable)
+ return itertools.chain.from_iterable(
+ itertools.combinations(s, r) for r in range(1, len(s) + 1)
+ )
+
+
+def downward_closure(Ws: List[Tuple]) -> List:
+ ans: Set = set()
+ for proj in Ws:
+ ans.update(powerset(proj))
+ return list(sorted(ans, key=len))
+
+
+def hypothetical_model_size(domain: Domain, cliques: List[Union[Tuple, List]]) -> float:
+ model = GraphicalModel(domain, cliques)
+ return model.size * 8 / 2**20
+
+
+def compile_workload(workload: List[Tuple]) -> Dict:
+ def score(cl: Tuple) -> int:
+ return sum(len(set(cl) & set(ax)) for ax in workload)
+
+ return {cl: score(cl) for cl in downward_closure(workload)}
+
+
+def filter_candidates(
+ candidates: Dict, model: GraphicalModel, size_limit: Union[float, int]
+) -> Dict:
+ ans = {}
+ free_cliques = downward_closure(model.cliques)
+ for cl in candidates:
+ cond1 = (
+ hypothetical_model_size(model.domain, model.cliques + [cl]) <= size_limit
+ )
+ cond2 = cl in free_cliques
+ if cond1 or cond2:
+ ans[cl] = candidates[cl]
+ return ans
+
+
+def cdp_delta(rho: Union[float, int], eps: Union[float, int]) -> Union[float, int]:
+ if rho < 0:
+ raise ValueError("rho must be positive")
+ if eps < 0:
+ raise ValueError("eps must be positive")
+ if rho == 0:
+ return 0 # degenerate case
+
+ # search for best alpha
+ # Note that any alpha in (1,infty) yields a valid upper bound on delta
+ # Thus if this search is slightly "incorrect" it will only result in larger delta (still valid)
+ # This code has two "hacks".
+ # First the binary search is run for a pre-specificed length.
+ # 1000 iterations should be sufficient to converge to a good solution.
+ # Second we set a minimum value of alpha to avoid numerical stability issues.
+ # Note that the optimal alpha is at least (1+eps/rho)/2. Thus we only hit this constraint
+ # when eps<=rho or close to it. This is not an interesting parameter regime, as you will
+ # inherently get large delta in this regime.
+ amin = 1.01 # don't let alpha be too small, due to numerical stability
+ amax = (eps + 1) / (2 * rho) + 2
+ for i in range(1000): # should be enough iterations
+ alpha = (amin + amax) / 2
+ derivative = (2 * alpha - 1) * rho - eps + math.log1p(-1.0 / alpha)
+ if derivative < 0:
+ amin = alpha
+ else:
+ amax = alpha
+ # now calculate delta
+ delta = math.exp(
+ (alpha - 1) * (alpha * rho - eps) + alpha * math.log1p(-1 / alpha)
+ ) / (alpha - 1.0)
+ return min(delta, 1.0) # delta<=1 always
+
+
+def cdp_rho(eps: float, delta: float) -> float:
+ if eps < 0:
+ raise ValueError("eps must be positive")
+ if delta <= 0:
+ raise ValueError("delta must be positive")
+
+ if delta >= 1:
+ return 0.0 # if delta>=1 anything goes
+ rho_min = 0.0 # maintain cdp_delta(rho,eps)<=delta
+ rho_max = eps + 1 # maintain cdp_delta(rho_max,eps)>delta
+ for i in range(1000):
+ rho = (rho_min + rho_max) / 2
+ if cdp_delta(rho, eps) <= delta:
+ rho_min = rho
+ else:
+ rho_max = rho
+ return rho_min
+
+
+class Mechanism:
+ def __init__(self, epsilon: float, delta: float):
+ """
+ Base class for a mechanism.
+ :param epsilon: privacy parameter
+ :param delta: privacy parameter
+ :param prng: pseudo random number generator
+ """
+ self.epsilon = epsilon
+ self.delta = delta
+ self.rho = 0 if delta == 0 else cdp_rho(epsilon, delta)
+ self.prng = np.random
+
+ def run(self, dataset: Dataset, workload: List[Tuple]) -> Any:
+ pass
+
+ # def generalized_exponential_mechanism(
+ # self, qualities, sensitivities, epsilon, t=None, base_measure=None
+ # ):
+ # def generalized_em_scores(q, ds, t):
+ # def pareto_efficient(costs: np.ndarray) -> int:
+ # eff = np.ones(costs.shape[0], dtype=bool)
+ # for i, c in enumerate(costs):
+ # if eff[i]:
+ # eff[eff] = np.any(
+ # costs[eff] <= c, axis=1
+ # ) # Keep any point with a lower cost
+ # return np.nonzero(eff)[0]
+
+ # q = -q
+ # idx = pareto_efficient(np.vstack([q, ds]).T)
+ # r = q + t * ds
+ # r = r[:, None] - r[idx][None, :]
+ # z = ds[:, None] + ds[idx][None, :]
+ # s = (r / z).max(axis=1)
+ # return -s
+
+ # if t is None:
+ # t = 2 * np.log(len(qualities) / 0.5) / epsilon
+ # if isinstance(qualities, dict):
+ # keys = list(qualities.keys())
+ # qualities = np.array([qualities[key] for key in keys])
+ # sensitivities = np.array([sensitivities[key] for key in keys])
+ # if base_measure is not None:
+ # base_measure = np.log([base_measure[key] for key in keys])
+ # else:
+ # keys = np.arange(qualities.size)
+ # scores = generalized_em_scores(qualities, sensitivities, t)
+ # key = self.exponential_mechanism(
+ # scores, epsilon, 1.0, base_measure=base_measure
+ # )
+ # return keys[key]
+
+ # def permute_and_flip(self, qualities, epsilon, sensitivity=1.0):
+ # """Sample a candidate from the permute-and-flip mechanism"""
+ # q = qualities - qualities.max()
+ # p = np.exp(0.5 * epsilon / sensitivity * q)
+ # for i in np.random.permutation(p.size):
+ # if np.random.rand() <= p[i]:
+ # return i
+
+ def exponential_mechanism(
+ self,
+ qualities: Union[Dict, np.ndarray, Any],
+ epsilon: float,
+ sensitivity: Union[float, int] = 1.0,
+ base_measure: Optional[Dict] = None,
+ ) -> np.ndarray:
+ if isinstance(qualities, dict):
+ keys = list(qualities.keys())
+ qualities = cast(np.ndarray, np.array([qualities[key] for key in keys]))
+ if base_measure is not None:
+ base_measure = np.log([base_measure[key] for key in keys])
+ else:
+ qualities = cast(np.ndarray, np.array(qualities))
+ keys = np.arange(qualities.size)
+
+ """ Sample a candidate from the permute-and-flip mechanism """
+ q = qualities - qualities.max()
+ if base_measure is None:
+ p = softmax(0.5 * epsilon / sensitivity * q)
+ else:
+ p = softmax(0.5 * epsilon / sensitivity * q + base_measure)
+
+ return keys[self.prng.choice(p.size, p=p)]
+
+ # def gaussian_noise_scale(self, l2_sensitivity, epsilon, delta):
+ # """Return the Gaussian noise necessary to attain (epsilon, delta)-DP"""
+ # if self.bounded:
+ # l2_sensitivity *= 2.0
+ # return (
+ # l2_sensitivity
+ # * privacy_calibrator.ana_gaussian_mech(epsilon, delta)["sigma"]
+ # )
+
+ # def laplace_noise_scale(self, l1_sensitivity, epsilon):
+ # """Return the Laplace noise necessary to attain epsilon-DP"""
+ # if self.bounded:
+ # l1_sensitivity *= 2.0
+ # return l1_sensitivity / epsilon
+
+ def gaussian_noise(self, sigma: float, size: Union[int, Tuple]) -> np.ndarray:
+ """Generate iid Gaussian noise of a given scale and size"""
+ return self.prng.normal(0, sigma, size)
+
+ # def laplace_noise(self, b, size):
+ # """Generate iid Laplace noise of a given scale and size"""
+ # return self.prng.laplace(0, b, size)
+
+ # def best_noise_distribution(self, l1_sensitivity, l2_sensitivity, epsilon, delta):
+ # """Adaptively determine if Laplace or Gaussian noise will be better, and
+ # return a function that samples from the appropriate distribution"""
+ # b = self.laplace_noise_scale(l1_sensitivity, epsilon)
+ # sigma = self.gaussian_noise_scale(l2_sensitivity, epsilon, delta)
+ # if np.sqrt(2) * b < sigma:
+ # return partial(self.laplace_noise, b)
+ # return partial(self.gaussian_noise, sigma)
+
+
+class AIM(Mechanism):
+ def __init__(
+ self,
+ epsilon: float,
+ delta: float,
+ rounds: Optional[Union[int, float]] = None,
+ max_model_size: int = 80,
+ structural_zeros: Dict = {},
+ ):
+ super(AIM, self).__init__(epsilon, delta)
+ self.rounds = rounds
+ self.max_model_size = max_model_size
+ self.structural_zeros = structural_zeros
+
+ def worst_approximated(
+ self,
+ candidates: Dict,
+ answers: Dict,
+ model: GraphicalModel,
+ eps: float,
+ sigma: float,
+ ) -> np.ndarray:
+ errors = {}
+ sensitivity = {}
+ for cl in candidates:
+ wgt = candidates[cl]
+ x = answers[cl]
+ bias = np.sqrt(2 / np.pi) * sigma * model.domain.size(cl)
+ xest = model.project(cl).datavector()
+ errors[cl] = wgt * (np.linalg.norm(x - xest, 1) - bias)
+ sensitivity[cl] = abs(wgt)
+ max_sensitivity = max(
+ sensitivity.values()
+ ) # if all weights are 0, could be a problem
+ return self.exponential_mechanism(errors, eps, max_sensitivity)
+
+ def run(self, data: Dataset, W: List) -> Dataset:
+ rounds = self.rounds or 16 * len(data.domain)
+ workload = [cl for cl, _ in W]
+ candidates = compile_workload(workload)
+ answers = {cl: data.project(cl).datavector() for cl in candidates}
+
+ oneway = [cl for cl in candidates if len(cl) == 1]
+
+ sigma = np.sqrt(rounds / (2 * 0.9 * self.rho))
+ epsilon = np.sqrt(8 * 0.1 * self.rho / rounds)
+
+ measurements = []
+ log.info("Initial Sigma", sigma)
+ rho_used = len(oneway) * 0.5 / sigma**2
+ for cl in oneway:
+ x = data.project(cl).datavector()
+ y = x + self.gaussian_noise(sigma, x.size)
+ identity_I = Identity(y.size)
+ measurements.append((identity_I, y, sigma, cl))
+
+ # backend = "torch" if torch.cuda.is_available() else "cpu" # TODO: fix torch backend option
+ zeros = self.structural_zeros
+ engine = FactoredInference(
+ data.domain,
+ # backend=backend,
+ iters=1000,
+ warm_start=True,
+ structural_zeros=zeros,
+ )
+ model = engine.estimate(measurements)
+
+ t = 0
+ terminate = False
+ while not terminate:
+ t += 1
+ if self.rho - rho_used < 2 * (0.5 / sigma**2 + 1.0 / 8 * epsilon**2):
+ # Just use up whatever remaining budget there is for one last round
+ remaining = self.rho - rho_used
+ sigma = np.sqrt(1 / (2 * 0.9 * remaining))
+ epsilon = np.sqrt(8 * 0.1 * remaining)
+ terminate = True
+
+ rho_used += 1.0 / 8 * epsilon**2 + 0.5 / sigma**2
+ size_limit = self.max_model_size * rho_used / self.rho
+ small_candidates = filter_candidates(candidates, model, size_limit)
+ cl = self.worst_approximated(
+ small_candidates, answers, model, epsilon, sigma
+ )
+
+ n = data.domain.size(cl)
+ Q = Identity(n)
+ x = data.project(cl).datavector()
+ y = x + self.gaussian_noise(sigma, n)
+ measurements.append((Q, y, sigma, cl))
+ z = model.project(cl).datavector()
+
+ model = engine.estimate(measurements)
+ w = model.project(cl).datavector()
+ log.info("Selected", cl, "Size", n, "Budget Used", rho_used / self.rho)
+ if np.linalg.norm(w - z, 1) <= sigma * np.sqrt(2 / np.pi) * n:
+ log.warning(f"Reducing sigma: {sigma/2}")
+ sigma /= 2
+ epsilon *= 2
+
+ log.info("Generating Data...")
+ engine.iters = 2500
+ model = engine.estimate(measurements)
+ synth = model.synthetic_data()
+
+ return synth
+
+
+def default_params() -> Dict[str, Any]:
+ """
+ Return default parameters to run this program
+
+ :returns: a dictionary of default parameter settings for each command line argument
+ """
+ params: Dict[str, Any] = {}
+ params["dataset"] = "../data/adult.csv" # TODO: Generalize
+ params["domain"] = "../data/adult-domain.json" # TODO: Generalize
+ params["epsilon"] = 1.0
+ params["delta"] = 1e-9
+ params["noise"] = "laplace"
+ params["max_model_size"] = 80
+ params["degree"] = 2
+ params["num_marginals"] = None
+ params["max_cells"] = 10000
+
+ return params
diff --git a/src/synthcity/plugins/core/models/mbi/__init__.py b/src/synthcity/plugins/core/models/mbi/__init__.py
new file mode 100644
index 00000000..606066c6
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/__init__.py
@@ -0,0 +1,12 @@
+"""
+This module contains the core functionality of the MBI package, from
+https://github.com/ryan112358/private-pgm/tree/master/src/mbi, which is
+licensed under Apache-2.0.
+
+It is further modified to include mbi.identity, which contains some functionality
+from Ektelo (https://github.com/ektelo/ektelo/blob/master/ektelo/matrix.py), which
+is also licensed under Apache-2.0.
+
+The code is also edited in order to be compatible with the synthcity codebase and its
+thorough type checking and code style.
+"""
diff --git a/src/synthcity/plugins/core/models/mbi/callbacks.py b/src/synthcity/plugins/core/models/mbi/callbacks.py
new file mode 100644
index 00000000..dd340c67
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/callbacks.py
@@ -0,0 +1,113 @@
+# stdlib
+import time
+
+# third party
+import numpy as np
+import pandas as pd
+
+# synthcity absolute
+import synthcity.logger as log
+
+
+class CallBack:
+ """A CallBack is a function called after every iteration of an iterative optimization procedure
+ It is useful for tracking loss and other metrics over time.
+ """
+
+ def __init__(self, engine, frequency=50):
+ """Initialize the callback objet
+
+ :param engine: the FactoredInference object that is performing the optimization
+ :param frequency: the number of iterations to perform before computing the callback function
+ """
+ self.engine = engine
+ self.frequency = frequency
+ self.calls = 0
+
+ def run(self, marginals):
+ pass
+
+ def __call__(self, marginals):
+ if self.calls == 0:
+ self.start = time.time()
+ if self.calls % self.frequency == 0:
+ self.run(marginals)
+ self.calls += 1
+
+
+class Logger(CallBack):
+ """Logger is the default callback function. It tracks the time, L1 loss, L2 loss, and
+ optionally the total variation distance to the true query answers (when available).
+ The last is for debugging purposes only - in practice the true answers can not be observed.
+ """
+
+ def __init__(self, engine, true_answers=None, frequency=50):
+ """Initialize the callback objet
+
+ :param engine: the FactoredInference object that is performing the optimization
+ :param true_answers: a dictionary containing true answers to the measurement queries.
+ :param frequency: the number of iterations to perform before computing the callback function
+ """
+ CallBack.__init__(self, engine, frequency)
+ self.true_answers = true_answers
+ self.idx = 0
+
+ def setup(self):
+ model = self.engine.model
+ total = sum(model.domain.size(cl) for cl in model.cliques)
+ log.debug("Total clique size:", total, flush=True)
+ # cl = max(model.cliques, key=lambda cl: model.domain.size(cl))
+ cols = ["iteration", "time", "l1_loss", "l2_loss", "feasibility"]
+ if self.true_answers is not None:
+ cols.append("variation")
+ self.results = pd.DataFrame(columns=cols)
+ log.debug("\t\t".join(cols), flush=True)
+
+ def variational_distances(self, marginals):
+ errors = []
+ for Q, y, proj in self.true_answers:
+ for cl in marginals:
+ if set(proj) <= set(cl):
+ mu = marginals[cl].project(proj)
+ x = mu.values.flatten()
+ diff = Q.dot(x) - y
+ err = 0.5 * np.abs(diff).sum() / y.sum()
+ errors.append(err)
+ break
+ return errors
+
+ def primal_feasibility(self, mu):
+ ans = 0
+ count = 0
+ for r in mu:
+ for s in mu:
+ if r == s:
+ break
+ d = tuple(set(r) & set(s))
+ if len(d) > 0:
+ x = mu[r].project(d).datavector()
+ y = mu[s].project(d).datavector()
+ err = np.linalg.norm(x - y, 1)
+ ans += err
+ count += 1
+ try:
+ return ans / count
+ except BaseException:
+ return 0
+
+ def run(self, marginals):
+ if self.idx == 0:
+ self.setup()
+
+ t = time.time() - self.start
+ l1_loss = self.engine._marginal_loss(marginals, metric="L1")[0]
+ l2_loss = self.engine._marginal_loss(marginals, metric="L2")[0]
+ feasibility = self.primal_feasibility(marginals)
+ row = [self.calls, t, l1_loss, l2_loss, feasibility]
+ if self.true_answers is not None:
+ variational = np.mean(self.variational_distances(marginals))
+ row.append(100 * variational)
+ self.results.loc[self.idx] = row
+ self.idx += 1
+
+ log.debug("\t\t".join(["%.2f" % v for v in row]), flush=True)
diff --git a/src/synthcity/plugins/core/models/mbi/clique_vector.py b/src/synthcity/plugins/core/models/mbi/clique_vector.py
new file mode 100644
index 00000000..8005a83f
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/clique_vector.py
@@ -0,0 +1,107 @@
+# third party
+import numpy as np
+
+
+class CliqueVector(dict):
+ """This is a convenience class for simplifying arithmetic over the
+ concatenated vector of marginals and potentials.
+
+ These vectors are represented as a dictionary mapping cliques (subsets of attributes)
+ to marginals/potentials (Factor objects)
+ """
+
+ def __init__(self, dictionary):
+ self.dictionary = dictionary
+ dict.__init__(self, dictionary)
+
+ @staticmethod
+ def zeros(domain, cliques):
+ # synthcity relative
+ from .factor import Factor
+
+ return CliqueVector({cl: Factor.zeros(domain.project(cl)) for cl in cliques})
+
+ @staticmethod
+ def ones(domain, cliques):
+ # synthcity relative
+ from .factor import Factor
+
+ return CliqueVector({cl: Factor.ones(domain.project(cl)) for cl in cliques})
+
+ @staticmethod
+ def uniform(domain, cliques):
+ # synthcity relative
+ from .factor import Factor
+
+ return CliqueVector({cl: Factor.uniform(domain.project(cl)) for cl in cliques})
+
+ @staticmethod
+ def random(domain, cliques, prng=np.random):
+ # synthcity relative
+ from .factor import Factor
+
+ return CliqueVector(
+ {cl: Factor.random(domain.project(cl), prng) for cl in cliques}
+ )
+
+ @staticmethod
+ def normal(domain, cliques, prng=np.random):
+ # synthcity relative
+ from .factor import Factor
+
+ return CliqueVector(
+ {cl: Factor.normal(domain.project(cl), prng) for cl in cliques}
+ )
+
+ @staticmethod
+ def from_data(data, cliques):
+ # synthcity relative
+ from .factor import Factor
+
+ ans = {}
+ for cl in cliques:
+ mu = data.project(cl)
+ ans[cl] = Factor(mu.domain, mu.datavector())
+ return CliqueVector(ans)
+
+ def combine(self, other):
+ # combines this CliqueVector with other, even if they do not share the same set of factors
+ # used for warm-starting optimization
+ # Important note: if other contains factors not defined within this CliqueVector, they
+ # are ignored and *not* combined into this CliqueVector
+ for cl in other:
+ for cl2 in self:
+ if set(cl) <= set(cl2):
+ self[cl2] += other[cl]
+ break
+
+ def __mul__(self, const):
+ ans = {cl: const * self[cl] for cl in self}
+ return CliqueVector(ans)
+
+ def __rmul__(self, const):
+ return self.__mul__(const)
+
+ def __add__(self, other):
+ if np.isscalar(other):
+ ans = {cl: self[cl] + other for cl in self}
+ else:
+ ans = {cl: self[cl] + other[cl] for cl in self}
+ return CliqueVector(ans)
+
+ def __sub__(self, other):
+ return self + -1 * other
+
+ def exp(self):
+ ans = {cl: self[cl].exp() for cl in self}
+ return CliqueVector(ans)
+
+ def log(self):
+ ans = {cl: self[cl].log() for cl in self}
+ return CliqueVector(ans)
+
+ def dot(self, other):
+ return sum((self[cl] * other[cl]).sum() for cl in self)
+
+ def size(self):
+ return sum(self[cl].domain.size() for cl in self)
diff --git a/src/synthcity/plugins/core/models/mbi/dataset.py b/src/synthcity/plugins/core/models/mbi/dataset.py
new file mode 100644
index 00000000..ad1347a8
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/dataset.py
@@ -0,0 +1,72 @@
+# stdlib
+import json
+
+# third party
+import numpy as np
+import pandas as pd
+
+# synthcity relative
+from .domain import Domain
+
+
+class Dataset:
+ def __init__(self, df, domain, weights=None):
+ """create a Dataset object
+
+ :param df: a pandas dataframe
+ :param domain: a domain object
+ :param weight: weight for each row
+ """
+ if set(domain.attrs) > set(df.columns):
+ raise AssertionError("data must contain domain attributes")
+ if weights is not None and df.shape[0] != weights.size:
+ raise AssertionError("weights must be the same size as the data")
+ self.domain = domain
+ self.df = df.loc[:, domain.attrs]
+ self.weights = weights
+
+ @staticmethod
+ def synthetic(domain, N):
+ """Generate synthetic data conforming to the given domain
+
+ :param domain: The domain object
+ :param N: the number of individuals
+ """
+ arr = [np.random.randint(low=0, high=n, size=N) for n in domain.shape]
+ values = np.array(arr).T
+ df = pd.DataFrame(values, columns=domain.attrs)
+ return Dataset(df, domain)
+
+ @staticmethod
+ def load(path, domain):
+ """Load data into a dataset object
+
+ :param path: path to csv file
+ :param domain: path to json file encoding the domain information
+ """
+ df = pd.read_csv(path)
+ config = json.load(open(domain))
+ domain = Domain(config.keys(), config.values())
+ return Dataset(df, domain)
+
+ def project(self, cols):
+ """project dataset onto a subset of columns"""
+ if type(cols) in [str, int]:
+ cols = [cols]
+ data = self.df.loc[:, cols]
+ domain = self.domain.project(cols)
+ return Dataset(data, domain, self.weights)
+
+ def drop(self, cols):
+ proj = [c for c in self.domain if c not in cols]
+ return self.project(proj)
+
+ @property
+ def records(self):
+ return self.df.shape[0]
+
+ def datavector(self, flatten=True):
+ """return the database in vector-of-counts form"""
+ bins = [range(n + 1) for n in self.domain.shape]
+ ans = np.histogramdd(self.df.values, bins, weights=self.weights)[0]
+ return ans.flatten() if flatten else ans
diff --git a/src/synthcity/plugins/core/models/mbi/domain.py b/src/synthcity/plugins/core/models/mbi/domain.py
new file mode 100644
index 00000000..5bc700bd
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/domain.py
@@ -0,0 +1,121 @@
+# stdlib
+from functools import reduce
+
+
+class Domain:
+ def __init__(self, attrs, shape):
+ """Construct a Domain object
+
+ :param attrs: a list or tuple of attribute names
+ :param shape: a list or tuple of domain sizes for each attribute
+ """
+ if len(attrs) != len(shape):
+ raise AssertionError("dimensions must be equal")
+ self.attrs = tuple(attrs)
+ self.shape = tuple(shape)
+ self.config = dict(zip(attrs, shape))
+
+ @staticmethod
+ def fromdict(config):
+ """Construct a Domain object from a dictionary of { attr : size } values"""
+ return Domain(config.keys(), config.values())
+
+ def project(self, attrs):
+ """project the domain onto a subset of attributes
+
+ :param attrs: the attributes to project onto
+ :return: the projected Domain object
+ """
+ # return the projected domain
+ if type(attrs) is str:
+ attrs = [attrs]
+ shape = tuple(self.config[a] for a in attrs)
+ return Domain(attrs, shape)
+
+ def marginalize(self, attrs):
+ """marginalize out some attributes from the domain (opposite of project)
+
+ :param attrs: the attributes to marginalize out
+ :return: the marginalized Domain object
+ """
+ proj = [a for a in self.attrs if a not in attrs]
+ return self.project(proj)
+
+ def axes(self, attrs):
+ """return the axes tuple for the given attributes
+
+ :param attrs: the attributes
+ :return: a tuple with the corresponding axes
+ """
+ return tuple(self.attrs.index(a) for a in attrs)
+
+ def transpose(self, attrs):
+ """reorder the attributes in the domain object"""
+ return self.project(attrs)
+
+ def invert(self, attrs):
+ """returns the attributes in the domain not in the list"""
+ return [a for a in self.attrs if a not in attrs]
+
+ def merge(self, other):
+ """merge this domain object with another
+
+ :param other: another Domain object
+ :return: a new domain object covering the full domain
+
+ Example:
+ >>> D1 = Domain(['a','b'], [10,20])
+ >>> D2 = Domain(['b','c'], [20,30])
+ >>> D1.merge(D2)
+ Domain(['a','b','c'], [10,20,30])
+ """
+ extra = other.marginalize(self.attrs)
+ return Domain(self.attrs + extra.attrs, self.shape + extra.shape)
+
+ def contains(self, other):
+ """determine if this domain contains another"""
+ return set(other.attrs) <= set(self.attrs)
+
+ def size(self, attrs=None):
+ """return the total size of the domain"""
+ if attrs is None:
+ return reduce(lambda x, y: x * y, self.shape, 1)
+ return self.project(attrs).size()
+
+ def sort(self, how="size"):
+ """return a new domain object, sorted by attribute size or attribute name"""
+ if how == "size":
+ attrs = sorted(self.attrs, key=self.size)
+ elif how == "name":
+ attrs = sorted(self.attrs)
+ return self.project(attrs)
+
+ def canonical(self, attrs):
+ """return the canonical ordering of the attributes"""
+ return tuple(a for a in self.attrs if a in attrs)
+
+ def __contains__(self, attr):
+ return attr in self.attrs
+
+ def __getitem__(self, a):
+ """return the size of an individual attribute
+ :param a: the attribute
+ """
+ return self.config[a]
+
+ def __iter__(self):
+ """iterator for the attributes in the domain"""
+ return self.attrs.__iter__()
+
+ def __len__(self):
+ return len(self.attrs)
+
+ def __eq__(self, other):
+ return self.attrs == other.attrs and self.shape == other.shape
+
+ def __repr__(self):
+ inner = ", ".join(["%s: %d" % x for x in zip(self.attrs, self.shape)])
+ return "Domain(%s)" % inner
+
+ def __str__(self):
+ return self.__repr__()
diff --git a/src/synthcity/plugins/core/models/mbi/factor.py b/src/synthcity/plugins/core/models/mbi/factor.py
new file mode 100644
index 00000000..5cd5d850
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/factor.py
@@ -0,0 +1,208 @@
+# third party
+import numpy as np
+import torch
+from scipy.special import logsumexp
+
+
+class Factor:
+ def __init__(self, domain, values):
+ """Initialize a factor over the given domain
+
+ :param domain: the domain of the factor
+ :param values: the ndarray of factor values (for each element of the domain)
+
+ Note: values may be a flattened 1d array or a ndarray with same shape as domain
+ """
+ if isinstance(domain, torch.Tensor):
+ domain = domain.detach().cpu().numpy()
+ if isinstance(values, torch.Tensor):
+ values = values.detach().cpu().numpy()
+ if domain.size() != values.size:
+ raise AssertionError(
+ f"domain size ({domain.size()}) does not match values size ({values.size})"
+ )
+ if values.ndim != 1 and values.shape != domain.shape:
+ raise AssertionError("invalid shape for values array")
+ self.domain = domain
+ self.values = values.reshape(domain.shape)
+
+ @staticmethod
+ def zeros(domain):
+ return Factor(domain, np.zeros(domain.shape))
+
+ @staticmethod
+ def ones(domain):
+ return Factor(domain, np.ones(domain.shape))
+
+ @staticmethod
+ def random(domain):
+ return Factor(domain, np.random.rand(*domain.shape))
+
+ @staticmethod
+ def uniform(domain):
+ return Factor.ones(domain) / domain.size()
+
+ @staticmethod
+ def active(domain, structural_zeros):
+ """create a factor that is 0 everywhere except in positions present in
+ 'structural_zeros', where it is -infinity
+
+ :param: domain: the domain of this factor
+ :param: structural_zeros: a list of values that are not possible
+ """
+ idx = tuple(np.array(structural_zeros).T)
+ vals = np.zeros(domain.shape)
+ vals[idx] = -np.inf
+ return Factor(domain, vals)
+
+ def expand(self, domain):
+ if not domain.contains(self.domain):
+ raise AssertionError("expanded domain must contain current domain")
+ dims = len(domain) - len(self.domain)
+ values = self.values.reshape(self.domain.shape + tuple([1] * dims))
+ ax = domain.axes(self.domain.attrs)
+ values = np.moveaxis(values, range(len(ax)), ax)
+ values = np.broadcast_to(values, domain.shape)
+ return Factor(domain, values)
+
+ def transpose(self, attrs):
+ if set(attrs) != set(self.domain.attrs):
+ raise AssertionError("attrs must be same as domain attributes")
+ newdom = self.domain.project(attrs)
+ ax = newdom.axes(self.domain.attrs)
+ values = np.moveaxis(self.values, range(len(ax)), ax)
+ return Factor(newdom, values)
+
+ def project(self, attrs, agg="sum"):
+ """
+ project the factor onto a list of attributes (in order)
+ using either sum or logsumexp to aggregate along other attributes
+ """
+ if agg not in ["sum", "logsumexp"]:
+ raise AssertionError("agg must be sum or logsumexp")
+ marginalized = self.domain.marginalize(attrs)
+ if agg == "sum":
+ ans = self.sum(marginalized.attrs)
+ elif agg == "logsumexp":
+ ans = self.logsumexp(marginalized.attrs)
+ return ans.transpose(attrs)
+
+ def sum(self, attrs=None):
+ if attrs is None:
+ return np.sum(self.values)
+ axes = self.domain.axes(attrs)
+ values = np.sum(self.values, axis=axes)
+ newdom = self.domain.marginalize(attrs)
+ return Factor(newdom, values)
+
+ def logsumexp(self, attrs=None):
+ if attrs is None:
+ return logsumexp(self.values)
+ axes = self.domain.axes(attrs)
+ values = logsumexp(self.values, axis=axes)
+ newdom = self.domain.marginalize(attrs)
+ return Factor(newdom, values)
+
+ def logaddexp(self, other):
+ newdom = self.domain.merge(other.domain)
+ factor1 = self.expand(newdom)
+ factor2 = self.expand(newdom)
+ return Factor(newdom, np.logaddexp(factor1.values, factor2.values))
+
+ def max(self, attrs=None):
+ if attrs is None:
+ return self.values.max()
+ axes = self.domain.axes(attrs)
+ values = np.max(self.values, axis=axes)
+ newdom = self.domain.marginalize(attrs)
+ return Factor(newdom, values)
+
+ def condition(self, evidence):
+ """evidence is a dictionary where
+ keys are attributes, and
+ values are elements of the domain for that attribute"""
+ slices = [evidence[a] if a in evidence else slice(None) for a in self.domain]
+ newdom = self.domain.marginalize(evidence.keys())
+ values = self.values[tuple(slices)]
+ return Factor(newdom, values)
+
+ def copy(self, out=None):
+ if out is None:
+ return Factor(self.domain, self.values.copy())
+ np.copyto(out.values, self.values)
+ return out
+
+ def __mul__(self, other):
+ if np.isscalar(other):
+ new_values = np.nan_to_num(other * self.values)
+ return Factor(self.domain, new_values)
+ newdom = self.domain.merge(other.domain)
+ factor1 = self.expand(newdom)
+ factor2 = other.expand(newdom)
+ return Factor(newdom, factor1.values * factor2.values)
+
+ def __add__(self, other):
+ if np.isscalar(other):
+ return Factor(self.domain, other + self.values)
+ newdom = self.domain.merge(other.domain)
+ factor1 = self.expand(newdom)
+ factor2 = other.expand(newdom)
+ return Factor(newdom, np.add(factor1.values, factor2.values))
+
+ def __iadd__(self, other):
+ if np.isscalar(other):
+ self.values += other
+ return self
+ factor2 = other.expand(self.domain)
+ self.values += factor2.values
+ return self
+
+ def __imul__(self, other):
+ if np.isscalar(other):
+ self.values *= other
+ return self
+ factor2 = other.expand(self.domain)
+ self.values *= factor2.values
+ return self
+
+ def __radd__(self, other):
+ return self.__add__(other)
+
+ def __rmul__(self, other):
+ return self.__mul__(other)
+
+ def __sub__(self, other):
+ if np.isscalar(other):
+ return Factor(self.domain, self.values - other)
+ other = Factor(
+ other.domain, np.where(other.values == -np.inf, 0, -other.values)
+ )
+ return self + other
+
+ def __truediv__(self, other):
+ if np.isscalar(other):
+ new_values = self.values / other
+ new_values = np.nan_to_num(new_values)
+ return Factor(self.domain, new_values)
+ tmp = other.expand(self.domain)
+ vals = np.divide(self.values, tmp.values, where=tmp.values > 0)
+ vals[tmp.values <= 0] = 0.0
+ return Factor(self.domain, vals)
+
+ def exp(self, out=None):
+ if out is None:
+ return Factor(self.domain, np.exp(self.values))
+ np.exp(self.values, out=out.values)
+ return out
+
+ def log(self, out=None):
+ if out is None:
+ return Factor(self.domain, np.log(self.values + 1e-100))
+ np.log(self.values, out=out.values)
+ return out
+
+ def datavector(self, flatten=True):
+ """Materialize the data vector"""
+ if flatten:
+ return self.values.flatten()
+ return self.values
diff --git a/src/synthcity/plugins/core/models/mbi/factor_graph.py b/src/synthcity/plugins/core/models/mbi/factor_graph.py
new file mode 100644
index 00000000..cd18ab13
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/factor_graph.py
@@ -0,0 +1,287 @@
+# stdlib
+from collections import defaultdict
+
+# third party
+import numpy as np
+
+# synthcity relative
+from .clique_vector import CliqueVector
+from .factor import Factor
+
+
+class FactorGraph:
+ def __init__(self, domain, cliques, total=1.0, convex=False, iters=25):
+ self.domain = domain
+ self.cliques = cliques
+ self.total = total
+ self.convex = convex
+ self.iters = iters
+
+ if convex:
+ self.counting_numbers = self.get_counting_numbers()
+ self.belief_propagation = self.convergent_belief_propagation
+ else:
+ counting_numbers = {}
+ for cl in cliques:
+ counting_numbers[cl] = 1.0
+ for a in domain:
+ counting_numbers[a] = 1.0 - len([cl for cl in cliques if a in cl])
+ self.counting_numbers = None, None, counting_numbers
+ self.belief_propagation = self.loopy_belief_propagation
+
+ self.potentials = None
+ self.marginals = None
+ self.messages = self.init_messages()
+ self.beliefs = {i: Factor.zeros(domain.project(i)) for i in domain}
+
+ def datavector(self, flatten=True):
+ """Materialize the explicit representation of the distribution as a data vector."""
+ logp = sum(self.potentials[cl] for cl in self.cliques)
+ ans = np.exp(logp - logp.logsumexp())
+ wgt = ans.domain.size() / self.domain.size()
+ return ans.expand(self.domain).datavector(flatten) * wgt * self.total
+
+ def init_messages(self):
+ mu_n = defaultdict(dict)
+ mu_f = defaultdict(dict)
+ for cl in self.cliques:
+ for v in cl:
+ mu_f[cl][v] = Factor.zeros(self.domain.project(v))
+ mu_n[v][cl] = Factor.zeros(self.domain.project(v))
+ return mu_n, mu_f
+
+ def primal_feasibility(self, mu):
+ ans = 0
+ count = 0
+ for r in mu:
+ for s in mu:
+ if r == s:
+ break
+ d = tuple(set(r) & set(s))
+ if len(d) > 0:
+ x = mu[r].project(d).datavector()
+ y = mu[s].project(d).datavector()
+ err = np.linalg.norm(x - y, 1)
+ ans += err
+ count += 1
+ try:
+ return ans / count
+ except BaseException:
+ return 0
+
+ def project(self, attrs):
+ if type(attrs) is list:
+ attrs = tuple(attrs)
+
+ if self.marginals is not None:
+ # we will average all ways to obtain the given marginal,
+ # since there may be more than one
+ ans = Factor.zeros(self.domain.project(attrs))
+ terminate = False
+ for cl in self.cliques:
+ if set(attrs) <= set(cl):
+ ans += self.marginals[cl].project(attrs)
+ terminate = True
+ if terminate:
+ return ans * (self.total / ans.sum())
+
+ belief = sum(self.beliefs[i] for i in attrs)
+ belief += np.log(self.total) - belief.logsumexp()
+ return belief.transpose(attrs).exp()
+
+ def loopy_belief_propagation(self, potentials, callback=None):
+ mu_n, mu_f = self.messages
+ self.potentials = potentials
+
+ for i in range(self.iters):
+ # factor to variable BP
+ for cl in self.cliques:
+ pre = sum(mu_n[c][cl] for c in cl)
+ for v in cl:
+ complement = [var for var in cl if var is not v]
+ mu_f[cl][v] = potentials[cl] + pre - mu_n[v][cl]
+ mu_f[cl][v] = mu_f[cl][v].logsumexp(complement)
+ mu_f[cl][v] -= mu_f[cl][v].logsumexp()
+
+ # variable to factor BP
+ for v in self.domain:
+ fac = [cl for cl in self.cliques if v in cl]
+ pre = sum(mu_f[cl][v] for cl in fac)
+ for f in fac:
+ complement = [var for var in fac if var is not f]
+ # mu_n[v][f] = Factor.zeros(self.domain.project(v))
+ mu_n[v][f] = pre - mu_f[f][v] # sum(mu_f[c][v] for c in complement)
+ # mu_n[v][f] += sum(mu_f[c][v] for c in complement)
+ # mu_n[v][f] -= mu_n[v][f].logsumexp()
+
+ if callback is not None:
+ mg = self.clique_marginals(mu_n, mu_f, potentials)
+ callback(mg)
+
+ self.beliefs = {
+ v: sum(mu_f[cl][v] for cl in self.cliques if v in cl) for v in self.domain
+ }
+ self.messages = mu_n, mu_f
+ self.marginals = self.clique_marginals(mu_n, mu_f, potentials)
+ return self.marginals
+
+ def convergent_belief_propagation(self, potentials, callback=None):
+ # Algorithm 11.2 in Koller & Friedman (modified to work in log space)
+
+ v, vhat, k = self.counting_numbers
+ sigma, delta = self.messages
+ # sigma, delta = self.init_messages()
+
+ for it in range(self.iters):
+ # pre = {}
+ # for r in self.cliques:
+ # pre[r] = sum(sigma[j][r] for j in r)
+
+ for i in self.domain:
+ nbrs = [r for r in self.cliques if i in r]
+ for r in nbrs:
+ comp = [j for j in r if i != j]
+ delta[r][i] = potentials[r] + sum(sigma[j][r] for j in comp)
+ # delta[r][i] = potentials[r] + pre[r] - sigma[i][r]
+ delta[r][i] /= vhat[i, r]
+ delta[r][i] = delta[r][i].logsumexp(comp)
+ belief = Factor.zeros(self.domain.project(i))
+ belief += sum(delta[r][i] * vhat[i, r] for r in nbrs) / vhat[i]
+ belief -= belief.logsumexp()
+ self.beliefs[i] = belief
+ for r in nbrs:
+ comp = [j for j in r if i != j]
+ A = -v[i, r] / vhat[i, r]
+ B = v[r]
+ sigma[i][r] = A * (potentials[r] + sum(sigma[j][r] for j in comp))
+ # sigma[i][r] = A*(potentials[r] + pre[r] - sigma[i][r])
+ sigma[i][r] += B * (belief - delta[r][i])
+ if callback is not None:
+ mg = self.clique_marginals(sigma, delta, potentials)
+ callback(mg)
+
+ self.messages = sigma, delta
+ return self.clique_marginals(sigma, delta, potentials)
+
+ def clique_marginals(self, mu_n, mu_f, potentials):
+ if self.convex:
+ v, _, _ = self.counting_numbers
+ marginals = {}
+ for cl in self.cliques:
+ belief = potentials[cl] + sum(mu_n[n][cl] for n in cl)
+ if self.convex:
+ belief *= 1.0 / v[cl]
+ belief += np.log(self.total) - belief.logsumexp()
+ marginals[cl] = belief.exp()
+ return CliqueVector(marginals)
+
+ def mle(self, marginals):
+ return -self.bethe_entropy(marginals)[1]
+
+ def bethe_entropy(self, marginals):
+ """
+ Return the Bethe Entropy and the gradient with respect to the marginals
+
+ """
+ _, _, weights = self.counting_numbers
+ entropy = 0
+ dmarginals = {}
+ attributes = set()
+ for cl in self.cliques:
+ mu = marginals[cl] / self.total
+ entropy += weights[cl] * (mu * mu.log()).sum()
+ dmarginals[cl] = weights[cl] * (1 + mu.log()) / self.total
+ for a in set(cl) - set(attributes):
+ p = mu.project(a)
+ entropy += weights[a] * (p * p.log()).sum()
+ dmarginals[cl] += weights[a] * (1 + p.log()) / self.total
+ attributes.update(a)
+
+ return -entropy, -1 * CliqueVector(dmarginals)
+
+ def get_counting_numbers(self):
+ # third party
+ from cvxopt import matrix, solvers
+
+ solvers.options["show_progress"] = False
+ index = {}
+ idx = 0
+
+ for i in self.domain:
+ index[i] = idx
+ idx += 1
+ for r in self.cliques:
+ index[r] = idx
+ idx += 1
+
+ for r in self.cliques:
+ for i in r:
+ index[r, i] = idx
+ idx += 1
+
+ vectors = {}
+ for r in self.cliques:
+ v = np.zeros(idx)
+ v[index[r]] = 1
+ for i in r:
+ v[index[r, i]] = 1
+ vectors[r] = v
+
+ for i in self.domain:
+ v = np.zeros(idx)
+ v[index[i]] = 1
+ for r in self.cliques:
+ if i in r:
+ v[index[r, i]] = -1
+ vectors[i] = v
+
+ constraints = []
+ for i in self.domain:
+ con = vectors[i].copy()
+ for r in self.cliques:
+ if i in r:
+ con += vectors[r]
+ constraints.append(con)
+ A = np.array(constraints)
+ b = np.ones(len(self.domain))
+
+ X = np.vstack([vectors[r] for r in self.cliques])
+ y = np.ones(len(self.cliques))
+ P = X.T @ X
+ q = -X.T @ y
+ G = -np.eye(q.size)
+ h = np.zeros(q.size)
+ minBound = 1.0 / len(self.domain)
+ for r in self.cliques:
+ h[index[r]] = -minBound
+
+ P = matrix(P)
+ q = matrix(q)
+ G = matrix(G)
+ h = matrix(h)
+ A = matrix(A)
+ b = matrix(b)
+
+ ans = solvers.qp(P, q, G, h, A, b)
+ x = np.array(ans["x"]).flatten()
+
+ counting_v = {}
+ for r in self.cliques:
+ counting_v[r] = x[index[r]]
+ for i in r:
+ counting_v[i, r] = x[index[r, i]]
+ for i in self.domain:
+ counting_v[i] = x[index[i]]
+
+ counting_vhat = {}
+ counting_k = {}
+ for i in self.domain:
+ nbrs = [r for r in self.cliques if i in r]
+ counting_vhat[i] = counting_v[i] + sum(counting_v[r] for r in nbrs)
+ counting_k[i] = counting_v[i] - sum(counting_v[i, r] for r in nbrs)
+ for r in nbrs:
+ counting_vhat[i, r] = counting_v[r] + counting_v[i, r]
+ for r in self.cliques:
+ counting_k[r] = counting_v[r] + sum(counting_v[i, r] for i in r)
+
+ return counting_v, counting_vhat, counting_k
diff --git a/src/synthcity/plugins/core/models/mbi/graphical_model.py b/src/synthcity/plugins/core/models/mbi/graphical_model.py
new file mode 100644
index 00000000..bbe913de
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/graphical_model.py
@@ -0,0 +1,332 @@
+# stdlib
+import itertools
+from functools import reduce
+
+# third party
+import cloudpickle
+import networkx as nx
+import numpy as np
+import pandas as pd
+
+# synthcity relative
+from .clique_vector import CliqueVector
+from .dataset import Dataset
+from .domain import Domain
+from .junction_tree import JunctionTree
+
+
+class GraphicalModel:
+ def __init__(self, domain, cliques, total=1.0, elimination_order=None):
+ """Constructor for a GraphicalModel
+
+ :param domain: a Domain object
+ :param total: the normalization constant for the distribution
+ :param cliques: a list of cliques (not necessarilly maximal cliques)
+ - each clique is a subset of attributes, represented as a tuple or list
+ :param elim_order: an elimination order for the JunctionTree algorithm
+ - Elimination order will impact the efficiency by not correctness.
+ By default, a greedy elimination order is used
+ """
+ self.domain = domain
+ self.total = total
+ tree = JunctionTree(domain, cliques, elimination_order)
+ self.junction_tree = tree
+
+ self.cliques = tree.maximal_cliques() # maximal cliques
+ self.message_order = tree.mp_order()
+ self.sep_axes = tree.separator_axes()
+ self.neighbors = tree.neighbors()
+ self.elimination_order = tree.elimination_order
+
+ self.size = sum(domain.size(cl) for cl in self.cliques)
+ if self.size * 8 > 4 * 10**9:
+ # stdlib
+ import warnings
+
+ message = "Size of parameter vector is %.2f GB. " % (
+ self.size * 8 / 10**9
+ )
+ message += "Consider removing some measurements or finding a better elimination order"
+ warnings.warn(message)
+
+ @staticmethod
+ def save(model, path):
+ cloudpickle.dump(model, open(path, "wb"))
+
+ @staticmethod
+ def load(path):
+ return cloudpickle.load(open(path, "rb"))
+
+ def project(self, attrs):
+ """Project the distribution onto a subset of attributes.
+ I.e., compute the marginal of the distribution
+
+ :param attrs: a subset of attributes in the domain, represented as a list or tuple
+ :return: a Factor object representing the marginal distribution
+ """
+ # use precalculated marginals if possible
+ if type(attrs) is list:
+ attrs = tuple(attrs)
+ if hasattr(self, "marginals"):
+ for cl in self.cliques:
+ if set(attrs) <= set(cl):
+ return self.marginals[cl].project(attrs)
+
+ elim = self.domain.invert(attrs)
+ elim_order = greedy_order(self.domain, self.cliques + [attrs], elim)
+ pots = list(self.potentials.values())
+ ans = variable_elimination_logspace(pots, elim_order, self.total)
+ return ans.project(attrs)
+
+ def krondot(self, matrices):
+ """Compute the answer to the set of queries Q1 x Q2 X ... x Qd, where
+ Qi is a query matrix on the ith attribute and "x" is the Kronecker product
+ This may be more efficient than computing a supporting marginal then multiplying that by Q.
+ In particular, if each Qi has only a few rows.
+
+ :param matrices: a list of matrices for each attribute in the domain
+ :return: the vector of query answers
+ """
+ if any(M.shape[1] != n for M, n in zip(matrices, self.domain.shape)):
+ raise ValueError("matrices must conform to the shape of the domain")
+ logZ = self.belief_propagation(self.potentials, logZ=True)
+ factors = [self.potentials[cl].exp() for cl in self.cliques]
+ Factor = type(factors[0]) # infer the type of the factors
+ elim = self.domain.attrs
+ for attr, Q in zip(elim, matrices):
+ d = Domain(["%s-answer" % attr, attr], Q.shape)
+ factors.append(Factor(d, Q))
+ result = variable_elimination(factors, elim)
+ result = result.transpose(["%s-answer" % a for a in elim])
+ return result.datavector(flatten=False) * self.total / np.exp(logZ)
+
+ def calculate_many_marginals(self, projections):
+ """Calculates marginals for all the projections in the list using
+ Algorithm for answering many out-of-clique queries (section 10.3 in Koller and Friedman)
+
+ This method may be faster than calling project many times
+
+ :param projections: a list of projections, where
+ each projection is a subset of attributes (represented as a list or tuple)
+ :return: a list of marginals, where each marginal is represented as a Factor
+ """
+
+ self.marginals = self.belief_propagation(self.potentials)
+ sep = self.sep_axes
+ neighbors = self.neighbors
+ # first calculate P(Cj | Ci) for all neighbors Ci, Cj
+ conditional = {}
+ for Ci in neighbors:
+ for Cj in neighbors[Ci]:
+ Sij = sep[(Cj, Ci)]
+ Z = self.marginals[Cj]
+ conditional[(Cj, Ci)] = Z / Z.project(Sij)
+
+ # now iterate through pairs of cliques in order of distance
+ pred, dist = nx.floyd_warshall_predecessor_and_distance(
+ self.junction_tree.tree, weight=False
+ )
+ results = {}
+ for Ci, Cj in sorted(
+ itertools.combinations(self.cliques, 2), key=lambda X: dist[X[0]][X[1]]
+ ):
+ Cl = pred[Ci][Cj]
+ Y = conditional[(Cj, Cl)]
+ if Cl == Ci:
+ X = self.marginals[Ci]
+ results[(Ci, Cj)] = results[(Cj, Ci)] = X * Y
+ else:
+ X = results[(Ci, Cl)]
+ S = set(Cl) - set(Ci) - set(Cj)
+ results[(Ci, Cj)] = results[(Cj, Ci)] = (X * Y).sum(S)
+
+ results = {
+ self.domain.canonical(key[0] + key[1]): results[key] for key in results
+ }
+
+ answers = {}
+ for proj in projections:
+ for attr in results:
+ if set(proj) <= set(attr):
+ answers[proj] = results[attr].project(proj)
+ break
+ if proj not in answers:
+ # just use variable elimination
+ answers[proj] = self.project(proj)
+
+ return answers
+
+ def datavector(self, flatten=True):
+ """Materialize the explicit representation of the distribution as a data vector."""
+ logp = sum(self.potentials[cl] for cl in self.cliques)
+ ans = np.exp(logp - logp.logsumexp())
+ wgt = ans.domain.size() / self.domain.size()
+ return ans.expand(self.domain).datavector(flatten) * wgt * self.total
+
+ def belief_propagation(self, potentials, logZ=False):
+ """Compute the marginals of the graphical model with given parameters
+
+ Note this is an efficient, numerically stable implementation of belief propagation
+
+ :param potentials: the (log-space) parameters of the graphical model
+ :param logZ: flag to return logZ instead of marginals
+ :return marginals: the marginals of the graphical model
+ """
+ beliefs = {cl: potentials[cl].copy() for cl in potentials}
+ messages = {}
+ for i, j in self.message_order:
+ sep = beliefs[i].domain.invert(self.sep_axes[(i, j)])
+ if (j, i) in messages:
+ tau = beliefs[i] - messages[(j, i)]
+ else:
+ tau = beliefs[i]
+ messages[(i, j)] = tau.logsumexp(sep)
+ beliefs[j] += messages[(i, j)]
+
+ cl = self.cliques[0]
+ if logZ:
+ return beliefs[cl].logsumexp()
+
+ logZ = beliefs[cl].logsumexp()
+ for cl in self.cliques:
+ beliefs[cl] += np.log(self.total) - logZ
+ beliefs[cl] = beliefs[cl].exp(out=beliefs[cl])
+
+ return CliqueVector(beliefs)
+
+ def mle(self, marginals):
+ """Compute the model parameters from the given marginals
+
+ :param marginals: target marginals of the distribution
+ :param: the potentials of the graphical model with the given marginals
+ """
+ potentials = {}
+ variables = set()
+ for cl in self.cliques:
+ new = tuple(variables & set(cl))
+ # factor = marginals[cl] / marginals[cl].project(new)
+ variables.update(cl)
+ potentials[cl] = marginals[cl].log() - marginals[cl].project(new).log()
+ return CliqueVector(potentials)
+
+ def fit(self, data):
+ # synthcity relative
+ from .factor import Factor
+
+ if not data.domain.contains(self.domain):
+ raise ValueError("data domain not compatible with model domain")
+ marginals = {}
+ for cl in self.cliques:
+ x = data.project(cl).datavector()
+ dom = self.domain.project(cl)
+ marginals[cl] = Factor(dom, x)
+ self.potentials = self.mle(marginals)
+
+ def synthetic_data(self, rows=None, method="round"):
+ """Generate synthetic tabular data from the distribution.
+ Valid options for method are 'round' and 'sample'."""
+ total = int(self.total) if rows is None else rows
+ cols = self.domain.attrs
+ data = np.zeros((total, len(cols)), dtype=int)
+ df = pd.DataFrame(data, columns=cols)
+ cliques = [set(cl) for cl in self.cliques]
+
+ def synthetic_col(counts, total):
+ if method == "sample":
+ probas = counts / counts.sum()
+ return np.random.choice(counts.size, total, True, probas)
+ counts *= total / counts.sum()
+ frac, integ = np.modf(counts)
+ integ = integ.astype(int)
+ extra = total - integ.sum()
+ if extra > 0:
+ idx = np.random.choice(counts.size, extra, False, frac / frac.sum())
+ integ[idx] += 1
+ vals = np.repeat(np.arange(counts.size), integ)
+ np.random.shuffle(vals)
+ return vals
+
+ order = self.elimination_order[::-1]
+ col = order[0]
+ marg = self.project([col]).datavector(flatten=False)
+ df.loc[:, col] = synthetic_col(marg, total)
+ used = {col}
+
+ for col in order[1:]:
+ relevant = [cl for cl in cliques if col in cl]
+ relevant = used.intersection(set.union(*relevant))
+ proj = tuple(relevant)
+ used.add(col)
+ marg = self.project(proj + (col,)).datavector(flatten=False)
+
+ def foo(group):
+ idx = group.name
+ vals = synthetic_col(marg[idx], group.shape[0])
+ group[col] = vals
+ return group
+
+ if len(proj) >= 1:
+ df = df.groupby(list(proj), group_keys=False).apply(foo)
+ else:
+ df[col] = synthetic_col(marg, df.shape[0])
+
+ return Dataset(df, self.domain)
+
+
+def variable_elimination_logspace(potentials, elim, total):
+ """run variable elimination on a list of **logspace** factors"""
+ k = len(potentials)
+ psi = dict(zip(range(k), potentials))
+ for z in elim:
+ psi2 = [psi.pop(i) for i in list(psi.keys()) if z in psi[i].domain]
+ phi = reduce(lambda x, y: x + y, psi2, 0)
+ tau = phi.logsumexp([z])
+ psi[k] = tau
+ k += 1
+ ans = reduce(lambda x, y: x + y, psi.values(), 0)
+ return (ans - ans.logsumexp() + np.log(total)).exp()
+
+
+def variable_elimination(factors, elim):
+ """run variable elimination on a list of (non-logspace) factors"""
+ k = len(factors)
+ psi = dict(zip(range(k), factors))
+ for z in elim:
+ psi2 = [psi.pop(i) for i in list(psi.keys()) if z in psi[i].domain]
+ phi = reduce(lambda x, y: x * y, psi2, 1)
+ tau = phi.sum([z])
+ psi[k] = tau
+ k += 1
+ return reduce(lambda x, y: x * y, psi.values(), 1)
+
+
+def greedy_order(domain, cliques, elim):
+ order = []
+ unmarked = set(elim)
+ cliques = set(cliques)
+ total_cost = 0
+ for k in range(len(elim)):
+ cost = {}
+ for a in unmarked:
+ # all cliques that have a
+ neighbors = list(filter(lambda cl: a in cl, cliques))
+ # variables in this "super-clique"
+ variables = tuple(set.union(set(), *map(set, neighbors)))
+ # domain for the resulting factor
+ newdom = domain.project(variables)
+ # cost of removing a
+ cost[a] = newdom.size()
+
+ # find the best variable to eliminate
+ a = min(cost, key=lambda a: cost[a])
+
+ # do some cleanup
+ order.append(a)
+ unmarked.remove(a)
+ neighbors = list(filter(lambda cl: a in cl, cliques))
+ variables = tuple(set.union(set(), *map(set, neighbors)) - {a})
+ cliques -= set(neighbors)
+ cliques.add(variables)
+ total_cost += cost[a]
+
+ return order
diff --git a/src/synthcity/plugins/core/models/mbi/identity.py b/src/synthcity/plugins/core/models/mbi/identity.py
new file mode 100644
index 00000000..5d97306d
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/identity.py
@@ -0,0 +1,346 @@
+# third party
+import numpy as np
+from scipy import sparse
+from scipy.sparse.linalg import LinearOperator
+
+
+# util function
+def class_to_dict(inst, ignore_list=[], attr_prefix=""):
+ """Writes state of class instance as a dict
+ Includes both attributes and properties (i.e. those methods labeled with @property)
+ Note: because this capture properties, it should be viewed as a snapshot of instance state
+ :param inst: instance to represent as dict
+ :param ignore_list: list of attr
+ :return: dict
+ """
+ output = vars(inst).copy() # captures regular variables
+ cls = type(inst) # class of instance
+ properties = [p for p in dir(cls) if isinstance(getattr(cls, p), property)]
+ for p in properties:
+ prop = getattr(cls, p) # get property object by name
+ output[p] = prop.fget(inst) # call its fget
+ for k in list(output.keys()): # filter out dict keys mentioned in ignore-list
+ if k in ignore_list:
+ del output[k]
+ else: # prepend attr_prefix
+ output[attr_prefix + k] = output.pop(k)
+
+ output[attr_prefix + "class"] = cls.__name__
+ return output
+
+
+# Main Matrix classes
+class EkteloMatrix(LinearOperator):
+ """
+ An EkteloMatrix is a linear transformation that can compute matrix-vector products
+ """
+
+ # must implement: _matmat, _transpose, matrix
+ # can implement: gram, sensitivity, sum, dense_matrix, sparse_matrix, __abs__
+
+ def __init__(self, matrix):
+ """Instantiate an EkteloMatrix from an explicitly represented backing matrix
+
+ :param matrix: a 2d numpy array or a scipy sparse matrix
+ """
+ self.matrix = matrix
+ self.dtype = matrix.dtype
+ self.shape = matrix.shape
+
+ def asDict(self):
+ d = class_to_dict(self, ignore_list=[])
+ return d
+
+ def _transpose(self):
+ return EkteloMatrix(self.matrix.T)
+
+ def _matmat(self, V):
+ """
+ Matrix multiplication of a m x n matrix Q
+
+ :param V: a n x p numpy array
+ :return Q*V: a m x p numpy aray
+ """
+ return self.matrix @ V
+
+ def gram(self):
+ """
+ Compute the Gram matrix of the given matrix.
+ For a matrix Q, the gram matrix is defined as Q^T Q
+ """
+ return self.T @ self # works for subclasses too
+
+ def sensitivity(self):
+ # note: this works because np.abs calls self.__abs__
+ return np.max(np.abs(self).sum(axis=0))
+
+ def sum(self, axis=None):
+ # this implementation works for all subclasses too
+ # (as long as they define _matmat and _transpose)
+ if axis == 0:
+ return self.T.dot(np.ones(self.shape[0]))
+ ans = self.dot(np.ones(self.shape[1]))
+ return ans if axis == 1 else np.sum(ans)
+
+ def inv(self):
+ return EkteloMatrix(np.linalg.inv(self.dense_matrix()))
+
+ def pinv(self):
+ return EkteloMatrix(np.linalg.pinv(self.dense_matrix()))
+
+ def trace(self):
+ return self.diag().sum()
+
+ def diag(self):
+ return np.diag(self.dense_matrix())
+
+ def _adjoint(self):
+ return self._transpose()
+
+ def __mul__(self, other):
+ if np.isscalar(other):
+ return Weighted(self, other) # :noqa F821
+ if type(other) == np.ndarray:
+ return self.dot(other)
+ if isinstance(other, EkteloMatrix):
+ return Product(self, other)
+ # note: this expects both matrix types to be compatible (e.g., sparse and sparse)
+ # todo: make it work for different backing representations
+ else:
+ raise TypeError(
+ "incompatible type %s for multiplication with EkteloMatrix"
+ % type(other)
+ )
+
+ def __add__(self, other):
+ if np.isscalar(other):
+ other = Weighted(Ones(self.shape), other) # :noqa F821
+ return Sum([self, other])
+
+ def __sub__(self, other):
+ return self + -1 * other
+
+ def __rmul__(self, other):
+ if np.isscalar(other):
+ return Weighted(self, other) # :noqa F821
+ return NotImplemented
+
+ def __getitem__(self, key):
+ """
+ return a given row from the matrix
+
+ :param key: the index of the row to return
+ :return: a 1xN EkteloMatrix
+ """
+ # row indexing, subclasses may provide more efficient implementation
+ m = self.shape[0]
+ v = np.zeros(m)
+ v[key] = 1.0
+ return EkteloMatrix(self.T.dot(v).reshape(1, self.shape[1]))
+
+ def dense_matrix(self):
+ """
+ return the dense representation of this matrix, as a 2D numpy array
+ """
+ if sparse.issparse(self.matrix):
+ return self.matrix.toarray()
+ return self.matrix
+
+ def sparse_matrix(self):
+ """
+ return the sparse representation of this matrix, as a scipy matrix
+ """
+ if sparse.issparse(self.matrix):
+ return self.matrix
+ return sparse.csr_matrix(self.matrix)
+
+ @property
+ def ndim(self):
+ # todo: deprecate if possible
+ return 2
+
+ def __abs__(self):
+ return EkteloMatrix(self.matrix.__abs__())
+
+ def __sqr__(self):
+ if sparse.issparse(self.matrix):
+ return EkteloMatrix(self.matrix.power(2))
+ return EkteloMatrix(self.matrix**2)
+
+
+class Ones(EkteloMatrix):
+ """A m x n matrix of all ones"""
+
+ def __init__(self, m, n, dtype=np.float64):
+ self.m = m
+ self.n = n
+ self.shape = (m, n)
+ self.dtype = dtype
+
+ def _matmat(self, V):
+ ans = V.sum(axis=0, keepdims=True)
+ return np.repeat(ans, self.m, axis=0)
+
+ def _transpose(self):
+ return Ones(self.n, self.m, self.dtype)
+
+ def gram(self):
+ return self.m * Ones(self.n, self.n, self.dtype)
+
+ def pinv(self):
+ c = 1.0 / (self.m * self.n)
+ return c * Ones(self.n, self.m, self.dtype)
+
+ def trace(self):
+ if self.n != self.m:
+ raise ValueError("matrix is not square")
+ return self.n
+
+ @property
+ def matrix(self):
+ return np.ones(self.shape, dtype=self.dtype)
+
+ def __abs__(self):
+ return self
+
+ def __sqr__(self):
+ return self
+
+
+class Sum(EkteloMatrix):
+ """Class for the Sum of matrices"""
+
+ def __init__(self, matrices):
+ # all must have same shape
+ self.matrices = matrices
+ self.shape = matrices[0].shape
+ self.dtype = np.result_type(*[Q.dtype for Q in matrices])
+
+ def _matmat(self, V):
+ return sum(Q.dot(V) for Q in self.matrices)
+
+ def _transpose(self):
+ return Sum([Q.T for Q in self.matrices])
+
+ def __mul__(self, other):
+ if isinstance(other, EkteloMatrix):
+ return Sum(
+ [Q @ other for Q in self.matrices]
+ ) # should use others rmul though
+ return EkteloMatrix.__mul__(self, other)
+
+ def diag(self):
+ return sum(Q.diag() for Q in self.matrices)
+
+ @property
+ def matrix(self):
+ def _any_sparse(matrices):
+ return any(sparse.issparse(Q.matrix) for Q in matrices)
+
+ if _any_sparse(self.matrices):
+ return sum(Q.sparse_matrix() for Q in self.matrices)
+ return sum(Q.dense_matrix() for Q in self.matrices)
+
+
+class Weighted(EkteloMatrix):
+ """Class for multiplication by a constant"""
+
+ def __init__(self, base, weight):
+ if isinstance(base, Weighted):
+ weight *= base.weight
+ base = base.base
+ self.base = base
+ self.weight = weight
+ self.shape = base.shape
+ self.dtype = base.dtype
+
+ def _matmat(self, V):
+ return self.weight * self.base.dot(V)
+
+ def _transpose(self):
+ return Weighted(self.base.T, self.weight)
+
+ def gram(self):
+ return Weighted(self.base.gram(), self.weight**2)
+
+ def pinv(self):
+ return Weighted(self.base.pinv(), 1.0 / self.weight)
+
+ def inv(self):
+ return Weighted(self.base.inv(), 1.0 / self.weight)
+
+ def trace(self):
+ return self.weight * self.base.trace()
+
+ def __abs__(self):
+ return Weighted(self.base.__abs__(), np.abs(self.weight))
+
+ def __sqr__(self):
+ return Weighted(self.base.__sqr__(), self.weight**2)
+
+ @property
+ def matrix(self):
+ return self.weight * self.base.matrix
+
+
+class Product(EkteloMatrix):
+ def __init__(self, A, B):
+ if A.shape[1] != B.shape[0]:
+ raise ValueError("inner dimensions do not match")
+ self._A = A
+ self._B = B
+ self.shape = (A.shape[0], B.shape[1])
+ self.dtype = np.result_type(A.dtype, B.dtype)
+
+ def _matmat(self, X):
+ return self._A.dot(self._B.dot(X))
+
+ def _transpose(self):
+ return Product(self._B.T, self._A.T)
+
+ @property
+ def matrix(self):
+ return self._A.matrix @ self._B.matrix
+
+ def gram(self):
+ return Product(self.T, self)
+
+ def inv(self):
+ return Product(self._B.inv(), self._A.inv())
+
+
+class Identity(EkteloMatrix):
+ def __init__(self, n, dtype=np.float64):
+ self.n = n
+ self.shape = (n, n)
+ self.dtype = dtype
+
+ def _matmat(self, V):
+ return V
+
+ def _transpose(self):
+ return self
+
+ @property
+ def matrix(self):
+ return sparse.eye(self.n, dtype=self.dtype)
+
+ def __mul__(self, other):
+ if other.shape[0] != self.n:
+ raise ValueError("dimension mismatch")
+ return other
+
+ def inv(self):
+ return self
+
+ def pinv(self):
+ return self
+
+ def trace(self):
+ return self.n
+
+ def __abs__(self):
+ return self
+
+ def __sqr__(self):
+ return self
diff --git a/src/synthcity/plugins/core/models/mbi/inference.py b/src/synthcity/plugins/core/models/mbi/inference.py
new file mode 100644
index 00000000..4fb3dd6d
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/inference.py
@@ -0,0 +1,419 @@
+# stdlib
+from collections import defaultdict
+
+# third party
+import numpy as np
+from scipy import sparse
+from scipy.sparse.linalg import aslinearoperator, eigsh, lsmr
+
+# synthcity absolute
+from synthcity import logger
+
+# synthcity relative
+from . import callbacks
+from .clique_vector import CliqueVector
+from .factor import Factor
+from .graphical_model import GraphicalModel
+
+
+class FactoredInference:
+ def __init__(
+ self,
+ domain,
+ # backend="numpy",
+ structural_zeros={},
+ metric="L2",
+ log=False,
+ iters=1000,
+ warm_start=False,
+ elim_order=None,
+ ):
+ """
+ Class for learning a GraphicalModel from noisy measurements on a data distribution
+
+ :param domain: The domain information (A Domain object)
+ :param backend: numpy or torch backend
+ :param structural_zeros: An encoding of the known (structural) zeros in the distribution.
+ Specified as a dictionary where
+ - each key is a subset of attributes of size r
+ - each value is a list of r-tuples corresponding to impossible attribute settings
+ :param metric: The optimization metric. May be L1, L2 or a custom callable function
+ - custom callable function must consume the marginals and produce the loss and gradient
+ - see FactoredInference._marginal_loss for more information
+ :param log: flag to log iterations of optimization
+ :param iters: number of iterations to optimize for
+ :param warm_start: initialize new model or reuse last model when calling infer multiple times
+ :param elim_order: an elimination order for the JunctionTree algorithm
+ - Elimination order will impact the efficiency by not correctness.
+ By default, a greedy elimination order is used
+ """
+ self.domain = domain
+ self.backend = "numpy"
+ # self.backend = backend
+ self.metric = metric
+ self.log = log
+ self.iters = iters
+ self.warm_start = warm_start
+ self.history = []
+ self.elim_order = elim_order
+ self.Factor = Factor
+ # if backend == "torch":
+ # # synthcity relative
+ # from .torch_factor import Factor
+
+ # self.Factor = Factor
+ # else:
+ # # synthcity relative
+ # from .factor import Factor
+
+ # self.Factor = Factor
+
+ self.structural_zeros = CliqueVector({})
+ for cl in structural_zeros:
+ dom = self.domain.project(cl)
+ fact = structural_zeros[cl]
+ self.structural_zeros[cl] = self.Factor.active(dom, fact)
+
+ def estimate(
+ self, measurements, total=None, engine="MD", callback=None, options={}
+ ):
+ """
+ Estimate a GraphicalModel from the given measurements
+
+ :param measurements: a list of (Q, y, noise, proj) tuples, where
+ Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator)
+ y is the noisy answers to the measurement queries
+ noise is the standard deviation of the noise added to y
+ proj defines the marginal used for this measurement set (a subset of attributes)
+ :param total: The total number of records (if known)
+ :param engine: the optimization algorithm to use, options include:
+ MD - Mirror Descent with armijo line search
+ RDA - Regularized Dual Averaging
+ IG - Interior Gradient
+ :param callback: a function to be called after each iteration of optimization
+ :param options: solver specific options passed as a dictionary
+ { param_name : param_value }
+
+ :return model: A GraphicalModel that best matches the measurements taken
+ """
+ measurements = self.fix_measurements(measurements)
+ options["callback"] = callback
+ if callback is None and self.log:
+ options["callback"] = callbacks.Logger(self)
+ if engine == "MD":
+ self.mirror_descent(measurements, total, **options)
+ elif engine == "RDA":
+ self.dual_averaging(measurements, total, **options)
+ elif engine == "IG":
+ self.interior_gradient(measurements, total, **options)
+ return self.model
+
+ def fix_measurements(self, measurements):
+ if not isinstance(measurements, list):
+ raise TypeError("measurements must be a list")
+ if any(len(m) != 4 for m in measurements):
+ raise ValueError("each measurement must be a 4-tuple (Q, y, noise,proj)")
+ ans = []
+ for Q, y, noise, proj in measurements:
+ if Q is not None and Q.shape[0] != y.size:
+ raise ValueError("shapes of Q and y are not compatible")
+ if type(proj) is list:
+ proj = tuple(proj)
+ if type(proj) is not tuple:
+ proj = (proj,)
+ if Q is None:
+ Q = sparse.eye(self.domain.size(proj))
+ if not np.isscalar(noise):
+ raise TypeError("noise must be a real value, given " + str(noise))
+ if any(a not in self.domain for a in proj):
+ raise ValueError(str(proj) + " not contained in domain")
+ if Q.shape[1] != self.domain.size(proj):
+ raise ValueError("shapes of Q and proj are not compatible")
+ ans.append((Q, y, noise, proj))
+ return ans
+
+ def interior_gradient(
+ self, measurements, total, lipschitz=None, c=1, sigma=1, callback=None
+ ):
+ """Use the interior gradient algorithm to estimate the GraphicalModel
+ See https://epubs.siam.org/doi/pdf/10.1137/S1052623403427823 for more information
+
+ :param measurements: a list of (Q, y, noise, proj) tuples, where
+ Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator)
+ y is the noisy answers to the measurement queries
+ noise is the standard deviation of the noise added to y
+ proj defines the marginal used for this measurement set (a subset of attributes)
+ :param total: The total number of records (if known)
+ :param lipschitz: the Lipchitz constant of grad L(mu)
+ - automatically calculated for metric=L2
+ - doesn't exist for metric=L1
+ - must be supplied for custom callable metrics
+ :param c, sigma: parameters of the algorithm
+ :param callback: a function to be called after each iteration of optimization
+ """
+ if self.metric == "L1":
+ raise ValueError("dual_averaging cannot be used with metric=L1")
+ if callable(self.metric) and lipschitz is None:
+ raise ValueError("lipschitz constant must be supplied")
+ self._setup(measurements, total)
+ # what are c and sigma? For now using 1
+ model = self.model
+ total = model.total
+ L = self._lipschitz(measurements) if lipschitz is None else lipschitz
+ if self.log:
+ logger.debug(f"Lipchitz constant: {L}")
+
+ theta = model.potentials
+ x = y = z = model.belief_propagation(theta)
+ sigma_over_L = sigma / L
+ for k in range(1, self.iters + 1):
+ a = (
+ np.sqrt((c * sigma_over_L) ** 2 + 4 * c * sigma_over_L)
+ - sigma_over_L * c
+ ) / 2
+ y = (1 - a) * x + a * z
+ c *= 1 - a
+ _, g = self._marginal_loss(y)
+ theta = theta - a / c / total * g
+ z = model.belief_propagation(theta)
+ x = (1 - a) * x + a * z
+ if callback is not None:
+ callback(x)
+
+ model.marginals = x
+ model.potentials = model.mle(x)
+
+ def dual_averaging(self, measurements, total=None, lipschitz=None, callback=None):
+ """Use the regularized dual averaging algorithm to estimate the GraphicalModel
+ See https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/xiao10JMLR.pdf
+
+ :param measurements: a list of (Q, y, noise, proj) tuples, where
+ Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator)
+ y is the noisy answers to the measurement queries
+ noise is the standard deviation of the noise added to y
+ proj defines the marginal used for this measurement set (a subset of attributes)
+ :param total: The total number of records (if known)
+ :param lipschitz: the Lipchitz constant of grad L(mu)
+ - automatically calculated for metric=L2
+ - doesn't exist for metric=L1
+ - must be supplied for custom callable metrics
+ :param callback: a function to be called after each iteration of optimization
+ """
+ if self.metric == "L1":
+ raise ValueError("dual_averaging cannot be used with metric=L1")
+ if callable(self.metric) and lipschitz is None:
+ raise ValueError("lipschitz constant must be supplied")
+ self._setup(measurements, total)
+ model = self.model
+ domain, cliques, total = model.domain, model.cliques, model.total
+ L = self._lipschitz(measurements) if lipschitz is None else lipschitz
+ logger.debug(f"Lipchitz constant: {L}")
+ if L == 0:
+ return
+
+ theta = model.potentials
+ gbar = CliqueVector(
+ {cl: self.Factor.zeros(domain.project(cl)) for cl in cliques}
+ )
+ w = v = model.belief_propagation(theta)
+ beta = 0
+
+ for t in range(1, self.iters + 1):
+ c = 2.0 / (t + 1)
+ u = (1 - c) * w + c * v
+ _, g = self._marginal_loss(u) # not interested in loss of this query point
+ gbar = (1 - c) * gbar + c * g
+ theta = -t * (t + 1) / (4 * L + beta) / self.model.total * gbar
+ v = model.belief_propagation(theta)
+ w = (1 - c) * w + c * v
+
+ if callback is not None:
+ callback(w)
+
+ model.marginals = w
+ model.potentials = model.mle(w)
+
+ def mirror_descent(self, measurements, total=None, stepsize=None, callback=None):
+ """Use the mirror descent algorithm to estimate the GraphicalModel
+ See https://web.iem.technion.ac.il/images/user-files/becka/papers/3.pdf
+
+ :param measurements: a list of (Q, y, noise, proj) tuples, where
+ Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator)
+ y is the noisy answers to the measurement queries
+ noise is the standard deviation of the noise added to y
+ proj defines the marginal used for this measurement set (a subset of attributes)
+ :param stepsize: The step size function for the optimization (None or scalar or function)
+ if None, will perform line search at each iteration (requires smooth objective)
+ if scalar, will use constant step size
+ if function, will be called with the iteration number
+ :param total: The total number of records (if known)
+ :param callback: a function to be called after each iteration of optimization
+ """
+ if self.metric == "L1" and stepsize is None:
+ raise ValueError("loss function not smooth, cannot use line search")
+
+ self._setup(measurements, total)
+ model = self.model
+ theta = model.potentials
+ mu = model.belief_propagation(theta)
+ ans = self._marginal_loss(mu)
+ if ans[0] == 0:
+ return ans[0]
+
+ nols = stepsize is not None
+ if np.isscalar(stepsize):
+ alpha = float(stepsize)
+ stepsize = alpha # changed from lambda func: stepsize = lambda t: alpha
+
+ if stepsize is None:
+ alpha = 1.0 / self.model.total**2
+ stepsize = (
+ 2.0 * alpha
+ ) # changed from lambda func: stepsize = lambda t: 2.0 * alpha
+
+ for t in range(1, self.iters + 1):
+ if callback is not None:
+ callback(mu)
+ omega, nu = theta, mu
+ curr_loss, dL = ans
+ alpha = stepsize
+ for i in range(25):
+ theta = omega - alpha * dL
+ mu = model.belief_propagation(theta)
+ ans = self._marginal_loss(mu)
+ if nols or curr_loss - ans[0] >= 0.5 * alpha * dL.dot(nu - mu):
+ break
+ alpha *= 0.5
+
+ model.potentials = theta
+ model.marginals = mu
+
+ return ans[0]
+
+ def _marginal_loss(self, marginals, metric=None):
+ """Compute the loss and gradient for a given dictionary of marginals
+
+ :param marginals: A dictionary with keys as projections and values as Factors
+ :return loss: the loss value
+ :return grad: A dictionary with gradient for each marginal
+ """
+ if metric is None:
+ metric = self.metric
+
+ if callable(metric):
+ return metric(marginals)
+
+ loss = 0.0
+ gradient = {}
+
+ for cl in marginals:
+ mu = marginals[cl]
+ gradient[cl] = self.Factor.zeros(mu.domain)
+ for Q, y, noise, proj in self.groups[cl]:
+ c = 1.0 / noise
+ mu2 = mu.project(proj)
+ x = mu2.datavector()
+ diff = c * (Q @ x - y)
+ if metric == "L1":
+ loss += abs(diff).sum()
+ sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff)
+ grad = c * (Q.T @ sign)
+ else:
+ loss += 0.5 * (diff @ diff)
+ grad = c * (Q.T @ diff)
+ gradient[cl] += self.Factor(mu2.domain, grad)
+ return float(loss), CliqueVector(gradient)
+
+ def _setup(self, measurements, total):
+ """Perform necessary setup for running estimation algorithms
+
+ 1. If total is None, find the minimum variance unbiased estimate for total and use that
+ 2. Construct the GraphicalModel
+ * If there are structural_zeros in the distribution, initialize factors appropriately
+ 3. Pre-process measurements into groups so that _marginal_loss may be evaluated efficiently
+ """
+ if total is None:
+ # find the minimum variance estimate of the total given the measurements
+ variances = np.array([])
+ estimates = np.array([])
+ for Q, y, noise, proj in measurements:
+ o = np.ones(Q.shape[1])
+ v = lsmr(Q.T, o, atol=0, btol=0)[0]
+ if np.allclose(Q.T.dot(v), o):
+ variances = np.append(variances, noise**2 * np.dot(v, v))
+ estimates = np.append(estimates, np.dot(v, y))
+ if estimates.size == 0:
+ total = 1
+ else:
+ variance = 1.0 / np.sum(1.0 / variances)
+ estimate = variance * np.sum(estimates / variances)
+ total = max(1, estimate)
+
+ # if not self.warm_start or not hasattr(self, 'model'):
+ # initialize the model and parameters
+ cliques = [m[3] for m in measurements]
+ if self.structural_zeros is not None:
+ cliques += list(self.structural_zeros.keys())
+
+ model = GraphicalModel(
+ self.domain, cliques, total, elimination_order=self.elim_order
+ )
+
+ model.potentials = CliqueVector.zeros(self.domain, model.cliques)
+ model.potentials.combine(self.structural_zeros)
+ if self.warm_start and hasattr(self, "model"):
+ model.potentials.combine(self.model.potentials)
+ self.model = model
+
+ # group the measurements into model cliques
+ cliques = self.model.cliques
+ # self.groups = { cl : [] for cl in cliques }
+ self.groups = defaultdict(lambda: [])
+ for Q, y, noise, proj in measurements:
+ if self.backend == "torch":
+ # third party
+ import torch
+
+ device = self.Factor.device
+ y = torch.tensor(y, dtype=torch.float32, device=device)
+ if isinstance(Q, np.ndarray):
+ Q = torch.tensor(Q, dtype=torch.float32, device=device)
+ elif sparse.issparse(Q):
+ Q = Q.tocoo()
+ idx = torch.LongTensor([Q.row, Q.col])
+ vals = torch.FloatTensor(Q.data)
+ Q = torch.sparse.FloatTensor(idx, vals).to(device)
+
+ # else Q is a Linear Operator, must be compatible with torch
+ m = (Q, y, noise, proj)
+ for cl in sorted(cliques, key=model.domain.size):
+ # (Q, y, noise, proj) tuple
+ if set(proj) <= set(cl):
+ self.groups[cl].append(m)
+ break
+
+ def _lipschitz(self, measurements):
+ """compute lipschitz constant for L2 loss
+
+ Note: must be called after _setup
+ """
+ eigs = {cl: 0.0 for cl in self.model.cliques}
+ for Q, _, noise, proj in measurements:
+ for cl in self.model.cliques:
+ if set(proj) <= set(cl):
+ n = self.domain.size(cl)
+ p = self.domain.size(proj)
+ Q = aslinearoperator(Q)
+ Q.dtype = np.dtype(Q.dtype)
+ eig = eigsh(Q.H * Q, 1)[0][0]
+ eigs[cl] += eig * n / p / noise**2
+ break
+ return max(eigs.values())
+
+ def infer(self, measurements, total=None, engine="MD", callback=None, options={}):
+ # stdlib
+ import warnings
+
+ message = "Function infer is deprecated. Please use estimate instead."
+ warnings.warn(message, DeprecationWarning)
+ return self.estimate(measurements, total, engine, callback, options)
diff --git a/src/synthcity/plugins/core/models/mbi/junction_tree.py b/src/synthcity/plugins/core/models/mbi/junction_tree.py
new file mode 100644
index 00000000..4656c505
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/junction_tree.py
@@ -0,0 +1,131 @@
+# stdlib
+import itertools
+from collections import OrderedDict
+
+# third party
+import networkx as nx
+import numpy as np
+
+
+class JunctionTree:
+ """A JunctionTree is a transformation of a GraphicalModel into a tree structure. It is used
+ to find the maximal cliques in the graphical model, and for specifying the message passing
+ order for belief propagation. The JunctionTree is characterized by an elimination_order,
+ which is chosen greedily by default, but may be passed in if desired.
+ """
+
+ def __init__(self, domain, cliques, elimination_order=None):
+ self.cliques = [tuple(cl) for cl in cliques]
+ self.domain = domain
+ self.graph = self._make_graph()
+ self.tree, self.order = self._make_tree(elimination_order)
+
+ def maximal_cliques(self):
+ """return the list of maximal cliques in the model"""
+ # return list(self.tree.nodes())
+ return list(nx.dfs_preorder_nodes(self.tree))
+
+ def mp_order(self):
+ """return a valid message passing order"""
+ edges = set()
+ messages = [(a, b) for a, b in self.tree.edges()] + [
+ (b, a) for a, b in self.tree.edges()
+ ]
+ for m1 in messages:
+ for m2 in messages:
+ if m1[1] == m2[0] and m1[0] != m2[1]:
+ edges.add((m1, m2))
+ G = nx.DiGraph()
+ G.add_nodes_from(messages)
+ G.add_edges_from(edges)
+ return list(nx.topological_sort(G))
+
+ def separator_axes(self):
+ return {(i, j): tuple(set(i) & set(j)) for i, j in self.mp_order()}
+
+ def neighbors(self):
+ return {i: set(self.tree.neighbors(i)) for i in self.maximal_cliques()}
+
+ def _make_graph(self):
+ G = nx.Graph()
+ G.add_nodes_from(self.domain.attrs)
+ for cl in self.cliques:
+ G.add_edges_from(itertools.combinations(cl, 2))
+ return G
+
+ def _triangulated(self, order):
+ edges = set()
+ G = nx.Graph(self.graph)
+ for node in order:
+ tmp = set(itertools.combinations(G.neighbors(node), 2))
+ edges |= tmp
+ G.add_edges_from(tmp)
+ G.remove_node(node)
+ tri = nx.Graph(self.graph)
+ tri.add_edges_from(edges)
+ cliques = [tuple(c) for c in nx.find_cliques(tri)]
+ cost = sum(self.domain.project(cl).size() for cl in cliques)
+ return tri, cost
+
+ def _greedy_order(self, stochastic=True):
+ order = []
+ domain, cliques = self.domain, self.cliques
+ unmarked = list(domain.attrs)
+ cliques = set(cliques)
+ total_cost = 0
+ for k in range(len(domain)):
+ cost = OrderedDict()
+ for a in unmarked:
+ # all cliques that have a
+ neighbors = list(filter(lambda cl: a in cl, cliques))
+ # variables in this "super-clique"
+ variables = tuple(set.union(set(), *map(set, neighbors)))
+ # domain for the resulting factor
+ newdom = domain.project(variables)
+ # cost of removing a
+ cost[a] = newdom.size()
+
+ # find the best variable to eliminate
+ if stochastic:
+ choices = list(unmarked)
+ costs = np.array([cost[a] for a in choices], dtype=float)
+ probas = np.max(costs) - costs + 1
+ probas /= probas.sum()
+ i = np.random.choice(probas.size, p=probas)
+ a = choices[i]
+ else:
+ a = min(cost, key=lambda a: cost[a])
+
+ # do some cleanup
+ order.append(a)
+ unmarked.remove(a)
+ neighbors = list(filter(lambda cl: a in cl, cliques))
+ variables = tuple(set.union(set(), *map(set, neighbors)) - {a})
+ cliques -= set(neighbors)
+ cliques.add(variables)
+ total_cost += cost[a]
+
+ return order, total_cost
+
+ def _make_tree(self, order=None):
+ if order is None:
+ # orders = [self._greedy_order(stochastic=True) for _ in range(1000)]
+ # orders.append(self._greedy_order(stochastic=False))
+ # order = min(orders, key=lambda x: x[1])[0]
+ order = self._greedy_order(stochastic=False)[0]
+ elif type(order) is int:
+ orders = [self._greedy_order(stochastic=False)] + [
+ self._greedy_order(stochastic=True) for _ in range(order)
+ ]
+ order = min(orders, key=lambda x: x[1])[0]
+ self.elimination_order = order
+ tri, cost = self._triangulated(order)
+ # cliques = [tuple(c) for c in nx.find_cliques(tri)]
+ cliques = sorted([self.domain.canonical(c) for c in nx.find_cliques(tri)])
+ complete = nx.Graph()
+ complete.add_nodes_from(cliques)
+ for c1, c2 in itertools.combinations(cliques, 2):
+ wgt = len(set(c1) & set(c2))
+ complete.add_edge(c1, c2, weight=-wgt)
+ spanning = nx.minimum_spanning_tree(complete)
+ return spanning, order
diff --git a/src/synthcity/plugins/core/models/mbi/local_inference.py b/src/synthcity/plugins/core/models/mbi/local_inference.py
new file mode 100644
index 00000000..5acc0ac7
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/local_inference.py
@@ -0,0 +1,293 @@
+# stdlib
+from collections import defaultdict
+from copy import deepcopy
+
+# third party
+import numpy as np
+from scipy import sparse
+from scipy.sparse.linalg import lsmr
+
+# synthcity absolute
+from synthcity import logger
+
+# synthcity relative
+from . import callbacks
+from .clique_vector import CliqueVector
+from .factor_graph import FactorGraph
+from .region_graph import RegionGraph
+
+"""
+This file implements Approx-Private-PGM from the following paper:
+
+Relaxed Marginal Consistency for Differentially Private Query Answering
+https://arxiv.org/pdf/2109.06153.pdf
+"""
+
+
+class LocalInference:
+ def __init__(
+ self,
+ domain,
+ backend="numpy",
+ structural_zeros={},
+ metric="L2",
+ log=False,
+ iters=1000,
+ warm_start=False,
+ marginal_oracle="convex",
+ inner_iters=1,
+ ):
+ """
+ Class for learning a GraphicalModel from noisy measurements on a data distribution
+
+ :param domain: The domain information (A Domain object)
+ :param backend: numpy or torch backend
+ :param structural_zeros: An encoding of the known (structural) zeros in the distribution.
+ Specified as a dictionary where
+ - each key is a subset of attributes of size r
+ - each value is a list of r-tuples corresponding to impossible attribute settings
+ :param metric: The optimization metric. May be L1, L2 or a custom callable function
+ - custom callable function must consume the marginals and produce the loss and gradient
+ - see FactoredInference._marginal_loss for more information
+ :param log: flag to log iterations of optimization
+ :param iters: number of iterations to optimize for
+ :param warm_start: initialize new model or reuse last model when calling infer multiple times
+ :param marginal_oracle: One of
+ - convex (Region graph, convex Kikuchi entropy)
+ - approx (Region graph, Kikuchi entropy)
+ - pairwise-convex (Factor graph, convex Bethe entropy)
+ - pairwise (Factor graph, Bethe entropy)
+ - Can also pass any and FactorGraph or RegionGraph object
+ """
+ self.domain = domain
+ self.backend = backend
+ self.metric = metric
+ self.log = log
+ self.iters = iters
+ self.warm_start = warm_start
+ self.history = []
+ self.marginal_oracle = marginal_oracle
+ self.inner_iters = inner_iters
+ if backend == "torch":
+ # third party
+ from mbi.torch_factor import Factor
+
+ self.Factor = Factor
+ else:
+ # third party
+ from mbi import Factor
+
+ self.Factor = Factor
+
+ self.structural_zeros = CliqueVector({})
+ for cl in structural_zeros:
+ dom = self.domain.project(cl)
+ fact = structural_zeros[cl]
+ self.structural_zeros[cl] = self.Factor.active(dom, fact)
+
+ def estimate(self, measurements, total=None, callback=None, options={}):
+ """
+ Estimate a GraphicalModel from the given measurements
+
+ :param measurements: a list of (Q, y, noise, proj) tuples, where
+ Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator)
+ y is the noisy answers to the measurement queries
+ noise is the standard deviation of the noise added to y
+ proj defines the marginal used for this measurement set (a subset of attributes)
+ :param total: The total number of records (if known)
+ :param callback: a function to be called after each iteration of optimization
+ :param options: solver specific options passed as a dictionary
+ { param_name : param_value }
+
+ :return model: A GraphicalModel that best matches the measurements taken
+ """
+ options["callback"] = callback
+ if callback is None and self.log:
+ options["callback"] = callbacks.Logger(self)
+ self.mirror_descent(measurements, total, **options)
+ return self.model
+
+ def mirror_descent_auto(self, alpha, iters, callback=None):
+ model = self.model
+ theta0 = model.potentials
+ messages0 = deepcopy(model.messages)
+ theta = theta0
+ mu = model.belief_propagation(theta)
+ l0, _ = self._marginal_loss(mu)
+
+ prev_l = np.inf
+ for t in range(iters):
+ if callback is not None:
+ callback(mu)
+ l, dL = self._marginal_loss(mu)
+ theta = theta - alpha * dL
+ mu = model.belief_propagation(theta)
+ if l > prev_l:
+ if t <= 50:
+ if self.log:
+ logger.debug(
+ f"Reducing learning rate and restarting. alpha/2: {alpha / 2}"
+ )
+ model.potentials = theta0
+ model.messages = messages0
+ return self.mirror_descent_auto(alpha / 2, iters, callback)
+ else:
+ model.damping = (0.9 + model.damping) / 2.0
+ if self.log:
+ logger.debug(
+ f"Increasing damping and continuing. Damping: {model.damping}"
+ )
+ alpha *= 0.5
+ prev_l = l
+
+ # run some extra iterations with no gradient update to make sure things are primal feasible
+ for _ in range(1000):
+ if model.primal_feasibility(mu) < 1.0:
+ break
+ mu = model.belief_propagation(theta)
+ if callback is not None:
+ callback(mu)
+ return l, theta, mu
+
+ def mirror_descent(
+ self, measurements, total=None, initial_alpha=10.0, callback=None
+ ):
+ """Use the mirror descent algorithm to estimate the GraphicalModel
+ See https://web.iem.technion.ac.il/images/user-files/becka/papers/3.pdf
+
+ :param measurements: a list of (Q, y, noise, proj) tuples, where
+ Q is the measurement matrix (a numpy array or scipy sparse matrix or LinearOperator)
+ y is the noisy answers to the measurement queries
+ noise is the standard deviation of the noise added to y
+ proj defines the marginal used for this measurement set (a subset of attributes)
+ :param total: The total number of records (if known)
+ :param stepsize: the learning rate function
+ :param callback: a function to be called after each iteration of optimization
+ """
+ self._setup(measurements, total)
+ l, theta, mu = self.mirror_descent_auto(
+ alpha=initial_alpha, iters=self.iters, callback=callback
+ )
+
+ self.model.potentials = theta
+ self.model.marginals = mu
+
+ return l
+
+ def _marginal_loss(self, marginals, metric=None):
+ """Compute the loss and gradient for a given dictionary of marginals
+
+ :param marginals: A dictionary with keys as projections and values as Factors
+ :return loss: the loss value
+ :return grad: A dictionary with gradient for each marginal
+ """
+ if metric is None:
+ metric = self.metric
+
+ if callable(metric):
+ return metric(marginals)
+
+ loss = 0.0
+ gradient = {}
+
+ for cl in marginals:
+ mu = marginals[cl]
+ gradient[cl] = self.Factor.zeros(mu.domain)
+ for Q, y, noise, proj in self.groups[cl]:
+ c = 1.0 / noise
+ mu2 = mu.project(proj)
+ x = mu2.datavector()
+ diff = c * (Q @ x - y)
+ if metric == "L1":
+ loss += abs(diff).sum()
+ sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff)
+ grad = c * (Q.T @ sign)
+ else:
+ loss += 0.5 * (diff @ diff)
+ grad = c * (Q.T @ diff)
+ gradient[cl] += self.Factor(mu2.domain, grad)
+ return float(loss), CliqueVector(gradient)
+
+ def _setup(self, measurements, total):
+ """Perform necessary setup for running estimation algorithms
+
+ 1. If total is None, find the minimum variance unbiased estimate for total and use that
+ 2. Construct the GraphicalModel
+ * If there are structural_zeros in the distribution, initialize factors appropriately
+ 3. Pre-process measurements into groups so that _marginal_loss may be evaluated efficiently
+ """
+ if total is None:
+ # find the minimum variance estimate of the total given the measurements
+ variances = np.array([])
+ estimates = np.array([])
+ for Q, y, noise, proj in measurements:
+ o = np.ones(Q.shape[1])
+ v = lsmr(Q.T, o, atol=0, btol=0)[0]
+ if np.allclose(Q.T.dot(v), o):
+ variances = np.append(variances, noise**2 * np.dot(v, v))
+ estimates = np.append(estimates, np.dot(v, y))
+ if estimates.size == 0:
+ total = 1
+ else:
+ variance = 1.0 / np.sum(1.0 / variances)
+ estimate = variance * np.sum(estimates / variances)
+ total = max(1, estimate)
+
+ # if not self.warm_start or not hasattr(self, 'model'):
+ # initialize the model and parameters
+ cliques = [m[3] for m in measurements]
+ if self.structural_zeros is not None:
+ cliques += list(self.structural_zeros.keys())
+ if self.marginal_oracle == "approx":
+ model = RegionGraph(
+ self.domain, cliques, total, convex=False, iters=self.inner_iters
+ )
+ elif self.marginal_oracle == "convex":
+ model = RegionGraph(
+ self.domain, cliques, total, convex=True, iters=self.inner_iters
+ )
+ elif self.marginal_oracle == "pairwise":
+ model = FactorGraph(
+ self.domain, cliques, total, convex=False, iters=self.inner_iters
+ )
+ elif self.marginal_oracle == "pairwise-convex":
+ model = FactorGraph(
+ self.domain, cliques, total, convex=True, iters=self.inner_iters
+ )
+ else:
+ model = self.marginal_oracle
+ model.total = total
+
+ if type(self.marginal_oracle) is str:
+ model.potentials = CliqueVector.zeros(self.domain, model.cliques)
+ model.potentials.combine(self.structural_zeros)
+ if self.warm_start and hasattr(self, "model"):
+ model.potentials.combine(self.model.potentials)
+ self.model = model
+
+ # group the measurements into model cliques
+ cliques = self.model.cliques
+ # self.groups = { cl : [] for cl in cliques }
+ self.groups = defaultdict(lambda: [])
+ for Q, y, noise, proj in measurements:
+ if self.backend == "torch":
+ # third party
+ import torch
+
+ device = self.Factor.device
+ y = torch.tensor(y, dtype=torch.float32, device=device)
+ if isinstance(Q, np.ndarray):
+ Q = torch.tensor(Q, dtype=torch.float32, device=device)
+ elif sparse.issparse(Q):
+ Q = Q.tocoo()
+ idx = torch.LongTensor([Q.row, Q.col])
+ vals = torch.FloatTensor(Q.data)
+ Q = torch.sparse.FloatTensor(idx, vals).to(device)
+
+ # else Q is a Linear Operator, must be compatible with torch
+ m = (Q, y, noise, proj)
+ for cl in sorted(cliques, key=model.domain.size):
+ # (Q, y, noise, proj) tuple
+ if set(proj) <= set(cl):
+ self.groups[cl].append(m)
+ break
diff --git a/src/synthcity/plugins/core/models/mbi/mechanism.py b/src/synthcity/plugins/core/models/mbi/mechanism.py
new file mode 100644
index 00000000..2bdd28ed
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/mechanism.py
@@ -0,0 +1,96 @@
+# third party
+import numpy as np
+from scipy import sparse
+from scipy.stats import laplace, norm
+
+# synthcity relative
+from .callbacks import Logger
+from .inference import FactoredInference
+from .local_inference import LocalInference
+
+
+def run(
+ dataset,
+ measurements,
+ eps=1.0,
+ delta=0.0,
+ bounded=True,
+ engine="MD",
+ options={},
+ iters=10000,
+ seed=None,
+ metric="L2",
+ elim_order=None,
+ frequency=1,
+ workload=None,
+ oracle="exact",
+):
+ """
+ Run a mechanism that measures the given measurements and runs inference.
+ This is a convenience method for running end-to-end experiments.
+ """
+
+ domain = dataset.domain
+ total = None
+
+ state = np.random.RandomState(seed)
+
+ if len(measurements) >= 1 and type(measurements[0][0]) is str:
+
+ def matrix(proj):
+ return sparse.eye(domain.project(proj).size())
+
+ measurements = [(proj, matrix(proj)) for proj in measurements]
+
+ l1 = 0
+ l2 = 0
+ for _, Q in measurements:
+ l1 += np.abs(Q).sum(axis=0).max()
+ try:
+ l2 += Q.power(2).sum(axis=0).max() # for spares matrices
+ except BaseException:
+ l2 += np.square(Q).sum(axis=0).max() # for dense matrices
+
+ if bounded:
+ total = dataset.df.shape[0]
+ l1 *= 2
+ l2 *= 2
+
+ if delta > 0:
+ noise = norm(loc=0, scale=np.sqrt(l2 * 2 * np.log(2 / delta)) / eps)
+ else:
+ noise = laplace(loc=0, scale=l1 / eps)
+
+ if workload is None:
+ workload = measurements
+
+ truth = []
+ for (
+ proj,
+ W,
+ ) in workload:
+ x = dataset.project(proj).datavector()
+ y = W.dot(x)
+ truth.append((W, y, proj))
+
+ answers = []
+ for proj, Q in measurements:
+ x = dataset.project(proj).datavector()
+ z = noise.rvs(size=Q.shape[0], random_state=state)
+ y = Q.dot(x)
+ answers.append((Q, y + z, 1.0, proj))
+
+ if oracle == "exact":
+ estimator = FactoredInference(
+ domain, metric=metric, iters=iters, warm_start=False, elim_order=elim_order
+ )
+ else:
+ estimator = LocalInference(
+ domain, metric=metric, iters=iters, warm_start=False, marginal_oracle=oracle
+ )
+ logger = Logger(estimator, true_answers=truth, frequency=frequency)
+ model = estimator.estimate(
+ answers, total, engine=engine, callback=logger, options=options
+ )
+
+ return model, logger, answers
diff --git a/src/synthcity/plugins/core/models/mbi/mixture_inference.py b/src/synthcity/plugins/core/models/mbi/mixture_inference.py
new file mode 100644
index 00000000..6b365d45
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/mixture_inference.py
@@ -0,0 +1,206 @@
+# third party
+import jax.numpy as jnp
+import numpy as np
+import pandas as pd
+from jax import vjp
+from jax.nn import softmax as jax_softmax
+from mbi import Dataset
+from scipy.sparse.linalg import lsmr
+
+""" This file is experimental.
+
+It is a close approximation to the method described in RAP (https://arxiv.org/abs/2103.06641)
+and an even closer approximation to RAP^{softmax} (https://arxiv.org/abs/2106.07153)
+
+Notable differences:
+- Code now shares the same interface as Private-PGM (see FactoredInference)
+- Named model "MixtureOfProducts", as that is one interpretation for the relaxed tabular format
+(at least when softmax is used).
+- Added support for unbounded-DP, with automatic estimate of total.
+"""
+
+
+def estimate_total(measurements):
+ # find the minimum variance estimate of the total given the measurements
+ variances = np.array([])
+ estimates = np.array([])
+ for Q, y, noise, proj in measurements:
+ o = np.ones(Q.shape[1])
+ v = lsmr(Q.T, o, atol=0, btol=0)[0]
+ if np.allclose(Q.T.dot(v), o):
+ variances = np.append(variances, noise**2 * np.dot(v, v))
+ estimates = np.append(estimates, np.dot(v, y))
+ if estimates.size == 0:
+ return 1
+ else:
+ variance = 1.0 / np.sum(1.0 / variances)
+ estimate = variance * np.sum(estimates / variances)
+ return max(1, estimate)
+
+
+def adam(loss_and_grad, x0, iters=250):
+ a = 1.0
+ b1, b2 = 0.9, 0.999
+ eps = 10e-8
+
+ x = x0
+ m = np.zeros_like(x)
+ v = np.zeros_like(x)
+ for t in range(1, iters + 1):
+ l, g = loss_and_grad(x)
+ m = b1 * m + (1 - b1) * g
+ v = b2 * v + (1 - b2) * g**2
+ mhat = m / (1 - b1**t)
+ vhat = v / (1 - b2**t)
+ x = x - a * mhat / (np.sqrt(vhat) + eps)
+ return x
+
+
+def synthetic_col(counts, total):
+ counts *= total / counts.sum()
+ frac, integ = np.modf(counts)
+ integ = integ.astype(int)
+ extra = total - integ.sum()
+ if extra > 0:
+ idx = np.random.choice(counts.size, extra, False, frac / frac.sum())
+ integ[idx] += 1
+ vals = np.repeat(np.arange(counts.size), integ)
+ np.random.shuffle(vals)
+ return vals
+
+
+class MixtureOfProducts:
+ def __init__(self, products, domain, total):
+ self.products = products
+ self.domain = domain
+ self.total = total
+ self.num_components = next(iter(products.values())).shape[0]
+
+ def project(self, cols):
+ products = {col: self.products[col] for col in cols}
+ domain = self.domain.project(cols)
+ return MixtureOfProducts(products, domain, self.total)
+
+ def datavector(self, flatten=True):
+ letters = "bcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"[
+ : len(self.domain)
+ ]
+ formula = (
+ ",".join(["a%s" % letter for letter in letters]) + "->" + "".join(letters)
+ )
+ components = [self.products[col] for col in self.domain]
+ ans = np.einsum(formula, *components) * self.total / self.num_components
+ return ans.flatten() if flatten else ans
+
+ def synthetic_data(self, rows=None):
+ total = rows or int(self.total)
+ subtotal = total // self.num_components + 1
+
+ dfs = []
+ for i in range(self.num_components):
+ df = pd.DataFrame()
+ for col in self.products:
+ counts = self.products[col][i]
+ df[col] = synthetic_col(counts, subtotal)
+ dfs.append(df)
+
+ df = pd.concat(dfs).sample(frac=1).reset_index(drop=True)[:total]
+ return Dataset(df, self.domain)
+
+
+class MixtureInference:
+ def __init__(
+ self, domain, components=10, metric="L2", iters=2500, warm_start=False
+ ):
+ """
+ :param domain: A Domain object
+ :param components: The number of mixture components
+ :metric: The metric to use for the loss function (can be callable)
+ """
+ self.domain = domain
+ self.components = components
+ self.metric = metric
+ self.iters = iters
+ self.warm_start = warm_start
+ self.params = np.random.normal(
+ loc=0, scale=0.25, size=sum(domain.shape) * components
+ )
+
+ def estimate(self, measurements, total=None, alpha=0.1):
+ if total is None:
+ total = estimate_total(measurements)
+ self.measurements = measurements
+ cliques = [M[-1] for M in measurements]
+ letters = "bcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+
+ def get_products(params):
+ products = {}
+ idx = 0
+ for col in self.domain:
+ n = self.domain[col]
+ k = self.components
+ products[col] = jax_softmax(
+ params[idx : idx + k * n].reshape(k, n), axis=1
+ )
+ idx += k * n
+ return products
+
+ def marginals_from_params(params):
+ products = get_products(params)
+ mu = {}
+ for cl in cliques:
+ let = letters[: len(cl)]
+ formula = (
+ ",".join(["a%s" % letter for letter in let]) + "->" + "".join(let)
+ )
+ components = [products[col] for col in cl]
+ ans = jnp.einsum(formula, *components) * total / self.components
+ mu[cl] = ans.flatten()
+ return mu
+
+ def loss_and_grad(params):
+ # For computing dL / dmu we will use ordinary numpy so as to support scipy sparse and linear operator inputs
+ # For computing dL / dparams we will use jax to avoid manually deriving gradients
+ params = jnp.array(params)
+ mu, backprop = vjp(marginals_from_params, params)
+ mu = {cl: np.array(mu[cl]) for cl in cliques}
+ loss, dL = self._marginal_loss(mu)
+ dL = {cl: jnp.array(dL[cl]) for cl in cliques}
+ dparams = backprop(dL)
+ return loss, np.array(dparams[0])
+
+ if not self.warm_start:
+ self.params = np.random.normal(
+ loc=0, scale=0.25, size=sum(self.domain.shape) * self.components
+ )
+ self.params = adam(loss_and_grad, self.params, iters=self.iters)
+ products = get_products(self.params)
+ return MixtureOfProducts(products, self.domain, total)
+
+ def _marginal_loss(self, marginals, metric=None):
+ """Compute the loss and gradient for a given dictionary of marginals
+
+ :param marginals: A dictionary with keys as projections and values as Factors
+ :return loss: the loss value
+ :return grad: A dictionary with gradient for each marginal
+ """
+ if metric is None:
+ metric = self.metric
+
+ loss = 0.0
+ gradient = {cl: np.zeros_like(marginals[cl]) for cl in marginals}
+
+ for Q, y, noise, cl in self.measurements:
+ x = marginals[cl]
+ c = 1.0 / noise
+ diff = c * (Q @ x - y)
+ if metric == "L1":
+ loss += abs(diff).sum()
+ sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff)
+ grad = c * (Q.T @ sign)
+ else:
+ loss += 0.5 * (diff @ diff)
+ grad = c * (Q.T @ diff)
+ gradient[cl] += grad
+
+ return float(loss), gradient
diff --git a/src/synthcity/plugins/core/models/mbi/public_inference.py b/src/synthcity/plugins/core/models/mbi/public_inference.py
new file mode 100644
index 00000000..d94b73a2
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/public_inference.py
@@ -0,0 +1,128 @@
+# third party
+import numpy as np
+from scipy.sparse.linalg import lsmr
+from scipy.special import logsumexp
+
+# synthcity relative
+from .clique_vector import CliqueVector
+from .dataset import Dataset
+from .factor import Factor
+
+""" This file is experimental.
+It is an attempt to re-implement and generalize the technique used in PMW^{Pub}.
+https://arxiv.org/pdf/2102.08598.pdf
+
+Notable differences:
+- Shares the same interface as Private-PGM (see FactoredInference)
+- Supports unbounded differential privacy, with automatic estimate of total
+- Supports arbitrary measurements over the data marginals
+- Solves an L2 minimization problem (by default), but can pass other loss functions if desired.
+"""
+
+
+def entropic_mirror_descent(loss_and_grad, x0, total, iters=250):
+ logP = np.log(x0 + np.nextafter(0, 1)) + np.log(total) - np.log(x0.sum())
+ P = np.exp(logP)
+ P = x0 * total / x0.sum()
+ loss, dL = loss_and_grad(P)
+ alpha = 1.0
+ begun = False
+
+ for _ in range(iters):
+ logQ = logP - alpha * dL
+ logQ += np.log(total) - logsumexp(logQ)
+ Q = np.exp(logQ)
+ # Q = P * np.exp(-alpha*dL)
+ # Q *= total / Q.sum()
+ new_loss, new_dL = loss_and_grad(Q)
+
+ if loss - new_loss >= 0.5 * alpha * dL.dot(P - Q):
+ logP = logQ
+ loss, dL = new_loss, new_dL
+ # increase step size if we haven't already decreased it at least once
+ if not begun:
+ alpha *= 2
+ else:
+ alpha *= 0.5
+ begun = True
+
+ return np.exp(logP)
+
+
+def estimate_total(measurements):
+ # find the minimum variance estimate of the total given the measurements
+ variances = np.array([])
+ estimates = np.array([])
+ for Q, y, noise, proj in measurements:
+ o = np.ones(Q.shape[1])
+ v = lsmr(Q.T, o, atol=0, btol=0)[0]
+ if np.allclose(Q.T.dot(v), o):
+ variances = np.append(variances, noise**2 * np.dot(v, v))
+ estimates = np.append(estimates, np.dot(v, y))
+ if estimates.size == 0:
+ return 1
+ else:
+ variance = 1.0 / np.sum(1.0 / variances)
+ estimate = variance * np.sum(estimates / variances)
+ return max(1, estimate)
+
+
+class PublicInference:
+ def __init__(self, public_data, metric="L2"):
+ self.public_data = public_data
+ self.metric = metric
+ self.weights = np.ones(self.public_data.records)
+
+ def estimate(self, measurements, total=None):
+ if total is None:
+ total = estimate_total(measurements)
+ self.measurements = measurements
+ cliques = [M[-1] for M in measurements]
+
+ def loss_and_grad(weights):
+ est = Dataset(self.public_data.df, self.public_data.domain, weights)
+ mu = CliqueVector.from_data(est, cliques)
+ loss, dL = self._marginal_loss(mu)
+ dweights = np.zeros(weights.size)
+ for cl in dL:
+ idx = est.project(cl).df.values
+ dweights += dL[cl].values[tuple(idx.T)]
+ return loss, dweights
+
+ # bounds = [(0,None) for _ in self.weights]
+ # res = minimize(loss_and_grad, x0=self.weights, method='L-BFGS-B', jac=True, bounds=bounds)
+ # self.weights = res.x
+
+ self.weights = entropic_mirror_descent(loss_and_grad, self.weights, total)
+ return Dataset(self.public_data.df, self.public_data.domain, self.weights)
+
+ def _marginal_loss(self, marginals, metric=None):
+ """Compute the loss and gradient for a given dictionary of marginals
+
+ :param marginals: A dictionary with keys as projections and values as Factors
+ :return loss: the loss value
+ :return grad: A dictionary with gradient for each marginal
+ """
+ if metric is None:
+ metric = self.metric
+
+ if callable(metric):
+ return metric(marginals)
+
+ loss = 0.0
+ gradient = {cl: Factor.zeros(marginals[cl].domain) for cl in marginals}
+
+ for Q, y, noise, cl in self.measurements:
+ mu = marginals[cl]
+ c = 1.0 / noise
+ x = mu.datavector()
+ diff = c * (Q @ x - y)
+ if metric == "L1":
+ loss += abs(diff).sum()
+ sign = diff.sign() if hasattr(diff, "sign") else np.sign(diff)
+ grad = c * (Q.T @ sign)
+ else:
+ loss += 0.5 * (diff @ diff)
+ grad = c * (Q.T @ diff)
+ gradient[cl] += Factor(mu.domain, grad)
+ return float(loss), CliqueVector(gradient)
diff --git a/src/synthcity/plugins/core/models/mbi/region_graph.py b/src/synthcity/plugins/core/models/mbi/region_graph.py
new file mode 100644
index 00000000..a2cb7a3e
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/region_graph.py
@@ -0,0 +1,544 @@
+# stdlib
+import itertools
+from collections import defaultdict
+
+# third party
+import networkx as nx
+import numpy as np
+from disjoint_set import DisjointSet
+
+# synthcity relative
+from .clique_vector import CliqueVector
+from .factor import Factor
+
+
+class RegionGraph:
+ def __init__(
+ self,
+ domain,
+ cliques,
+ total=1.0,
+ minimal=True,
+ convex=True,
+ iters=25,
+ convergence=1e-3,
+ damping=0.5,
+ ):
+ self.domain = domain
+ self.cliques = cliques
+ if not convex:
+ self.cliques = []
+ for r in cliques:
+ if not any(set(r) < set(s) for s in cliques):
+ self.cliques.append(r)
+ self.total = total
+ self.minimal = minimal
+ self.convex = convex
+ self.iters = iters
+ self.convergence = convergence
+ self.damping = damping
+ if convex:
+ self.belief_propagation = self.hazan_peng_shashua
+ else:
+ self.belief_propagation = self.generalized_belief_propagation
+ self.build_graph()
+ self.cliques = sorted(self.regions, key=len)
+ self.potentials = CliqueVector.zeros(domain, self.cliques)
+ self.marginals = CliqueVector.uniform(domain, self.cliques) * total
+
+ def show(self):
+ # third party
+ import matplotlib.pyplot as plt
+
+ labels = {r: "".join(r) for r in self.regions}
+
+ pos = {}
+ xloc = defaultdict(lambda: 0)
+ for r in sorted(self.regions):
+ y = len(r)
+ pos[r] = (xloc[y] + 0.5 * (y % 2), y)
+ xloc[y] += 1
+
+ nx.draw(self.G, pos=pos, node_color="orange", node_size=1000)
+ nx.draw(
+ self.G, pos=pos, nodelist=self.cliques, node_color="green", node_size=1000
+ )
+ nx.draw_networkx_labels(self.G, pos=pos, labels=labels)
+ plt.show()
+
+ def project(self, attrs, maxiter=100, alpha=None):
+ if type(attrs) is list:
+ attrs = tuple(attrs)
+
+ for cl in self.cliques:
+ if set(attrs) <= set(cl):
+ return self.marginals[cl].project(attrs)
+
+ # Use multiplicative weights/entropic mirror descent to solve projection problem
+ intersections = [set(cl) & set(attrs) for cl in self.cliques]
+ target_cliques = [
+ tuple(t) for t in intersections if not any(t < s for s in intersections)
+ ]
+ target_cliques = list(set(target_cliques))
+ target_mu = CliqueVector.from_data(self, target_cliques)
+
+ if len(target_cliques) == 0:
+ return Factor.uniform(self.domain.project(attrs)) * self.total
+ # P = Factor.uniform(self.domain.project(attrs))*self.total
+ # Use a smart initialization
+ P = estimate_kikuchi_marginal(self.domain.project(attrs), self.total, target_mu)
+ if alpha is None:
+ # start with a safe step size
+ alpha = 1.0 / (self.total * len(target_cliques))
+
+ curr_mu = CliqueVector.from_data(P, target_cliques)
+ diff = curr_mu - target_mu
+ curr_loss, dL = diff.dot(diff), sum(diff.values()).expand(P.domain)
+ begun = False
+
+ for _ in range(maxiter):
+ if curr_loss <= 1e-8:
+ return P # stop early if marginals are almost exactly realized
+ Q = P * (-alpha * dL).exp()
+ Q *= self.total / Q.sum()
+ curr_mu = CliqueVector.from_data(Q, target_cliques)
+ diff = curr_mu - target_mu
+ loss = diff.dot(diff)
+
+ if curr_loss - loss >= 0.5 * alpha * dL.dot(P - Q):
+ P = Q
+ curr_loss = loss
+ dL = sum(diff.values()).expand(P.domain)
+ # increase step size if we haven't already decreased it at least once
+ if not begun:
+ alpha *= 2
+ else:
+ alpha *= 0.5
+ begun = True
+
+ return P
+
+ def primal_feasibility(self, mu):
+ ans = 0
+ count = 0
+ for r in self.cliques:
+ for s in self.children[r]:
+ x = mu[r].project(s).datavector()
+ y = mu[s].datavector()
+ err = np.linalg.norm(x - y, 1)
+ ans += err
+ count += 1
+ return 0 if count == 0 else ans / count
+
+ def is_converged(self, mu):
+ return self.primal_feasibility(mu) <= self.convergence
+
+ def build_graph(self):
+ # Alg 11.3 of Koller & Friedman
+ regions = set(self.cliques)
+ size = 0
+ while len(regions) > size:
+ size = len(regions)
+ for r1, r2 in itertools.combinations(regions, 2):
+ z = tuple(sorted(set(r1) & set(r2)))
+ if len(z) > 0 and z not in regions:
+ regions.update({z})
+
+ G = nx.DiGraph()
+ G.add_nodes_from(regions)
+ for r1 in regions:
+ for r2 in regions:
+ if set(r2) < set(r1) and not any(
+ set(r2) < set(r3) and set(r3) < set(r1) for r3 in regions
+ ):
+ G.add_edge(r1, r2)
+
+ H = G.reverse()
+ G1, H1 = nx.transitive_closure(G), nx.transitive_closure(H)
+
+ self.children = {r: list(G.neighbors(r)) for r in regions}
+ self.parents = {r: list(H.neighbors(r)) for r in regions}
+ self.descendants = {r: list(G1.neighbors(r)) for r in regions}
+ self.ancestors = {r: list(H1.neighbors(r)) for r in regions}
+ self.forebears = {r: set([r] + self.ancestors[r]) for r in regions}
+ self.downp = {r: set([r] + self.descendants[r]) for r in regions}
+
+ if self.minimal:
+ min_edges = []
+ for r in regions:
+ ds = DisjointSet()
+ for u in self.parents[r]:
+ ds.find(u)
+ for u, v in itertools.combinations(self.parents[r], 2):
+ uv = set(self.ancestors[u]) & set(self.ancestors[v])
+ if len(uv) > 0:
+ ds.union(u, v)
+ canonical = set()
+ for u in self.parents[r]:
+ canonical.update({ds.find(u)})
+ # if len(canonical) > 1:# or r in self.cliques:
+ min_edges.extend([(u, r) for u in canonical])
+ # G = nx.DiGraph(min_edges)
+ # regions = list(G.nodes)
+ G = nx.DiGraph()
+ G.add_nodes_from(regions)
+ G.add_edges_from(min_edges)
+
+ H = G.reverse()
+ G1, H1 = nx.transitive_closure(G), nx.transitive_closure(H)
+
+ self.children = {r: list(G.neighbors(r)) for r in regions}
+ self.parents = {r: list(H.neighbors(r)) for r in regions}
+ # self.descendants = { r : list(G1.neighbors(r)) for r in regions }
+ # self.ancestors = { r : list(H1.neighbors(r)) for r in regions }
+ # self.forebears = { r : set([r] + self.ancestors[r]) for r in regions }
+ # self.downp = { r : set([r] + self.descendants[r]) for r in regions }
+
+ self.G = G
+ self.regions = regions
+
+ if self.convex:
+ self.counting_numbers = {r: 1.0 for r in regions}
+ else:
+ moebius = {}
+
+ def get_counting_number(r):
+ if r not in moebius:
+ moebius[r] = 1 - sum(
+ get_counting_number(s) for s in self.ancestors[r]
+ )
+ return moebius[r]
+
+ for r in regions:
+ get_counting_number(r)
+ self.counting_numbers = moebius
+
+ if self.minimal:
+ # https://people.eecs.berkeley.edu/~ananth/2002+/Payam/submittedkikuchi.pdf
+ # Eq. 30 and 31
+ N, D, B = {}, {}, {}
+ for r in regions:
+ B[r] = set()
+ for p in self.parents[r]:
+ B[r].add((p, r))
+ for d in self.descendants[r]:
+ for p in set(self.parents[d]) - {r} - set(self.descendants[r]):
+ B[r].add((p, d))
+
+ for p in self.regions:
+ for r in self.children[p]:
+ N[p, r], D[p, r] = set(), set()
+ for s in self.parents[p]:
+ N[p, r].add((s, p))
+ for d in self.descendants[p]:
+ for s in (
+ set(self.parents[d]) - {p} - set(self.descendants[p])
+ ):
+ N[p, r].add((s, d))
+ for s in set(self.parents[r]) - {p}:
+ D[p, r].add((s, r))
+ for d in self.descendants[r]:
+ for p1 in (
+ set(self.parents[d]) - {r} - set(self.descendants[r])
+ ):
+ D[p, r].add((p1, d))
+ cancel = N[p, r] & D[p, r]
+ N[p, r] = N[p, r] - cancel
+ D[p, r] = D[p, r] - cancel
+
+ self.N, self.D, self.B = N, D, B
+
+ else:
+ # From Yedida et al. for fully saturated region graphs
+ # for sending messages ru --> rd and computing beliefs B_r
+ N, D, B = {}, {}, {}
+ for r in regions:
+ B[r] = [(ru, r) for ru in self.parents[r]]
+ for rd in self.descendants[r]:
+ for ru in set(self.parents[rd]) - self.downp[r]:
+ B[r].append((ru, rd))
+
+ for ru in regions:
+ for rd in self.children[ru]:
+ fu, fd = self.downp[ru], self.downp[rd]
+
+ def cond(r):
+ return not r[0] in fu and r[1] in (fu - fd)
+
+ N[ru, rd] = [e for e in G.edges if cond(e)]
+
+ def cond(r):
+ return r[0] in (fu - fd) and r[1] in fd and r != (ru, rd)
+
+ D[ru, rd] = [e for e in G.edges if cond(e)]
+
+ self.N, self.D, self.B = N, D, B
+
+ self.messages = {}
+ self.message_order = []
+ for ru in sorted(
+ regions, key=len
+ ): # nx.topological_sort(H): # should be G or H?
+ for rd in self.children[ru]:
+ self.message_order.append((ru, rd))
+ self.messages[ru, rd] = Factor.zeros(self.domain.project(rd))
+ self.messages[rd, ru] = Factor.zeros(
+ self.domain.project(rd)
+ ) # only for hazan et al
+
+ def generalized_belief_propagation(self, potentials, callback=None):
+ # https://users.cs.duke.edu/~brd/Teaching/Bio/asmb/current/4paper/4-2.pdf
+ pot = {}
+ for r in self.regions:
+ if r in self.cliques:
+ pot[r] = potentials[r]
+ else:
+ pot[r] = Factor.zeros(self.domain.project(r))
+
+ for _ in range(self.iters):
+ new = {}
+ for ru, rd in self.message_order:
+ # Yedida et al. strongly recommend using updated messages for LHS (denom in our case)
+ # num = sum(pot[c] for c in self.downp[ru] if c != rd)
+ num = pot[ru]
+ num = num + sum(self.messages[r1, r2] for r1, r2 in self.N[ru, rd])
+ denom = sum(new[r1, r2] for r1, r2 in self.D[ru, rd])
+ diff = tuple(set(ru) - set(rd))
+ new[ru, rd] = num.logsumexp(diff) - denom
+ new[ru, rd] -= new[ru, rd].logsumexp()
+
+ # self.messages = new
+ for ru, rd in self.message_order:
+ self.messages[ru, rd] = 0.5 * self.messages[ru, rd] + 0.5 * new[ru, rd]
+ # ru, rd = self.message_order[0]
+
+ marginals = {}
+ for r in self.cliques:
+ # belief = sum(potentials[c] for c in self.downp[r]) + sum(self.messages[r1,r2] for r1,r2 in self.B[r])
+ belief = potentials[r] + sum(self.messages[r1, r2] for r1, r2 in self.B[r])
+ belief += np.log(self.total) - belief.logsumexp()
+ marginals[r] = belief.exp()
+
+ return CliqueVector(marginals)
+
+ def hazan_peng_shashua(self, potentials, callback=None):
+ # https://arxiv.org/pdf/1210.4881.pdf
+ c0 = self.counting_numbers
+ pot = {}
+ for r in self.regions:
+ if r in self.cliques:
+ pot[r] = potentials[r]
+ else:
+ pot[r] = Factor.zeros(self.domain.project(r))
+
+ messages = self.messages
+ # for p in sorted(self.regions, key=len): #nx.topological_sort(H): # should be G or H?
+ # for r in self.children[p]:
+ # messages[p,r] = Factor.zeros(self.domain.project(r))
+ # messages[r,p] = Factor.zeros(self.domain.project(r))
+
+ cc = {}
+ for r in self.regions:
+ for p in self.parents[r]:
+ cc[p, r] = c0[p] / (c0[r] + sum(c0[p1] for p1 in self.parents[r]))
+
+ for _ in range(self.iters):
+ new = {}
+ for r in self.regions:
+ for p in self.parents[r]:
+ new[p, r] = (
+ pot[p]
+ + sum(messages[c, p] for c in self.children[p] if c != r)
+ - sum(messages[p, p1] for p1 in self.parents[p])
+ ) / c0[p]
+ new[p, r] = c0[p] * new[p, r].logsumexp(tuple(set(p) - set(r)))
+ new[p, r] -= new[p, r].logsumexp()
+
+ for r in self.regions:
+ for p in self.parents[r]:
+ new[r, p] = (
+ cc[p, r]
+ * (
+ pot[r]
+ + sum(messages[c, r] for c in self.children[r])
+ + sum(messages[p1, r] for p1 in self.parents[r])
+ )
+ - messages[p, r]
+ )
+ # new[r,p] = cc[p,r]*(pot[r] + sum(messages[c,r] for c in self.children[r]) + sum(new[p1,r] for p1 in self.parents[r])) - new[p,r]
+ new[r, p] -= new[r, p].logsumexp()
+
+ # messages = new
+ # Damping is not described in paper, but is needed to get convergence for dense graphs
+ rho = self.damping
+ for p in self.regions:
+ for r in self.children[p]:
+ messages[p, r] = rho * messages[p, r] + (1.0 - rho) * new[p, r]
+ messages[r, p] = rho * messages[r, p] + (1.0 - rho) * new[r, p]
+ mu = {}
+ for r in self.regions:
+ belief = (
+ pot[r]
+ + sum(messages[c, r] for c in self.children[r])
+ - sum(messages[r, p] for p in self.parents[r])
+ ) / c0[r]
+ belief += np.log(self.total) - belief.logsumexp()
+ mu[r] = belief.exp()
+
+ if callback is not None:
+ callback(mu)
+
+ if self.is_converged(mu):
+ self.messages = messages
+ return CliqueVector(mu)
+
+ self.messages = messages
+ return CliqueVector(mu)
+
+ def wiegerinck(self, potentials, callback=None):
+ c = self.counting_numbers
+ m = {}
+ for delta in self.regions:
+ m[delta] = 0
+ for alpha in self.ancestors[delta]:
+ m[delta] += c[alpha]
+
+ Q = {}
+ for r in self.regions:
+ if r in self.cliques:
+ Q[r] = potentials[r] / c[r]
+ else:
+ Q[r] = Factor.zeros(self.domain.project(r))
+
+ inner = [r for r in self.regions if len(self.parents[r]) > 0]
+
+ def diff(r, s):
+ return tuple(set(r) - set(s))
+
+ for _ in range(self.iters):
+ for r in inner:
+ A = c[r] / (m[r] + c[r])
+ B = m[r] / (m[r] + c[r])
+ Qbar = (
+ sum(c[s] * Q[s].logsumexp(diff(s, r)) for s in self.ancestors[r])
+ / m[r]
+ )
+ Q[r] = Q[r] * A + Qbar * B
+ Q[r] -= Q[r].logsumexp()
+ for s in self.ancestors[r]:
+ Q[s] = Q[s] + Q[r] - Q[s].logsumexp(diff(s, r))
+ Q[s] -= Q[s].logsumexp()
+
+ marginals = {}
+ for r in self.regions:
+ marginals[r] = (Q[r] + np.log(self.total) - Q[r].logsumexp()).exp()
+ if callback is not None:
+ callback(marginals)
+
+ return CliqueVector(marginals)
+
+ def loh_wibisono(self, potentials, callback=None):
+ # https://papers.nips.cc/paper/2014/file/39027dfad5138c9ca0c474d71db915c3-Paper.pdf
+ pot = {}
+ for r in self.regions:
+ if r in self.cliques:
+ pot[r] = potentials[r]
+ else:
+ pot[r] = Factor.zeros(self.domain.project(r))
+
+ rho = self.counting_numbers
+
+ for _ in range(self.iters):
+ new = {}
+ for s, r in self.message_order:
+ diff = tuple(set(s) - set(r))
+ num = pot[s] / rho[s]
+ for v in self.parents[s]:
+ num += self.messages[v, s] * rho[v] / rho[s]
+ for w in self.children[s]:
+ if w != r:
+ num -= self.messages[s, w]
+ num = num.logsumexp(diff)
+ denom = pot[r] / rho[r]
+ for u in self.parents[r]:
+ if u != s:
+ denom += self.messages[u, r] * rho[u] / rho[r]
+ for t in self.children[r]:
+ denom -= self.messages[r, t]
+
+ new[s, r] = rho[r] / (rho[r] + rho[s]) * (num - denom)
+ new[s, r] -= new[s, r].logsumexp()
+
+ for ru, rd in self.message_order:
+ self.messages[ru, rd] = 0.5 * self.messages[ru, rd] + 0.5 * new[ru, rd]
+
+ # ru, rd = self.message_order[0]
+
+ marginals = {}
+ for r in self.regions:
+ belief = pot[r] / rho[r]
+ for s in self.parents[r]:
+ belief += self.messages[s, r] * rho[s] / rho[r]
+ for t in self.children[r]:
+ belief -= self.messages[r, t]
+ belief += np.log(self.total) - belief.logsumexp()
+ marginals[r] = belief.exp()
+ if callback is not None:
+ callback(marginals)
+
+ return CliqueVector(marginals)
+
+ def kikuchi_entropy(self, marginals):
+ """
+ Return the Bethe Entropy and the gradient with respect to the marginals
+
+ """
+ weights = self.counting_numbers
+ entropy = 0
+ dmarginals = {}
+ for cl in self.regions:
+ mu = marginals[cl] / self.total
+ entropy += weights[cl] * (mu * mu.log()).sum()
+ dmarginals[cl] = weights[cl] * (1 + mu.log()) / self.total
+ return -entropy, -1 * CliqueVector(dmarginals)
+
+ def mle(self, mu):
+ return -1 * self.kikuchi_entropy(mu)[1]
+
+
+def estimate_kikuchi_marginal(domain, total, marginals):
+ marginals = dict(marginals)
+ regions = set(marginals.keys())
+ size = 0
+ while len(regions) > size:
+ size = len(regions)
+ for r1, r2 in itertools.combinations(regions, 2):
+ z = tuple(sorted(set(r1) & set(r2)))
+ if len(z) > 0 and z not in regions:
+ marginals[z] = marginals[r1].project(z)
+ regions.update({z})
+
+ G = nx.DiGraph()
+ G.add_nodes_from(regions)
+ for r1 in regions:
+ for r2 in regions:
+ if set(r2) < set(r1) and not any(
+ set(r2) < set(r3) and set(r3) < set(r1) for r3 in regions
+ ):
+ G.add_edge(r1, r2)
+
+ H1 = nx.transitive_closure(G.reverse())
+ ancestors = {r: list(H1.neighbors(r)) for r in regions}
+ moebius = {}
+
+ def get_counting_number(r):
+ if r not in moebius:
+ moebius[r] = 1 - sum(get_counting_number(s) for s in ancestors[r])
+ return moebius[r]
+
+ logP = Factor.zeros(domain)
+ for r in regions:
+ kr = get_counting_number(r)
+ logP += kr * marginals[r].log()
+ logP += np.log(total) - logP.logsumexp()
+ return logP.exp()
diff --git a/src/synthcity/plugins/core/models/mbi/torch_factor.py b/src/synthcity/plugins/core/models/mbi/torch_factor.py
new file mode 100644
index 00000000..06c9c905
--- /dev/null
+++ b/src/synthcity/plugins/core/models/mbi/torch_factor.py
@@ -0,0 +1,206 @@
+# third party
+import numpy as np
+import torch
+
+
+class Factor:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ def __init__(self, domain, values):
+ """Initialize a factor over the given domain
+
+ :param domain: the domain of the factor
+ :param values: the ndarray or tensor of factor values (for each element of the domain)
+
+ Note: values may be a flattened 1d array or a ndarray with same shape as domain
+ """
+ if type(values) == np.ndarray:
+ values = torch.tensor(values, dtype=torch.float32, device=Factor.device)
+ if domain.size() != values.nelement():
+ raise ValueError("domain size does not match values size")
+ if len(values.shape) != 1 and values.shape != domain.shape:
+ raise ValueError("invalid shape for values array")
+ self.domain = domain
+ self.values = values.reshape(domain.shape).to(Factor.device)
+
+ @staticmethod
+ def zeros(domain):
+ return Factor(domain, torch.zeros(domain.shape, device=Factor.device))
+
+ @staticmethod
+ def ones(domain):
+ return Factor(domain, torch.ones(domain.shape, device=Factor.device))
+
+ @staticmethod
+ def random(domain):
+ return Factor(domain, torch.rand(domain.shape, device=Factor.device))
+
+ @staticmethod
+ def uniform(domain):
+ return Factor.ones(domain) / domain.size()
+
+ @staticmethod
+ def active(domain, structural_zeros):
+ """create a factor that is 0 everywhere except in positions present in
+ 'structural_zeros', where it is -infinity
+
+ :param: domain: the domain of this factor
+ :param: structural_zeros: a list of values that are not possible
+ """
+ idx = tuple(np.array(structural_zeros).T)
+ vals = torch.zeros(domain.shape, device=Factor.device)
+ vals[idx] = -np.inf
+ return Factor(domain, vals)
+
+ def expand(self, domain):
+ if not domain.contains(self.domain):
+ raise AssertionError("expanded domain must contain current domain")
+ dims = len(domain) - len(self.domain)
+ values = self.values.view(self.values.size() + tuple([1] * dims))
+ ax = domain.axes(self.domain.attrs)
+ # need to find replacement for moveaxis
+ ax = ax + tuple(i for i in range(len(domain)) if i not in ax)
+ ax = tuple(np.argsort(ax))
+ values = values.permute(ax)
+ values = values.expand(domain.shape)
+ return Factor(domain, values)
+
+ def transpose(self, attrs):
+ if set(attrs) != set(self.domain.attrs):
+ raise AssertionError("attrs must be same as domain attributes")
+ newdom = self.domain.project(attrs)
+ ax = newdom.axes(self.domain.attrs)
+ ax = tuple(np.argsort(ax))
+ values = self.values.permute(ax)
+ return Factor(newdom, values)
+
+ def project(self, attrs, agg="sum"):
+ """
+ project the factor onto a list of attributes (in order)
+ using either sum or logsumexp to aggregate along other attributes
+ """
+ if agg not in ["sum", "logsumexp"]:
+ raise ValueError("agg must be sum or logsumexp")
+ marginalized = self.domain.marginalize(attrs)
+ if agg == "sum":
+ ans = self.sum(marginalized.attrs)
+ elif agg == "logsumexp":
+ ans = self.logsumexp(marginalized.attrs)
+ return ans.transpose(attrs)
+
+ def sum(self, attrs=None):
+ if attrs is None:
+ return float(self.values.sum())
+ elif attrs == tuple():
+ return self
+ axes = self.domain.axes(attrs)
+ values = self.values.sum(dim=axes)
+ newdom = self.domain.marginalize(attrs)
+ return Factor(newdom, values)
+
+ def logsumexp(self, attrs=None):
+ if attrs is None:
+ return float(
+ self.values.logsumexp(dim=tuple(range(len(self.values.shape))))
+ )
+ elif attrs == tuple():
+ return self
+ axes = self.domain.axes(attrs)
+ values = self.values.logsumexp(dim=axes)
+ newdom = self.domain.marginalize(attrs)
+ return Factor(newdom, values)
+
+ def logaddexp(self, other):
+ return NotImplementedError
+
+ def max(self, attrs=None):
+ if attrs is None:
+ return float(self.values.max())
+ return NotImplementedError # torch.max does not behave like numpy
+
+ def condition(self, evidence):
+ """evidence is a dictionary where
+ keys are attributes, and
+ values are elements of the domain for that attribute"""
+ slices = [evidence[a] if a in evidence else slice(None) for a in self.domain]
+ newdom = self.domain.marginalize(evidence.keys())
+ values = self.values[tuple(slices)]
+ return Factor(newdom, values)
+
+ def copy(self, out=None):
+ if out is None:
+ return Factor(self.domain, self.values.clone())
+ np.copyto(out.values, self.values)
+ return out
+
+ def __mul__(self, other):
+ if np.isscalar(other):
+ return Factor(self.domain, other * self.values)
+ newdom = self.domain.merge(other.domain)
+ factor1 = self.expand(newdom)
+ factor2 = other.expand(newdom)
+ return Factor(newdom, factor1.values * factor2.values)
+
+ def __add__(self, other):
+ if np.isscalar(other):
+ return Factor(self.domain, other + self.values)
+ newdom = self.domain.merge(other.domain)
+ factor1 = self.expand(newdom)
+ factor2 = other.expand(newdom)
+ return Factor(newdom, factor1.values + factor2.values)
+
+ def __iadd__(self, other):
+ if np.isscalar(other):
+ self.values += other
+ return self
+ factor2 = other.expand(self.domain)
+ self.values += factor2.values
+ return self
+
+ def __imul__(self, other):
+ if np.isscalar(other):
+ self.values *= other
+ return self
+ factor2 = other.expand(self.domain)
+ self.values *= factor2.values
+ return self
+
+ def __radd__(self, other):
+ return self.__add__(other)
+
+ def __rmul__(self, other):
+ return self.__mul__(other)
+
+ def __sub__(self, other):
+ if np.isscalar(other):
+ return Factor(self.domain, self.values - other)
+ zero = torch.tensor(0.0, device=Factor.device)
+ inf = torch.tensor(np.inf, device=Factor.device)
+ values = torch.where(other.values == -inf, zero, -other.values)
+ other = Factor(other.domain, values)
+ return self + other
+
+ def __truediv__(self, other):
+ if np.isscalar(other):
+ return self * (1.0 / other)
+ tmp = other.expand(self.domain)
+ vals = torch.div(self.values, tmp.values)
+ vals[tmp.values <= 0] = 0.0
+ return Factor(self.domain, vals)
+
+ def exp(self, out=None):
+ if out is None:
+ return Factor(self.domain, self.values.exp())
+ torch.exp(self.values, out=out.values)
+ return out
+
+ def log(self, out=None):
+ if out is None:
+ return Factor(self.domain, torch.log(self.values + 1e-100))
+ torch.log(self.values, out=out.values)
+ return out
+
+ def datavector(self, flatten=True):
+ """Materialize the data vector as a numpy array"""
+ ans = self.values.to("cpu").numpy()
+ return ans.flatten() if flatten else ans
diff --git a/src/synthcity/plugins/core/models/tabular_aim.py b/src/synthcity/plugins/core/models/tabular_aim.py
new file mode 100644
index 00000000..1b2c68cd
--- /dev/null
+++ b/src/synthcity/plugins/core/models/tabular_aim.py
@@ -0,0 +1,133 @@
+# stdlib
+import itertools
+from typing import Any, Optional, Union
+
+# third party
+import numpy as np
+import pandas as pd
+import torch
+from pydantic import validate_arguments
+
+# synthcity absolute
+from synthcity.utils.constants import DEVICE
+
+# synthcity relative
+from .aim import AIM
+from .mbi.dataset import Dataset
+from .mbi.domain import Domain
+
+
+class TabularAIM:
+ """
+ .. inheritance-diagram:: synthcity.plugins.core.models.tabular_aim.TabularAIM
+ :parts: 1
+
+
+ Adaptive and Iterative Mechanism (AIM) implementation, based on:
+ - code: https://github.com/ryan112358/private-pgm/blob/master/mechanisms/aim.py
+ - paper: https://www.vldb.org/pvldb/vol15/p2599-mckenna.pdf.
+
+
+ Args:
+ X (pd.DataFrame): Reference dataset, used for training the tabular encoder
+ # AIM parameters
+
+ # core plugin arguments
+ encoder_max_clusters (int = 20): The max number of clusters to create for continuous columns when encoding with TabularEncoder. Defaults to 20.
+ encoder_whitelist (list = []): Ignore columns from encoding with TabularEncoder. Defaults to [].
+ device: Union[str, torch.device] = DEVICE, # This is not used for this model, as it is built with sklearn, which is cpu only
+ random_state (int, optional): _description_. Defaults to 0. # This is not used for this model
+ **kwargs (Any): The keyword arguments are passed to a SKLearn RandomForestClassifier - https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html.
+ """
+
+ def __init__(
+ self,
+ X: pd.DataFrame,
+ # AIM parameters
+ epsilon: float = 1.0,
+ delta: float = 1e-9,
+ max_model_size: int = 80,
+ degree: int = 2,
+ num_marginals: Optional[int] = None,
+ max_cells: int = 1000,
+ # core plugin arguments
+ encoder_max_clusters: int = 20,
+ encoder_whitelist: list = [],
+ device: Union[str, torch.device] = DEVICE,
+ learning_rate: float = 5e-3,
+ weight_decay: float = 1e-3,
+ logging_epoch: int = 100,
+ random_state: int = 0,
+ **kwargs: Any,
+ ):
+ super(TabularAIM, self).__init__()
+ self.columns = X.columns
+ self.epsilon = epsilon
+ self.delta = delta
+ self.max_model_size = max_model_size
+ self.degree = degree
+ self.num_marginals = num_marginals
+ self.max_cells = max_cells
+ self.prng = np.random
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def fit(
+ self,
+ X: pd.DataFrame,
+ **kwargs: Any,
+ ) -> Any:
+ """
+
+ Args:
+ data: Pandas DataFrame that contains the tabular data
+
+ Returns:
+ AIMTrainer used for the fine-tuning process
+ """
+ domain_shapes = X.nunique().to_dict()
+ mbi_domain = Domain(self.columns, domain_shapes.values())
+ self.dataset = Dataset(X, mbi_domain)
+
+ workload = list(itertools.combinations(self.dataset.domain, self.degree))
+ if len(workload) == 0:
+ raise ValueError("No workload found. Is the dataset empty?")
+ workload = [
+ cl for cl in workload if self.dataset.domain.size(cl) <= self.max_cells
+ ]
+ if len(workload) == 0:
+ raise ValueError(
+ "Domain sizes for the cells are too large. Increase max_cells values or further discretize the data."
+ )
+ if self.num_marginals is not None:
+ workload = [
+ workload[i]
+ for i in self.prng.choice(
+ len(workload), self.num_marginals, replace=False
+ )
+ ]
+
+ self.workload = [(cl, 1.0) for cl in workload]
+ self.model = AIM(self.epsilon, self.delta, max_model_size=self.max_model_size)
+ return self
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def generate(
+ self,
+ count: int,
+ start_col: Optional[str] = "",
+ start_col_dist: Optional[Union[dict, list]] = None,
+ temperature: float = 0.7,
+ k: int = 100,
+ max_length: int = 100,
+ ) -> pd.DataFrame:
+ """
+ Generates tabular data using the trained AIM model.
+
+ Args:
+ count (int): The number of samples to generate
+
+ Returns:
+ pd.DataFrame: n_samples rows of generated data
+ """
+ synth_dataset = self.model.run(self.dataset, self.workload)
+ return synth_dataset.df
diff --git a/src/synthcity/plugins/privacy/plugin_aim.py b/src/synthcity/plugins/privacy/plugin_aim.py
new file mode 100644
index 00000000..c32d6bf9
--- /dev/null
+++ b/src/synthcity/plugins/privacy/plugin_aim.py
@@ -0,0 +1,166 @@
+"""
+Reference: "Adversarial random forests for density estimation and generative modeling" Authors: David S. Watson, Kristin Blesch, Jan Kapar, and Marvin N. Wright
+"""
+
+# stdlib
+from pathlib import Path
+from typing import Any, List, Optional, Union
+
+# third party
+import pandas as pd
+import torch
+
+# Necessary packages
+from pydantic import validate_arguments
+
+# synthcity absolute
+from synthcity.plugins.core.dataloader import DataLoader
+from synthcity.plugins.core.distribution import (
+ Distribution,
+ FloatDistribution,
+ IntegerDistribution,
+ LogDistribution,
+)
+from synthcity.plugins.core.models.tabular_aim import TabularAIM
+from synthcity.plugins.core.plugin import Plugin
+from synthcity.plugins.core.schema import Schema
+from synthcity.utils.constants import DEVICE
+
+
+class AIMPlugin(Plugin):
+ """
+ .. inheritance-diagram:: synthcity.plugins.privacy.plugin_adsgan.AdsGANPlugin
+ :parts: 1
+
+ Args:
+
+
+ Example:
+ >>> from sklearn.datasets import load_iris
+ >>> from synthcity.plugins import Plugins
+ >>>
+ >>> X, y = load_iris(as_frame = True, return_X_y = True)
+ >>> X["target"] = y
+ >>>
+ >>> plugin = Plugins().get("aim")
+ >>> plugin.fit(X)
+ >>>
+ >>> plugin.generate(50)
+
+ """
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ # AIM plugin arguments
+ epsilon: float = 1.0,
+ delta: float = 1e-9,
+ max_model_size: int = 80,
+ degree: int = 2,
+ num_marginals: Optional[int] = None,
+ max_cells: int = 1000,
+ # core plugin arguments
+ device: Union[str, torch.device] = DEVICE,
+ random_state: int = 0,
+ sampling_patience: int = 500,
+ workspace: Path = Path("workspace"),
+ compress_dataset: bool = False,
+ **kwargs: Any,
+ ) -> None:
+ """
+ .. inheritance-diagram:: synthcity.plugins.generic.plugin_aim.AIMPlugin
+ :parts: 1
+
+ Adversarial Random Forest implementation.
+
+ Args:
+ degree: int = 2
+ Degree of marginals to use. Defaults to 2.
+ device: Union[str, torch.device] = synthcity.utils.constants.DEVICE
+ The device that the model is run on. Defaults to "cuda" if cuda is available else "cpu".
+ random_state: int = 0
+ random_state used. Defaults to 0.
+ sampling_patience: int = 500
+ Max inference iterations to wait for the generated data to match the training schema. Defaults to 500.
+ workspace: Path
+ Path for caching intermediary results. Defaults to Path("workspace").
+ compress_dataset: bool. Default = False
+ Drop redundant features before training the generator. Defaults to False.
+ dataloader_sampler: Any = None
+ Optional sampler for the dataloader. Defaults to None.
+ """
+ super().__init__(
+ device=device,
+ random_state=random_state,
+ sampling_patience=sampling_patience,
+ workspace=workspace,
+ compress_dataset=compress_dataset,
+ **kwargs,
+ )
+ self.epsilon = epsilon
+ self.delta = delta
+ self.max_model_size = max_model_size
+ self.degree = degree
+ self.num_marginals = num_marginals
+ self.max_cells = max_cells
+
+ @staticmethod
+ def name() -> str:
+ return "aim"
+
+ @staticmethod
+ def type() -> str:
+ return "generic"
+
+ @staticmethod
+ def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
+ return [
+ FloatDistribution(name="epsilon", low=0.5, high=3.0, step=0.5),
+ LogDistribution(name="delta", low=1e-10, high=1e-5),
+ IntegerDistribution(name="max_model_size", low=50, high=200, step=50),
+ IntegerDistribution(name="degree", low=2, high=5, step=1),
+ IntegerDistribution(name="num_marginals", low=0, high=5, step=1),
+ IntegerDistribution(name="max_cells", low=5000, high=25000, step=5000),
+ ]
+
+ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "AIMPlugin":
+ """
+ Internal function to fit the model to the data.
+
+ Args:
+ X (DataLoader): The data to fit the model to.
+
+ Raises:
+ NotImplementedError: _description_
+
+ Returns:
+ AIMPlugin: _description_
+ """
+
+ self.model = TabularAIM(
+ X.dataframe(),
+ self.epsilon,
+ self.delta,
+ self.max_model_size,
+ self.degree,
+ self.num_marginals,
+ self.max_cells,
+ **kwargs,
+ )
+ if "cond" in kwargs and kwargs["cond"] is not None:
+ raise NotImplementedError(
+ "conditional generation is not currently available for the Adaptive and Iterative Mechanism for Differentially Private Synthetic Data(AIM) plugin."
+ )
+ self.model.fit(X.dataframe(), **kwargs)
+ return self
+
+ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> pd.DataFrame:
+ if "cond" in kwargs and kwargs["cond"] is not None:
+ raise NotImplementedError(
+ "conditional generation is not currently available for the Adaptive and Iterative Mechanism for Differentially Private Synthetic Data (AIM) plugin."
+ )
+
+ return self._safe_generate(self.model.generate, count, syn_schema)
+
+
+plugin = AIMPlugin
diff --git a/src/synthcity/utils/datasets/categorical/__init__.py b/src/synthcity/utils/datasets/categorical/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/synthcity/utils/datasets/categorical/categorical_adult.py b/src/synthcity/utils/datasets/categorical/categorical_adult.py
new file mode 100644
index 00000000..4a53f6ad
--- /dev/null
+++ b/src/synthcity/utils/datasets/categorical/categorical_adult.py
@@ -0,0 +1,33 @@
+# stdlib
+import io
+from pathlib import Path
+from typing import Union
+
+# third party
+import numpy as np
+import pandas as pd
+import requests
+
+URL = "https://raw.githubusercontent.com/ryan112358/private-pgm/master/data/adult.csv"
+df_path = Path(__file__).parent / "data/adult.csv"
+
+
+class CategoricalAdultDataloader:
+ def __init__(self, as_numpy: bool = False) -> None:
+ self.as_numpy = as_numpy
+
+ def load(
+ self,
+ ) -> Union[pd.DataFrame, np.ndarray]:
+ # Load Google Data
+ if not df_path.exists():
+ s = requests.get(URL, timeout=5).content
+ df = pd.read_csv(io.StringIO(s.decode("utf-8")))
+
+ df.to_csv(df_path, index=None)
+ else:
+ df = pd.read_csv(df_path)
+
+ if self.as_numpy:
+ return df.to_numpy()
+ return df
diff --git a/src/synthcity/utils/datasets/categorical/data/__init__ .py b/src/synthcity/utils/datasets/categorical/data/__init__ .py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/nb_eval.py b/tests/nb_eval.py
index 6561e3b8..f2f4bb75 100644
--- a/tests/nb_eval.py
+++ b/tests/nb_eval.py
@@ -72,6 +72,7 @@ def run_notebook(notebook_path: Path) -> None:
"plugin_ctgan(generic)",
"plugin_fourier_flows",
"plugin_timegan",
+ "plugin_aim",
]
if not goggle_disabled:
diff --git a/tests/plugins/privacy/test_aim.py b/tests/plugins/privacy/test_aim.py
new file mode 100644
index 00000000..e1691319
--- /dev/null
+++ b/tests/plugins/privacy/test_aim.py
@@ -0,0 +1,177 @@
+# stdlib
+import random
+from datetime import datetime, timedelta
+
+# third party
+import numpy as np
+import pandas as pd
+import pytest
+from fhelpers import generate_fixtures
+from sklearn.datasets import load_iris
+
+# synthcity absolute
+from synthcity.metrics.eval import PerformanceEvaluatorXGB
+from synthcity.plugins import Plugin
+from synthcity.plugins.core.constraints import Constraints
+from synthcity.plugins.core.dataloader import GenericDataLoader
+from synthcity.plugins.privacy.plugin_aim import plugin
+from synthcity.utils.datasets.categorical.categorical_adult import (
+ CategoricalAdultDataloader,
+)
+from synthcity.utils.serialization import load, save
+
+plugin_name = "aim"
+plugin_args = {
+ "epsilon": 1.0,
+ "delta": 1e-9,
+ "max_model_size": 80,
+ "degree": 2,
+ "num_marginals": None,
+ "max_cells": 1000,
+}
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_sanity(test_plugin: Plugin) -> None:
+ assert test_plugin is not None
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_name(test_plugin: Plugin) -> None:
+ assert test_plugin.name() == plugin_name
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_type(test_plugin: Plugin) -> None:
+ assert test_plugin.type() == "generic"
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_hyperparams(test_plugin: Plugin) -> None:
+ assert len(test_plugin.hyperparameter_space()) == 6
+
+
+@pytest.mark.parametrize(
+ "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
+)
+def test_plugin_fit(test_plugin: Plugin) -> None:
+ X = CategoricalAdultDataloader().load().head()
+ test_plugin.fit(GenericDataLoader(X))
+
+
+@pytest.mark.parametrize(
+ "test_plugin",
+ generate_fixtures(plugin_name, plugin, plugin_args),
+)
+@pytest.mark.parametrize("serialize", [True, False])
+def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
+ X = CategoricalAdultDataloader().load().head()
+ test_plugin.fit(GenericDataLoader(X))
+
+ if serialize:
+ saved = save(test_plugin)
+ test_plugin = load(saved)
+
+ X_gen = test_plugin.generate()
+ assert len(X_gen) == len(X)
+ assert X_gen.shape[1] == X.shape[1]
+ assert test_plugin.schema_includes(X_gen)
+
+ X_gen = test_plugin.generate(50)
+ assert len(X_gen) == 50
+ assert test_plugin.schema_includes(X_gen)
+
+ # generate with random seed
+ X_gen1 = test_plugin.generate(50, random_state=0)
+ X_gen2 = test_plugin.generate(50, random_state=0)
+ X_gen3 = test_plugin.generate(50)
+ assert (X_gen1.numpy() == X_gen2.numpy()).all()
+ assert (X_gen1.numpy() != X_gen3.numpy()).any()
+
+
+@pytest.mark.parametrize(
+ "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
+)
+def test_plugin_generate_constraints_aim(test_plugin: Plugin) -> None:
+ X = CategoricalAdultDataloader().load().head()
+ test_plugin.fit(GenericDataLoader(X, target_column="income>50K"))
+
+ constraints = Constraints(
+ rules=[
+ ("income>50K", "eq", 1),
+ ]
+ )
+
+ X_gen = test_plugin.generate(constraints=constraints).dataframe()
+ assert len(X_gen) == len(X)
+ assert test_plugin.schema_includes(X_gen)
+ assert constraints.filter(X_gen).sum() == len(X_gen)
+ assert (X_gen["income>50K"] == 1).all()
+
+ X_gen = test_plugin.generate(count=50, constraints=constraints).dataframe()
+ assert len(X_gen) == 50
+ assert test_plugin.schema_includes(X_gen)
+ assert constraints.filter(X_gen).sum() == len(X_gen)
+ assert list(X_gen.columns) == list(X.columns)
+
+
+def test_sample_hyperparams() -> None:
+ assert plugin is not None
+ for i in range(100):
+ args = plugin.sample_hyperparameters()
+ assert plugin(**args) is not None
+
+
+@pytest.mark.slow
+@pytest.mark.parametrize("compress_dataset", [True, False])
+def test_eval_performance_aim(compress_dataset: bool) -> None:
+ assert plugin is not None
+ results = []
+
+ X_raw, y = load_iris(as_frame=True, return_X_y=True)
+ X_raw["target"] = y
+ # Descretize the data
+ num_bins = 3
+ for col in X_raw.columns:
+ X_raw[col] = pd.cut(X_raw[col], bins=num_bins, labels=list(range(num_bins)))
+
+ X = GenericDataLoader(X_raw, target_column="target")
+
+ for retry in range(2):
+ test_plugin = plugin(**plugin_args)
+ evaluator = PerformanceEvaluatorXGB()
+
+ test_plugin.fit(X)
+ X_syn = test_plugin.generate(count=1000)
+
+ results.append(evaluator.evaluate(X, X_syn)["syn_id"])
+ print(results)
+ assert np.mean(results) > 0.7
+
+
+def gen_datetime(min_year: int = 2000, max_year: int = datetime.now().year) -> datetime:
+ # generate a datetime in format yyyy-mm-dd hh:mm:ss.000000
+ start = datetime(min_year, 1, 1, 00, 00, 00)
+ years = max_year - min_year + 1
+ end = start + timedelta(days=365 * years)
+ return start + (end - start) * random.random()
+
+
+def test_plugin_encoding() -> None:
+ assert plugin is not None
+ data = [[gen_datetime(), i % 2 == 0, i] for i in range(10)]
+
+ df = pd.DataFrame(data, columns=["date", "bool", "int"])
+ X = GenericDataLoader(df)
+ test_plugin = plugin(**plugin_args)
+ test_plugin.fit(X)
+
+ syn = test_plugin.generate(10)
+
+ assert len(syn) == 10
+ assert test_plugin.schema_includes(syn)
+
+ syn_df = syn.dataframe()
+
+ assert syn_df["date"].infer_objects().dtype.kind == "M"
+ assert syn_df["bool"].infer_objects().dtype.kind == "b"
diff --git a/tutorials/plugins/generic/plugin_great.ipynb b/tutorials/plugins/generic/plugin_great.ipynb
index 89212557..355821fa 100644
--- a/tutorials/plugins/generic/plugin_great.ipynb
+++ b/tutorials/plugins/generic/plugin_great.ipynb
@@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Adversarial Random Forest"
+ "# GReaT"
]
},
{
diff --git a/tutorials/plugins/privacy/plugin_aim.ipynb b/tutorials/plugins/privacy/plugin_aim.ipynb
new file mode 100644
index 00000000..dc8a821d
--- /dev/null
+++ b/tutorials/plugins/privacy/plugin_aim.ipynb
@@ -0,0 +1,181 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Adaptive and Iterative Mechanism for Differentially Private Synthetic Data\n",
+ "\n",
+ "This method is designed for categorical data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# stdlib\n",
+ "import warnings\n",
+ "import sys\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins import Plugins\n",
+ "from synthcity.utils.datasets.categorical.categorical_adult import CategoricalAdultDataloader\n",
+ "import synthcity.logger as log\n",
+ "log.add(sink=sys.stderr, level=\"INFO\")\n",
+ "\n",
+ "eval_plugin = \"aim\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.plugins.core.dataloader import GenericDataLoader\n",
+ "\n",
+ "X = CategoricalAdultDataloader().load()\n",
+ "loader = GenericDataLoader(X, target_column=\"income>50K\", sensitive_columns=[\"sex\", \"race\"])\n",
+ "\n",
+ "loader.dataframe()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Train the generator"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.plugins import Plugins\n",
+ "\n",
+ "syn_model = Plugins().get(eval_plugin)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "syn_model.fit(loader)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Generate new samples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "syn_model.generate(count=10).dataframe()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "syn_model.plot(plt, loader, count=100)\n",
+ "\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Benchmarks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.benchmark import Benchmarks\n",
+ "\n",
+ "score = Benchmarks.evaluate(\n",
+ " [\n",
+ " (eval_plugin, eval_plugin, {\"epsilon\": 1.0, \"delta\": 1e-7, \"max_model_size\": 80, \"degree\": 2, \"num_marginals\": None, \"max_cells\": 1000}),\n",
+ " ], # (testname, plugin, plugin_args) The plugin_args are given are simply to illustrate some of the paramters that can be passed to the plugin\n",
+ " loader,\n",
+ " repeats=2,\n",
+ " metrics={\n",
+ " \"detection\": [\"detection_mlp\"],\n",
+ " \"privacy\": [\"distinct l-diversity\", \"k-anonymization\", \"k-map\"],\n",
+ " },\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Benchmarks.print(score)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "synthcity-all",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}