diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml index e258db95..0f977ba1 100644 --- a/.github/workflows/test_full.yml +++ b/.github/workflows/test_full.yml @@ -27,7 +27,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - pip install pip==23.0.1 + python -m pip install -U pip pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_pr.yml b/.github/workflows/test_pr.yml index fec1126d..bac8905e 100644 --- a/.github/workflows/test_pr.yml +++ b/.github/workflows/test_pr.yml @@ -54,7 +54,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - pip install pip==23.0.1 + python -m pip install -U pip pip install -r prereq.txt - name: Test Core run: | diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 69195a09..c93a0c35 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -32,7 +32,7 @@ jobs: if: ${{ matrix.os == 'macos-latest' }} - name: Install dependencies run: | - pip install pip==23.0.1 + python -m pip install -U pip pip install -r prereq.txt pip install .[all] diff --git a/prereq.txt b/prereq.txt index 0d7eb1f0..6554078b 100644 --- a/prereq.txt +++ b/prereq.txt @@ -1,3 +1,4 @@ numpy -torch<2.0 +torch>=1.10.0,<2.0 tsai +wheel>=0.40 diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index 485c9a5a..788db4e0 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -111,17 +111,25 @@ def as_constraint(self) -> Constraints: @abstractmethod def min(self) -> Any: - "Get the min value of the distribution" + """Get the min value of the distribution.""" ... @abstractmethod def max(self) -> Any: - "Get the max value of the distribution" + """Get the max value of the distribution.""" ... - @abstractmethod def __eq__(self, other: Any) -> bool: - ... + return type(self) == type(other) and self.get() == other.get() + + def __contains__(self, item: Any) -> bool: + """ + Example: + >>> dist = CategoricalDistribution(name="foo", choices=["a", "b", "c"]) + >>> "a" in dist + True + """ + return self.has(item) @abstractmethod def dtype(self) -> str: @@ -146,7 +154,7 @@ def _validate_choices(cls: Any, v: List, values: Dict) -> List: raise ValueError( "Invalid choices for CategoricalDistribution. Provide data or choices params" ) - return v + return sorted(set(v)) def get(self) -> List[Any]: return [self.name, self.choices] @@ -176,12 +184,6 @@ def min(self) -> Any: def max(self) -> Any: return max(self.choices) - def __eq__(self, other: Any) -> bool: - if not isinstance(other, CategoricalDistribution): - return False - - return self.name == other.name and set(self.choices) == set(other.choices) - def dtype(self) -> str: types = { "object": 0, @@ -259,16 +261,6 @@ def min(self) -> Any: def max(self) -> Any: return self.high - def __eq__(self, other: Any) -> bool: - if not isinstance(other, type(self)): - return False - - return ( - self.name == other.name - and self.low == other.low - and self.high == other.high - ) - def dtype(self) -> str: return "float" @@ -276,16 +268,17 @@ def dtype(self) -> str: class LogDistribution(FloatDistribution): low: float = np.finfo(np.float64).tiny high: float = np.finfo(np.float64).max - base: float = 2.0 + + def get(self) -> List[Any]: + return [self.name, self.low, self.high] def sample(self, count: int = 1) -> Any: np.random.seed(self.random_state) msamples = self.sample_marginal(count) if msamples is not None: return msamples - lo = np.log2(self.low) / np.log2(self.base) - hi = np.log2(self.high) / np.log2(self.base) - return self.base ** np.random.uniform(lo, hi, count) + lo, hi = np.log2(self.low), np.log2(self.high) + return 2.0 ** np.random.uniform(lo, hi, count) class IntegerDistribution(Distribution): @@ -313,6 +306,12 @@ def _validate_high_thresh(cls: Any, v: int, values: Dict) -> int: return int(values[mkey].index.max()) return v + @validator("step", always=True) + def _validate_step(cls: Any, v: int, values: Dict) -> int: + if v < 1: + raise ValueError("Step must be greater than 0") + return v + def get(self) -> List[Any]: return [self.name, self.low, self.high, self.step] @@ -322,9 +321,9 @@ def sample(self, count: int = 1) -> Any: if msamples is not None: return msamples - high = (self.high + 1 - self.low) // self.step - s = np.random.choice(high, count) - return s * self.step + self.low + steps = (self.high - self.low) // self.step + samples = np.random.choice(steps + 1, count) + return samples * self.step + self.low def has(self, val: Any) -> bool: return self.low <= val and val <= self.high @@ -347,34 +346,31 @@ def min(self) -> Any: def max(self) -> Any: return self.high - def __eq__(self, other: Any) -> bool: - if not isinstance(other, IntegerDistribution): - return False - - return ( - self.name == other.name - and self.low == other.low - and self.high == other.high - ) - def dtype(self) -> str: return "int" -class LogIntDistribution(FloatDistribution): - low: float = 1.0 - high: float = float(np.iinfo(np.int64).max) - base: float = 2.0 +class IntLogDistribution(IntegerDistribution): + low: int = 1 + high: int = np.iinfo(np.int64).max + + @validator("step", always=True) + def _validate_step(cls: Any, v: int, values: Dict) -> int: + if v != 1: + raise ValueError("Step must be 1 for IntLogDistribution") + return v + + def get(self) -> List[Any]: + return [self.name, self.low, self.high] def sample(self, count: int = 1) -> Any: np.random.seed(self.random_state) msamples = self.sample_marginal(count) if msamples is not None: return msamples - lo = np.log2(self.low) / np.log2(self.base) - hi = np.log2(self.high) / np.log2(self.base) - s = self.base ** np.random.uniform(lo, hi, count) - return s.astype(int) + lo, hi = np.log2(self.low), np.log2(self.high) + samples = 2.0 ** np.random.uniform(lo, hi, count) + return samples.astype(int) class DatetimeDistribution(Distribution): @@ -383,32 +379,27 @@ class DatetimeDistribution(Distribution): :parts: 1 """ - offset: int = 120 low: datetime = datetime.utcfromtimestamp(0) high: datetime = datetime.now() - - @validator("offset", always=True) - def _validate_offset(cls: Any, v: int) -> int: - if v < 0: - raise ValueError("offset must be greater than 0") - return v + step: timedelta = timedelta(microseconds=1) + offset: timedelta = timedelta(seconds=120) @validator("low", always=True) def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime: mkey = "marginal_distribution" if mkey in values and values[mkey] is not None: v = values[mkey].index.min() - return v - timedelta(seconds=values["offset"]) + return v @validator("high", always=True) def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime: mkey = "marginal_distribution" if mkey in values and values[mkey] is not None: v = values[mkey].index.max() - return v + timedelta(seconds=values["offset"]) + return v def get(self) -> List[Any]: - return [self.name, self.low, self.high] + return [self.name, self.low, self.high, self.step, self.offset] def sample(self, count: int = 1) -> Any: np.random.seed(self.random_state) @@ -416,16 +407,18 @@ def sample(self, count: int = 1) -> Any: if msamples is not None: return msamples - delta = self.high - self.low - return self.low + delta * np.random.rand(count) + n = (self.high - self.low) // self.step + 1 + samples = np.round(np.random.rand(count) * n - 0.5) + return self.low + samples * self.step def has(self, val: datetime) -> bool: return self.low <= val and val <= self.high def includes(self, other: "Distribution") -> bool: - return self.min() - timedelta( - seconds=self.offset - ) <= other.min() and other.max() <= self.max() + timedelta(seconds=self.offset) + return ( + self.min() - self.offset <= other.min() + and other.max() <= self.max() + self.offset + ) def as_constraint(self) -> Constraints: return Constraints( @@ -442,16 +435,6 @@ def min(self) -> Any: def max(self) -> Any: return self.high - def __eq__(self, other: Any) -> bool: - if not isinstance(other, DatetimeDistribution): - return False - - return ( - self.name == other.name - and self.low == other.low - and self.high == other.high - ) - def dtype(self) -> str: return "datetime" diff --git a/src/synthcity/plugins/core/models/factory.py b/src/synthcity/plugins/core/models/factory.py index 30b993be..fefd23ef 100644 --- a/src/synthcity/plugins/core/models/factory.py +++ b/src/synthcity/plugins/core/models/factory.py @@ -20,6 +20,7 @@ DatetimeEncoder, FeatureEncoder, GaussianQuantileTransformer, + LabelEncoder, MinMaxScaler, OneHotEncoder, OrdinalEncoder, @@ -75,6 +76,7 @@ datetime=DatetimeEncoder, onehot=OneHotEncoder, ordinal=OrdinalEncoder, + label=LabelEncoder, standard=StandardScaler, minmax=MinMaxScaler, robust=RobustScaler, diff --git a/src/synthcity/plugins/generic/plugin_ddpm.py b/src/synthcity/plugins/generic/plugin_ddpm.py index cc851a8e..8789d7d6 100644 --- a/src/synthcity/plugins/generic/plugin_ddpm.py +++ b/src/synthcity/plugins/generic/plugin_ddpm.py @@ -18,8 +18,8 @@ from synthcity.plugins.core.distribution import ( Distribution, IntegerDistribution, + IntLogDistribution, LogDistribution, - LogIntDistribution, ) from synthcity.plugins.core.models.tabular_ddpm import TabDDPM from synthcity.plugins.core.models.tabular_encoder import TabularEncoder @@ -180,11 +180,11 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]: """ return [ LogDistribution(name="lr", low=1e-5, high=1e-1), - LogIntDistribution(name="batch_size", low=256, high=4096), + IntLogDistribution(name="batch_size", low=256, high=4096), IntegerDistribution(name="num_timesteps", low=10, high=1000), - LogIntDistribution(name="n_iter", low=1000, high=10000), + IntLogDistribution(name="n_iter", low=1000, high=10000), # IntegerDistribution(name="n_layers_hidden", low=2, high=8), - # LogIntDistribution(name="dim_hidden", low=128, high=1024), + # IntLogDistribution(name="dim_hidden", low=128, high=1024), ] def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin": diff --git a/src/synthcity/plugins/time_series/plugin_fflows.py b/src/synthcity/plugins/time_series/plugin_fflows.py index d357ea24..d62557c0 100644 --- a/src/synthcity/plugins/time_series/plugin_fflows.py +++ b/src/synthcity/plugins/time_series/plugin_fflows.py @@ -11,6 +11,7 @@ from fflows import FourierFlow # synthcity absolute +from synthcity.plugins import Plugins from synthcity.plugins.core.dataloader import DataLoader from synthcity.plugins.core.distribution import ( CategoricalDistribution, @@ -24,7 +25,6 @@ from synthcity.plugins.core.models.ts_model import TimeSeriesModel from synthcity.plugins.core.plugin import Plugin from synthcity.plugins.core.schema import Schema -from synthcity.plugins.generic import GenericPlugins from synthcity.utils.constants import DEVICE @@ -134,9 +134,7 @@ def __init__( normalize=normalize, ).to(device) - self.static_model = GenericPlugins().get( - self.static_model_name, device=self.device - ) + self.static_model = Plugins().get(self.static_model_name, device=self.device) self.temporal_encoder = TimeSeriesTabularEncoder( max_clusters=encoder_max_clusters diff --git a/src/synthcity/utils/optuna_sample.py b/src/synthcity/utils/optuna_sample.py new file mode 100644 index 00000000..87b7aafc --- /dev/null +++ b/src/synthcity/utils/optuna_sample.py @@ -0,0 +1,27 @@ +# stdlib +from typing import Any, Dict, List + +# third party +import optuna + +# synthcity absolute +import synthcity.plugins.core.distribution as D + + +def suggest(trial: optuna.Trial, dist: D.Distribution) -> Any: + if isinstance(dist, D.FloatDistribution): + return trial.suggest_float(dist.name, dist.low, dist.high) + elif isinstance(dist, D.LogDistribution): + return trial.suggest_float(dist.name, dist.low, dist.high, log=True) + elif isinstance(dist, D.IntegerDistribution): + return trial.suggest_int(dist.name, dist.low, dist.high, dist.step) + elif isinstance(dist, D.IntLogDistribution): + return trial.suggest_int(dist.name, dist.low, dist.high, log=True) + elif isinstance(dist, D.CategoricalDistribution): + return trial.suggest_categorical(dist.name, dist.choices) + else: + raise ValueError(f"Unknown dist: {dist}") + + +def suggest_all(trial: optuna.Trial, distributions: List[D.Distribution]) -> Dict: + return {dist.name: suggest(trial, dist) for dist in distributions} diff --git a/tutorials/tutorial8_hyperparameter_optimization.ipynb b/tutorials/tutorial8_hyperparameter_optimization.ipynb new file mode 100644 index 00000000..971dd38d --- /dev/null +++ b/tutorials/tutorial8_hyperparameter_optimization.ipynb @@ -0,0 +1,346 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tutorial 8: Hyperparameter Optimization\n", + "\n", + "To automatically tune hyperparameters in a `synthcity` plugin to generate more realistic data, we use hyperparameter optimization (HPO) algorithms such as Tree-structured Parzen estimators (TPE), Bayesian optimization, and genetic programming. In this tutorial we will use `optuna`, a very popular HPO library implementing TPE, to tune the hyperparameters of the `nflow` plugin to synthesize the diabetes dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# stdlib\n", + "import sys\n", + "import warnings\n", + "\n", + "# third party\n", + "import optuna\n", + "from sklearn.datasets import load_diabetes\n", + "\n", + "# synthcity absolute\n", + "import synthcity.logger as log\n", + "from synthcity.plugins import Plugins\n", + "from synthcity.plugins.core.dataloader import GenericDataLoader\n", + "\n", + "log.add(sink=sys.stderr, level=\"INFO\")\n", + "warnings.filterwarnings(\"ignore\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "X, y = load_diabetes(return_X_y=True, as_frame=True)\n", + "X[\"target\"] = y\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loader = GenericDataLoader(\n", + " X,\n", + " target_column=\"target\",\n", + " sensitive_columns=[\"sex\"],\n", + ")\n", + "train_loader, test_loader = loader.train(), loader.test()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load the plugin class" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PLUGIN = \"tvae\"\n", + "plugin_cls = type(Plugins().get(PLUGIN))\n", + "plugin_cls" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Display the hyperparameter space" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plugin_cls.hyperparameter_space()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use a trial to suggest a set of hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from synthcity.utils.optuna_sample import suggest_all\n", + "\n", + "trial = optuna.create_study().ask()\n", + "params = suggest_all(trial, plugin_cls.hyperparameter_space())\n", + "params['n_iter'] = 100 # speed up\n", + "params" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate the plugin with the suggested hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from synthcity.benchmark import Benchmarks\n", + "\n", + "plugin = plugin_cls(**params).fit(train_loader)\n", + "report = Benchmarks.evaluate(\n", + " [(\"trial\", PLUGIN, params)],\n", + " train_loader, # Benchmarks.evaluate will split out a validation set\n", + " repeats=1,\n", + " metrics={\"detection\": [\"detection_mlp\"]}, # DELETE THIS LINE FOR ALL METRICS\n", + ")\n", + "report['trial']" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create an Optuna study and optimize the hyperparameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def objective(trial: optuna.Trial):\n", + " hp_space = Plugins().get(PLUGIN).hyperparameter_space()\n", + " hp_space[0].high = 100 # speed up for now\n", + " params = suggest_all(trial, hp_space)\n", + " ID = f\"trial_{trial.number}\"\n", + " try:\n", + " report = Benchmarks.evaluate(\n", + " [(ID, PLUGIN, params)],\n", + " train_loader,\n", + " repeats=1,\n", + " metrics={\"detection\": [\"detection_mlp\"]}, # DELETE THIS LINE FOR ALL METRICS\n", + " )\n", + " except Exception as e: # invalid set of params\n", + " print(f\"{type(e).__name__}: {e}\")\n", + " print(params)\n", + " raise optuna.TrialPruned()\n", + " score = report[ID].query('direction == \"minimize\"')['mean'].mean()\n", + " # average score across all metrics with direction=\"minimize\"\n", + " return score\n", + "\n", + "study = optuna.create_study(direction=\"minimize\")\n", + "study.optimize(objective, n_trials=2)\n", + "study.best_params" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualize the study" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from optuna.visualization import plot_contour\n", + "from optuna.visualization import plot_edf\n", + "from optuna.visualization import plot_optimization_history\n", + "from optuna.visualization import plot_parallel_coordinate\n", + "from optuna.visualization import plot_param_importances\n", + "from optuna.visualization import plot_slice\n", + "\n", + "plot_optimization_history(study)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize high-dimensional parameter relationships. \n", + "plot_parallel_coordinate(study)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize hyperparameter relationships.\n", + "fig = plot_contour(study, params=['batch_size', 'lr', 'encoder_dropout', 'decoder_dropout'])\n", + "fig.update_layout(width=800, height=800)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize individual hyperparameters as slice plot.\n", + "plot_slice(study)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize parameter importances.\n", + "plot_param_importances(study)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Learn which hyperparameters are affecting the trial duration with hyperparameter importance.\n", + "optuna.visualization.plot_param_importances(\n", + " study, target=lambda t: t.duration.total_seconds(), target_name=\"duration\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize empirical distribution function of the objective.\n", + "plot_edf(study)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Test performance of the optimized plugin" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "best_params = study.best_params\n", + "report = Benchmarks.evaluate(\n", + " [(\"test\", PLUGIN, best_params)],\n", + " train_loader,\n", + " test_loader,\n", + " repeats=1,\n", + " metrics={\"detection\": [\"detection_mlp\", \"detection_xgb\"]}, # DELETE THIS LINE FOR ALL METRICS\n", + ")\n", + "Benchmarks.print(report)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Congratulations!\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!\n", + "\n", + "### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub\n", + "\n", + "- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.\n", + "\n", + "\n", + "### Checkout other projects from vanderschaarlab\n", + "- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)\n", + "- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "synthcity", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}