Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optuna hyperparameter optimization tutorial #178

Merged
merged 115 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
115 commits
Select commit Hold shift + click to select a range
2ecba4f
first commit for the addition of the TabDDPM plugin
Mar 1, 2023
fed898b
Add DDPM test script and update DDPM plugin
Mar 3, 2023
34979cf
add TabDDPM class and refactor
Mar 5, 2023
0abdc01
handle discrete cols and label generation
Mar 7, 2023
405a052
add hparam space and update tests of DDPM
Mar 7, 2023
0e36041
debug and test DDPM
Mar 7, 2023
fc9cee0
update TensorDataLoader and training loop
Mar 7, 2023
d8b57ad
clear bugs
Mar 7, 2023
92dcc32
debug for regression tasks
Mar 7, 2023
0b9d0e3
debug for regression tasks; ALL TESTS PASSED
Mar 7, 2023
6e6fe41
Merge branch 'tab_ddpm' of https://github.com/TZTsai/synthcity into t…
Mar 7, 2023
bb98229
remove the official repo of TabDDPM
Mar 7, 2023
b4486a4
passed all pre-commit checks
Mar 8, 2023
2a9aa2a
convert assert to conditional AssertionErrors
Mar 8, 2023
246cd5b
added an auto annotation tool
Mar 10, 2023
f458bb4
update auto-anno and generate annotations
Mar 10, 2023
137c176
remove auto-anno and flake8 noqa
Mar 10, 2023
6c4af11
add python<3.9 compatible annotations
Mar 10, 2023
191cdcc
remove star import
Mar 10, 2023
9349a66
replace builtin type annos to typing annos
Mar 12, 2023
02579e9
resolve py38 compatibility issue
Mar 12, 2023
f930bc0
tests/plugins/generic/test_ddpm.py
Mar 12, 2023
3cf73d7
change TabDDPM method signatures
Mar 13, 2023
5d37c4b
remove Iterator subscription
Mar 13, 2023
681ba60
update AssertionErrors, add EarlyStop callback, removed additional ML…
Mar 15, 2023
bcbc131
Merge branch 'main' into tab_ddpm
bcebere Mar 15, 2023
a9438dc
remove TensorDataLoader, update test_ddpm
Mar 16, 2023
52be80f
update EarlyStopping
Mar 16, 2023
794ebd6
add TabDDPM tutorial, update TabDDPM plugin and encoders
Mar 27, 2023
bcdce4b
add TabDDPM tutorial
Mar 27, 2023
8120e97
major update of FeatureEncoder and TabularEncoder
Mar 30, 2023
2750791
add LogDistribution and LogIntDistribution
Mar 30, 2023
52011d3
update DDPM to use TabularEncoder
Mar 30, 2023
0ee6c8b
update test_tabular_encoder and debug
Mar 30, 2023
244854d
debug and DDPM tutorial OK
Mar 30, 2023
e336d3c
Merge branch 'main' of https://github.com/vanderschaarlab/synthcity
Mar 30, 2023
c847c95
Merge branch 'main' into tab_ddpm
Mar 30, 2023
428177b
debug LogDistribution and LogIntDistribution
Mar 31, 2023
3377a95
Merge branch 'main' into tab_ddpm
Mar 31, 2023
4705319
change discrete encoding of BinEncoder to passthrough; passed all te…
Apr 1, 2023
d9d73f1
add tabnet to plugins/core/models
Apr 2, 2023
d29ef37
add factory.py, let DDPM use TabNet, refactor
Apr 2, 2023
6e58cf3
update docstrings and refactor
Apr 2, 2023
2a6ca6f
fix type annotation compatibility
Apr 2, 2023
36acaa0
make SkipConnection serializable
Apr 3, 2023
de15b9b
fix TabularEncoder.activation_layout
Apr 3, 2023
694cd22
remove unnecessary code
Apr 3, 2023
a459785
fix minor bug and add more nn models in factory
Apr 6, 2023
57816b6
update pandas and torch version requirement
Apr 6, 2023
cc7e8fb
update pandas and torch version requirement
Apr 6, 2023
f20db25
Merge branch 'main' into tabnet
Apr 6, 2023
7b0c19a
Merge branch 'main' into tab_ddpm
Apr 6, 2023
8a58996
update ddpm tutorial
Apr 6, 2023
31b5f13
Merge branch 'tab_ddpm' of https://github.com/TZTsai/synthcity into t…
Apr 6, 2023
cef348e
restore setup.cfg
Apr 6, 2023
9cb5da1
restore setup.cfg
Apr 6, 2023
fe5ff25
replace LabelEncoder with OrdinalEncoder
Apr 7, 2023
2922a1d
update setup.cfg
Apr 7, 2023
11fb825
update setup.cfg
Apr 7, 2023
9222b4e
debug datetimeDistribution
Apr 7, 2023
7d55c65
Merge branch 'tab_ddpm' into tabnet
Apr 7, 2023
95302b9
clean
Apr 7, 2023
785db82
update setup.cfg and goggle test
Apr 7, 2023
44ead6d
Merge branch 'tab_ddpm' into tabnet
Apr 7, 2023
27cc95c
move DDPM tutorial to tutorials/plugins
Apr 7, 2023
1d7c77c
update tabnet.py reference
Apr 7, 2023
6c25377
update tab_ddpm
Apr 7, 2023
3623d37
update distribution, add optuna utils and tutorial
Apr 8, 2023
2fb8508
update
Apr 8, 2023
5adfabf
Fix plugin type of static_model of fflows
Apr 8, 2023
a2a88c5
update intlogdistribution and tutorial
Apr 8, 2023
4a7e73b
try fixing goggle
Apr 8, 2023
8051caa
add more activations
Apr 8, 2023
3cd9917
minor fix
Apr 8, 2023
42cbe8c
update
Apr 9, 2023
101c76f
Merge branch 'tab_ddpm' into tabnet
Apr 9, 2023
7c58f2d
update
Apr 9, 2023
104e3a3
update
Apr 9, 2023
7b4e04a
update
Apr 9, 2023
fede549
Update tabular_encoder.py
Apr 10, 2023
539effa
Update test_goggle.py
Apr 10, 2023
0cb9f25
Update tabular_encoder.py
Apr 10, 2023
42c6941
update
Apr 10, 2023
d7d966d
update tutorial 8
Apr 10, 2023
e20e581
update
Apr 10, 2023
472ad52
default cat nonlin of goggle <- gumbel_softmax
Apr 10, 2023
5dbe666
get_nonlin('softmax') <- GumbelSoftmax()
Apr 10, 2023
74e897b
remove debug logging
Apr 10, 2023
27553e9
update
Apr 10, 2023
b5eb2e7
Merge branch 'tab_ddpm' into tabnet
Apr 10, 2023
7fc5ce4
update
Apr 10, 2023
7aeba49
Merge branch 'main' into optuna_tutorial
Apr 16, 2023
8af4966
Merge branch 'main' into tabnet
robsdavis Apr 18, 2023
b8c9522
fix merge
Apr 18, 2023
ecc9d08
fix merge
Apr 18, 2023
c2775ba
update pip upgrade commands in workflows
Apr 19, 2023
1d9c7a4
update pip upgrade commands in workflows
Apr 19, 2023
385d2ed
keep pip version to 23.0.1 in workflows
Apr 19, 2023
81fb12b
keep pip version to 23.0.1 in workflows
Apr 19, 2023
68d6911
Merge branch 'tabnet' into optuna_tutorial
Apr 20, 2023
3884fc4
update
Apr 20, 2023
90f60be
Merge branch 'main' into optuna_tutorial
Apr 20, 2023
7640f35
update
Apr 20, 2023
c91246b
update
Apr 20, 2023
38fc796
update
Apr 20, 2023
899a9d8
update
Apr 20, 2023
60fa08d
update
Apr 20, 2023
50a77c5
fix distribution
Apr 20, 2023
7ed4ab2
Merge branch 'main' into tabnet
Apr 20, 2023
b0036e9
Merge branch 'tabnet' into optuna_tutorial
Apr 20, 2023
05eee67
resolve merge conflicts
Apr 20, 2023
fbf5aad
Merge branch 'tab_ddpm' into optuna_tutorial
Apr 20, 2023
727662f
update
Apr 20, 2023
212d7cb
move upgrading of wheel to prereq.txt
Apr 24, 2023
d8e63c3
update
Apr 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 71 additions & 61 deletions src/synthcity/plugins/core/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,25 @@ def as_constraint(self) -> Constraints:

@abstractmethod
def min(self) -> Any:
"Get the min value of the distribution"
"""Get the min value of the distribution."""
...

@abstractmethod
def max(self) -> Any:
"Get the max value of the distribution"
"""Get the max value of the distribution."""
...

@abstractmethod
def __eq__(self, other: Any) -> bool:
...
return type(self) == type(other) and self.get() == other.get()

def __contains__(self, item: Any) -> bool:
"""
Example:
>>> dist = CategoricalDistribution(name="foo", choices=["a", "b", "c"])
>>> "a" in dist
True
"""
return self.has(item)

@abstractmethod
def dtype(self) -> str:
Expand All @@ -146,7 +154,7 @@ def _validate_choices(cls: Any, v: List, values: Dict) -> List:
raise ValueError(
"Invalid choices for CategoricalDistribution. Provide data or choices params"
)
return v
return sorted(set(v))

def get(self) -> List[Any]:
return [self.name, self.choices]
Expand Down Expand Up @@ -176,12 +184,6 @@ def min(self) -> Any:
def max(self) -> Any:
return max(self.choices)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, CategoricalDistribution):
return False

return self.name == other.name and set(self.choices) == set(other.choices)

def dtype(self) -> str:
types = {
"object": 0,
Expand Down Expand Up @@ -259,20 +261,26 @@ def min(self) -> Any:
def max(self) -> Any:
return self.high

def __eq__(self, other: Any) -> bool:
if not isinstance(other, FloatDistribution):
return False

return (
self.name == other.name
and self.low == other.low
and self.high == other.high
)

def dtype(self) -> str:
return "float"


class LogDistribution(FloatDistribution):
low: float = np.finfo(np.float64).tiny
high: float = np.finfo(np.float64).max

def get(self) -> List[Any]:
return [self.name, self.low, self.high]

def sample(self, count: int = 1) -> Any:
np.random.seed(self.random_state)
msamples = self.sample_marginal(count)
if msamples is not None:
return msamples
lo, hi = np.log2(self.low), np.log2(self.high)
return 2.0 ** np.random.uniform(lo, hi, count)


class IntegerDistribution(Distribution):
"""
.. inheritance-diagram:: synthcity.plugins.core.distribution.IntegerDistribution
Expand All @@ -298,6 +306,12 @@ def _validate_high_thresh(cls: Any, v: int, values: Dict) -> int:
return int(values[mkey].index.max())
return v

@validator("step", always=True)
def _validate_step(cls: Any, v: int, values: Dict) -> int:
if v < 1:
raise ValueError("Step must be greater than 0")
return v

def get(self) -> List[Any]:
return [self.name, self.low, self.high, self.step]

Expand All @@ -307,8 +321,9 @@ def sample(self, count: int = 1) -> Any:
if msamples is not None:
return msamples

choices = [val for val in range(self.low, self.high + 1, self.step)]
return np.random.choice(choices, count).tolist()
steps = (self.high - self.low) // self.step
samples = np.random.choice(steps + 1, count)
return samples * self.step + self.low

def has(self, val: Any) -> bool:
return self.low <= val and val <= self.high
Expand All @@ -331,21 +346,31 @@ def min(self) -> Any:
def max(self) -> Any:
return self.high

def __eq__(self, other: Any) -> bool:
if not isinstance(other, IntegerDistribution):
return False

return (
self.name == other.name
and self.low == other.low
and self.high == other.high
)

def dtype(self) -> str:
return "int"


OFFSET = 120
class IntLogDistribution(IntegerDistribution):
low: int = 1
high: int = np.iinfo(np.int64).max

@validator("step", always=True)
def _validate_step(cls: Any, v: int, values: Dict) -> int:
if v != 1:
raise ValueError("Step must be 1 for IntLogDistribution")
return v

def get(self) -> List[Any]:
return [self.name, self.low, self.high]

def sample(self, count: int = 1) -> Any:
np.random.seed(self.random_state)
msamples = self.sample_marginal(count)
if msamples is not None:
return msamples
lo, hi = np.log2(self.low), np.log2(self.high)
samples = 2.0 ** np.random.uniform(lo, hi, count)
return samples.astype(int)


class DatetimeDistribution(Distribution):
Expand All @@ -356,49 +381,44 @@ class DatetimeDistribution(Distribution):

low: datetime = datetime.utcfromtimestamp(0)
high: datetime = datetime.now()
step: timedelta = timedelta(microseconds=1)
offset: timedelta = timedelta(seconds=120)

@validator("low", always=True)
def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
mkey = "marginal_distribution"
if mkey in values and values[mkey] is not None:
v = values[mkey].index.min()

return v - timedelta(seconds=OFFSET)
return v

@validator("high", always=True)
def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime:
mkey = "marginal_distribution"
if mkey in values and values[mkey] is not None:
v = values[mkey].index.max()

return v + timedelta(seconds=OFFSET)
return v

def get(self) -> List[Any]:
return [self.name, self.low, self.high]
return [self.name, self.low, self.high, self.step, self.offset]

def sample(self, count: int = 1) -> Any:
np.random.seed(self.random_state)
msamples = self.sample_marginal(count)
if msamples is not None:
return msamples

samples = np.random.uniform(
datetime.timestamp(self.low), datetime.timestamp(self.high), count
)

samples_dt = []
for s in samples:
samples_dt.append(datetime.fromtimestamp(s))

return samples_dt
n = (self.high - self.low) // self.step + 1
samples = np.round(np.random.rand(count) * n - 0.5)
return self.low + samples * self.step

def has(self, val: datetime) -> bool:
return self.low <= val and val <= self.high

def includes(self, other: "Distribution") -> bool:
return self.min() - timedelta(
seconds=OFFSET
) <= other.min() and other.max() <= self.max() + timedelta(seconds=OFFSET)
return (
self.min() - self.offset <= other.min()
and other.max() <= self.max() + self.offset
)

def as_constraint(self) -> Constraints:
return Constraints(
Expand All @@ -415,16 +435,6 @@ def min(self) -> Any:
def max(self) -> Any:
return self.high

def __eq__(self, other: Any) -> bool:
if not isinstance(other, DatetimeDistribution):
return False

return (
self.name == other.name
and self.low == other.low
and self.high == other.high
)

def dtype(self) -> str:
return "datetime"

Expand Down
20 changes: 12 additions & 8 deletions src/synthcity/plugins/generic/plugin_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@

# synthcity absolute
from synthcity.plugins.core.dataloader import DataLoader
from synthcity.plugins.core.distribution import CategoricalDistribution, Distribution
from synthcity.plugins.core.distribution import (
Distribution,
IntegerDistribution,
IntLogDistribution,
LogDistribution,
)
from synthcity.plugins.core.models.tabular_ddpm import TabDDPM
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
Expand Down Expand Up @@ -174,13 +179,12 @@ def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
Gaussian diffusion loss MSE
"""
return [
# TODO: change to loguniform distribution
CategoricalDistribution(name="lr", choices=[1e-5, 1e-4, 1e-3, 2e-3, 3e-3]),
CategoricalDistribution(name="batch_size", choices=[256, 4096]),
CategoricalDistribution(name="num_timesteps", choices=[100, 1000]),
CategoricalDistribution(name="n_iter", choices=[5000, 10000, 20000]),
CategoricalDistribution(name="n_layers_hidden", choices=[2, 4, 6, 8]),
CategoricalDistribution(name="dim_hidden", choices=[128, 256, 512, 1024]),
LogDistribution(name="lr", low=1e-5, high=1e-1),
IntLogDistribution(name="batch_size", low=256, high=4096),
IntegerDistribution(name="num_timesteps", low=10, high=1000),
IntLogDistribution(name="n_iter", low=1000, high=10000),
IntegerDistribution(name="n_layers_hidden", low=2, high=8),
IntLogDistribution(name="dim_hidden", low=128, high=1024),
]

def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "TabDDPMPlugin":
Expand Down
6 changes: 2 additions & 4 deletions src/synthcity/plugins/time_series/plugin_fflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from fflows import FourierFlow

# synthcity absolute
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import DataLoader
from synthcity.plugins.core.distribution import (
CategoricalDistribution,
Expand All @@ -24,7 +25,6 @@
from synthcity.plugins.core.models.ts_model import TimeSeriesModel
from synthcity.plugins.core.plugin import Plugin
from synthcity.plugins.core.schema import Schema
from synthcity.plugins.generic import GenericPlugins
from synthcity.utils.constants import DEVICE


Expand Down Expand Up @@ -134,9 +134,7 @@ def __init__(
normalize=normalize,
).to(device)

self.static_model = GenericPlugins().get(
self.static_model_name, device=self.device
)
self.static_model = Plugins().get(self.static_model_name, device=self.device)

self.temporal_encoder = TimeSeriesTabularEncoder(
max_clusters=encoder_max_clusters
Expand Down
27 changes: 27 additions & 0 deletions src/synthcity/utils/optuna_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# stdlib
from typing import Any, Dict, List

# third party
import optuna

# synthcity absolute
import synthcity.plugins.core.distribution as D


def suggest(trial: optuna.Trial, dist: D.Distribution) -> Any:
if isinstance(dist, D.FloatDistribution):
return trial.suggest_float(dist.name, dist.low, dist.high)
elif isinstance(dist, D.LogDistribution):
return trial.suggest_float(dist.name, dist.low, dist.high, log=True)
elif isinstance(dist, D.IntegerDistribution):
return trial.suggest_int(dist.name, dist.low, dist.high, dist.step)
elif isinstance(dist, D.IntLogDistribution):
return trial.suggest_int(dist.name, dist.low, dist.high, log=True)
elif isinstance(dist, D.CategoricalDistribution):
return trial.suggest_categorical(dist.name, dist.choices)
else:
raise ValueError(f"Unknown dist: {dist}")


def suggest_all(trial: optuna.Trial, distributions: List[D.Distribution]) -> Dict:
return {dist.name: suggest(trial, dist) for dist in distributions}
Loading